# Simultaneous Fitting of Multiple Neural Networks

In many cases users are interested in fitting multiple neural networks or parameters simultaneously. This tutorial addresses how to perform this kind of study.

The following is a fully working demo on the Fitzhugh-Nagumo ODE:

```
using Lux, DiffEqFlux, Optimizaton, OptimizationOptimJL, DifferentialEquations, Random
rng = Random.default_rng()
function fitz(du,u,p,t)
v,w = u
a,b,τinv,l = p
du[1] = v - v^3/3 -w + l
du[2] = τinv*(v + a - b*w)
end
p_ = Float32[0.7,0.8,1/12.5,0.5]
u0 = [1f0;1f0]
tspan = (0f0,10f0)
prob = ODEProblem(fitz,u0,tspan,p_)
sol = solve(prob, Tsit5(), saveat = 0.5 )
# Ideal data
X = Array(sol)
Xₙ = X + Float32(1e-3)*randn(eltype(X), size(X)) #noisy data
# For xz term
NN_1 = Lux.Chain(Lux.Dense(2, 16, tanh), Lux.Dense(16, 1))
p1,st1 = Lux.setup(rng, NN_1)
# for xy term
NN_2 = Lux.Chain(Lux.Dense(3, 16, tanh), Lux.Dense(16, 1))
p2 = Lux.setup(rng, NN_2)
scaling_factor = 1f0
p1 = Lux.ComponentArray(p1)
p2 = Lux.ComponentArray(p2)
p = Lux.ComponentArray(p1;p1)
p = Lux.ComponentArray(p;p2)
function dudt_(u,p,t)
v,w = u
z1 = NN_1([v,w], p.p1, st1)[1]
z2 = NN_2([v,w,t], p.p2, st2)[1]
[z1[1],scaling_factor*z2[1]]
end
prob_nn = ODEProblem(dudt_,u0, tspan, p)
sol_nn = solve(prob_nn, Tsit5(),saveat = sol.t)
function predict(θ)
Array(solve(prob_nn, Vern7(), p=θ, saveat = sol.t,
abstol=1e-6, reltol=1e-6,
sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end
# No regularisation right now
function loss(θ)
pred = predict(θ)
sum(abs2, Xₙ .- pred), pred
end
loss(p)
const losses = []
callback(θ,l,pred) = begin
push!(losses, l)
if length(losses)%50==0
println(losses[end])
end
false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p)
res1_uode = Optimization.solve(optprob, ADAM(0.01), callback=callback, maxiters = 500)
optprob2 = Optimization.OptimizationProblem(optf, res1_uode.u)
res2_uode = Optimization.solve(optprob2, BFGS(), maxiters = 10000)
```

The key is that `Optimization.solve`

acts on a single parameter vector `p`

. Thus what we do here is concatenate all of the parameters into a single vector `p = [p1;p2;scaling_factor]`

and then train on this parameter vector. Whenever we need to evaluate the neural networks, we cut the vector and grab the portion that corresponds to the neural network. For example, the `p1`

portion is `p[1:length(p1)]`

, which is why the first neural network's evolution is written like `NN_1([v,w], p[1:length(p1)])`

.

This method is flexible to use with many optimizers and in fairly optimized ways. The allocations can be reduced by using `@view p[1:length(p1)]`

. We can also see with the `scaling_factor`

that we can grab parameters directly out of the vector and use them as needed.