Adjoint Sensitivity Analysis of Continuous Functionals

The automatic differentiation tutorial demonstrated how to use AD packages like ForwardDiff.jl and Zygote.jl to compute derivatives of differential equation solutions with respect to initial conditions and parameters. The subsequent direct sensitivity analysis tutorial showed how to directly use the SciMLSensitivity.jl internals to define and solve the augmented differential equation systems which are used in the automatic differentiation process.

While these internal functions give more flexibility, the previous demonstration focused on a case which was possible via automatic differentiation: discrete cost functionals. What is meant by discrete cost functionals is differentiation of a cost which uses a finite number of time points. In the automatic differentiation case, these finite time points are the points returned by solve, i.e. those chosen by the saveat option in the solve call. In the direct adjoint sensitivity tooling, these were the time points chosen by the ts vector.

However, there is an expanded set of cost functionals supported by SciMLSensitivity.jl, continuous cost functionals, which are not possible through automatic differentiation interfaces. In an abstract sense, a continuous cost functional is a total cost $G$ defined as the integral of the instantanious cost $g$ at all time points. In other words, the total cost is defined as:

\[G(u,p)=G(u(\cdot,p))=\int_{t_{0}}^{T}g(u(t,p),p)dt\]

Notice that this cost function cannot accurately be computed using only estimates of u at discrete time points. The purpose of this tutorial is to demonstrate how such cost functionals can be easily evaluated using the direct sensitivity analysis interfaces.

Example: Continuous Functionals with Forward Sensitivity Analysis via Interpolation

Evaluating continuous cost functionals with forward sensitivity analysis is rather straightforward since one can simply use the fact that the solution from ODEForwardSensitivityProblem is continuous when dense=true. For example,

using OrdinaryDiffEq, DiffEqSensitivity

function f(du,u,p,t)
  du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
  du[2] = dy = -p[3]*u[2] + u[1]*u[2]
end

p = [1.5,1.0,3.0]
prob = ODEForwardSensitivityProblem(f,[1.0;1.0],(0.0,10.0),p)
sol = solve(prob,DP8())
retcode: Success
Interpolation: specialized 7th order interpolation
t: 29-element Vector{Float64}:
  0.0
  0.0008156803234081559
  0.005709762263857092
  0.0350742539065507
  0.21126120376271237
  0.7310736576107115
  1.540222712617339
  1.8813609521809873
  2.152579320783659
  2.4063310936458313
  ⋮
  7.06335508324238
  7.725935931520461
  8.248979464799838
  8.558003561400229
  8.826370809858448
  9.171011735077728
  9.493946491887929
  9.834929243565446
 10.0
u: 29-element Vector{Vector{Float64}}:
 [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
 [1.000408588538797, 0.9983701355642965, 0.000816013510644824, 3.3221540105064986e-7, -0.0008153483114736793, -3.320348233838803e-7, 3.3244140163739386e-7, -0.0008143507848015043]
 [1.0028915156848317, 0.9886535575433295, 0.005726241237749595, 1.6146661657253714e-5, -0.00569368529833123, -1.608538141187717e-5, 1.622392391840196e-5, -0.005644946211107334]
 [1.0189121274128703, 0.9325571368196612, 0.035730565614057186, 0.0005806610224735019, -0.03450963303532002, -0.0005673264615114388, 0.0005982168608307346, -0.03270217921302392]
 [1.1547444753248808, 0.6650474487943201, 0.2425299727395315, 0.016200587979236063, -0.19849894573997479, -0.014142604007354115, 0.019606496433572696, -0.13955624899729313]
 [1.9959026411822738, 0.30725094819300397, 1.3911332658550344, 0.12370946057241991, -0.7599570019860618, -0.07996649485520481, 0.22938885567774317, -0.20775306589835107]
 [5.301945679460528, 0.42750381763015505, 6.007173244666822, 1.3784417780305767, -2.2838704251834256, -0.6234622488931644, 1.6987114297214112, -0.37086469132837896]
 [6.87027218261142, 1.271239694305228, 2.4824531036401054, 6.439149047129174, -1.3710515178380684, -2.7849756817446436, 3.27150137715362, -0.4673136955867301]
 [5.679447262940526, 3.3399696666276184, -12.31026712685238, 12.730861084274661, 2.9056310552521434, -6.735837810357075, 2.6735636536872054, 0.8629095700513469]
 [2.883829002749019, 4.5740386280108, -11.501927808807801, 1.5722522609380736, 3.149561013903327, -5.1291264008317885, 0.2620893294221336, 1.5914283310423607]
 ⋮
 [1.4188559781481476, 0.45747115725464815, 4.034573276929158, -1.6292338883363282, -0.023285325083547498, -0.22674038895044946, 0.9917653080111668, -0.6383372364225355]
 [3.1052863771479244, 0.25723901241904107, 12.030706368117613, 0.365728839531434, -0.35946707519700205, -0.15629621407194033, 2.9860723747763775, -0.21350725830194323]
 [5.721950238349628, 0.5147791777176387, 19.32271372300001, 5.150408334846674, -0.8979267459804972, -0.4767888291176948, 5.486826693938301, 0.45258313578635195]
 [6.904564478528068, 1.495971008987011, 0.8111489963245155, 21.257021239671847, -0.8855299499694995, -1.8375562968810901, 3.431317525524077, 3.2107866575616226]
 [5.2514090115128305, 3.677081620267781, -40.30711363477381, 31.383649530201538, 0.415505434372468, -4.82219205478915, -4.808987858125876, 6.2826997734128]
 [1.9684994758096335, 4.224446725401004, -19.720107818957793, -13.79059338129197, 0.6970930457396636, -4.432963103736415, -3.4673015480941123, -1.7679943278194123]
 [1.0719720094857825, 2.526693931636829, -4.294296853936208, -16.729797938117194, 0.340702293847172, -2.2485501722107686, -0.8127688764749763, -3.4279769771844464]
 [0.9574127406054814, 1.26808182822097, 0.6626267557386809, -9.021543589083157, 0.21249817674662475, -1.0143393705751722, 0.214003733351684, -2.2552164416099822]
 [1.0265055472929496, 0.9095251254980055, 2.1626974581623566, -6.256489916604859, 0.18838932135666003, -0.6976152811504686, 0.5638188540587054, -1.7090441865195158]

gives a continuous solution sol(t) with the derivative at each time point. This can then be used to define a continuous cost function via Integrals.jl, though the derivative would need to be defined by hand using the extra sensitivity terms.

Example: Continuous Adjoints on an Energy Functional

Continuous adjoints on a continuous functional are more automatic than forward mode. In this case we'd like to calculate the adjoint sensitivity of the scalar energy functional:

\[G(u,p)=\int_{0}^{T}\frac{\sum_{i=1}^{n}u_{i}^{2}(t)}{2}dt\]

which is:

g(u,p,t) = (sum(u).^2) ./ 2
g (generic function with 1 method)

Notice that the gradient of this function with respect to the state u is:

function dg(out,u,p,t)
  out[1]= u[1] + u[2]
  out[2]= u[1] + u[2]
end
dg (generic function with 1 method)

To get the adjoint sensitivities, we call:

prob = ODEProblem(f,[1.0;1.0],(0.0,10.0),p)
sol = solve(prob,DP8())
res = adjoint_sensitivities(sol,Vern9(),g,nothing,dg,abstol=1e-8,reltol=1e-8)
([-57.43046098128135, -14.286736992838398], [21.070498390881525 -101.36640650577681 63.16461880750126])

Notice that we can check this against autodifferentiation and numerical differentiation as follows:

using QuadGK, ForwardDiff, Calculus
function G(p)
  tmp_prob = remake(prob,p=p)
  sol = solve(tmp_prob,Vern9(),abstol=1e-14,reltol=1e-14)
  res,err = quadgk((t)-> (sum(sol(t)).^2)./2,0.0,10.0,atol=1e-14,rtol=1e-10)
  res
end
res2 = ForwardDiff.gradient(G,[1.5,1.0,3.0])
res3 = Calculus.gradient(G,[1.5,1.0,3.0])
3-element Vector{Float64}:
   21.0514703701335
 -101.40824363829655
   63.1928818469471