Getting Started

Installation

RobustNeuralNetworks.jl is written in Julia and can be installed with the package manager. To add the package, type the following into the REPL.

] add RobustNeuralNetworks

Basic Usage

You should now be able to construct robust neural network models. The following example constructs a Lipschitz-bounded REN and evalutates it given a batch of random initial states and inputs.

using Random
using RobustNeuralNetworks

# Setup
rng = Xoshiro(42)
batches = 10
nu, nx, nv, ny = 4, 10, 20, 1
γ = 1

# Construct a REN
lipschitz_ren_ps = LipschitzRENParams{Float64}(nu, nx, nv, ny, γ; rng)
ren = REN(lipschitz_ren_ps)

# Some random inputs
x0 = init_states(ren, batches; rng)
u0 = randn(rng, ren.nu, batches)

# Evaluate the REN over one timestep
x1, y1 = ren(x0, u0)

# Print results for testing
println(round.(y1; digits=2))

# output

[1.06 1.13 0.95 0.93 1.03 0.78 0.75 1.42 0.89 1.44]

For detailed examples of training models from RobustNeuralNetworks.jl, we recommend starting with Fitting a Curve with LBDN and working through the subsequent examples.

Walkthrough

Let's step through the example above. It constructs and evaluates a Lipschitz-bounded REN. We start by importing packages and setting a random seed.

using Random
using RobustNeuralNetworks

Let's set a random seed and define our batch size and some hyperparameters. For this example, we'll build a Lipschitz-bounded REN with 4 inputs, 1 output, 10 states, 20 neurons, and a Lipschitz bound of γ = 1.

rng = Xoshiro(42)
batches = 10

γ = 1
nu, nx, nv, ny = 4, 10, 20, 1
(4, 10, 20, 1)

Now we can construct the REN parameters. The variable lipschitz_ren_ps contains all the parameters required to build a Lipschitz-bounded REN. Note that we separate the model parameterisation and its "explicit" (callable) form in RobustNeuralNetworks.jl. See the Package Overview for more details.

lipschitz_ren_ps = LipschitzRENParams{Float64}(nu, nx, nv, ny, γ; rng)
RobustNeuralNetworks.LipschitzRENParams{Float64}(NNlib.relu, 4, 10, 20, 1, RobustNeuralNetworks.DirectRENParams{Float64}([-0.06712865084409714 0.05348382890224457 … -0.10525631904602051 0.07951543480157852; -0.14378681778907776 -0.07980822771787643 … -0.005727448966354132 -0.09460792690515518; … ; -0.02156129479408264 0.1843864619731903 … -0.04695108160376549 0.3686339259147644; 0.04774993658065796 -0.10604676604270935 … -0.0002614356635604054 0.2120383381843567], [0.1391763687133789 0.03398861736059189 … -0.0188425425440073 0.355768084526062; 0.05736823007464409 -0.11152595281600952 … -0.07333791255950928 0.02312125451862812; … ; 0.03470980376005173 -0.1601126790046692 … -0.007050040178000927 -0.562127947807312; -0.17466461658477783 0.014204648323357105 … 0.46678298711776733 0.4063137173652649], [1.0;;], [0.0;;], [0.0; 0.0; 0.0;;], [0.29797041416168213 0.3009076714515686 0.07594253867864609 0.30558115243911743; -0.33255529403686523 -0.4442485272884369 0.42857375741004944 -0.14390836656093597; … ; 0.5437711477279663 0.44203391671180725 -0.25169432163238525 0.4955179989337921; -0.13386361300945282 0.09707959741353989 -0.38008028268814087 0.15276393294334412], [0.13988827168941498 -0.029209962114691734 … -0.2967631220817566 0.5497584342956543], [0.34582117199897766 -0.33130916953086853 0.21589359641075134 -0.19357502460479736; -0.3548073172569275 -0.3922453224658966 -0.37108856439590454 0.20604968070983887; … ; -0.06690354645252228 -0.6899434328079224 0.6516411304473877 -0.23289446532726288; 0.30476993322372437 0.5113816261291504 0.14442609250545502 -0.08943861722946167], [0.5269575715065002 0.09066978842020035 … -0.03968740254640579 0.177832692861557], [0.0 0.0 0.0 0.0], [-0.0, -0.0, 0.0, -0.0, -0.0, 0.0, -0.0, -0.0, -0.0, 0.0], [-0.8271766304969788, 0.13810677826404572, -0.28148776292800903, -0.12242458760738373, -0.4490170180797577, -0.5567880868911743, 0.025562886148691177, -0.026959802955389023, -0.2853533923625946, 0.527955174446106, -0.2083873748779297, 0.020896757021546364, 0.2908017039299011, 0.06982783228158951, 0.6342710852622986, 0.06637370586395264, -0.11119255423545837, -0.11263858526945114, -0.11802570521831512, -0.23057030141353607], [0.15656277537345886], 1.0e-12, [6.395632842977319], true, false, false, true), 1.0, [1.0], false)

Once the parameters are defined, we can create a REN object in its explicit form.

ren = REN(lipschitz_ren_ps)
RobustNeuralNetworks.REN{Float64}(NNlib.relu, 4, 10, 20, 1, RobustNeuralNetworks.ExplicitRENParams{Float64}([-0.013530075647082196 -0.0844773991448444 … -0.13669422708891066 -0.10291401272852563; -0.06259329876993296 0.10298993876353763 … 0.061047439909097806 -0.06895545526812626; … ; 0.013565870729768396 0.140306111007941 … 0.11172587563678857 0.09422294500278613; -0.03509408374331969 0.1484130350126683 … 0.0757283786226681 -0.025023752366458627], [-0.1617530496552856 0.1863359318782896 … 0.2726756221400413 0.0116621543850988; -0.06983220865616446 0.1136084132772561 … -0.35321611235441386 0.1911812922744098; … ; 0.0945350460806226 -0.02841494508454463 … 0.25438193044713 -0.22142468315220812; 0.03998661393235471 -0.2375267453441095 … 0.023986939755703998 0.058012964100629935], [0.2618882490270226 0.19063902926450554 0.20242752178533482 0.2410739222030993; -0.1223453534632878 -0.23546211883329585 0.14746179811130855 -0.034636748313121385; … ; 0.30623632139319057 0.15118739867065792 -0.21577962884384347 0.3543825096918331; -0.13466712276289344 -0.03834574179584228 -0.2677731704763374 0.037154220768812016], [0.020458277087454943 -0.17331254942511387 … 0.08885194913441921 -0.43437686553256405; 0.17306245795479736 0.1236544215903414 … -0.4596335847061486 -0.013933512767059262; … ; -0.03445537723145436 0.19940875378441306 … 0.10892289594474876 0.24026134012574749; -0.03529554919591985 -0.12526604185335674 … 0.15527488657996147 -0.19477906923540608], [0.13988827168941498 -0.029209962114691734 … -0.2967631220817566 0.5497584342956543], [-0.0 -0.0 … -0.0 -0.0; 0.07411589563231619 -0.0 … -0.0 -0.0; … ; -0.43437658445977356 0.00873114843659628 … -0.0 -0.0; 0.19038509036481827 0.5523879111128286 … -0.04300865282430539 -0.0], [0.49129255761708096 -0.470676009568589 0.30671030504199853 -0.2750033156707983; -0.5399206163615482 -0.5968911179962035 -0.5646962637703237 0.3135517933227863; … ; -0.07674324096023266 -0.7914153721355127 0.7474798414312898 -0.2671469155016799; 0.41793692831096263 0.7012675553601659 0.1980543055255232 -0.12264891277782837], [0.5269575715065002 0.09066978842020035 … -0.03968740254640579 0.177832692861557], [-5.000444502909205e-13 -0.0 -0.0 -0.0], [-0.0, -0.0, 0.0, -0.0, -0.0, 0.0, -0.0, -0.0, -0.0, 0.0], [-0.8271766304969788, 0.13810677826404572, -0.28148776292800903, -0.12242458760738373, -0.4490170180797577, -0.5567880868911743, 0.025562886148691177, -0.026959802955389023, -0.2853533923625946, 0.527955174446106, -0.2083873748779297, 0.020896757021546364, 0.2908017039299011, 0.06982783228158951, 0.6342710852622986, 0.06637370586395264, -0.11119255423545837, -0.11263858526945114, -0.11802570521831512, -0.23057030141353607], [0.15656277537345886]))

Now we can evaluate the REN. We can use the init_states function to create a batch of initial states, all zeros, of the correct dimensions.

# Some random inputs
x0 = init_states(ren, batches)
u0 = randn(rng, ren.nu, batches)

# Evaluate the REN over one timestep
x1, y1 = ren(x0, u0)
([-0.011217305668004496 0.2109562758364567 … 0.16450272255323128 0.8922002462715966; 0.19057624039614884 -0.2211223532442873 … -0.01394907003389612 -0.325078246832505; … ; -0.5195445211656348 0.09444376139104804 … -0.09810399115149572 -0.4156912325824145; -0.8869819188521283 -0.34396300079797376 … -0.8506379810129112 -1.6660480267543558], [1.0592088843679517 1.1260166302387575 … 0.8912662543023053 1.4445156180304226])

Having evaluated the REN, we can check that the outputs are the same as in the original example.

# Print results for testing
println(round.(y1; digits=2))
[1.06 1.13 0.95 0.93 1.03 0.78 0.75 1.42 0.89 1.44]