Getting Started

Installation

RobustNeuralNetworks.jl is written in Julia and can be installed with the package manager. To add the package, type the following into the REPL.

] add RobustNeuralNetworks

Basic Usage

You should now be able to construct robust neural network models. The following example constructs a Lipschitz-bounded REN and evalutates it given a batch of random initial states and inputs.

using Random
using RobustNeuralNetworks

# Setup
rng = MersenneTwister(42)
batches = 10
nu, nx, nv, ny = 4, 10, 20, 1
γ = 1

# Construct a REN
lipschitz_ren_ps = LipschitzRENParams{Float64}(nu, nx, nv, ny, γ; rng=rng)
ren = REN(lipschitz_ren_ps)

# Some random inputs
x0 = init_states(ren, batches; rng=rng)
u0 = randn(rng, ren.nu, batches)

# Evaluate the REN over one timestep
x1, y1 = ren(x0, u0)

# Print results for testing
println(round.(y1; digits=2))

# output

[0.23 -0.01 -0.06 0.15 -0.03 -0.11 0.0 0.42 0.24 0.22]

For detailed examples of training models from RobustNeuralNetworks.jl, we recommend starting with Fitting a Curve with LBDN and working through the subsequent examples.

Walkthrough

Let's step through the example above. It constructs and evaluates a Lipschitz-bounded REN. We start by importing packages and setting a random seed.

using Random
using RobustNeuralNetworks

Let's set a random seed and define our batch size and some hyperparameters. For this example, we'll build a Lipschitz-bounded REN with 4 inputs, 1 output, 10 states, 20 neurons, and a Lipschitz bound of γ = 1.

rng = MersenneTwister(42)
batches = 10

γ = 1
nu, nx, nv, ny = 4, 10, 20, 1
(4, 10, 20, 1)

Now we can construct the REN parameters. The variable lipschitz_ren_ps contains all the parameters required to build a Lipschitz-bounded REN. Note that we separate the model parameterisation and its "explicit" (callable) form in RobustNeuralNetworks.jl. See the Package Overview for more details.

lipschitz_ren_ps = LipschitzRENParams{Float64}(nu, nx, nv, ny, γ; rng=rng)
RobustNeuralNetworks.LipschitzRENParams{Float64}(NNlib.relu, 4, 10, 20, 1, RobustNeuralNetworks.DirectRENParams{Float64}([-0.055291873027992544 -0.013593743890657511 … 0.2311391829990709 -0.0718455773131941; 0.19138686252703796 -0.13271994488794459 … -0.007305408748129384 0.01654525479881978; … ; 0.299742414221673 -0.07809317389792667 … 0.029240605543225565 -0.2925394173379708; -0.09088275363047098 0.23681083486791726 … -0.20800363096068009 0.06250220218829722], [-0.4027217316502711 -0.05989682990932095 … 0.3667023851248275 0.24020479834174271; -0.029669638899475635 0.45602977715411475 … -0.28922165908454506 0.23866725891928675; … ; 0.24386504581456347 -0.257910039003683 … -0.30226665718964324 -0.14108400498798115; 0.4239280896991361 -0.348644225412096 … -0.23073711185928106 0.3443438279810179], [1.0;;], [0.0;;], [0.0; 0.0; 0.0;;], [-0.14860443345556892 0.28927854523357593 -0.2808882807058518 0.09862214692777999; -0.11876644790634974 0.04998531072479769 0.13418637938318348 -0.0020346545106518935; … ; -0.7061017979400626 0.10975173297691977 -0.3800612302452505 0.2724350958323865; 0.2681458502870939 -0.22886911922490294 -0.09952832055021918 0.1875562767703701], [-0.5934562187063348 0.2754975526047605 … -0.012726196152459896 -0.6009373820078587], [0.37137295690849725 -0.1669853779501278 -0.012057172879416636 -0.12902298731585712; -0.0749666114986592 -0.3369205918841214 -0.07207188081049729 -0.20829222859076407; … ; -0.03983008458781778 -0.23442820907916068 0.050461061693234346 -0.14086913535658063; 0.13352303128113077 0.27997176791177697 -0.005238334612274115 -0.28212632798233844], [0.1359598637213978 -0.6023915796564709 … 0.05063436980831905 -0.19732915272019172], [0.0 0.0 0.0 0.0], [0.0, 0.0, 0.0, -0.0, -0.0, 0.0, -0.0, -0.0, 0.0, -0.0], [0.049919701188874614, -0.016614088781157416, 0.19226555343123963, 0.18100419623378342, 0.17398164193189308, -0.22094225572615017, -0.03919620329098578, -0.3202239121417771, 0.11139989526928168, 0.3254046764833112, -0.19333577189995443, 0.14029292458284798, 0.16133218027132115, 0.22918344641649552, 0.35892658718976567, -0.09336242666176917, -0.26557914527550575, 0.4397810675941355, -0.18499921489689392, -0.4046346788072461], [-0.19506959830062723], 1.0e-12, [6.324555320336759], true, false, false, true), 1.0, 1.0)

Once the parameters are defined, we can create a REN object in its explicit form.

ren = REN(lipschitz_ren_ps)
RobustNeuralNetworks.REN{Float64}(NNlib.relu, 4, 10, 20, 1, RobustNeuralNetworks.ExplicitRENParams{Float64}([2.4833550409616386e-14 -1.1538658310043907e-14 … 5.924769201335763e-16 2.5089187745339034e-14; 3.0810107430801096e-14 -1.410574673658543e-14 … 6.572131044213426e-16 3.093718569025709e-14; … ; 1.4882416583916145e-13 -6.898091493262644e-14 … 3.122131952310639e-15 1.5042936848333657e-13; -2.6653877568673407e-14 1.2411876741132481e-14 … -6.01691731636436e-16 -2.701741543375337e-14], [0.047732413114233777 0.016143253397889447 … 0.024814132112393673 0.03233602967053757; 0.05305369331894097 0.028189925126010394 … 0.0022682164577717506 0.016273046329176217; … ; 0.22120113605213662 0.018035757674375683 … 0.037802646379185345 0.06348752984146105; -0.04358059061417341 -0.01154535284902278 … -0.024312998190262638 0.055401688686320616], [-0.08308836277171436 0.028501001047320694 -0.12908806360592934 0.10597203722520973; -0.10345437105758014 0.039786454951766574 0.18328589439314596 0.04479722429277596; … ; -0.5006132343897761 0.13118112006658547 -0.19456657995708593 0.121896983114105; 0.09007053561495977 -0.14367520794278504 0.14699264608509 0.09369271269797161], [0.1343561931214768 -0.06237158060782834 … 0.0028811615988942286 0.1360498995645635; -0.46713067731836044 0.21685400589170492 … -0.010017245486691174 -0.4730193693060996; … ; 0.055569776932339256 -0.02579691148409351 … 0.0011916496265401498 0.05627029461630003; -0.19280327716496964 0.08950421163150811 … -0.0041345127857744714 -0.19523377288828428], [-0.5934562187063348 0.2754975526047605 … -0.012726196152459896 -0.6009373820078587], [-0.0 -0.0 … -0.0 -0.0; 0.03363062565698647 -0.0 … -0.0 -0.0; … ; -0.0902553304459757 -0.1427151459173325 … -0.0 -0.0; -0.020531797627658013 -0.12129605968113616 … 0.0682671641416631 -0.0], [0.6183986884601169 -0.278058853762524 -0.0200772289862917 -0.21484506249869795; -0.09795772549560815 -0.4402489881004275 -0.09417522514690625 -0.2721722716722692; … ; -0.07365729876788123 -0.4335252815667975 0.09331703750435563 -0.26050760618535873; 0.21983187219038328 0.4609445824436962 -0.008624376588270578 -0.4644918429388607], [0.1359598637213978 -0.6023915796564709 … 0.05063436980831905 -0.19732915272019172], [-5.000444502909205e-13 -0.0 -0.0 -0.0], [0.0, 0.0, 0.0, -0.0, -0.0, 0.0, -0.0, -0.0, 0.0, -0.0], [0.049919701188874614, -0.016614088781157416, 0.19226555343123963, 0.18100419623378342, 0.17398164193189308, -0.22094225572615017, -0.03919620329098578, -0.3202239121417771, 0.11139989526928168, 0.3254046764833112, -0.19333577189995443, 0.14029292458284798, 0.16133218027132115, 0.22918344641649552, 0.35892658718976567, -0.09336242666176917, -0.26557914527550575, 0.4397810675941355, -0.18499921489689392, -0.4046346788072461], [-0.19506959830062723]))

Now we can evaluate the REN. We can use the init_states function to create a batch of initial states, all zeros, of the correct dimensions.

# Some random inputs
x0 = init_states(ren, batches)
u0 = randn(rng, ren.nu, batches)

# Evaluate the REN over one timestep
x1, y1 = ren(x0, u0)
([0.2639017210647274 0.08165554601458805 … 0.1402176802832335 0.04948676878665539; -0.08942436080341934 -0.10419998264870195 … -0.14452300501515303 0.1166531837996839; … ; 0.6488589431765428 0.23670412413831232 … 0.18021663934950782 0.18911625175968177; -0.3209071212530465 -0.12728679354546207 … -0.2499478157662913 -0.08235411584378027], [0.228497072875445 -0.01427934716148907 … 0.24162794642230076 0.22107359506647603])

Having evaluated the REN, we can check that the outputs are the same as in the original example.

# Print results for testing
println(round.(y1; digits=2))
[0.23 -0.01 -0.06 0.15 -0.03 -0.11 0.0 0.42 0.24 0.22]