LuxTestUtils.jl

Join the chat at https://julialang.zulipchat.com #machine-learning Latest Docs Stable Docs

CI

ColPrac: Contributor's Guide on Collaborative Practices for Community Packages SciML Code Style

Utilities for testing Lux.jl.

Installation

] add LuxTestUtils

Warning This is a testing package. Hence, we don't use features like weak dependencies to reduce load times. It is recommended that you exclusively use this package for testing and not add a dependency to it in your main package Project.toml.

Exported Functions

Testing using JET.jl

We export a simple macro @jet to allow testing your code using JET

help> @jet

  @jet f(args...) call_broken=false opt_broken=false


  Run JET tests on the function `f` with the arguments `args`. If JET fails to compile or
  julia version is < 1.7, then the macro will be a no-op.

  Keyword Arguments
  ===================`call_broken`: Marks the test_call as broken.`opt_broken`: Marks the test_opt as broken.

  All additional arguments will be forwarded to @JET.test_call and @JET.test_opt.

   Note
  
    Instead of specifying target_modules with every call, you can set preferences for
    target_modules using Preferences.jl. For example, to set `target_modules` to
    (Lux, LuxLib) we can run:
  
    using Preferences
    
    set_preferences!(Base.UUID("ac9de150-d08f-4546-94fb-7472b5760531"),
                     "target_modules" => ["Lux", "LuxLib"])

  Example
  =========

  @jet sum([1, 2, 3]) target_modules=(Base, Core)
  
  @jet sum(1, 1) target_modules=(Base, Core) opt_broken=true

Gradient Correctness

help?> @test_gradients
  @test_gradients f args... [kwargs...]


  Compare the gradients computed by `Zygote.jl` (Reverse Mode AD) against:`Tracker.jl` (Reverse Mode AD)`ReverseDiff.jl` (Reverse Mode AD)`ForwardDiff.jl` (Forward Mode AD)`FiniteDifferences.jl` (Finite Differences)

   Tip
  
    This function is completely compatible with `Test.jl`

  Arguments
  ===========`f`: The function to test.`args`...: Inputs to f wrt which the gradients are computed.

  Keyword Arguments
  ===================`gpu_testing`: Disables ForwardDiff, ReverseDiff and FiniteDifferences tests.
      (Default: `false`)`soft_fail`: If `true`, the test will not fail if any of the gradients are incorrect,
      instead it will show up as broken. (Default: `false`)`skip_(tracker|reverse_diff|forward_diff|finite_differences)`: Skip the corresponding
      gradient computation and check. (Default: `false`)`large_arrays_skip_(forward_diff|finite_differences)`: Skip the corresponding
      gradient computation and check for large arrays. (Forward Mode and Finite Differences
      are not efficient for large arrays.) (Default: `true`)`large_array_length`: The length of the array above which the gradient computation is
      considered large. (Default: `25`)`max_total_array_size`: Treat as large array if the total size of all arrays is
      greater than this value. (Default: `100`)`(tracker|reverse_diff|forward_diff|finite_differences)_broken`: Mark the
      corresponding gradient test as broken. (Default: `false`)

  Keyword Arguments for check_approx
  ====================================`atol`: Absolute tolerance for gradient comparisons. (Default: `0.0`)`rtol`: Relative tolerance for gradient comparisons. (Default:
      `atol > 0 ? 0.0 : √eps(typeof(atol))`)`nans`: Whether or not NaNs are considered equal. (Default: `false`)

  Example
  =========

  using LuxTestUtils, Test
  
  x = randn(10)
  
  @testset "Showcase Gradient Testing" begin
      @test_gradients sum abs2 x
  
      @test_gradients prod x
  end

Internally, it uses check_approx which extends Base.isapprox for more common cases. It follows the exact same function call as isapprox.

Passing Runtime Variables to Macro

Macros operate on the syntax and hence can't directly take variable inputs. To get around this (and especially because you are not using this package in your core package), we can do the following:

Say we want to mark the Float16 tests for the sum function as broken.

using LuxTestUtils

for T in (Float16, Float32, Float64)
    x = rand(T, 10, 1)
    # Use `@eval` to interpolate the runtime variable `T` into the macro call
    @eval @jet sum($x) call_broken=$(T == Float16)
end