Parameter Estimation on Highly Stiff Systems

This tutorial goes into training a model on stiff chemical reaction system data.

Copy-Pasteable Code

Before getting to the explanation, here's some code to start with. We will follow a full explanation of the definition and training process:

using DifferentialEquations, DiffEqFlux, Optimization, OptimizationOptimJL, LinearAlgebra
using ForwardDiff
using DiffEqBase: UJacobianWrapper
using Plots
function rober(du,u,p,t)
    y₁,y₂,y₃ = u
    k₁,k₂,k₃ = p
    du[1] = -k₁*y₁+k₃*y₂*y₃
    du[2] =  k₁*y₁-k₂*y₂^2-k₃*y₂*y₃
    du[3] =  k₂*y₂^2
    nothing
end

p = [0.04,3e7,1e4]
u0 = [1.0,0.0,0.0]
prob = ODEProblem(rober,u0,(0.0,1e5),p)
sol = solve(prob,Rosenbrock23())
ts = sol.t
Js = map(u->I + 0.1*ForwardDiff.jacobian(UJacobianWrapper(rober, 0.0, p), u), sol.u)

function predict_adjoint(p)
    p = exp.(p)
    _prob = remake(prob,p=p)
    Array(solve(_prob,Rosenbrock23(autodiff=false),saveat=ts,sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))))
end

function loss_adjoint(p)
    prediction = predict_adjoint(p)
    prediction = [prediction[:, i] for i in axes(prediction, 2)]
    diff = map((J,u,data) -> J * (abs2.(u .- data)) , Js, prediction, sol.u)
    loss = sum(abs, sum(diff)) |> sqrt
    loss, prediction
end

callback = function (p,l,pred) #callback function to observe training
    println("Loss: $l")
    println("Parameters: $(exp.(p))")
    # using `remake` to re-create our `prob` with current parameters `p`
    plot(solve(remake(prob, p=exp.(p)), Rosenbrock23())) |> display
    return false # Tell it to not halt the optimization. If return true, then optimization stops
end

initp = ones(3)
# Display the ODE with the initial parameter values.
callback(initp,loss_adjoint(initp)...)

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p)->loss_adjoint(x), adtype)
optprob = Optimization.OptimizationProblem(optf, initp)

res = Optimization.solve(optprob, ADAM(0.01), callback = callback, maxiters = 300)

optprob2 = Optimization.OptimizationProblem(optf, res.u)

res2 = Optimization.solve(optprob2, BFGS(), callback = callback, maxiters = 30, allow_f_increases=true)
println("Ground truth: $(p)\nFinal parameters: $(round.(exp.(res2.u), sigdigits=5))\nError: $(round(norm(exp.(res2.u) - p) ./ norm(p) .* 100, sigdigits=3))%")

Output:

Ground truth: [0.04, 3.0e7, 10000.0]
Final parameters: [0.040002, 3.0507e7, 10084.0]
Error: 1.69%

Explanation

First, let's get a time series array from the Robertson's equation as data.

using DifferentialEquations, DiffEqFlux, Optimization, OptimizationOptimJL, LinearAlgebra
using ForwardDiff
using DiffEqBase: UJacobianWrapper
using Plots
function rober(du,u,p,t)
    y₁,y₂,y₃ = u
    k₁,k₂,k₃ = p
    du[1] = -k₁*y₁+k₃*y₂*y₃
    du[2] =  k₁*y₁-k₂*y₂^2-k₃*y₂*y₃
    du[3] =  k₂*y₂^2
    nothing
end

p = [0.04,3e7,1e4]
u0 = [1.0,0.0,0.0]
prob = ODEProblem(rober,u0,(0.0,1e5),p)
sol = solve(prob,Rosenbrock23())
ts = sol.t
Js = map(u->I + 0.1*ForwardDiff.jacobian(UJacobianWrapper(rober, 0.0, p), u), sol.u)

Note that we also computed a shifted and scaled Jacobian along with the solution. We will use this matrix to scale the loss later.

We fit the parameters in log space, so we need to compute exp.(p) to get back the original parameters.

function predict_adjoint(p)
    p = exp.(p)
    _prob = remake(prob,p=p)
    Array(solve(_prob,Rosenbrock23(autodiff=false),saveat=ts,sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))))
end

function loss_adjoint(p)
    prediction = predict_adjoint(p)
    prediction = [prediction[:, i] for i in axes(prediction, 2)]
    diff = map((J,u,data) -> J * (abs2.(u .- data)) , Js, prediction, sol.u)
    loss = sum(abs, sum(diff)) |> sqrt
    loss, prediction
end

The difference between the data and the prediction is weighted by the transformed Jacobian to do a relative scaling of the loss.

We define a callback function.

callback = function (p,l,pred) #callback function to observe training
    println("Loss: $l")
    println("Parameters: $(exp.(p))")
    # using `remake` to re-create our `prob` with current parameters `p`
    plot(solve(remake(prob, p=exp.(p)), Rosenbrock23())) |> display
    return false # Tell it to not halt the optimization. If return true, then optimization stops
end

We then use a combination of ADAM and BFGS to minimize the loss function to accelerate the optimization. The initial guess of the parameters are chosen to be [1, 1, 1.0].

initp = ones(3)
# Display the ODE with the initial parameter values.
callback(initp,loss_adjoint(initp)...)

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p)->loss_adjoint(x), adtype)

optprob = Optimization.OptimizationProblem(optf, initp)
res = Optimization.solve(optprob, ADAM(0.01), callback = callback, maxiters = 300)

optprob2 = Optimization.OptimizationProblem(optf, res.u)
res2 = Optimization.solve(optprob2, BFGS(), callback = callback, maxiters = 30, allow_f_increases=true)

Finally, we can analyze the difference between the fitted parameters and the ground truth.

println("Ground truth: $(p)\nFinal parameters: $(round.(exp.(res2.u), sigdigits=5))\nError: $(round(norm(exp.(res2.u) - p) ./ norm(p) .* 100, sigdigits=3))%")

It gives the output

Ground truth: [0.04, 3.0e7, 10000.0]
Final parameters: [0.040002, 3.0507e7, 10084.0]
Error: 1.69%