Differentiating an ODE Solution with Automatic Differentiation
This tutorial assumes familiarity with DifferentialEquations.jl If you are not familiar with DifferentialEquations.jl, please consult the DifferentialEquations.jl documentation
In this tutorial we will introduce how to use local sensitivity analysis via automatic differentiation. The automatic differentiation interfaces are the most common ways that local sensitivity analysis is done. It's fairly fast and flexible, but most notably, it's a very small natural extension to the normal differential equation solving code and is thus the easiest way to do most things.
Setup
Let's first define a differential equation we wish to solve. We will choose the Lotka-Volterra equation. This is done via DifferentialEquations.jl using:
using DifferentialEquations
function lotka_volterra!(du,u,p,t)
du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2]
end
p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0]
prob = ODEProblem(lotka_volterra!,u0,(0.0,10.0),p)
sol = solve(prob,Tsit5(),reltol=1e-6,abstol=1e-6)
retcode: Success
Interpolation: specialized 4th order "free" interpolation
t: 104-element Vector{Float64}:
0.0
0.02238867177415836
0.06688455734042167
0.12204058370592046
0.1901739257791462
0.2700959083607983
0.36248999526528664
0.46634994723326256
0.5804932836619872
0.7035671542772408
⋮
9.363458886638309
9.438253922310366
9.514924252134602
9.594877279366209
9.679331500393587
9.769895418274855
9.868269463612691
9.975570519524737
10.0
u: 104-element Vector{Vector{Float64}}:
[1.0, 1.0]
[1.0117558257818347, 0.9563342092954507]
[1.0384182069752141, 0.8758683256119991]
[1.077484888768547, 0.7868751594529567]
[1.134905798469782, 0.6915812943328938]
[1.2153494575722215, 0.5976695148897933]
[1.3266197572115415, 0.509348485294802]
[1.4766111777880502, 0.4313359515293469]
[1.6746724108423345, 0.3664853909559147]
[1.9317154837401005, 0.3161360659954519]
⋮
[1.280491826065612, 3.211190344771746]
[1.1439255609696892, 2.808359820206377]
[1.0502507947520396, 2.426494798337826]
[0.9895322978770432, 2.070758606013577]
[0.9563824393777867, 1.7446019564591093]
[0.9484176810689373, 1.4490311382147962]
[0.9660833878411721, 1.1849913769152567]
[1.0122116164569332, 0.9547857522939497]
[1.0263542618115815, 0.9096831916565506]
Now let's differentiate the solution to this ODE using a few different automatic differentiation methods.
Forward-Mode Automatic Differentiation with ForwardDiff.jl
Let's say we need the derivative of the solution with respect to the initial condition u0
and its parameters p
. One of the simplest ways to do this is via ForwardDiff.jl. To do this, all that one needs to do is use the ForwardDiff.jl library to differentiate some function f
which uses a differential equation solve
inside of it. For example, let's say we want the derivative of the first component of ODE solution with respect to these quantities at evenly spaced time points of dt = 1
. We can compute this via:
using ForwardDiff
function f(x)
_prob = remake(prob,u0=x[1:2],p=x[3:end])
solve(_prob,Tsit5(),reltol=1e-6,abstol=1e-6,saveat=1)[1,:]
end
x = [u0;p]
dx = ForwardDiff.jacobian(f,x)
11×6 Matrix{Float64}:
1.0 0.0 0.0 0.0 0.0 0.0
2.14463 -1.1848 2.54832 -1.1848 0.477483 -0.628218
-5.88478 0.266338 -3.38158 0.266338 3.50594 -12.662
0.691824 0.3718 -0.762033 0.3718 -0.0477691 -0.278507
2.7989 -0.408784 3.80837 -0.408784 0.883252 0.914524
4.0171 -1.65424 12.3007 -1.65424 3.95659 -2.0814
-2.07453 0.851802 -7.0992 0.851802 -1.06005 -3.46806
2.63655 -0.00114306 3.54679 -0.00114306 0.872776 1.30651
7.88534 -0.610538 16.9144 -0.610538 4.3355 3.54215
-16.5707 0.866198 -36.104 0.866198 -5.67502 -19.8444
1.96602 0.188561 2.16063 0.188561 0.563199 0.939672
Let's dig into what this is saying a bit. x
is a vector which concatenates the initial condition and parameters, meaning that the first 2 values are the initial conditions and the last 4 are the parameters. We use the remake
function to build a function f(x)
which uses these new initial conditions and parameters to solve the differential equation and return the time series of the first component.
Then ForwardDiff.jacobian(f,x)
computes the Jacobian of f
with respect to x
. The output dx[i,j]
corresponds to the derivative of the solution of the first component at time t=j-1
with respect to x[i]
. For example, dx[3,2]
is the derivative of the first component of the solution at time t=1
with respect to p[1]
.
Since the global error is 1-2 orders of magnitude higher than the local error, we use accuracies of 1e-6 (instead of the default 1e-3) to get reasonable sensitivities
Reverse-Mode Automatic Differentiation
The solve
function is automatically compatible with AD systems like Zygote.jl and thus there is no machinery that is necessary to use other than to put solve
inside of a function that is differentiated by Zygote. For example, the following computes the solution to an ODE and computes the gradient of a loss function (the sum of the ODE's output at each timepoint with dt=0.1) via the adjoint method:
using Zygote, DiffEqSensitivity
function sum_of_solution(u0,p)
_prob = remake(prob,u0=u0,p=p)
sum(solve(_prob,Tsit5(),reltol=1e-6,abstol=1e-6,saveat=0.1))
end
du01,dp1 = Zygote.gradient(sum_of_solution,u0,p)
([-39.12773752717614, -8.787495434529855], [8.304244027642122, -159.48401961535745, 75.20316229889958, -339.1951631373864])
Zygote.jl's automatic differentiation system is overloaded to allow SciMLSensitivity.jl to redefine the way the derivatives are computed, allowing trade-offs between numerical stability, memory, and compute performance, similar to how ODE solver algorithms are chosen. The algorithms for differentiation calculation are called AbstractSensitivityAlgorithms
, or sensealg
s for short. These are choosen by passing the sensealg
keyword argument into solve.
Let's demonstrate this by choosing the QuadratureAdjoint
sensealg
for the differentiation of this system:
function sum_of_solution(u0,p)
_prob = remake(prob,u0=u0,p=p)
sum(solve(_prob,Tsit5(),reltol=1e-6,abstol=1e-6,saveat=0.1,sensealg=QuadratureAdjoint()))
end
du01,dp1 = Zygote.gradient(sum_of_solution,u0,p)
([-39.12610324833554, -8.78792570678882], [8.30761040004664, -159.4845962251739, 75.20354296957558, -339.1934967454819])
Here this computes the derivative of the output with respect to the initial condition and the the derivative with respect to the parameters respectively using the QuadratureAdjoint()
. For more information on the choices of sensitivity algorithms, see the reference documentation in choosing sensitivity algorithms
When Should You Use Forward or Reverse Mode?
Good question! The simple answer is, if you are differentiating a system of 100 equations or less, use forward-mode, otherwise reverse-mode. But it can be a lot more complicated than that! For more information, see the reference documentation in choosing sensitivity algorithms