Package Overview
The RobustNeuralNetwork.jl
package is divided into Recurrent Equilibrium Network (REN) and Lipschitz-Bounded Deep Network (LBDN) models.
REN Overview
The REN models are defined by two fundamental types:
Any subtype of
AbstractRENParams
holds all the information required to directly parameterise a REN satisfying some user-defined behavioural constraints.Any subtype of
AbstractREN
represents the REN in its explicit form so that it can be called and evaluated.
When working with most models (eg: RNN and LSTM) the typical workflow is to create a single instance of a model. Its parameters are updated during training, but the model object is only created once. For example:
using Flux
# Define a model
model = Flux.RNNCell(2,5)
# Train the model
for k in 1:num_training_epochs
... # Run some code and compute gradients
Flux.update!(...) # Update model parameters
When working with RENs, it is much more efficient to split up the model parameterisation and the model implementation into subtypes of AbstractRENParams
and AbstractREN
. Converting our direct parameterisation to an explicit model for evaluation can be slow, so we only do it when the model parameters are updated:
using Flux
using RobustNeuralNetworks
# Define a model parameterisation
params = ContractingRENParams{Float64}(2, 5, 10, 1)
# Train the model
for k in 1:num_training_epochs
model = REN(params) # Create explicit model for evaluation
... # Run some code and compute gradients
Flux.update!(...) # Update model parameters
See the section on REN Wrappers for more details.
(Direct) Parameter Types
Subtypes of AbstractRENParams
define direct parameterisations of a REN. They are not callable models. There are four REN parameter types currently in this package:
ContractingRENParams
parameterises a REN with a user-defined upper bound on the contraction rate.LipschitzRENParams
parameterises a REN with a user-defined Lipschitz constant of $\gamma \in (0,\infty)$.PassiveRENParams
parameterises an input/output passive REN with user-tunable passivity parameter $\nu \ge 0$.GeneralRENParams
parameterises a REN satisfying some generalbehavioural constraints defined by an Integral Quadratic Constraint (IQC).
For more information on these four parameterisations, please see Revay et al. (2021).
Each of these parameter types has the following collection of attributes:
A static nonlinearity
nl
. Common choices areFlux.relu
orFlux.tanh
(seeFlux.jl
for more information).Model sizes
nu
,nx
,nv
,ny
defining the number of inputs, states, neurons, and outputs (respectively).An instance of
DirectParams
containing the direct parameters of the REN, including all trainable parameters.Other attributes used to define how the direct parameterisation should be converted to the implicit model. These parameters encode the user-tunable behavioural constraints. Eg: $\gamma$ for a Lipschitz-bounded REN.
The typical workflow is to create an instance of a REN parameterisation only once. This defines all dimensions and desired properties of a REN. It is then converted to an explicit model for the REN to be evaluated.
Explicit REN Models
An explicit REN model must be created to call and use the network for computation. The explicit parameterisation contains all information required to evaluate a REN. We encode RENs in explicit form as subtypes of the AbstractREN
type. Each subtype of AbstractREN
is callable and includes the following attributes:
A static nonlinearity
nl
and model sizesnu
,nx
,nv
,ny
(same asAbstractRENParams
.An instance of
ExplicitParams
containing all REN parameters in explicit form for model evaluation (see theExplicitParams
docs for more detail).
Each subtype of AbstractRENParams
has a method direct_to_explicit
associated with it that converts the DirectParams
struct to an instance of ExplicitParams
satisfying the specified behavioural constraints.
REN Wrappers
There are three explicit REN wrappers currently implemented in this package. Each of them constructs a REN from a direct parameterisation params::AbstractRENParams
and can be used to evaluate REN models.
REN
is the basic and most commonly-used wrapper. A new instance ofREN
must be created whenever the parametersparams
are changed.
We strongly recommend using REN
to train your models with Flux.jl
. It is the most efficient subtype of AbstractREN
that is compatible with automatic differentiation.
WrapREN
includes both theDirectParams
andExplicitParams
as part of the REN wrapper. When any of the direct parameters are changed, the explicit model can be updated by callingupdate_explicit!
. This can be useful when not using automatic differentiation to train the model. For example:
using RobustNeuralNetworks
# Define a model parameterisation AND a model
params = ContractingRENParams{Float64}(2, 5, 10, 1)
model = WrapREN(params)
# Train the model
for k in 1:num_training_epochs
... # Run some code and compute gradients
... # Update model parameters
update_explicit!(model) # Update explicit model parameters
DiffREN
also includesDirectParams
, but never stores theExplicitParams
. Instead, the explicit parameters are computed every time the model is evaluated. This is slow, but does not require creating a new object when the parameters are updated, and is still compatible withFlux.jl
. For example:
using Flux
# Define a model parameterisation AND a model
params = ContractingRENParams{Float64}(2, 5, 10, 1)
model = DiffREN(params)
# Train the model
for k in 1:num_training_epochs
... # Run some code and compute gradients
Flux.update!(...) # Update model parameters
See the docstring of each wrapper and the examples (eg: PDE Observer Design with REN) for more details.
LBDN Overview
[To be written once LBDN has been properly added to the package.]
Walkthrough
Let's step through the example from Getting Started, which constructs and evaluates a Lipschitz-bounded REN. Start by importing packages and setting a random seed.
using Random
using RobustNeuralNetworks
Let's set a random seed and define our batch size and some hyperparameters. For this example, we'll build a Lipschitz-bounded REN with 4 inputs, 2 outputs, 10 states, 20 neurons, and a Lipschitz bound of γ = 1
.
rng = MersenneTwister(42)
batches = 10
nu, nx, nv, ny = 4, 10, 20, 2
γ = 1
1
Let's construct the REN parameters. The variable lipschitz_ren_ps
contains all the parameters required to build a Lipschitz-bounded REN.
lipschitz_ren_ps = LipschitzRENParams{Float64}(nu, nx, nv, ny, γ; rng=rng)
RobustNeuralNetworks.LipschitzRENParams{Float64}(NNlib.relu, 4, 10, 20, 2, RobustNeuralNetworks.DirectParams{Float64}([0.0], [-0.055291873027992544 -0.013593743890657511 … 0.2311391829990709 -0.0718455773131941; 0.19138686252703796 -0.13271994488794459 … -0.007305408748129384 0.01654525479881978; … ; 0.299742414221673 -0.07809317389792667 … 0.029240605543225565 -0.2925394173379708; -0.09088275363047098 0.23681083486791726 … -0.20800363096068009 0.06250220218829722], [-0.4027217316502711 -0.05989682990932095 … 0.3667023851248275 0.24020479834174271; -0.029669638899475635 0.45602977715411475 … -0.28922165908454506 0.23866725891928675; … ; 0.24386504581456347 -0.257910039003683 … -0.30226665718964324 -0.14108400498798115; 0.4239280896991361 -0.348644225412096 … -0.23073711185928106 0.3443438279810179], [1.0 0.0; 0.0 1.0], [0.0 0.0; 0.0 0.0], [0.0 0.0; 0.0 0.0], [-0.14860443345556892 0.28927854523357593 -0.2808882807058518 0.09862214692777999; -0.11876644790634974 0.04998531072479769 0.13418637938318348 -0.0020346545106518935; … ; -0.7061017979400626 0.10975173297691977 -0.3800612302452505 0.2724350958323865; 0.2681458502870939 -0.22886911922490294 -0.09952832055021918 0.1875562767703701], [-0.5681910710560191 0.4378019361476338 … 0.3013898662760045 0.7938899723748694; 0.26376882498432547 -0.2781258011148226 … 0.46805064529581947 0.5007703585412882], [0.37137295690849725 -0.1669853779501278 -0.012057172879416636 -0.12902298731585712; -0.0749666114986592 -0.3369205918841214 -0.07207188081049729 -0.20829222859076407; … ; -0.03983008458781778 -0.23442820907916068 0.050461061693234346 -0.14086913535658063; 0.13352303128113077 0.27997176791177697 -0.005238334612274115 -0.28212632798233844], [-0.07715739309995721 -0.060692097760757034 … -0.3615421674504401 0.03442143660437479; 0.03500960391853888 0.3870984259549774 … -0.5071937976947221 -0.07421075293208351], [0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0], [-0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, 0.0, -0.0, 0.0], [-0.04361888821437799, -0.12293021148481298, 0.2416814662323979, -0.4947313922943012, 0.19496239954436642, -0.001611390433078892, -0.05297347925490908, -0.4939094456295804, -0.13781995934961985, -0.1297308393959833, -0.3491150010325225, -0.2907469054484082, -0.22492699071250752, 0.15197404168663606, -0.009661461682356215, -0.00890841124234689, -0.0980596781817795, -0.02051171758010866, 0.2740551905642245, -0.05190622094551547], [1.2360783920870024, 0.6481917534597987], 1.0e-12, true, false, false), 1.0, 1.0)
Once the parameters are defined, we can create a REN object in its explicit form.
ren = REN(lipschitz_ren_ps)
RobustNeuralNetworks.REN(NNlib.relu, 4, 10, 20, 2, RobustNeuralNetworks.ExplicitParams{Float64}([2.665798074091981e-13 -2.0241182889188027e-13 … -1.6569872312723127e-13 -4.071908687335506e-13; -4.699250769630555e-14 1.0511414619780719e-14 … 2.336831030607331e-13 3.641397942048102e-13; … ; 1.533376289059322e-13 -1.1841330905459486e-13 … -7.913247349208631e-14 -2.1108746478550075e-13; -1.2270788274240502e-14 4.966959389938655e-15 … 4.294634100595175e-14 6.926164004932504e-14], [0.45911662479736826 0.07669303546030566 … 0.06433009114398817 0.39422627980528896; -0.11414327320628695 0.19041752076299726 … 0.12449474730671775 -0.36478430984819643; … ; 0.24438077223794505 0.01863638168627174 … 0.046086209089316606 0.17167361386986965; 0.05257203281028996 0.19589372365793434 … 0.06381530490196645 0.09017808073836853], [-0.9753050157373089 -0.07994843456392947 -0.024045674463004704 0.8568617262039884; 0.48416916371986785 0.6866568291255193 0.03560246884642568 -0.383085646328238; … ; -0.5363006051931082 0.007252250563210508 -0.2782658377479142 0.3670455256975613; 0.09887377046656216 0.11983229761551577 0.6892973579713886 0.4725503544693741], [-0.494238350274585 0.40523420758172 … 0.06395758560336251 0.40715231017115844; -0.7889761132493699 0.7753670639012072 … -0.9408947832243453 -0.8413978694288441; … ; -0.2912962892008259 0.07001613647337462 … 1.408267313822142 2.1997211056111583; 0.3771303670345447 -0.3441444588819682 … 0.2347653611254467 0.09478954167043416], [-0.5681910710560191 0.4378019361476338 … 0.3013898662760045 0.7938899723748694; 0.26376882498432547 -0.2781258011148226 … 0.46805064529581947 0.5007703585412882], [-0.0 -0.0 … -0.0 -0.0; -0.42974706082033093 -0.0 … -0.0 -0.0; … ; -0.2116734177320437 0.2712646819200933 … -0.0 -0.0; -0.3281074272139957 0.7325466583262839 … 0.06050788477666996 -0.0], [3.4582797268015297 -1.5549924583752184 -0.11227817146013015 -1.2014810799369497; -0.43302738136751007 -1.9461442729738696 -0.41630663563017145 -1.2031521300315748; … ; -0.1619465477301956 -0.9531699353346162 0.20517141303082154 -0.5727647929654013; 1.2867992616496295 2.6981672058738555 -0.05048331397716411 -2.7189313113721862], [-0.07715739309995721 -0.060692097760757034 … -0.3615421674504401 0.03442143660437479; 0.03500960391853888 0.3870984259549774 … -0.5071937976947221 -0.07421075293208351], [-5.000444502909205e-13 -0.0 -0.0 -0.0; -0.0 -5.000444502909205e-13 -0.0 -0.0], [-0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, 0.0, -0.0, 0.0], [-0.04361888821437799, -0.12293021148481298, 0.2416814662323979, -0.4947313922943012, 0.19496239954436642, -0.001611390433078892, -0.05297347925490908, -0.4939094456295804, -0.13781995934961985, -0.1297308393959833, -0.3491150010325225, -0.2907469054484082, -0.22492699071250752, 0.15197404168663606, -0.009661461682356215, -0.00890841124234689, -0.0980596781817795, -0.02051171758010866, 0.2740551905642245, -0.05190622094551547], [1.2360783920870024, 0.6481917534597987]), Float64)
Now we can evaluate the REN. Note that we can use the init_states
function to create a batch of initial states, all zeros, of the correct dimensions.
# Some random inputs
x0 = init_states(ren, batches; rng=rng)
u0 = randn(rng, ren.nu, batches)
# Evaluate the REN over one timestep
x1, y1 = ren(x0, u0)
([1.3225230416430145 1.7365888445922524 … 1.2542453174886439 1.1192153120141475; -0.27619744403624946 -0.3963830379822417 … -0.3172818435367979 -0.2099603671753979; … ; 0.7778348295195779 0.99238300933187 … 0.7072649572997154 0.7430339437501547; 0.2222631810414445 0.2822190092329658 … 0.20442957810790582 -0.04824019596846418], [0.733376548927789 0.721360965401085 … 0.8665300300190658 1.0660831555905794; 1.131773325331808 1.0681734375127818 … 0.9574528917968447 0.8560089918951828])
Having evaluated the REN, we can check that the outputs are the same as in the original example.
# Print results for testing
yout = round.(y1; digits=2)
println(yout[1,:])
println(yout[2,:])
[0.73, 0.72, -0.53, 0.25, 0.84, 0.97, 0.96, 1.13, 0.87, 1.07]
[1.13, 1.07, 1.44, 0.83, 0.94, 1.26, 0.86, 0.8, 0.96, 0.86]