Fitting a Curve with LBDN

For our first example, let's fit a Lipschitz-bounded Deep Network (LBDN) to a curve in one dimension. Consider the step function function below.

\[f(x) = \begin{cases} 1 \ \text{if} \ x > 0 \\ 0 \ \text{if} \ x < 0 \end{cases}\]

Our aim is to demonstrate how to train a model in RobustNeuralNetworks.jl, and how to ensure the model naturally satisfies some user-defined robustness certificate (the Lipschitz bound). We'll follow the steps below to fit an LBDN model to our function $f(x)$:

  1. Generate training data
  2. Define a model with a Lipshitz bound (maximum slope) of 10.0
  3. Define a loss function
  4. Train the model to minimise the loss function
  5. Examine the trained model

1. Generate training data

Let's generate training data for $f(x)$ on the interval $[-0.3, 0.3]$ as an example. We zip() the data up into a sequence of tuples (x,y) to make training with Flux.jl easier in Step 4.

# Function to estimate
f(x) = x < 0 ? 0 : 1

# Training data
dx = 0.01
xs = -0.3:dx:0.3
ys = f.(xs)
data = zip(xs,ys)
zip(-0.3:0.01:0.3, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0  …  1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

2. Define a model

Since we are only dealing with a simple one-dimensional curve, we can afford to use a small model. Let's choose an LBDN with four hidden layers, each with 16 neurons, and a Lipschitz bound of γ = 10.0. This means that the maximum slope the model can achieve between two points should be exactly 10.0 by construction.

using Random
using RobustNeuralNetworks

# Random seed for consistency
rng = MersenneTwister(42)

# Model specification
nu = 1                  # Number of inputs
ny = 1                  # Number of outputs
nh = fill(16,4)         # 4 hidden layers, each with 16 neurons
γ = 10                  # Lipschitz bound of 10

# Set up model: define parameters, then create model
model_ps = DenseLBDNParams{Float64}(nu, nh, ny, γ; rng)
model = DiffLBDN(model_ps)
RobustNeuralNetworks.DiffLBDN{Float64}(NNlib.relu, 1, [16, 16, 16, 16], 1, RobustNeuralNetworks.DenseLBDNParams{Float64}(NNlib.relu, 1, [16, 16, 16, 16], 1, RobustNeuralNetworks.DirectLBDNParams{Float64, 5, 4}(([-0.13688436150550842 -0.3915737271308899 … 0.018839871510863304 0.2584531009197235; -0.10939962416887283 0.10109587758779526 … 0.03919880464673042 -0.02107500284910202; … ; -0.21816867589950562 0.140190988779068 … 0.37911009788513184 0.18024104833602905; 0.16852818429470062 0.16767153143882751 … -0.02548663690686226 0.25702613592147827], [-0.1688804030418396 -0.07811985909938812 … 0.2257712483406067 0.05619138106703758; -0.3693792521953583 -0.4018996059894562 … -0.20043642818927765 -0.3080611228942871; … ; -0.24779440462589264 -0.1136888638138771 … -0.1585552841424942 0.16035549342632294; -0.04835795238614082 0.010114463977515697 … 0.19538505375385284 -0.13168814778327942], [-0.23528078198432922 -0.14559900760650635 … -0.1999872326850891 0.08348648995161057; -0.04775254428386688 0.04699297249317169 … -0.262779176235199 -0.12548503279685974; … ; -0.32253220677375793 0.049679722636938095 … -0.5403138399124146 0.1455782651901245; 0.06204642727971077 -0.07663608342409134 … -0.0583852156996727 -0.10783261805772781], [0.34514960646629333 -0.20820173621177673 … -0.14836210012435913 -0.3138638436794281; -0.08742053061723709 -0.03044208139181137 … 0.1568867564201355 0.2581433355808258; … ; -0.23312348127365112 0.05836038291454315 … 0.10704799741506577 -0.09697312861680984; -0.03870696946978569 0.15158973634243011 … -0.13241752982139587 0.15475988388061523], [0.22192633152008057; -0.3082961142063141; … ; -0.38415879011154175; -0.6190966367721558;;]), ([4.220611207108098], [4.808080373772013], [4.678697523506925], [4.567501950762178], [1.4328398299143486]), ([0.2949219048023224, 0.2776477634906769, 0.26687565445899963, -0.33890992403030396, -0.0601242296397686, -0.4912010729312897, 0.17087963223457336, 0.4991479814052582, -0.2965635359287262, 0.2151995152235031, 0.24747225642204285, 0.351551353931427, 0.5505682229995728, -0.14321142435073853, -0.40737977623939514, 0.6745933890342712], [-0.17201782763004303, 0.1513056457042694, 0.441098153591156, -0.15050721168518066, -0.12680283188819885, -0.2454695701599121, -0.31577154994010925, -0.40353816747665405, 0.21354864537715912, 0.2864471971988678, -0.05785873904824257, -0.36150500178337097, -0.14177212119102478, -0.1755463182926178, -0.07045135647058487, -0.16310331225395203], [0.2597928047180176, -0.42728325724601746, 0.15640889108181, 0.04983280599117279, -0.21352699398994446, -0.1168680265545845, 0.07855833321809769, -0.2857336103916168, -0.07068639248609543, 0.08566267043352127, -0.27968358993530273, 0.3596493601799011, 0.10564349591732025, -0.13930249214172363, -0.3795991837978363, 0.032387617975473404], [-0.2801879346370697, 0.4566338062286377, 0.3702215850353241, -0.3318230211734772, -0.41208863258361816, 0.141834557056427, 0.04177972674369812, -0.13789023458957672, 0.3563082814216614, 0.3375706076622009, -0.1371907889842987, 0.37675443291664124, 0.5317311882972717, -0.1089698076248169, 0.20241987705230713, 0.43033576011657715]), ([0.5561270117759705, 0.943281888961792, 0.5950038433074951, -0.4219045042991638, -0.12413075566291809, 0.05632342770695686, -0.09764140099287033, 0.6227636933326721, -0.2639862596988678, -0.12892699241638184, -0.10817783325910568, 0.07958763092756271, -0.31016403436660767, 0.050157107412815094, 0.07657332718372345, -0.025484846904873848], [0.26330801844596863, 0.37769031524658203, -0.07185274362564087, -0.3942621648311615, -0.11825945973396301, -0.41797399520874023, 0.0028410409577190876, -0.7552792429924011, -0.13185355067253113, 0.2671450674533844, -0.3319263160228729, -0.43978604674339294, 0.34058287739753723, -0.31803256273269653, 0.13363948464393616, 0.0005925644654780626], [-0.1264718472957611, -0.2577274441719055, -0.13553689420223236, 0.0880647823214531, 0.30640465021133423, -0.12403088808059692, 0.0273556187748909, -0.08639398962259293, 0.11978327482938766, -0.16468779742717743, -0.022894911468029022, -0.38888290524482727, 0.11429232358932495, 0.21736784279346466, 0.41620510816574097, 0.01626288890838623], [-0.41817373037338257, -0.13462907075881958, -0.22239384055137634, 0.12524200975894928, -0.2507738769054413, 0.05554039776325226, -0.08715879917144775, -0.17783884704113007, -0.6454927325248718, 0.1585848331451416, -0.08791015297174454, 0.12313690781593323, -0.047714684158563614, -0.20945775508880615, -0.2146414816379547, 0.17544031143188477], [0.25976964831352234]), [2.302585092994046], false)))

Note that we first constructed the model parameters model_ps, and then created a callable model. In RobustNeuralNetworks.jl, model parameterisations are separated from "explicit" definitions of a model used for evaluation on data. See the Direct & explicit parameterisations for more information.

A layer-wise approach

We have also provided single LBDN layers with SandwichFC to mimic the layer-wise construction of models like with Flux.Dense. This may be more convenient for users used to working with Flux.jl.

For example, we can construct an identical model to the LBDN model above with the following.

using Flux

chain_model = Flux.Chain(
    (x) -> (√γ * x),
    SandwichFC(nu => nh[1], Flux.relu; T=Float64, rng),
    SandwichFC(nh[1] => nh[2], Flux.relu; T=Float64, rng),
    SandwichFC(nh[2] => nh[3], Flux.relu; T=Float64, rng),
    SandwichFC(nh[3] => nh[4], Flux.relu; T=Float64, rng),
    (x) -> (√γ * x),
    SandwichFC(nh[4] => ny; output_layer=true, T=Float64, rng),
)

See Section 3.1 of Wang & Manchester (2023) for further details.

3. Define a loss function

Let's stick to a simple loss function based on the mean-squared error (MSE) for this example. All AbstractLBDN models take an AbstractArray as their input, which is why x and y are wrapped in vectors.

# Loss function
loss(model,x,y) = Flux.mse(model([x]),[y])
loss (generic function with 1 method)

4. Train the model

Our objective is to minimise the loss function with a model that has a Lipschitz bound no greater than 10.0. Let's set up a callback function to check the fit error and slope of our model at each training epoch.

using Flux

# Check fit error/slope during training
mse(model, xs, ys) = sum(loss.((model,), xs, ys)) / length(xs)
lip(model, xs, dx) = maximum(abs.(diff(model(xs'), dims=2)))/dx

# Callback function to show results while training
function progress(model, iter, xs, ys, dx)
    fit_error = round(mse(model, xs, ys), digits=4)
    slope = round(lip(model, xs, dx), digits=4)
    @show iter fit_error slope
    println()
end
progress (generic function with 1 method)

We'll train the model for 300 training epochs a learning rate of lr = 2e-4. We'll also use the Adam optimiser from Flux.jl and the default Flux.train! method.

# Define hyperparameters and optimiser
num_epochs = 300
lr = 2e-4
opt_state = Flux.setup(Adam(lr), model)

# Train the model
for i in 1:num_epochs
    Flux.train!(loss, model, data, opt_state)
    (i % 100 == 0) && progress(model, i, xs, ys, dx)
end
iter = 100
fit_error = 0.0171
slope = 9.0355

iter = 200
fit_error = 0.0153
slope = 9.8895

iter = 300
fit_error = 0.0148
slope = 9.9432

Note that this training loop is for demonstration only. For a better fit, or on more complex problems, we strongly recommend:

5. Examine the trained model

The final estimated lower bound of our Lipschitz constantt is very close to the maximum allowable value of 10.0.

using Printf

# Estimate Lipschitz lower-bound
Empirical_Lipschitz = lip(model, xs, dx)
@printf "Empirical lower Lipschitz bound: %.2f\n" Empirical_Lipschitz
Empirical lower Lipschitz bound: 9.94

We can now plot the results to see what our model looks like.

using CairoMakie

# Create a figure
f1 = Figure(resolution = (600, 400))
ax = Axis(f1[1,1], xlabel="x", ylabel="y")

# Compute the best-possible fit with Lipschitz bound 10.0
get_best(x) = x<-0.05 ? 0 : (x<0.05 ? 10x + 0.5 : 1)
ybest = get_best.(xs)
ŷ = map(x -> model([x])[1], xs)

# Plot
lines!(xs, ys, label = "Data")
lines!(xs, ybest, label = "Maximum slope = 10.0")
lines!(xs, ŷ, label = "LBDN: slope = $(round(Empirical_Lipschitz; digits=2))")
axislegend(ax, position=:lt)
save("lbdn_curve_fit.svg", f1)
CairoMakie.Screen{SVG}

The model roughly approximates the step function $f(x)$, but maintains a maximum Lipschitz constant (slope on the graph) below 10.0. It is reasonably close to the best-possible value, and can easily be improved with a slightly larger model and more training time.

The benefit of using an LBDN is that we have full control over the Lipschitz bound, and can still use standard unconstrained gradient descent tools lile Flux.train! to train our models. For examples in which setting the Lipschitz bound improves model performance and robustness, see Image Classification with LBDN and Reinforcement Learning with LBDN.