DeepEquilibriumNetworks.DeepEquilibriumNetworkMethod
DeepEquilibriumNetwork(model, solver; init = missing, jacobian_regularization=nothing,
    problem_type::Type{pType}=SteadyStateProblem{false}, kwargs...)

Deep Equilibrium Network as proposed in baideep2019 and pal2022mixing.

Arguments

  • model: Neural Network.
  • solver: Solver for the rootfinding problem. ODE Solvers and Nonlinear Solvers are both supported.

Keyword Arguments

  • init: Initial Condition for the rootfinding problem. If nothing, the initial condition is set to zero(x). If missing, the initial condition is set to WrappedFunction(zero). In other cases the initial condition is set to init(x, ps, st).
  • jacobian_regularization: Must be one of nothing, AutoFiniteDiff or AutoZygote.
  • problem_type: Provides a way to simulate a Vanilla Neural ODE by setting the problem_type to ODEProblem. By default, the problem type is set to SteadyStateProblem.
  • kwargs: Additional Parameters that are directly passed to SciMLBase.solve.

Example

julia> using DeepEquilibriumNetworks, Lux, Random, OrdinaryDiffEq

julia> model = DeepEquilibriumNetwork(Parallel(+, Dense(2, 2; use_bias=false),
               Dense(2, 2; use_bias=false)), VCABM3(); verbose=false)
DeepEquilibriumNetwork(
    model = Parallel(
        +
        Dense(2 => 2, bias=false),      # 4 parameters
        Dense(2 => 2, bias=false),      # 4 parameters
    ),
    init = WrappedFunction(Base.Fix1{typeof(DeepEquilibriumNetworks.__zeros_init), Nothing}(DeepEquilibriumNetworks.__zeros_init, nothing)),
)         # Total: 8 parameters,
          #        plus 0 states.

julia> rng = Random.default_rng()
TaskLocalRNG()

julia> ps, st = Lux.setup(rng, model);

julia> model(ones(Float32, 2, 1), ps, st);

See also: SkipDeepEquilibriumNetwork, MultiScaleDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork.

DeepEquilibriumNetworks.DeepEquilibriumSolutionType
DeepEquilibriumSolution(z_star, u₀, residual, jacobian_loss, nfe, solution)

Stores the solution of a DeepEquilibriumNetwork and its variants.

Fields

  • z_star: Steady-State or the value reached due to maxiters
  • u0: Initial Condition
  • residual: Difference of the $z^*$ and $f(z^*, x)$
  • jacobian_loss: Jacobian Stabilization Loss (see individual networks to see how it can be computed)
  • nfe: Number of Function Evaluations
  • original: Original Internal Solution
DeepEquilibriumNetworks.MultiScaleDeepEquilibriumNetworkMethod
MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
    post_fuse_layer::Union{Nothing, Tuple}, solver,
    scales::NTuple{N, NTuple{L, Int64}}; kwargs...)

Multi Scale Deep Equilibrium Network as proposed in baimultiscale2020.

Arguments

  • main_layers: Tuple of Neural Networks. Each Neural Network is applied to the corresponding scale.
  • mapping_layers: Matrix of Neural Networks. Each Neural Network is applied to the corresponding scale and the corresponding layer.
  • post_fuse_layer: Neural Network applied to the fused output of the main layers.
  • solver: Solver for the rootfinding problem. ODE Solvers and Nonlinear Solvers are both supported.
  • scales: Scales of the Multi Scale DEQ. Each scale is a tuple of integers. The length of the tuple is the number of layers in the corresponding main layer.

For keyword arguments, see DeepEquilibriumNetwork.

Example

julia> using DeepEquilibriumNetworks, Lux, Random, NonlinearSolve

julia> main_layers = (Parallel(+, Dense(4 => 4, tanh; use_bias=false),
               Dense(4 => 4, tanh; use_bias=false)), Dense(3 => 3, tanh), Dense(2 => 2, tanh),
           Dense(1 => 1, tanh))
(Parallel(), Dense(3 => 3, tanh_fast), Dense(2 => 2, tanh_fast), Dense(1 => 1, tanh_fast))

julia> mapping_layers = [NoOpLayer() Dense(4 => 3, tanh) Dense(4 => 2, tanh) Dense(4 => 1, tanh);
           Dense(3 => 4, tanh) NoOpLayer() Dense(3 => 2, tanh) Dense(3 => 1, tanh);
           Dense(2 => 4, tanh) Dense(2 => 3, tanh) NoOpLayer() Dense(2 => 1, tanh);
           Dense(1 => 4, tanh) Dense(1 => 3, tanh) Dense(1 => 2, tanh) NoOpLayer()]
4×4 Matrix{LuxCore.AbstractExplicitLayer}:
 NoOpLayer()               …  Dense(4 => 1, tanh_fast)
 Dense(3 => 4, tanh_fast)     Dense(3 => 1, tanh_fast)
 Dense(2 => 4, tanh_fast)     Dense(2 => 1, tanh_fast)
 Dense(1 => 4, tanh_fast)     NoOpLayer()

julia> model = MultiScaleDeepEquilibriumNetwork(main_layers, mapping_layers, nothing,
           NewtonRaphson(), ((4,), (3,), (2,), (1,)))
DeepEquilibriumNetwork(
    model = MultiScaleInputLayer{scales = 4}(
        model = Chain(
            layer_1 = Parallel(
                layer_1 = Parallel(
                    +
                    Dense(4 => 4, tanh_fast, bias=false),  # 16 parameters
                    Dense(4 => 4, tanh_fast, bias=false),  # 16 parameters
                ),
                layer_2 = Dense(3 => 3, tanh_fast),  # 12 parameters
                layer_3 = Dense(2 => 2, tanh_fast),  # 6 parameters
                layer_4 = Dense(1 => 1, tanh_fast),  # 2 parameters
            ),
            layer_2 = BranchLayer(
                layer_1 = Parallel(
                    +
                    NoOpLayer(),
                    Dense(3 => 4, tanh_fast),  # 16 parameters
                    Dense(2 => 4, tanh_fast),  # 12 parameters
                    Dense(1 => 4, tanh_fast),  # 8 parameters
                ),
                layer_2 = Parallel(
                    +
                    Dense(4 => 3, tanh_fast),  # 15 parameters
                    NoOpLayer(),
                    Dense(2 => 3, tanh_fast),  # 9 parameters
                    Dense(1 => 3, tanh_fast),  # 6 parameters
                ),
                layer_3 = Parallel(
                    +
                    Dense(4 => 2, tanh_fast),  # 10 parameters
                    Dense(3 => 2, tanh_fast),  # 8 parameters
                    NoOpLayer(),
                    Dense(1 => 2, tanh_fast),  # 4 parameters
                ),
                layer_4 = Parallel(
                    +
                    Dense(4 => 1, tanh_fast),  # 5 parameters
                    Dense(3 => 1, tanh_fast),  # 4 parameters
                    Dense(2 => 1, tanh_fast),  # 3 parameters
                    NoOpLayer(),
                ),
            ),
        ),
    ),
    init = WrappedFunction(Base.Fix1{typeof(DeepEquilibriumNetworks.__zeros_init), Val{((4,), (3,), (2,), (1,))}}(DeepEquilibriumNetworks.__zeros_init, Val{((4,), (3,), (2,), (1,))}())),
)         # Total: 152 parameters,
          #        plus 0 states.

julia> rng = Random.default_rng()
TaskLocalRNG()

julia> ps, st = Lux.setup(rng, model);

julia> x = rand(rng, Float32, 4, 12);

julia> model(x, ps, st);
DeepEquilibriumNetworks.MultiScaleSkipDeepEquilibriumNetworkMethod
MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
    post_fuse_layer::Union{Nothing, Tuple}, [init = nothing,] solver,
    scales::NTuple{N, NTuple{L, Int64}}; kwargs...)

Skip Multi Scale Deep Equilibrium Network as proposed in pal2022mixing. Alias which creates a MultiScaleDeepEquilibriumNetwork with init kwarg set to passed value.

If init is not passed, it creates a MultiScale Regularized Deep Equilibrium Network.