Enforcing Physical Constraints via Universal Differential-Algebraic Equations
As shown in the stiff ODE tutorial, differential-algebraic equations (DAEs) can be used to impose physical constraints. One way to define a DAE is through an ODE with a singular mass matrix. For example, if we make Mu' = f(u)
where the last row of M
is all zeros, then we have a constraint defined by the right hand side. Using NeuralODEMM
, we can use this to define a neural ODE where the sum of all 3 terms must add to one. An example of this is as follows:
using Lux, DiffEqFlux, Optimization, OptimizationOptimJL, DifferentialEquations, Plots
using Random
rng = Random.default_rng()
function f!(du, u, p, t)
y₁, y₂, y₃ = u
k₁, k₂, k₃ = p
du[1] = -k₁*y₁ + k₃*y₂*y₃
du[2] = k₁*y₁ - k₃*y₂*y₃ - k₂*y₂^2
du[3] = y₁ + y₂ + y₃ - 1
return nothing
end
u₀ = [1.0, 0, 0]
M = [1. 0 0
0 1. 0
0 0 0]
tspan = (0.0,1.0)
p = [0.04, 3e7, 1e4]
stiff_func = ODEFunction(f!, mass_matrix = M)
prob_stiff = ODEProblem(stiff_func, u₀, tspan, p)
sol_stiff = solve(prob_stiff, Rodas5(), saveat = 0.1)
nn_dudt2 = Lux.Chain(Lux.Dense(3, 64, tanh),
Lux.Dense(64, 2))
pinit, st = Lux.setup(rng, nn_dudt2)
model_stiff_ndae = NeuralODEMM(nn_dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1],
tspan, M, Rodas5(autodiff=false), saveat = 0.1)
model_stiff_ndae(u₀, Lux.ComponentArray(pinit), st)
function predict_stiff_ndae(p)
return model_stiff_ndae(u₀, p, st)[1]
end
function loss_stiff_ndae(p)
pred = predict_stiff_ndae(p)
loss = sum(abs2, Array(sol_stiff) .- pred)
return loss, pred
end
# callback = function (p, l, pred) #callback function to observe training
# display(l)
# return false
# end
l1 = first(loss_stiff_ndae(Lux.ComponentArray(pinit)))
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss_stiff_ndae(x), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(pinit))
result_stiff = Optimization.solve(optprob, BFGS(), maxiters=100)
Step-by-Step Description
Load Packages
using Lux, DiffEqFlux, Optimization, OptimizationOptimJL, DifferentialEquations, Plots
using Random
rng = Random.default_rng()
Differential Equation
First, we define our differential equations as a highly stiff problem which makes the fitting difficult.
function f!(du, u, p, t)
y₁, y₂, y₃ = u
k₁, k₂, k₃ = p
du[1] = -k₁*y₁ + k₃*y₂*y₃
du[2] = k₁*y₁ - k₃*y₂*y₃ - k₂*y₂^2
du[3] = y₁ + y₂ + y₃ - 1
return nothing
end
Parameters
u₀ = [1.0, 0, 0]
M = [1. 0 0
0 1. 0
0 0 0]
tspan = (0.0,1.0)
p = [0.04, 3e7, 1e4]
u₀
= Initial ConditionsM
= Semi-explicit Mass Matrix (last row is the constraint equation and are therefore
all zeros)
tspan
= Time span over which to evaluatep
= parametersk1
,k2
andk3
of the differential equation above
ODE Function, Problem and Solution
We define and solve our ODE problem to generate the "labeled" data which will be used to train our Neural Network.
stiff_func = ODEFunction(f!, mass_matrix = M)
prob_stiff = ODEProblem(stiff_func, u₀, tspan, p)
sol_stiff = solve(prob_stiff, Rodas5(), saveat = 0.1)
Because this is a DAE we need to make sure to use a compatible solver. Rodas5
works well for this example.
Neural Network Layers
Next, we create our layers using Lux.Chain
. We use this instead of Flux.Chain
because it is more suited to SciML applications (similarly for Lux.Dense
). The input to our network will be the initial conditions fed in as u₀
.
nn_dudt2 = Lux.Chain(Lux.Dense(3, 64, tanh),
Lux.Dense(64, 2))
pinit, st = Lux.setup(rng, nn_dudt2)
model_stiff_ndae = NeuralODEMM(nn_dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1],
tspan, M, Rodas5(autodiff=false), saveat = 0.1)
model_stiff_ndae(u₀, Lux.ComponentArray(pinit), st)
Because this is a stiff problem, we have manually imposed that sum constraint via (u,p,t) -> [u[1] + u[2] + u[3] - 1]
, making the fitting easier.
Prediction Function
For simplicity, we define a wrapper function that only takes in the model's parameters to make predictions.
function predict_stiff_ndae(p)
return model_stiff_ndae(u₀, p, st)[1]
end
Train Parameters
Training our network requires a loss function, an optimizer and a callback function to display the progress.
Loss
We first make our predictions based on the current parameters, then calculate the loss from these predictions. In this case, we use least squares as our loss.
function loss_stiff_ndae(p)
pred = predict_stiff_ndae(p)
loss = sum(abs2, sol_stiff .- pred)
return loss, pred
end
l1 = first(loss_stiff_ndae(Lux.ComponentArray(pinit)))
Notice that we are feeding the parameters of model_stiff_ndae
to the loss_stiff_ndae
function. model_stiff_node.p
are the weights of our NN and is of size 386 (4 * 64 + 65 * 2) including the biases.
Optimizer
The optimizer is BFGS
(see below).
Callback
The callback function displays the loss during training.
callback = function (p, l, pred) #callback function to observe training
display(l)
return false
end
Train
Finally, training with Optimization.solve
by passing: loss function, model parameters, optimizer, callback and maximum iteration.
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss_stiff_ndae(x), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(pinit))
result_stiff = Optimization.solve(optprob, BFGS(), maxiters=100)