Neural Stochastic Differential Equations
With neural stochastic differential equations, there is once again a helper form neural_dmsde
which can be used for the multiplicative noise case (consult the layers API documentation, or this full example using the layer function).
However, since there are far too many possible combinations for the API to support, in many cases you will want to performantly define neural differential equations for non-ODE systems from scratch. For these systems, it is generally best to use TrackerAdjoint
with non-mutating (out-of-place) forms. For example, the following defines a neural SDE with neural networks for both the drift and diffusion terms:
dudt(u, p, t) = model(u)
g(u, p, t) = model2(u)
prob = SDEProblem(dudt, g, x, tspan, nothing)
where model
and model2
are different neural networks. The same can apply to a neural delay differential equation. Its out-of-place formulation is f(u,h,p,t)
. Thus for example, if we want to define a neural delay differential equation which uses the history value at p.tau
in the past, we can define:
dudt!(u, h, p, t) = model([u; h(t - p.tau)])
prob = DDEProblem(dudt_, u0, h, tspan, nothing)
First let's build training data from the same example as the neural ODE:
using Plots, Statistics
using Lux, Optimization, OptimizationFlux, DiffEqFlux, StochasticDiffEq, DiffEqBase.EnsembleAnalysis, Random
rng = Random.default_rng()
u0 = Float32[2.; 0.]
datasize = 30
tspan = (0.0f0, 1.0f0)
tsteps = range(tspan[1], tspan[2], length = datasize)
function trueSDEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end
mp = Float32[0.2, 0.2]
function true_noise_func(du, u, p, t)
du .= mp.*u
end
prob_truesde = SDEProblem(trueSDEfunc, true_noise_func, u0, tspan)
For our dataset we will use DifferentialEquations.jl's parallel ensemble interface to generate data from the average of 10,000 runs of the SDE:
# Take a typical sample from the mean
ensemble_prob = EnsembleProblem(prob_truesde)
ensemble_sol = solve(ensemble_prob, SOSRI(), trajectories = 10000)
ensemble_sum = EnsembleSummary(ensemble_sol)
sde_data, sde_data_vars = Array.(timeseries_point_meanvar(ensemble_sol, tsteps))
Now we build a neural SDE. For simplicity we will use the NeuralDSDE
neural SDE with diagonal noise layer function:
drift_dudt = Lux.Chain(ActivationFunction(x -> x.^3),
Lux.Dense(2, 50, tanh),
Lux.Dense(50, 2))
p1, st1 = Lux.setup(rng, drift_dudt)
diffusion_dudt = Lux.Chain(Lux.Dense(2, 2))
p2, st2 = Lux.setup(rng, diffusion_dudt)
p1 = Lux.ComponentArray(p1)
p2 = Lux.ComponentArray(p2)
#Component Arrays doesn't provide a name to the first ComponentVector, only subsequent ones get a name for dereferencing
p = [p1, p2]
neuralsde = NeuralDSDE(drift_dudt, diffusion_dudt, tspan, SOSRI(),
saveat = tsteps, reltol = 1e-1, abstol = 1e-1)
Let's see what that looks like:
# Get the prediction using the correct initial condition
prediction0, st1, st2 = neuralsde(u0,p,st1,st2)
drift_(u, p, t) = drift_dudt(u, p[1], st1)[1]
diffusion_(u, p, t) = diffusion_dudt(u, p[2], st2)[1]
prob_neuralsde = SDEProblem(drift_, diffusion_, u0,(0.0f0, 1.2f0), p)
ensemble_nprob = EnsembleProblem(prob_neuralsde)
ensemble_nsol = solve(ensemble_nprob, SOSRI(), trajectories = 100,
saveat = tsteps)
ensemble_nsum = EnsembleSummary(ensemble_nsol)
plt1 = plot(ensemble_nsum, title = "Neural SDE: Before Training")
scatter!(plt1, tsteps, sde_data', lw = 3)
scatter(tsteps, sde_data[1,:], label = "data")
scatter!(tsteps, prediction0[1,:], label = "prediction")
Now just as with the neural ODE we define a loss function that calculates the mean and variance from n
runs at each time point and uses the distance from the data values:
function predict_neuralsde(p, u = u0)
return Array(neuralsde(u, p, st1, st2)[1])
end
function loss_neuralsde(p; n = 100)
u = repeat(reshape(u0, :, 1), 1, n)
samples = predict_neuralsde(p, u)
means = mean(samples, dims = 2)
vars = var(samples, dims = 2, mean = means)[:, 1, :]
means = means[:, 1, :]
loss = sum(abs2, sde_data - means) + sum(abs2, sde_data_vars - vars)
return loss, means, vars
end
list_plots = []
iter = 0
# Callback function to observe training
callback = function (p, loss, means, vars; doplot = false)
global list_plots, iter
if iter == 0
list_plots = []
end
iter += 1
# loss against current data
display(loss)
# plot current prediction against data
plt = Plots.scatter(tsteps, sde_data[1,:], yerror = sde_data_vars[1,:],
ylim = (-4.0, 8.0), label = "data")
Plots.scatter!(plt, tsteps, means[1,:], ribbon = vars[1,:], label = "prediction")
push!(list_plots, plt)
if doplot
display(plt)
end
return false
end
Now we train using this loss function. We can pre-train a little bit using a smaller n
and then decrease it after it has had some time to adjust towards the right mean behavior:
opt = ADAM(0.025)
# First round of training with n = 10
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss_neuralsde(x, n=10), adtype)
optprob = Optimization.OptimizationProblem(optf, p)
result1 = Optimization.solve(optprob, opt,
callback = callback, maxiters = 100)
We resume the training with a larger n
. (WARNING - this step is a couple of orders of magnitude longer than the previous one).
optf2 = Optimization.OptimizationFunction((x,p) -> loss_neuralsde(x, n=100), adtype)
optprob2 = Optimization.OptimizationProblem(optf2, result1.u)
result2 = Optimization.solve(optprob2, opt,
callback = callback, maxiters = 100)
And now we plot the solution to an ensemble of the trained neural SDE:
_, means, vars = loss_neuralsde(result2.u, n = 1000)
plt2 = Plots.scatter(tsteps, sde_data', yerror = sde_data_vars',
label = "data", title = "Neural SDE: After Training",
xlabel = "Time")
plot!(plt2, tsteps, means', lw = 8, ribbon = vars', label = "prediction")
plt = plot(plt1, plt2, layout = (2, 1))
savefig(plt, "NN_sde_combined.png"); nothing # sde
Try this with GPUs as well!