Benchmarks
Vs Torchdiffeq 1 million and less ODEs
A raw ODE solver benchmark showcases >30x performance advantage for DifferentialEquations.jl for ODEs ranging in size from 3 to nearly 1 million.
Vs Torchdiffeq on neural ODE training
A training benchmark using the spiral ODE from the original neural ODE paper demonstrates a 100x performance advantage for DiffEqFlux in training neural ODEs.
Vs torchsde on small SDEs
Using the code from torchsde's README we demonstrated a >70,000x performance advantage over torchsde. Further benchmarking is planned but was found to be computationally infeasible for the time being.
A bunch of adjoint choices on neural ODEs
Quick summary:
BacksolveAdjoint
can be the fastest (but use with caution!); about 25% faster- Using
ZygoteVJP
is faster than other vjp choices with FastDense due to the overloads
using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, Plots, DiffEqSensitivity,
Zygote, BenchmarkTools, Random
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2], length = datasize)
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
dudt2 = FastChain((x, p) -> x.^3,
FastDense(2, 50, tanh),
FastDense(50, 2))
Random.seed!(100)
p = initial_params(dudt2)
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
function loss_neuralode(p)
pred = Array(prob_neuralode(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end
@btime Zygote.gradient(loss_neuralode,p)
# 2.709 ms (56506 allocations: 6.62 MiB)
prob_neuralode_interpolating = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps, sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)))
function loss_neuralode_interpolating(p)
pred = Array(prob_neuralode_interpolating(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end
@btime Zygote.gradient(loss_neuralode_interpolating,p)
# 5.501 ms (103835 allocations: 2.57 MiB)
prob_neuralode_interpolating_zygote = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps, sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP()))
function loss_neuralode_interpolating_zygote(p)
pred = Array(prob_neuralode_interpolating_zygote(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end
@btime Zygote.gradient(loss_neuralode_interpolating_zygote,p)
# 2.899 ms (56150 allocations: 6.61 MiB)
prob_neuralode_backsolve = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps, sensealg=BacksolveAdjoint(autojacvec=ReverseDiffVJP(true)))
function loss_neuralode_backsolve(p)
pred = Array(prob_neuralode_backsolve(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end
@btime Zygote.gradient(loss_neuralode_backsolve,p)
# 4.871 ms (85855 allocations: 2.20 MiB)
prob_neuralode_quad = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps, sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true)))
function loss_neuralode_quad(p)
pred = Array(prob_neuralode_quad(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end
@btime Zygote.gradient(loss_neuralode_quad,p)
# 11.748 ms (79549 allocations: 3.87 MiB)
prob_neuralode_backsolve_tracker = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps, sensealg=BacksolveAdjoint(autojacvec=TrackerVJP()))
function loss_neuralode_backsolve_tracker(p)
pred = Array(prob_neuralode_backsolve_tracker(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end
@btime Zygote.gradient(loss_neuralode_backsolve_tracker,p)
# 27.604 ms (186143 allocations: 12.22 MiB)
prob_neuralode_backsolve_zygote = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps, sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP()))
function loss_neuralode_backsolve_zygote(p)
pred = Array(prob_neuralode_backsolve_zygote(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end
@btime Zygote.gradient(loss_neuralode_backsolve_zygote,p)
# 2.091 ms (49883 allocations: 6.28 MiB)
prob_neuralode_backsolve_false = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps, sensealg=BacksolveAdjoint(autojacvec=ReverseDiffVJP(false)))
function loss_neuralode_backsolve_false(p)
pred = Array(prob_neuralode_backsolve_false(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end
@btime Zygote.gradient(loss_neuralode_backsolve_false,p)
# 4.822 ms (9956 allocations: 1.03 MiB)
prob_neuralode_tracker = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps, sensealg=TrackerAdjoint())
function loss_neuralode_tracker(p)
pred = Array(prob_neuralode_tracker(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end
@btime Zygote.gradient(loss_neuralode_tracker,p)
# 12.614 ms (76346 allocations: 3.12 MiB)