DeepEquilibriumNetworks.DeepEquilibriumNetworkMethod
DeepEquilibriumNetwork(model, solver; init = missing, jacobian_regularization=nothing,
    problem_type::Type{pType}=SteadyStateProblem{false}, kwargs...)

Deep Equilibrium Network as proposed in baideep2019 and pal2022mixing.

Arguments

  • model: Neural Network.
  • solver: Solver for the rootfinding problem. ODE Solvers and Nonlinear Solvers are both supported.

Keyword Arguments

  • init: Initial Condition for the rootfinding problem. If nothing, the initial condition is set to zero(x). If missing, the initial condition is set to WrappedFunction{:direct_call}(zero). In other cases the initial condition is set to init(x, ps, st).
  • jacobian_regularization: Must be one of nothing, AutoForwardDiff, AutoFiniteDiff or AutoZygote.
  • problem_type: Provides a way to simulate a Vanilla Neural ODE by setting the problem_type to ODEProblem. By default, the problem type is set to SteadyStateProblem.
  • kwargs: Additional Parameters that are directly passed to SciMLBase.solve.

Example

julia> model = DeepEquilibriumNetwork(
           Parallel(+, Dense(2, 2; use_bias=false), Dense(2, 2; use_bias=false)),
           VCABM3(); verbose=false);

julia> rng = Xoshiro(0);

julia> ps, st = Lux.setup(rng, model);

julia> size(first(model(ones(Float32, 2, 1), ps, st)))
(2, 1)

See also: SkipDeepEquilibriumNetwork, MultiScaleDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork.

DeepEquilibriumNetworks.DeepEquilibriumSolutionType
DeepEquilibriumSolution(z_star, u₀, residual, jacobian_loss, nfe, solution)

Stores the solution of a DeepEquilibriumNetwork and its variants.

Fields

  • z_star: Steady-State or the value reached due to maxiters
  • u0: Initial Condition
  • residual: Difference of the $z^*$ and $f(z^*, x)$
  • jacobian_loss: Jacobian Stabilization Loss (see individual networks to see how it can be computed)
  • nfe: Number of Function Evaluations
  • original: Original Internal Solution
DeepEquilibriumNetworks.MultiScaleDeepEquilibriumNetworkMethod
MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
    post_fuse_layer::Union{Nothing, Tuple}, solver,
    scales::NTuple{N, NTuple{L, Int64}}; kwargs...)

Multi Scale Deep Equilibrium Network as proposed in baimultiscale2020.

Arguments

  • main_layers: Tuple of Neural Networks. Each Neural Network is applied to the corresponding scale.
  • mapping_layers: Matrix of Neural Networks. Each Neural Network is applied to the corresponding scale and the corresponding layer.
  • post_fuse_layer: Neural Network applied to the fused output of the main layers.
  • solver: Solver for the rootfinding problem. ODE Solvers and Nonlinear Solvers are both supported.
  • scales: Scales of the Multi Scale DEQ. Each scale is a tuple of integers. The length of the tuple is the number of layers in the corresponding main layer.

For keyword arguments, see DeepEquilibriumNetwork.

Example

julia> main_layers = (
           Parallel(+, Dense(4 => 4, tanh; use_bias=false), Dense(4 => 4, tanh; use_bias=false)),
           Dense(3 => 3, tanh), Dense(2 => 2, tanh), Dense(1 => 1, tanh));

julia> mapping_layers = [NoOpLayer() Dense(4 => 3, tanh) Dense(4 => 2, tanh) Dense(4 => 1, tanh);
                         Dense(3 => 4, tanh) NoOpLayer() Dense(3 => 2, tanh) Dense(3 => 1, tanh);
                         Dense(2 => 4, tanh) Dense(2 => 3, tanh) NoOpLayer() Dense(2 => 1, tanh);
                         Dense(1 => 4, tanh) Dense(1 => 3, tanh) Dense(1 => 2, tanh) NoOpLayer()];

julia> model = MultiScaleDeepEquilibriumNetwork(
           main_layers, mapping_layers, nothing, NewtonRaphson(), ((4,), (3,), (2,), (1,)));

julia> rng = Xoshiro(0);

julia> ps, st = Lux.setup(rng, model);

julia> x = rand(rng, Float32, 4, 12);

julia> size.(first(model(x, ps, st)))
((4, 12), (3, 12), (2, 12), (1, 12))
DeepEquilibriumNetworks.MultiScaleSkipDeepEquilibriumNetworkMethod
MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
    post_fuse_layer::Union{Nothing, Tuple}, [init = nothing,] solver,
    scales::NTuple{N, NTuple{L, Int64}}; kwargs...)

Skip Multi Scale Deep Equilibrium Network as proposed in pal2022mixing. Alias which creates a MultiScaleDeepEquilibriumNetwork with init kwarg set to passed value.

If init is not passed, it creates a MultiScale Regularized Deep Equilibrium Network.