struct ADVI{AD} <: VariationalInference{AD}

Automatic Differentiation Variational Inference (ADVI) with automatic differentiation backend AD.


  • samples_per_step::Int64: Number of samples used to estimate the ELBO in each optimization step.

  • max_iters::Int64: Maximum number of gradient steps.

DecayedADAGrad(η=0.1, pre=1.0, post=0.9)

Implements a decayed version of AdaGrad. It has parameter specific learning rates based on how frequently it is updated.


  • η: learning rate
  • pre: weight of new gradient norm
  • post: weight of histroy of gradient norms



ADAGrad optimiser. Parameters don't need tuning.

TruncatedADAGrad(η=0.1, τ=1.0, n=100)

Implements a truncated version of AdaGrad in the sense that only the n previous gradient norms are used to compute the scaling rather than all previous. It has parameter specific learning rates based on how frequently it is updated.


  • η: learning rate
  • τ: constant scale factor
  • n: number of previous gradient norms to use in the scaling.



ADAGrad optimiser. Parameters don't need tuning.

TruncatedADAGrad (Appendix E).

grad!(vo, alg::VariationalInference, q, model::Model, θ, out, args...)

Computes the gradients used in optimize!. Default implementation is provided for VariationalInference{AD} where AD is either ForwardDiffAD or TrackerAD. This implicitly also gives a default implementation of optimize!.

Variance reduction techniques, e.g. control variates, should be implemented in this function.

optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad())

Iteratively updates parameters by calling grad! and using the given optimizer to compute the steps.

vi(model, alg::VariationalInference)
vi(model, alg::VariationalInference, q::VariationalPosterior)
vi(model, alg::VariationalInference, getq::Function, θ::AbstractArray)

Constructs the variational posterior from the model and performs the optimization following the configuration of the given VariationalInference instance.


  • model: Turing.Model or Function z ↦ log p(x, z) where x denotes the observations
  • alg: the VI algorithm used
  • q: a VariationalPosterior for which it is assumed a specialized implementation of the variational objective used exists.
  • getq: function taking parameters θ as input and returns a VariationalPosterior
  • θ: only required if getq is used, in which case it is the initial parameters for the variational posterior