API reference

Entry points

DifferentiationInterfaceTest.test_differentiationFunction
test_differentiation(
    backends::Vector{<:ADTypes.AbstractADType};
    ...
)
test_differentiation(
    backends::Vector{<:ADTypes.AbstractADType},
    scenarios::Vector{<:Scenario};
    correctness,
    type_stability,
    call_count,
    sparsity,
    detailed,
    input_type,
    output_type,
    first_order,
    second_order,
    onearg,
    twoarg,
    inplace,
    outofplace,
    excluded,
    logging,
    isequal,
    isapprox,
    atol,
    rtol
)

Cross-test a list of backends on a list of scenarios, running a variety of different tests.

Default arguments

Keyword arguments

Testing:

  • correctness=true: whether to compare the differentiation results with the theoretical values specified in each scenario
  • type_stability=false: whether to check type stability with JET.jl (thanks to JET.@test_opt)
  • sparsity: whether to check sparsity of the jacobian / hessian
  • detailed=false: whether to print a detailed or condensed test log

Filtering:

  • input_type=Any, output_type=Any: restrict scenario inputs / outputs to subtypes of this
  • first_order=true, second_order=true: include first order / second order operators
  • onearg=true, twoarg=true: include one-argument / two-argument functions
  • inplace=true, outofplace=true: include in-place / out-of-place operators

Options:

  • logging=false: whether to log progress
  • isequal=isequal: function used to compare objects exactly, with the standard signature isequal(x, y)
  • isapprox=isapprox: function used to compare objects approximately, with the standard signature isapprox(x, y; atol, rtol)
  • atol=0: absolute precision for correctness testing (when comparing to the reference outputs)
  • rtol=1e-3: relative precision for correctness testing (when comparing to the reference outputs)
test_differentiation(
    backend::ADTypes.AbstractADType,
    args...;
    kwargs...
)

Shortcut for a single backend.

DifferentiationInterfaceTest.benchmark_differentiationFunction
benchmark_differentiation(
    backends::Vector{<:ADTypes.AbstractADType},
    scenarios::Vector{<:Scenario};
    input_type,
    output_type,
    first_order,
    second_order,
    onearg,
    twoarg,
    inplace,
    outofplace,
    excluded,
    logging
) -> DataFrames.DataFrame

Benchmark a list of backends for a list of operators on a list of scenarios.

The object returned is a DataFrames.DataFrame where each column corresponds to a field of DifferentiationBenchmarkDataRow.

The keyword arguments available here have the same meaning as those in test_differentiation.

DifferentiationInterfaceTest.DifferentiationBenchmarkDataRowType
DifferentiationBenchmarkDataRow

Ad-hoc storage type for differentiation benchmarking results.

If you have a vector rows::Vector{DifferentiationBenchmarkDataRow}, you can turn it into a DataFrame as follows:

using DataFrames

df = DataFrame(rows)

The resulting DataFrame will have one column for each of the following fields.

Fields

  • backend::ADTypes.AbstractADType: backend used for benchmarking

  • scenario::Scenario: scenario used for benchmarking

  • operator::Symbol: differentiation operator used for benchmarking, e.g. :gradient or :hessian

  • calls::Int64: number of calls to the differentiated function for one call to the operator

  • samples::Int64: number of benchmarking samples taken

  • evals::Int64: number of evaluations used for averaging in each sample

  • time::Float64: minimum runtime over all samples, in seconds

  • allocs::Float64: minimum number of allocations over all samples

  • bytes::Float64: minimum memory allocated over all samples, in bytes

  • gc_fraction::Float64: minimum fraction of time spent in garbage collection over all samples, between 0.0 and 1.0

  • compile_fraction::Float64: minimum fraction of time spent compiling over all samples, between 0.0 and 1.0

See the documentation of Chairmarks.jl for more details on the measurement fields.

Pre-made scenario lists

The precise contents of the scenario lists are not part of the API, only their existence.

Scenario types

DifferentiationInterfaceTest.ScenarioType
Scenario{op,args,pl}

Store a testing scenario composed of a function and its input + output for a given operator.

This generic type should never be used directly: use the specific constructor corresponding to the operator you want to test, or a predefined list of scenarios.

Constructors

Type parameters

  • op: one of :pushforward, :pullback, :derivative, :gradient, :jacobian,:second_derivative, :hvp, :hessian
  • args: either 1 (for f(x) = y) or 2 (for f!(y, x) = nothing)
  • pl: either :inplace or :outofplace

Fields

  • f::Any: function f (if args==1) or f! (if args==2) to apply

  • x::Any: primal input

  • y::Any: primal output

  • seed::Any: seed for pushforward, pullback or HVP

  • res1::Any: first-order result of the operator

  • res2::Any: second-order result of the operator (when it makes sense)

Note that the res1 and res2 fields are given more meaningful names in the keyword arguments of each specialized constructor. For example:

  • the keyword grad of GradientScenario becomes res1
  • the keyword hess of HessianScenario becomes res2, and the keyword grad becomes res1

Internals

This is not part of the public API.

DifferentiationInterfaceTest.flux_scenariosFunction
flux_scenarios(rng=Random.default_rng())

Create a vector of Scenarios with neural networks from Flux.jl.

Warning

This function requires FiniteDifferences.jl and Flux.jl to be loaded (it is implemented in a package extension).

Danger

These scenarios are still experimental and not part of the public API. Their ground truth values are computed with finite differences, and thus subject to imprecision.

DifferentiationInterfaceTest.lux_scenariosFunction
lux_scenarios(rng=Random.default_rng())

Create a vector of Scenarios with neural networks from Lux.jl.

Warning

This function requires ComponentArrays.jl, FiniteDiff.jl, Lux.jl and LuxTestUtils.jl to be loaded (it is implemented in a package extension).

Danger

These scenarios are still experimental and not part of the public API. Their ground truth values are computed with finite differences, and thus subject to imprecision.