Getting Started

Installation

RobustNeuralNetworks.jl is written in Julia and can be installed with the package manager. To add the package, type the following into the REPL.

] add RobustNeuralNetworks

Basic Usage

You should now be able to construct robust neural network models. The following example constructs a Lipschitz-bounded REN and evalutates it given a batch of random initial states and inputs.

using Random
using RobustNeuralNetworks

# Setup
rng = MersenneTwister(42)
batches = 10
nu, nx, nv, ny = 4, 10, 20, 1
γ = 1

# Construct a REN
lipschitz_ren_ps = LipschitzRENParams{Float64}(nu, nx, nv, ny, γ; rng)
ren = REN(lipschitz_ren_ps)

# Some random inputs
x0 = init_states(ren, batches; rng)
u0 = randn(rng, ren.nu, batches)

# Evaluate the REN over one timestep
x1, y1 = ren(x0, u0)

# Print results for testing
println(round.(y1; digits=2))

# output

[0.98 1.24 0.86 1.93 1.08 1.19 1.23 1.4 0.95 0.65]

For detailed examples of training models from RobustNeuralNetworks.jl, we recommend starting with Fitting a Curve with LBDN and working through the subsequent examples.

Walkthrough

Let's step through the example above. It constructs and evaluates a Lipschitz-bounded REN. We start by importing packages and setting a random seed.

using Random
using RobustNeuralNetworks

Let's set a random seed and define our batch size and some hyperparameters. For this example, we'll build a Lipschitz-bounded REN with 4 inputs, 1 output, 10 states, 20 neurons, and a Lipschitz bound of γ = 1.

rng = MersenneTwister(42)
batches = 10

γ = 1
nu, nx, nv, ny = 4, 10, 20, 1
(4, 10, 20, 1)

Now we can construct the REN parameters. The variable lipschitz_ren_ps contains all the parameters required to build a Lipschitz-bounded REN. Note that we separate the model parameterisation and its "explicit" (callable) form in RobustNeuralNetworks.jl. See the Package Overview for more details.

lipschitz_ren_ps = LipschitzRENParams{Float64}(nu, nx, nv, ny, γ; rng)
RobustNeuralNetworks.LipschitzRENParams{Float64}(NNlib.relu, 4, 10, 20, 1, RobustNeuralNetworks.DirectRENParams{Float64}([-0.012744845822453499 -0.37026506662368774 … 0.07789403945207596 0.22788318991661072; 0.012581326998770237 -0.025255916640162468 … -0.12215472012758255 0.059514738619327545; … ; -0.02097960375249386 0.08394220471382141 … -0.08244176208972931 0.1286536306142807; 0.09965844452381134 0.1498170644044876 … -0.327943354845047 0.359980970621109], [-0.06890452653169632 0.3572097420692444 … 0.17479808628559113 -0.06819283217191696; 0.12506835162639618 0.38542118668556213 … -0.46472954750061035 0.17522890865802765; … ; -0.3479945957660675 -0.590043842792511 … -0.2983097434043884 -0.04920057952404022; 0.05702903866767883 -0.08506370335817337 … -0.20291493833065033 -0.18650390207767487], [1.0;;], [0.0;;], [0.0; 0.0; 0.0;;], [-0.21015840768814087 0.409101665019989 -0.39723601937294006 0.21266880631446838; -0.16796112060546875 0.07068990170955658 0.1897682100534439 0.04039272665977478; … ; -0.998578667640686 0.1552123874425888 0.13947278261184692 -0.0548844113945961; 0.3792154788970947 -0.32366979122161865 -0.0028774358797818422 0.24299202859401703], [0.876518189907074 -0.48295989632606506 … -0.1869172900915146 -0.3707870841026306], [0.5252006649971008 -0.47647765278816223 -0.10192503035068512 0.19833898544311523; -0.1060187965631485 -0.11922039836645126 -0.41825592517852783 -0.0273088738322258; … ; 0.18883007764816284 0.3959398567676544 -0.2945697009563446 0.7964305281639099; -0.2361529916524887 -0.017051417380571365 -0.37859970331192017 0.1349947601556778], [-0.1292954534292221 0.2907565236091614 … -0.19500981271266937 -0.18349771201610565], [0.0 0.0 0.0 0.0], [0.0, -0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, 0.0, -0.0], [-0.4639650881290436, -0.11212684214115143, -0.3258433938026428, -0.2355683594942093, 0.1515263020992279, 0.2405761480331421, -0.26404595375061035, 0.1207185909152031, 0.989287793636322, 0.2850775122642517, 0.17870092391967773, 0.20526191592216492, -0.42397892475128174, -0.26564258337020874, 0.06363289058208466, 0.03090890869498253, -0.40099260210990906, -0.4432661533355713, -0.31549209356307983, -0.010103775188326836], [0.7340183258056641], 1.0e-12, [6.356351744112863], true, false, false, true), 1.0, [1.0], false)

Once the parameters are defined, we can create a REN object in its explicit form.

ren = REN(lipschitz_ren_ps)
RobustNeuralNetworks.REN{Float64}(NNlib.relu, 4, 10, 20, 1, RobustNeuralNetworks.ExplicitRENParams{Float64}([0.11529062535517377 -0.13692508441716722 … 0.00626154859753389 -0.20688766393671792; 0.03190387109128367 -0.17416660148745822 … -0.022454514008590144 0.06677999886097519; … ; -0.02831014387584309 -0.02982788180493909 … 0.24553864190718516 0.050164333086822395; 0.08444508547860556 -0.0952522378498945 … 0.05303896565108798 -0.034662288080881724], [-0.018204023209158506 -0.08475042968154453 … -0.5042043792507099 -0.34635826529963376; -0.11594735811200008 -0.10489649250120442 … -0.15746549274227528 0.03573840149908618; … ; 0.6361200723215923 -0.04200697909710808 … 0.0014021818471178589 -0.1263503748367844; -0.1728057260909245 0.06359879726909863 … -0.14155313266493116 -0.2014555045051909], [-0.0644924149923658 0.053406737046181645 -0.23320237987631834 0.3870329671571183; -0.021921698630574016 0.18697193302866608 -0.13115778649197338 0.03495556459172818; … ; -0.6715634149930647 0.31728831164402294 0.04806129667493473 -0.02587128274197365; 0.09891673026710299 -0.09548404017049833 0.13177506959431437 0.1448031198393146], [-0.08891588798978942 -0.4452264519562949 … 0.10935144433167061 -0.12481559228167986; -0.4779250670330235 -0.18037884508950433 … 0.38833898436797404 0.17341771938176773; … ; 0.3585064350567448 -0.6604193611290036 … -0.017269200674190403 -0.3132833674820376; 0.11813916306626371 0.3475526817689998 … -0.12150018931420384 -0.38524677286252396], [0.876518189907074 -0.48295989632606506 … -0.1869172900915146 -0.3707870841026306], [-0.0 -0.0 … -0.0 -0.0; -0.19462673648218362 -0.0 … -0.0 -0.0; … ; -0.06167188416570796 -0.37356014922554776 … -0.0 -0.0; 0.07874567690103058 -0.3941621225930908 … -0.4527360332937664 -0.0], [0.6871864252125833 -0.6234359488386922 -0.13336140243988076 0.25951196841629537; -0.1789428224029922 -0.2012250210649657 -0.705949304882374 -0.04609302424287562; … ; 0.2548709305484705 0.5344146493633665 -0.397591605792885 1.0749717013227127; -0.3405418566224053 -0.024588811228647117 -0.5459555899772756 0.1946677276287063], [-0.1292954534292221 0.2907565236091614 … -0.19500981271266937 -0.18349771201610565], [-5.000444502909205e-13 -0.0 -0.0 -0.0], [0.0, -0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, 0.0, -0.0], [-0.4639650881290436, -0.11212684214115143, -0.3258433938026428, -0.2355683594942093, 0.1515263020992279, 0.2405761480331421, -0.26404595375061035, 0.1207185909152031, 0.989287793636322, 0.2850775122642517, 0.17870092391967773, 0.20526191592216492, -0.42397892475128174, -0.26564258337020874, 0.06363289058208466, 0.03090890869498253, -0.40099260210990906, -0.4432661533355713, -0.31549209356307983, -0.010103775188326836], [0.7340183258056641]))

Now we can evaluate the REN. We can use the init_states function to create a batch of initial states, all zeros, of the correct dimensions.

# Some random inputs
x0 = init_states(ren, batches)
u0 = randn(rng, ren.nu, batches)

# Evaluate the REN over one timestep
x1, y1 = ren(x0, u0)
([0.09258521726467764 0.010789690941942154 … 0.06837414882738868 -0.22720357857515916; -0.18121326995674358 -0.08939315602941635 … -0.14665608161855848 -0.17820982584154027; … ; -0.324657504311093 0.4171367087013206 … -0.27515129575051694 -0.4272889649693754; 0.10966842602477844 -0.4757230160432385 … -0.028002479209890394 -0.1605014804886037], [0.9805281372525552 1.235213188931588 … 0.9521630754386898 0.6520008212352842])

Having evaluated the REN, we can check that the outputs are the same as in the original example.

# Print results for testing
println(round.(y1; digits=2))
[0.98 1.24 0.86 1.93 1.08 1.19 1.23 1.4 0.95 0.65]