Package Overview

The RobustNeuralNetwork.jl package is divided into Recurrent Equilibrium Network (REN) and Lipschitz-Bounded Deep Network (LBDN) models.

REN Overview

The REN models are defined by two fundamental types:

  • Any subtype of AbstractRENParams holds all the information required to directly parameterise a REN satisfying some user-defined behavioural constraints.

  • Any subtype of AbstractREN represents the REN in its explicit form so that it can be called and evaluated.

Separate Objects for Parameters and Model

When working with most models (eg: RNN and LSTM) the typical workflow is to create a single instance of a model. Its parameters are updated during training, but the model object is only created once. For example:

using Flux

# Define a model
model = Flux.RNNCell(2,5)

# Train the model
for k in 1:num_training_epochs
    ...                     # Run some code and compute gradients
    Flux.update!(...)       # Update model parameters

When working with RENs, it is much more efficient to split up the model parameterisation and the model implementation into subtypes of AbstractRENParams and AbstractREN. Converting our direct parameterisation to an explicit model for evaluation can be slow, so we only do it when the model parameters are updated:

using Flux
using RobustNeuralNetworks

# Define a model parameterisation
params = ContractingRENParams{Float64}(2, 5, 10, 1)

# Train the model
for k in 1:num_training_epochs
    model = REN(params)     # Create explicit model for evaluation
    ...                     # Run some code and compute gradients
    Flux.update!(...)       # Update model parameters

See the section on REN Wrappers for more details.

(Direct) Parameter Types

Subtypes of AbstractRENParams define direct parameterisations of a REN. They are not callable models. There are four REN parameter types currently in this package:

  • ContractingRENParams parameterises a REN with a user-defined upper bound on the contraction rate.

  • LipschitzRENParams parameterises a REN with a user-defined Lipschitz constant of $\gamma \in (0,\infty)$.

  • PassiveRENParams parameterises an input/output passive REN with user-tunable passivity parameter $\nu \ge 0$.

  • GeneralRENParams parameterises a REN satisfying some generalbehavioural constraints defined by an Integral Quadratic Constraint (IQC).

For more information on these four parameterisations, please see Revay et al. (2021).

Each of these parameter types has the following collection of attributes:

  • A static nonlinearity nl. Common choices are Flux.relu or Flux.tanh (see Flux.jl for more information).

  • Model sizes nu, nx, nv, ny defining the number of inputs, states, neurons, and outputs (respectively).

  • An instance of DirectParams containing the direct parameters of the REN, including all trainable parameters.

  • Other attributes used to define how the direct parameterisation should be converted to the implicit model. These parameters encode the user-tunable behavioural constraints. Eg: $\gamma$ for a Lipschitz-bounded REN.

The typical workflow is to create an instance of a REN parameterisation only once. This defines all dimensions and desired properties of a REN. It is then converted to an explicit model for the REN to be evaluated.

Explicit REN Models

An explicit REN model must be created to call and use the network for computation. The explicit parameterisation contains all information required to evaluate a REN. We encode RENs in explicit form as subtypes of the AbstractREN type. Each subtype of AbstractREN is callable and includes the following attributes:

  • A static nonlinearity nl and model sizes nu, nx, nv, ny (same as AbstractRENParams.

  • An instance of ExplicitParams containing all REN parameters in explicit form for model evaluation (see the ExplicitParams docs for more detail).

Each subtype of AbstractRENParams has a method direct_to_explicit associated with it that converts the DirectParams struct to an instance of ExplicitParams satisfying the specified behavioural constraints.

REN Wrappers

There are three explicit REN wrappers currently implemented in this package. Each of them constructs a REN from a direct parameterisation params::AbstractRENParams and can be used to evaluate REN models.

  • REN is the basic and most commonly-used wrapper. A new instance of REN must be created whenever the parameters params are changed.
REN is recommended

We strongly recommend using REN to train your models with Flux.jl. It is the most efficient subtype of AbstractREN that is compatible with automatic differentiation.

  • WrapREN includes both the DirectParams and ExplicitParams as part of the REN wrapper. When any of the direct parameters are changed, the explicit model can be updated by calling update_explicit!. This can be useful when not using automatic differentiation to train the model. For example:
using RobustNeuralNetworks

# Define a model parameterisation AND a model
params = ContractingRENParams{Float64}(2, 5, 10, 1)
model  = WrapREN(params)

# Train the model
for k in 1:num_training_epochs
    ...                     # Run some code and compute gradients
    ...                     # Update model parameters
    update_explicit!(model) # Update explicit model parameters
WrapREN incompatible with Flux.jl

Since the explicit parameters are stored in an instance of WrapREN, changing them with update_explicit! directly mutates the model. This will cause errors if the model is to be trained with Flux.jl. Use REN or DiffREN to avoid this issue.

  • DiffREN also includes DirectParams, but never stores the ExplicitParams. Instead, the explicit parameters are computed every time the model is evaluated. This is slow, but does not require creating a new object when the parameters are updated, and is still compatible with Flux.jl. For example:
using Flux

# Define a model parameterisation AND a model
params = ContractingRENParams{Float64}(2, 5, 10, 1)
model  = DiffREN(params)

# Train the model
for k in 1:num_training_epochs
    ...                     # Run some code and compute gradients
    Flux.update!(...)       # Update model parameters

See the docstring of each wrapper and the examples (eg: PDE Observer Design with REN) for more details.

LBDN Overview

[To be written once LBDN has been properly added to the package.]

Walkthrough

Let's step through the example from Getting Started, which constructs and evaluates a Lipschitz-bounded REN. Start by importing packages and setting a random seed.

using Random
using RobustNeuralNetworks

Let's set a random seed and define our batch size and some hyperparameters. For this example, we'll build a Lipschitz-bounded REN with 4 inputs, 2 outputs, 10 states, 20 neurons, and a Lipschitz bound of γ = 1.

rng = MersenneTwister(42)
batches = 10

nu, nx, nv, ny = 4, 10, 20, 2
γ = 1
1

Let's construct the REN parameters. The variable lipschitz_ren_ps contains all the parameters required to build a Lipschitz-bounded REN.

lipschitz_ren_ps = LipschitzRENParams{Float64}(nu, nx, nv, ny, γ; rng=rng)
RobustNeuralNetworks.LipschitzRENParams{Float64}(NNlib.relu, 4, 10, 20, 2, RobustNeuralNetworks.DirectParams{Float64}([0.0], [-0.055291873027992544 -0.013593743890657511 … 0.2311391829990709 -0.0718455773131941; 0.19138686252703796 -0.13271994488794459 … -0.007305408748129384 0.01654525479881978; … ; 0.299742414221673 -0.07809317389792667 … 0.029240605543225565 -0.2925394173379708; -0.09088275363047098 0.23681083486791726 … -0.20800363096068009 0.06250220218829722], [-0.4027217316502711 -0.05989682990932095 … 0.3667023851248275 0.24020479834174271; -0.029669638899475635 0.45602977715411475 … -0.28922165908454506 0.23866725891928675; … ; 0.24386504581456347 -0.257910039003683 … -0.30226665718964324 -0.14108400498798115; 0.4239280896991361 -0.348644225412096 … -0.23073711185928106 0.3443438279810179], [1.0 0.0; 0.0 1.0], [0.0 0.0; 0.0 0.0], [0.0 0.0; 0.0 0.0], [-0.14860443345556892 0.28927854523357593 -0.2808882807058518 0.09862214692777999; -0.11876644790634974 0.04998531072479769 0.13418637938318348 -0.0020346545106518935; … ; -0.7061017979400626 0.10975173297691977 -0.3800612302452505 0.2724350958323865; 0.2681458502870939 -0.22886911922490294 -0.09952832055021918 0.1875562767703701], [-0.5681910710560191 0.4378019361476338 … 0.3013898662760045 0.7938899723748694; 0.26376882498432547 -0.2781258011148226 … 0.46805064529581947 0.5007703585412882], [0.37137295690849725 -0.1669853779501278 -0.012057172879416636 -0.12902298731585712; -0.0749666114986592 -0.3369205918841214 -0.07207188081049729 -0.20829222859076407; … ; -0.03983008458781778 -0.23442820907916068 0.050461061693234346 -0.14086913535658063; 0.13352303128113077 0.27997176791177697 -0.005238334612274115 -0.28212632798233844], [-0.07715739309995721 -0.060692097760757034 … -0.3615421674504401 0.03442143660437479; 0.03500960391853888 0.3870984259549774 … -0.5071937976947221 -0.07421075293208351], [0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0], [-0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, 0.0, -0.0, 0.0], [-0.04361888821437799, -0.12293021148481298, 0.2416814662323979, -0.4947313922943012, 0.19496239954436642, -0.001611390433078892, -0.05297347925490908, -0.4939094456295804, -0.13781995934961985, -0.1297308393959833, -0.3491150010325225, -0.2907469054484082, -0.22492699071250752, 0.15197404168663606, -0.009661461682356215, -0.00890841124234689, -0.0980596781817795, -0.02051171758010866, 0.2740551905642245, -0.05190622094551547], [1.2360783920870024, 0.6481917534597987], 1.0e-12, true, false, false), 1.0, 1.0)

Once the parameters are defined, we can create a REN object in its explicit form.

ren = REN(lipschitz_ren_ps)
RobustNeuralNetworks.REN(NNlib.relu, 4, 10, 20, 2, RobustNeuralNetworks.ExplicitParams{Float64}([2.665798074091981e-13 -2.0241182889188027e-13 … -1.6569872312723127e-13 -4.071908687335506e-13; -4.699250769630555e-14 1.0511414619780719e-14 … 2.336831030607331e-13 3.641397942048102e-13; … ; 1.533376289059322e-13 -1.1841330905459486e-13 … -7.913247349208631e-14 -2.1108746478550075e-13; -1.2270788274240502e-14 4.966959389938655e-15 … 4.294634100595175e-14 6.926164004932504e-14], [0.45911662479736826 0.07669303546030566 … 0.06433009114398817 0.39422627980528896; -0.11414327320628695 0.19041752076299726 … 0.12449474730671775 -0.36478430984819643; … ; 0.24438077223794505 0.01863638168627174 … 0.046086209089316606 0.17167361386986965; 0.05257203281028996 0.19589372365793434 … 0.06381530490196645 0.09017808073836853], [-0.9753050157373089 -0.07994843456392947 -0.024045674463004704 0.8568617262039884; 0.48416916371986785 0.6866568291255193 0.03560246884642568 -0.383085646328238; … ; -0.5363006051931082 0.007252250563210508 -0.2782658377479142 0.3670455256975613; 0.09887377046656216 0.11983229761551577 0.6892973579713886 0.4725503544693741], [-0.494238350274585 0.40523420758172 … 0.06395758560336251 0.40715231017115844; -0.7889761132493699 0.7753670639012072 … -0.9408947832243453 -0.8413978694288441; … ; -0.2912962892008259 0.07001613647337462 … 1.408267313822142 2.1997211056111583; 0.3771303670345447 -0.3441444588819682 … 0.2347653611254467 0.09478954167043416], [-0.5681910710560191 0.4378019361476338 … 0.3013898662760045 0.7938899723748694; 0.26376882498432547 -0.2781258011148226 … 0.46805064529581947 0.5007703585412882], [-0.0 -0.0 … -0.0 -0.0; -0.42974706082033093 -0.0 … -0.0 -0.0; … ; -0.2116734177320437 0.2712646819200933 … -0.0 -0.0; -0.3281074272139957 0.7325466583262839 … 0.06050788477666996 -0.0], [3.4582797268015297 -1.5549924583752184 -0.11227817146013015 -1.2014810799369497; -0.43302738136751007 -1.9461442729738696 -0.41630663563017145 -1.2031521300315748; … ; -0.1619465477301956 -0.9531699353346162 0.20517141303082154 -0.5727647929654013; 1.2867992616496295 2.6981672058738555 -0.05048331397716411 -2.7189313113721862], [-0.07715739309995721 -0.060692097760757034 … -0.3615421674504401 0.03442143660437479; 0.03500960391853888 0.3870984259549774 … -0.5071937976947221 -0.07421075293208351], [-5.000444502909205e-13 -0.0 -0.0 -0.0; -0.0 -5.000444502909205e-13 -0.0 -0.0], [-0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, 0.0, -0.0, 0.0], [-0.04361888821437799, -0.12293021148481298, 0.2416814662323979, -0.4947313922943012, 0.19496239954436642, -0.001611390433078892, -0.05297347925490908, -0.4939094456295804, -0.13781995934961985, -0.1297308393959833, -0.3491150010325225, -0.2907469054484082, -0.22492699071250752, 0.15197404168663606, -0.009661461682356215, -0.00890841124234689, -0.0980596781817795, -0.02051171758010866, 0.2740551905642245, -0.05190622094551547], [1.2360783920870024, 0.6481917534597987]), Float64)

Now we can evaluate the REN. Note that we can use the init_states function to create a batch of initial states, all zeros, of the correct dimensions.

# Some random inputs
x0 = init_states(ren, batches; rng=rng)
u0 = randn(rng, ren.nu, batches)

# Evaluate the REN over one timestep
x1, y1 = ren(x0, u0)
([1.3225230416430145 1.7365888445922524 … 1.2542453174886439 1.1192153120141475; -0.27619744403624946 -0.3963830379822417 … -0.3172818435367979 -0.2099603671753979; … ; 0.7778348295195779 0.99238300933187 … 0.7072649572997154 0.7430339437501547; 0.2222631810414445 0.2822190092329658 … 0.20442957810790582 -0.04824019596846418], [0.733376548927789 0.721360965401085 … 0.8665300300190658 1.0660831555905794; 1.131773325331808 1.0681734375127818 … 0.9574528917968447 0.8560089918951828])

Having evaluated the REN, we can check that the outputs are the same as in the original example.

# Print results for testing
yout = round.(y1; digits=2)
println(yout[1,:])
println(yout[2,:])
[0.73, 0.72, -0.53, 0.25, 0.84, 0.97, 0.96, 1.13, 0.87, 1.07]
[1.13, 1.07, 1.44, 0.83, 0.94, 1.26, 0.86, 0.8, 0.96, 0.86]