# Neural Stochastic Differential Equations

With neural stochastic differential equations, there is once again a helper form `neural_dmsde`

which can be used for the multiplicative noise case (consult the layers API documentation, or this full example using the layer function).

However, since there are far too many possible combinations for the API to support, in many cases you will want to performantly define neural differential equations for non-ODE systems from scratch. For these systems, it is generally best to use `TrackerAdjoint`

with non-mutating (out-of-place) forms. For example, the following defines a neural SDE with neural networks for both the drift and diffusion terms:

```
dudt(u, p, t) = model(u)
g(u, p, t) = model2(u)
prob = SDEProblem(dudt, g, x, tspan, nothing)
```

where `model`

and `model2`

are different neural networks. The same can apply to a neural delay differential equation. Its out-of-place formulation is `f(u,h,p,t)`

. Thus for example, if we want to define a neural delay differential equation which uses the history value at `p.tau`

in the past, we can define:

```
dudt!(u, h, p, t) = model([u; h(t - p.tau)])
prob = DDEProblem(dudt_, u0, h, tspan, nothing)
```

First let's build training data from the same example as the neural ODE:

```
using Plots, Statistics
using Lux, Optimization, OptimizationFlux, DiffEqFlux, StochasticDiffEq, DiffEqBase.EnsembleAnalysis, Random
rng = Random.default_rng()
u0 = Float32[2.; 0.]
datasize = 30
tspan = (0.0f0, 1.0f0)
tsteps = range(tspan[1], tspan[2], length = datasize)
```

```
function trueSDEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end
mp = Float32[0.2, 0.2]
function true_noise_func(du, u, p, t)
du .= mp.*u
end
prob_truesde = SDEProblem(trueSDEfunc, true_noise_func, u0, tspan)
```

For our dataset we will use DifferentialEquations.jl's parallel ensemble interface to generate data from the average of 10,000 runs of the SDE:

```
# Take a typical sample from the mean
ensemble_prob = EnsembleProblem(prob_truesde)
ensemble_sol = solve(ensemble_prob, SOSRI(), trajectories = 10000)
ensemble_sum = EnsembleSummary(ensemble_sol)
sde_data, sde_data_vars = Array.(timeseries_point_meanvar(ensemble_sol, tsteps))
```

Now we build a neural SDE. For simplicity we will use the `NeuralDSDE`

neural SDE with diagonal noise layer function:

```
drift_dudt = Lux.Chain(ActivationFunction(x -> x.^3),
Lux.Dense(2, 50, tanh),
Lux.Dense(50, 2))
p1, st1 = Lux.setup(rng, drift_dudt)
diffusion_dudt = Lux.Chain(Lux.Dense(2, 2))
p2, st2 = Lux.setup(rng, diffusion_dudt)
p1 = Lux.ComponentArray(p1)
p2 = Lux.ComponentArray(p2)
#Component Arrays doesn't provide a name to the first ComponentVector, only subsequent ones get a name for dereferencing
p = [p1, p2]
neuralsde = NeuralDSDE(drift_dudt, diffusion_dudt, tspan, SOSRI(),
saveat = tsteps, reltol = 1e-1, abstol = 1e-1)
```

Let's see what that looks like:

```
# Get the prediction using the correct initial condition
prediction0, st1, st2 = neuralsde(u0,p,st1,st2)
drift_(u, p, t) = drift_dudt(u, p[1], st1)[1]
diffusion_(u, p, t) = diffusion_dudt(u, p[2], st2)[1]
prob_neuralsde = SDEProblem(drift_, diffusion_, u0,(0.0f0, 1.2f0), p)
ensemble_nprob = EnsembleProblem(prob_neuralsde)
ensemble_nsol = solve(ensemble_nprob, SOSRI(), trajectories = 100,
saveat = tsteps)
ensemble_nsum = EnsembleSummary(ensemble_nsol)
plt1 = plot(ensemble_nsum, title = "Neural SDE: Before Training")
scatter!(plt1, tsteps, sde_data', lw = 3)
scatter(tsteps, sde_data[1,:], label = "data")
scatter!(tsteps, prediction0[1,:], label = "prediction")
```

Now just as with the neural ODE we define a loss function that calculates the mean and variance from `n`

runs at each time point and uses the distance from the data values:

```
function predict_neuralsde(p, u = u0)
return Array(neuralsde(u, p, st1, st2)[1])
end
function loss_neuralsde(p; n = 100)
u = repeat(reshape(u0, :, 1), 1, n)
samples = predict_neuralsde(p, u)
means = mean(samples, dims = 2)
vars = var(samples, dims = 2, mean = means)[:, 1, :]
means = means[:, 1, :]
loss = sum(abs2, sde_data - means) + sum(abs2, sde_data_vars - vars)
return loss, means, vars
end
```

```
list_plots = []
iter = 0
# Callback function to observe training
callback = function (p, loss, means, vars; doplot = false)
global list_plots, iter
if iter == 0
list_plots = []
end
iter += 1
# loss against current data
display(loss)
# plot current prediction against data
plt = Plots.scatter(tsteps, sde_data[1,:], yerror = sde_data_vars[1,:],
ylim = (-4.0, 8.0), label = "data")
Plots.scatter!(plt, tsteps, means[1,:], ribbon = vars[1,:], label = "prediction")
push!(list_plots, plt)
if doplot
display(plt)
end
return false
end
```

Now we train using this loss function. We can pre-train a little bit using a smaller `n`

and then decrease it after it has had some time to adjust towards the right mean behavior:

```
opt = ADAM(0.025)
# First round of training with n = 10
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss_neuralsde(x, n=10), adtype)
optprob = Optimization.OptimizationProblem(optf, p)
result1 = Optimization.solve(optprob, opt,
callback = callback, maxiters = 100)
```

We resume the training with a larger `n`

. (WARNING - this step is a couple of orders of magnitude longer than the previous one).

```
optf2 = Optimization.OptimizationFunction((x,p) -> loss_neuralsde(x, n=100), adtype)
optprob2 = Optimization.OptimizationProblem(optf2, result1.u)
result2 = Optimization.solve(optprob2, opt,
callback = callback, maxiters = 100)
```

And now we plot the solution to an ensemble of the trained neural SDE:

```
_, means, vars = loss_neuralsde(result2.u, n = 1000)
plt2 = Plots.scatter(tsteps, sde_data', yerror = sde_data_vars',
label = "data", title = "Neural SDE: After Training",
xlabel = "Time")
plot!(plt2, tsteps, means', lw = 8, ribbon = vars', label = "prediction")
plt = plot(plt1, plt2, layout = (2, 1))
savefig(plt, "NN_sde_combined.png"); nothing # sde
```

Try this with GPUs as well!