— MethodDeepEquilibriumNetwork(model, solver; init = missing, jacobian_regularization=nothing,
problem_type::Type=SteadyStateProblem{false}, kwargs...)
Deep Equilibrium Network as proposed in baideep2019 and pal2022mixing.
: Neural Network.solver
: Solver for the rootfinding problem. ODE Solvers and Nonlinear Solvers are both supported.
Keyword Arguments
: Initial Condition for the rootfinding problem. Ifnothing
, the initial condition is set tozero(x)
. Ifmissing
, the initial condition is set toWrappedFunction(zero)
. In other cases the initial condition is set toinit(x, ps, st)
: Must be one ofnothing
: Provides a way to simulate a Vanilla Neural ODE by setting theproblem_type
. By default, the problem type is set toSteadyStateProblem
: Additional Parameters that are directly passed toSciMLBase.solve
julia> model = DeepEquilibriumNetwork(
Parallel(+, Dense(2, 2; use_bias=false), Dense(2, 2; use_bias=false)),
VCABM3(); verbose=false);
julia> rng = Xoshiro(0);
julia> ps, st = Lux.setup(rng, model);
julia> size(first(model(ones(Float32, 2, 1), ps, st)))
(2, 1)
See also: SkipDeepEquilibriumNetwork
, MultiScaleDeepEquilibriumNetwork
, MultiScaleSkipDeepEquilibriumNetwork
— TypeDeepEquilibriumSolution(z_star, u₀, residual, jacobian_loss, nfe, solution)
Stores the solution of a DeepEquilibriumNetwork and its variants.
: Steady-State or the value reached due to maxitersu0
: Initial Conditionresidual
: Difference of the $z^*$ and $f(z^*, x)$jacobian_loss
: Jacobian Stabilization Loss (see individual networks to see how it can be computed)nfe
: Number of Function Evaluationsoriginal
: Original Internal Solution
— MethodMultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
post_fuse_layer::Union{Nothing, Tuple}, solver,
scales::NTuple{N, NTuple{L, Int64}}; kwargs...)
Multi Scale Deep Equilibrium Network as proposed in baimultiscale2020.
: Tuple of Neural Networks. Each Neural Network is applied to the corresponding scale.mapping_layers
: Matrix of Neural Networks. Each Neural Network is applied to the corresponding scale and the corresponding layer.post_fuse_layer
: Neural Network applied to the fused output of the main layers.solver
: Solver for the rootfinding problem. ODE Solvers and Nonlinear Solvers are both supported.scales
: Scales of the Multi Scale DEQ. Each scale is a tuple of integers. The length of the tuple is the number of layers in the corresponding main layer.
For keyword arguments, see DeepEquilibriumNetwork
julia> main_layers = (
Parallel(+, Dense(4 => 4, tanh; use_bias=false), Dense(4 => 4, tanh; use_bias=false)),
Dense(3 => 3, tanh), Dense(2 => 2, tanh), Dense(1 => 1, tanh));
julia> mapping_layers = [NoOpLayer() Dense(4 => 3, tanh) Dense(4 => 2, tanh) Dense(4 => 1, tanh);
Dense(3 => 4, tanh) NoOpLayer() Dense(3 => 2, tanh) Dense(3 => 1, tanh);
Dense(2 => 4, tanh) Dense(2 => 3, tanh) NoOpLayer() Dense(2 => 1, tanh);
Dense(1 => 4, tanh) Dense(1 => 3, tanh) Dense(1 => 2, tanh) NoOpLayer()];
julia> model = MultiScaleDeepEquilibriumNetwork(
main_layers, mapping_layers, nothing, NewtonRaphson(), ((4,), (3,), (2,), (1,)));
julia> rng = Xoshiro(0);
julia> ps, st = Lux.setup(rng, model);
julia> x = rand(rng, Float32, 4, 12);
julia> size.(first(model(x, ps, st)))
((4, 12), (3, 12), (2, 12), (1, 12))
— MethodMultiScaleNeuralODE(args...; kwargs...)
Same arguments as MultiScaleDeepEquilibriumNetwork
but sets problem_type
to ODEProblem{false}
— MethodMultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
post_fuse_layer::Union{Nothing, Tuple}, [init = nothing,] solver,
scales::NTuple{N, NTuple{L, Int64}}; kwargs...)
Skip Multi Scale Deep Equilibrium Network as proposed in pal2022mixing. Alias which creates a MultiScaleDeepEquilibriumNetwork
with init
kwarg set to passed value.
If init
is not passed, it creates a MultiScale Regularized Deep Equilibrium Network.
— MethodSkipDeepEquilibriumNetwork(model, [init=nothing,] solver; kwargs...)
Skip Deep Equilibrium Network as proposed in pal2022mixing. Alias which creates a DeepEquilibriumNetwork
with init
kwarg set to passed value.