ContinuousNormalizingFlows.jl

deps version pkgeval Stable Dev Build Status Coverage Coverage Code Style: Blue ColPrac: Contributor's Guide on Collaborative Practices for Community Packages

Implementations of Infinitesimal Continuous Normalizing Flows Algorithms in Julia

Citing

See CITATION.bib for the relevant reference(s).

Usage

To add this package, we can do it by

using Pkg
Pkg.add("ContinuousNormalizingFlows")

To use this package, here is an example:

using ContinuousNormalizingFlows
using Distributions, Lux
# using Flux
# using ForwardDiff, ADTypes
# using CUDA, ComputationalResources

# Parameters
nvars = 1
n = 1024

# Data
data_dist = Beta{Float32}(2.0f0, 4.0f0)
r = rand(data_dist, nvars, n)
r = convert.(Float32, r)

# Model
nn = Lux.Chain(Lux.Dense(nvars => 4 * nvars, tanh), Lux.Dense(4 * nvars => nvars, tanh)) # use Lux
# nn = Flux.Chain(Flux.Dense(nvars => 4 * nvars, tanh), Flux.Dense(4 * nvars => nvars, tanh)) |> FluxCompatLayer # use Flux

icnf = construct(RNODE, nn, nvars; tspan = (0.0f0, 32.0f0)) # process data one by one
# icnf = construct(RNODE, nn, nvars; compute_mode = ZygoteMatrixMode) # process data in batches
# icnf = construct(RNODE, nn, nvars; array_type = CuArray) # process data by GPU

# Train It
using DataFrames, MLJBase
df = DataFrame(transpose(r), :auto)
model = ICNFModel(icnf; n_epochs = 300, batch_size = 32) # use Zygote
# model = ICNFModel(icnf; adtype = AutoForwardDiff()) # use ForwardDiff
# model = ICNFModel(icnf; resource = CUDALibs()) # use GPU
mach = machine(model, df)
fit!(mach)
ps, st = fitted_params(mach)

# Use It
d = ICNFDist(icnf, TestMode(), ps, st) # direct way
# d = ICNFDist(icnf, mach, TestMode()) # alternative way
actual_pdf = pdf.(data_dist, vec(r))
estimated_pdf = pdf(d, r)
new_data = rand(d, n)

# Evaluate It
using Distances
mad_ = meanad(estimated_pdf, actual_pdf)
msd_ = msd(estimated_pdf, actual_pdf)
tv_dis = totalvariation(estimated_pdf, actual_pdf) / n

# Plot It
using Plots
p = plot(x -> pdf(data_dist, x), 0, 1; label = "actual")
p = plot!(p, x -> pdf(d, convert.(Float32, vcat(x))), 0, 1; label = "estimated")
savefig(p, "plot.png")