# Riemannian Hamiltonian Variational Autoencoder

The Riemannian Hamiltonian Variational Autoencoder (RHVAE) is a variant of the Hamiltonian Variational Autoencoder (HVAE) that uses concepts from Riemannian geometry to improve the sampling of the latent space representation. As the HVAE, the RHVAE uses Hamiltonian dynamics to improve the sampling of the latent. However, the RHVAE accounts for the geometry of the latent space by learning a Riemannian metric tensor that is used to compute the kinetic energy of the dynamical system. This allows the RHVAE to sample the latent space more evenly while learning the curvature of the latent space.

For the implementation of the RHVAE in `AutoEncoderToolkit.jl`

, the `RHVAE`

requires two arguments to construct: the original `VAE`

as well as a separate neural network used to compute the metric tensor. To facilitate the dispatch of the necessary functions associated with this second network, we also provide a `MetricChain`

struct.

RHVAEs require the computation of nested gradients. This means that the AutoDiff framework must differentiate a function of an already AutoDiff differentiated function. This is known to be problematic for `Julia`

's AutoDiff backends. See details below to understand how to we circumvent this problem.

## Reference

Chadebec, C., Mantoux, C. & Allassonnière, S. Geometry-Aware Hamiltonian Variational Auto-Encoder. Preprint at http://arxiv.org/abs/2010.11518 (2020).

`MetricChain`

struct

`AutoEncoderToolkit.RHVAEs.MetricChain`

— Type`MetricChain <: AbstractMetricChain`

A `MetricChain`

is used to compute the Riemannian metric tensor in the latent space of a Riemannian Hamiltonian Variational AutoEncoder (RHVAE).

**Fields**

`mlp::Flux.Chain`

: A multi-layer perceptron (MLP) consisting of the hidden layers. The inputs are first run through this MLP.`diag::Flux.Dense`

: A dense layer that computes the diagonal elements of a lower-triangular matrix. The output of the`mlp`

is fed into this layer.`lower::Flux.Dense`

: A dense layer that computes the off-diagonal elements of the lower-triangular matrix. The output of the`mlp`

is also fed into this layer.

The outputs of `diag`

and `lower`

are used to construct a lower-triangular matrix used to compute the Riemannian metric tensor in latent space.

**Note**

If the dimension of the latent space is `n`

, the number of neurons in the output layer of `diag`

must be `n`

, and the number of neurons in the output layer of `lower`

must be `n * (n - 1) ÷ 2`

.

**Example**

```
mlp = Flux.Chain(Dense(10, 10, relu), Dense(10, 10, relu))
diag = Flux.Dense(10, 5)
lower = Flux.Dense(10, 15)
metric_chain = MetricChain(mlp, diag, lower)
```

`RHVAE`

struct

`AutoEncoderToolkit.RHVAEs.RHVAE`

— Type```
RHVAE{
V<:VAE{<:AbstractVariationalEncoder,<:AbstractVariationalDecoder}
} <: AbstractVariationalAutoEncoder
```

A Riemannian Hamiltonian Variational AutoEncoder (RHVAE) as described in Chadebec, C., Mantoux, C. & Allassonnière, S. Geometry-Aware Hamiltonian Variational Auto-Encoder. Preprint at http://arxiv.org/abs/2010.11518 (2020).

The RHVAE is a type of Variational AutoEncoder (VAE) that incorporates a Riemannian metric in the latent space. This metric is computed by a `MetricChain`

, which is a struct that contains a multi-layer perceptron (MLP) and two dense layers for computing the elements of a lower-triangular matrix.

The inverse metric is computed as follows:

G⁻¹(z) = ∑ᵢ₌₁ⁿ L*ψᵢ L*ψᵢᵀ exp(-‖z - cᵢ‖₂² / T²) + λIₗ

where L_ψᵢ is computed by the `MetricChain`

, T is the temperature, λ is a regularization factor, and each column of `centroids`

are the cᵢ.

**Fields**

`vae::V`

: The underlying VAE, where`V`

is a subtype of`VAE`

with an`AbstractVariationalEncoder`

and an`AbstractVariationalDecoder`

.`metric_chain::MetricChain`

: The`MetricChain`

that computes the Riemannian metric in the latent space.`centroids_data::AbstractArray`

: An array where the last dimension represents a data point xᵢ from which the centroids cᵢ are computed by passing them through the encoder.`centroids_latent::AbstractMatrix`

: A matrix where each column represents a centroid cᵢ in the inverse metric computation.`L::AbstractArray{<:Number, 3}`

: A 3D array where each slice represents a L*ψᵢ matrix. L*ψᵢ can intuitively be seen as the triangular matrix in the Cholesky decomposition of G⁻¹(centroids_latentᵢ) up to a regularization factor.`M::AbstractArray{<:Number, 3}`

: A 3D array where each slice represents a L*ψᵢ L*ψᵢᵀ.`T::Number`

: The temperature parameter in the inverse metric computation.`λ::Number`

: The regularization factor in the inverse metric computation.

## Forward pass

### Metric Network

`AutoEncoderToolkit.RHVAEs.MetricChain`

— Method`(m::MetricChain)(x::AbstractArray; matrix::Bool=false)`

Perform a forward pass through the MetricChain.

**Arguments**

`x::AbstractArray`

: The input data to be processed.`matrix::Bool=false`

: A boolean flag indicating whether to return the result as a lower triangular matrix (if`true`

) or as a tuple of diagonal and lower off-diagonal elements (if`false`

). Defaults to`false`

.

**Returns**

- If
`matrix`

is`true`

, returns a lower triangular matrix constructed from the outputs of the`diag`

and`lower`

components of the MetricChain. - If
`matrix`

is`false`

, returns a`NamedTuple`

with two elements:`diag`

, the output of the`diag`

component of the MetricChain, and`lower`

, the output of the`lower`

component of the MetricChain.

**Example**

```
m = MetricChain(...)
x = rand(Float32, 100, 10)
m(x, matrix=true) # Returns a lower triangular matrix
```

### RHVAE

`AutoEncoderToolkit.RHVAEs.RHVAE`

— Method```
(rhvae::RHVAE{VAE{E,D}})(
x::AbstractArray;
ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
K::Int=3,
βₒ::Number=0.3f0,
∇H::Function=∇hamiltonian_TaylorDiff,
∇H_kwargs::Union{NamedTuple,Dict}=(
reconstruction_loglikelihood=decoder_loglikelihood,
position_logprior=spherical_logprior,
momentum_logprior=riemannian_logprior,
G_inv=G_inv,
),
tempering_schedule::Function=quadratic_tempering,
latent::Bool=false,
) where where {E<:AbstractGaussianLogEncoder,D<:AbstractVariationalDecoder}
```

Run the Riemannian Hamiltonian Variational Autoencoder (RHVAE) on the given input.

**Arguments**

`x::AbstractArray`

: The input to the RHVAE. If it is a vector, it represents a single data point. If`Array,`

the last dimension must contain each of the data points.

**Optional Keyword Arguments**

`K::Int=3`

: The number of leapfrog steps to perform in the Hamiltonian Monte Carlo (HMC) part of the RHVAE.`ϵ::Union{<:Number,<:AbstractVector}=0.01f0`

: The step size for the leapfrog steps in the HMC part of the RHVAE. If it is a scalar, the same step size is used for all dimensions. If it is an array, each element corresponds to the step size for a specific dimension.`βₒ::Number=0.3f0`

: The initial inverse temperature for the tempering schedule.`steps::Int`

: The number of fixed-point iterations to perform. Default is 3.`∇H::Function=∇hamiltonian_finite`

: The function to compute the gradient of the Hamiltonian in the HMC part of the RHVAE.`∇H_kwargs::Union{NamedTuple,Dict}`

: Additional keyword arguments to be passed to the`∇hamiltonian`

function. Default is a NamedTuple with`reconstruction_loglikelihood`

,`position_logprior`

, and`momentum_logprior`

.`G_inv::Function=G_inv`

: The function to compute the inverse of the Riemannian metric tensor.`tempering_schedule::Function=quadratic_tempering`

: The function to compute the tempering schedule in the RHVAE.`latent::Bool=false`

: If`true`

, the function returns a NamedTuple containing the outputs of the encoder and decoder, and the final state of the phase space after the leapfrog and tempering steps. If`false`

, the function only returns the output of the decoder.

**Returns**

If `latent=true`

, the function returns a NamedTuple with the following fields:

`encoder`

: The outputs of the encoder.`decoder`

: The output of the decoder.`phase_space`

: The final state of the phase space after the leapfrog and tempering steps.

If `latent=false`

, the function only returns the output of the decoder.

**Description**

This function runs the RHVAE on the given input. It first passes the input through the encoder to obtain the mean and log standard deviation of the latent space. It then uses the reparameterization trick to sample from the latent space. After that, it performs the leapfrog and tempering steps to refine the sample from the latent space. Finally, it passes the refined sample through the decoder to obtain the output.

**Notes**

Ensure that the dimensions of `x`

match the input dimensions of the RHVAE, and that the dimensions of `ϵ`

match the dimensions of the latent space.

## Loss function

`AutoEncoderToolkit.RHVAEs.loss`

— Function```
loss(
rhvae::RHVAE,
x::AbstractArray;
K::Int=3,
ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
βₒ::Number=0.3f0,
steps::Int=3,
∇H_kwargs::Union{NamedTuple,Dict}=(
reconstruction_loglikelihood=decoder_loglikelihood,
position_logprior=spherical_logprior,
momentum_logprior=riemannian_logprior,
),
G_inv::Function=G_inv,
tempering_schedule::Function=quadratic_tempering,
reg_function::Union{Function,Nothing}=nothing,
reg_kwargs::Union{NamedTuple,Dict}=Dict(),
reg_strength::Number=1.0f0,
logp_prefactor::AbstractArray=ones(Float32, 3),
logq_prefactor::AbstractArray=ones(Float32, 3),
)
```

Compute the loss for a Riemannian Hamiltonian Variational Autoencoder (RHVAE).

**Arguments**

`rhvae::RHVAE`

: The RHVAE used to encode the input data and decode the latent space.`x::AbstractArray`

: Input data to the RHVAE encoder. The last dimension is taken as having each of the samples in a batch.

**Optional Keyword Arguments**

`K::Int`

: The number of HMC steps (default is 3).`ϵ::Union{<:Number,<:AbstractVector}`

: The step size for the leapfrog integrator (default is 0.001).`βₒ::Number`

: The initial inverse temperature (default is 0.3).`steps::Int`

: The number of steps in the leapfrog integrator (default is 3).`∇H_kwargs::Union{NamedTuple,Dict}`

: Additional keyword arguments to be passed to the`∇hamiltonian`

function.`G_inv::Function`

: The function to compute the inverse of the Riemannian metric tensor (default is`G_inv`

).`tempering_schedule::Function`

: The tempering schedule function used in the HMC (default is`quadratic_tempering`

).`reg_function::Union{Function, Nothing}=nothing`

: A function that computes the regularization term based on the VAE outputs. This function must take as input the VAE outputs and the keyword arguments provided in`reg_kwargs`

.`reg_kwargs::Union{NamedTuple,Dict}=Dict()`

: Keyword arguments to pass to the regularization function.`reg_strength::Number=1.0f0`

: The strength of the regularization term.`logp_prefactor::AbstractArray`

: A 3-element array to scale the log likelihood, log prior of the latent variables, and log prior of the momentum variables. Default is an array of ones.`logq_prefactor::AbstractArray`

: A 3-element array to scale the log posterior of the initial latent variables, log prior of the initial momentum variables, and the tempering Jacobian term. Default is an array of ones.

**Returns**

- The computed loss.

```
loss(
rhvae::RHVAE,
x_in::AbstractArray,
x_out::AbstractArray;
K::Int=3,
ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
βₒ::Number=0.3f0,
steps::Int=3,
∇H_kwargs::Union{NamedTuple,Dict}=(
reconstruction_loglikelihood=decoder_loglikelihood,
position_logprior=spherical_logprior,
momentum_logprior=riemannian_logprior,
),
G_inv::Function=G_inv,
tempering_schedule::Function=quadratic_tempering,
reg_function::Union{Function,Nothing}=nothing,
reg_kwargs::Union{NamedTuple,Dict}=Dict(),
reg_strength::Number=1.0f0,
logp_prefactor::AbstractArray=ones(Float32, 3),
logq_prefactor::AbstractArray=ones(Float32, 3),
)
```

Compute the loss for a Riemannian Hamiltonian Variational Autoencoder (RHVAE).

**Arguments**

`rhvae::RHVAE`

: The RHVAE used to encode the input data and decode the latent space.`x_in::AbstractArray`

: Input data to the RHVAE encoder. The last dimension is taken as having each of the samples in a batch.`x_out::AbstractArray`

: Target data to compute the reconstruction error. The last dimension is taken as having each of the samples in a batch.

**Optional Keyword Arguments**

`K::Int`

: The number of HMC steps (default is 3).`ϵ::Union{<:Number,<:AbstractVector}`

: The step size for the leapfrog integrator (default is 0.001).`βₒ::Number`

: The initial inverse temperature (default is 0.3).`steps::Int`

: The number of steps in the leapfrog integrator (default is 3).`∇H_kwargs::Union{NamedTuple,Dict}`

: Additional keyword arguments to be passed to the`∇hamiltonian`

function.`G_inv::Function`

: The function to compute the inverse of the Riemannian metric tensor (default is`G_inv`

).`tempering_schedule::Function`

: The tempering schedule function used in the HMC (default is`quadratic_tempering`

).`reg_function::Union{Function, Nothing}=nothing`

: A function that computes the regularization term based on the VAE outputs. This function must take as input the VAE outputs and the keyword arguments provided in`reg_kwargs`

.`reg_kwargs::Union{NamedTuple,Dict}=Dict()`

: Keyword arguments to pass to the regularization function.`reg_strength::Number=1.0f0`

: The strength of the regularization term.`logp_prefactor::AbstractArray`

: A 3-element array to scale the log likelihood, log prior of the latent variables, and log prior of the momentum variables. Default is an array of ones.`logq_prefactor::AbstractArray`

: A 3-element array to scale the log posterior of the initial latent variables, log prior of the initial momentum variables, and the tempering Jacobian term. Default is an array of ones.

**Returns**

- The computed loss.

## Training

`AutoEncoderToolkit.RHVAEs.train!`

— Function```
train!(
rhvae::RHVAE,
x::AbstractArray,
opt::NamedTuple;
loss_function::Function=loss,
loss_kwargs::Union{NamedTuple,Dict}=Dict(),
verbose::Bool=false,
loss_return::Bool=false,
)
```

Customized training function to update parameters of a Riemannian Hamiltonian Variational Autoencoder given a specified loss function.

**Arguments**

`rhvae::RHVAE`

: A struct containing the elements of a Riemannian Hamiltonian Variational Autoencoder.`x::AbstractArray`

: Input data to the RHVAE encoder. The last dimension is taken as having each of the samples in a batch.`opt::NamedTuple`

: State of the optimizer for updating parameters. Typically initialized using`Flux.Optimisers.update!`

.

**Optional Keyword Arguments**

`loss_function::Function=loss`

: The loss function used for training. It should accept the RHVAE model, data`x`

, and keyword arguments in that order.`loss_kwargs::Dict=Dict()`

: Arguments for the loss function. These might include parameters like`K`

,`ϵ`

,`βₒ`

,`steps`

,`∇H`

,`∇H_kwargs`

,`tempering_schedule`

,`reg_function`

,`reg_kwargs`

,`reg_strength`

, depending on the specific loss function in use.`verbose::Bool=false`

: Whether to print the loss at each iteration.`loss_return::Bool=false`

: Whether to return the loss at each iteration.

**Description**

Trains the RHVAE by:

- Computing the gradient of the loss w.r.t the RHVAE parameters.
- Updating the RHVAE parameters using the optimizer.
- Updating the metric parameters.

```
train!(
rhvae::RHVAE,
x_in::AbstractArray,
x_out::AbstractArray,
opt::NamedTuple;
loss_function::Function=loss,
loss_kwargs::Union{NamedTuple,Dict}=Dict(),
verbose::Bool=false,
loss_return::Bool=false,
)
```

Customized training function to update parameters of a Riemannian Hamiltonian Variational Autoencoder given a specified loss function.

**Arguments**

`rhvae::RHVAE`

: A struct containing the elements of a Riemannian Hamiltonian Variational Autoencoder.`x_in::AbstractArray`

: Input data to the RHVAE encoder. The last dimension is taken as having each of the samples in a batch.`x_out::AbstractArray`

: Target data to compute the reconstruction error. The last dimension is taken as having each of the samples in a batch.`opt::NamedTuple`

: State of the optimizer for updating parameters. Typically initialized using`Flux.Optimisers.update!`

.

**Optional Keyword Arguments**

`loss_function::Function=loss`

: The loss function used for training. It should accept the RHVAE model, data`x`

, and keyword arguments in that order.`loss_kwargs::Dict=Dict()`

: Arguments for the loss function. These might include parameters like`K`

,`ϵ`

,`βₒ`

,`steps`

,`∇H`

,`∇H_kwargs`

,`tempering_schedule`

,`reg_function`

,`reg_kwargs`

,`reg_strength`

, depending on the specific loss function in use.`verbose::Bool=false`

: Whether to print the loss at each iteration.`loss_return::Bool=false`

: Whether to return the loss at each iteration.

**Description**

Trains the RHVAE by:

- Computing the gradient of the loss w.r.t the RHVAE parameters.
- Updating the RHVAE parameters using the optimizer.
- Updating the metric parameters.

## Computing the gradient of the potential energy

One of the crucial components in the training of the RHVAE is the computation of the gradient of the Hamiltonian $\nabla H$ with respect to the latent space representation. This gradient is used in the leapfrog steps of the generalized Hamiltonian dynamics. When training the RHVAE, we need to backpropagate through the leapfrog steps to update the parameters of the neural network. This requires computing a gradient of a function of the gradient of the Hamiltonian, i.e., nested gradients. `Zygote.jl`

the main AutoDiff backend in `Flux.jl`

famously struggle with these types of computations. Specifically, `Zygote.jl`

does not support `Zygote`

over `Zygote`

differentiation (meaning differentiating a function of something previously differentiated with `Zygote`

using `Zygote`

), or `Zygote`

over `ForwardDiff`

(meaning differentiating a function of something differentiated with `ForwardDiff`

using `Zygote`

).

With this, we are left with a couple of options to compute the gradient of the potential energy:

- Use finite differences to approximate the gradient of the potential energy.
- Use the relatively new
`TaylorDiff.jl`

AutoDiff backend to compute the gradient of the potential energy. This backend is composable with`Zygote.jl`

, so we can, in principle, do`Zygote`

over`TaylorDiff`

differentiation.

The second option would be preferred, as the gradients computed with `TaylorDiff`

are much more accurate than the ones computed with finite differences. However, there are two problems with this approach:

- The
`TaylorDiff`

nested gradient capability stopped working with`Julia ≥ 1.10`

, as discussed in #70. - Even for
`Julia < 1.10`

, we could not get`TaylorDiff`

to work on`CUDA`

devices. (PRs are welcome!)

With these limitations in mind, we have implemented the gradient of the potential using both finite differences and `TaylorDiff`

. The user can choose which method to use by setting the `adtype`

keyword argument in the `∇H_kwargs`

in the `loss`

function to either `:finite`

or `:TaylorDiff`

. This means that for the `train!`

function, the user can pass `loss_kwargs`

that looks like this:

```
# Define the autodiff backend to use
loss_kwargs = Dict(
:∇H_kwargs => Dict(
:adtype => :finite
)
)
```

Although verbose, the nested dictionaries help to keep everything organized. (PRs with better design ideas are welcome!)

The default both for `cpu`

and `gpu`

devices is `:finite`

.

`AutoEncoderToolkit.RHVAEs.∇hamiltonian_finite`

— Function```
∇hamiltonian_finite(
x::AbstractArray,
z::AbstractVecOrMat,
ρ::AbstractVecOrMat,
G⁻¹::AbstractArray,
logdetG::Union{<:Number,AbstractVector},
decoder::AbstractVariationalDecoder,
decoder_output::NamedTuple,
var::Symbol;
reconstruction_loglikelihood::Function=decoder_loglikelihood,
position_logprior::Function=spherical_logprior,
momentum_logprior::Function=riemannian_logprior,
fdtype::Symbol=:central,
)
```

Compute the gradient of the Hamiltonian with respect to a given variable using a naive finite difference method.

This function takes a point `x`

in the data space, a point `z`

in the latent space, a momentum `ρ`

, the inverse of the Riemannian metric tensor `G⁻¹`

, a `decoder`

of type `AbstractVariationalDecoder`

, a `decoder_output`

NamedTuple, and a variable `var`

(:z or :ρ), and computes the gradient of the Hamiltonian with respect to `var`

using a simple finite differences method. The computation is based on the log-likelihood of the decoder, the log-prior of the latent space, and `G⁻¹`

.

The Hamiltonian is computed as follows:

Hₓ(z, ρ) = Uₓ(z) + κ(ρ),

where Uₓ(z) is the potential energy, and κ(ρ) is the kinetic energy. The potential energy is defined as follows:

Uₓ(z) = -log p(x|z) - log p(z),

where p(x|z) is the log-likelihood of the decoder and p(z) is the log-prior in latent space. The kinetic energy is defined as follows:

κ(ρ) = 0.5 * log((2π)ᴰ det G(z)) + 0.5 * ρᵀ G⁻¹ ρ

where D is the dimension of the latent space, and G(z) is the metric tensor at the point `z`

.

**Arguments**

`x::AbstractArray`

: The point in the data space. This does not necessarily need to be a vector. Array inputs are supported. The last dimension is assumed to have each of the data points.`z::AbstractVecOrMat`

: The point in the latent space. If matrix, each column represents a point in the latent space.`ρ::AbstractVecOrMat`

: The momentum. If matrux, each column represents a momentum vector.`G⁻¹::AbstractArray`

: The inverse of the Riemannian metric tensor. If 3D array, each slice along the third dimension represents the inverse of the metric tensor at the corresponding column of`z`

.`logdetG::Union{<:Number,AbstractVector}`

: The log determinant of the Riemannian metric tensor. If vector, each element represents the log determinant of the metric tensor at the corresponding column of`z`

.`decoder::AbstractVariationalDecoder`

: The decoder instance.`decoder_output::NamedTuple`

: The output of the decoder.`var::Symbol`

: The variable with respect to which the gradient is computed. Must be :z or :ρ.

**Optional Keyword Arguments**

`reconstruction_loglikelihood::Function`

: The function to compute the log-likelihood of the decoder reconstruction. Default is`decoder_loglikelihood`

. This function must take as input the decoder, the point`x`

in the data space, and the`decoder_output`

.`position_logprior::Function`

: The function to compute the log-prior of the latent space position. Default is`spherical_logprior`

. This function must take as input the point`z`

in the latent space.`momentum_logprior::Function`

: The function to compute the log-prior of the momentum. Default is`riemannian_logprior`

. This function must take as input the momentum`ρ`

and`G⁻¹`

.`fdtype::Symbol=:central`

: The type of finite difference method to use. Must be :central or :forward. Default is :central.

**Returns**

A vector representing the gradient of the Hamiltonian at the point `(z, ρ)`

with respect to variable `var`

.

```
∇hamiltonian_finite(
x::AbstractArray,
z::AbstractVecOrMat,
ρ::AbstractVecOrMat,
rhvae::RHVAE,
var::Symbol;
reconstruction_loglikelihood::Function=decoder_loglikelihood,
position_logprior::Function=spherical_logprior,
momentum_logprior::Function=riemannian_logprior,
G_inv::Function=G_inv,
fdtype::Symbol=:central,
)
```

Compute the gradient of the Hamiltonian with respect to a given variable using a naive finite difference method.

This function takes a point `x`

in the data space, a point `z`

in the latent space, a momentum `ρ`

, an instance of `RHVAE`

, and a variable `var`

(:z or :ρ), and computes the gradient of the Hamiltonian with respect to `var`

using a simple finite differences method. The computation is based on the log-likelihood of the decoder, the log-prior of the latent space, and the inverse of the metric tensor G at the point `z`

.

The Hamiltonian is computed as follows:

Hₓ(z, ρ) = Uₓ(z) + κ(ρ),

where Uₓ(z) is the potential energy, and κ(ρ) is the kinetic energy. The potential energy is defined as follows:

Uₓ(z) = -log p(x|z) - log p(z),

where p(x|z) is the log-likelihood of the decoder and p(z) is the log-prior in latent space. The kinetic energy is defined as follows:

κ(ρ) = 0.5 * log((2π)ᴰ det G(z)) + 0.5 * ρᵀ G⁻¹ ρ

where D is the dimension of the latent space, and G(z) is the metric tensor at the point `z`

.

**Arguments**

`x::AbstractArray`

: The point in the data space. This does not necessarily need to be a vector. Array inputs are supported. The last dimension is assumed to have each of the data points.`z::AbstractVecOrMat`

: The point in the latent space. If matrix, each column represents a point in the latent space.`ρ::AbstractVecOrMat`

: The momentum. If matrux, each column represents a momentum vector.`rhvae::RHVAE`

: An instance of the RHVAE model.`var::Symbol`

: The variable with respect to which the gradient is computed. Must be :z or :ρ.

**Optional Keyword Arguments**

`reconstruction_loglikelihood::Function`

: The function to compute the log-likelihood of the decoder reconstruction. Default is`decoder_loglikelihood`

. This function must take as input the decoder, the point`x`

in the data space, and the`decoder_output`

.`position_logprior::Function`

: The function to compute the log-prior of the latent space position. Default is`spherical_logprior`

. This function must take as input the point`z`

in the latent space.`momentum_logprior::Function`

: The function to compute the log-prior of the momentum. Default is`riemannian_logprior`

. This function must take as input the momentum`ρ`

and the inverse of the Riemannian metric tensor`G⁻¹`

.`G_inv::Function`

: The function to compute the inverse of the Riemannian metric tensor. Default is`G_inv`

. This function must take as input the point`z`

in the latent space and the`rhvae`

instance.`fdtype::Symbol=:central`

: The type of finite difference method to use. Must be :central or :forward. Default is :central.

**Returns**

A vector representing the gradient of the Hamiltonian at the point `(z, ρ)`

with respect to variable `var`

.

**Note**

The inverse of the Riemannian metric tensor `G⁻¹`

, the log determinant of the metric tensor, and the output of the decoder are computed internally in this function. The user does not need to provide these as inputs.

`AutoEncoderToolkit.RHVAEs.∇hamiltonian_TaylorDiff`

— Function```
∇hamiltonian_TaylorDiff(
x::AbstractArray,
z::AbstractVector,
ρ::AbstractVector,
G⁻¹::AbstractMatrix,
logdetG::Union{<:Number,AbstractVector},
decoder::AbstractVariationalDecoder,
decoder_output::NamedTuple,
var::Symbol;
reconstruction_loglikelihood::Function=decoder_loglikelihood,
position_logprior::Function=spherical_logprior,
momentum_logprior::Function=riemannian_logprior,
)
```

Compute the gradient of the Hamiltonian with respect to a given variable using the TaylorDiff.jl automatic differentiation library.

This function takes a point `x`

in the data space, a point `z`

in the latent space, a momentum `ρ`

, an instance of `AbstractVariationalDecoder`

, and a variable `var`

(:z or :ρ), and computes the gradient of the Hamiltonian with respect to `var`

using TaylorDiff.jl.

The Hamiltonian is computed as follows:

Hₓ(z, ρ) = Uₓ(z) + κ(ρ),

where Uₓ(z) is the potential energy, and κ(ρ) is the kinetic energy. The potential energy is defined as follows:

Uₓ(z) = -log p(x|z) - log p(z),

where p(x|z) is the log-likelihood of the decoder and p(z) is the log-prior in latent space. The kinetic energy is defined as follows:

κ(ρ) = 0.5 * log((2π)ᴰ det G(z)) + 0.5 * ρᵀ G⁻¹ ρ

where D is the dimension of the latent space, and G(z) is the metric tensor at the point `z`

.

**Arguments**

`x::AbstractArray`

: The point in the data space. This does not necessarily need to be a vector. Array inputs are supported. The last dimension is assumed to have each of the data points.`z::AbstractVector`

: The point in the latent space.`ρ::AbstractVector`

: The momentum.`G⁻¹::AbstractMatrix`

: The inverse of the Riemannian metric tensor.`logdetG::Number`

: The logarithm of the determinant of the Riemannian metric tensor.`decoder::AbstractVariationalDecoder`

: An instance of the decoder model.`decoder_output::NamedTuple`

: The output of the decoder model.`var::Symbol`

: The variable with respect to which the gradient is computed. Must be :z or :ρ.

**Optional Keyword Arguments**

`reconstruction_loglikelihood::Function`

: The function to compute the log-likelihood of the decoder reconstruction. Default is`decoder_loglikelihood`

. This function must take as input the decoder, the point`x`

in the data space, and the`decoder_output`

.`position_logprior::Function`

: The function to compute the log-prior of the latent space position. Default is`spherical_logprior`

. This function must take as input the point`z`

in the latent space.`momentum_logprior::Function`

: The function to compute the log-prior of the momentum. Default is`riemannian_logprior`

. This function must take as input the momentum`ρ`

and the inverse of the Riemannian metric tensor`G⁻¹`

.

**Returns**

A vector representing the gradient of the Hamiltonian at the point `(z, ρ)`

with respect to variable `var`

.

**Note**

`TaylorDiff.jl`

is composable with `Zygote.jl.`

Thus, for backpropagation using this function one should use `Zygote.jl.`

```
∇hamiltonian_TaylorDiff(
x::AbstractArray,
z::AbstractVecOrMat,
ρ::AbstractVecOrMat,
rhvae::RHVAE,
var::Symbol;
reconstruction_loglikelihood::Function=decoder_loglikelihood,
position_logprior::Function=spherical_logprior,
momentum_logprior::Function=riemannian_logprior,
G_inv::Function=G_inv,
)
```

Compute the gradient of the Hamiltonian with respect to a given variable using the TaylorDiff.jl automatic differentiation library.

This function takes a point `x`

in the data space, a point `z`

in the latent space, a momentum `ρ`

, an instance of `RHVAE`

, and a variable `var`

(:z or :ρ), and computes the gradient of the Hamiltonian with respect to `var`

using TaylorDiff.jl.

The Hamiltonian is computed as follows:

Hₓ(z, ρ) = Uₓ(z) + κ(ρ),

Uₓ(z) = -log p(x|z) - log p(z),

κ(ρ) = 0.5 * log((2π)ᴰ det G(z)) + 0.5 * ρᵀ G⁻¹ ρ

where D is the dimension of the latent space, and G(z) is the metric tensor at the point `z`

.

**Arguments**

`x::AbstractArray`

: The point in the data space. This does not necessarily need to be a vector. Array inputs are supported. The last dimension is assumed to have each of the data points.`z::AbstractVecOrMat`

: The point in the latent space. If matrix, each column represents a point in the latent space.`ρ::AbstractVecOrMat`

: The momentum. If matrix, each column represents a momentum vector.`rhvae::RHVAE`

: An instance of the RHVAE model.`var::Symbol`

: The variable with respect to which the gradient is computed. Must be :z or :ρ.

**Optional Keyword Arguments**

`reconstruction_loglikelihood::Function`

: The function to compute the log-likelihood of the decoder reconstruction. Default is`decoder_loglikelihood`

. This function must take as input the decoder, the point`x`

in the data space, and the`decoder_output`

.`position_logprior::Function`

: The function to compute the log-prior of the latent space position. Default is`spherical_logprior`

. This function must take as input the point`z`

in the latent space.`momentum_logprior::Function`

: The function to compute the log-prior of the momentum. Default is`riemannian_logprior`

. This function must take as input the momentum`ρ`

and the inverse of the Riemannian metric tensor`G⁻¹`

.`G_inv::Function`

: The function to compute the inverse of the Riemannian metric tensor. Default is`G_inv`

. This function must take as input the point`z`

in the latent space and the`rhvae`

instance.

**Returns**

A matrix representing the gradient of the Hamiltonian at the point `(z, ρ)`

with respect to variable `var`

.

`AutoEncoderToolkit.RHVAEs.∇hamiltonian_ForwardDiff`

— Function```
∇hamiltonian_ForwardDiff(
x::AbstractArray,
z::AbstractVector,
ρ::AbstractVector,
G⁻¹::AbstractMatrix,
logdetG::Union{<:Number,AbstractVector},
decoder::AbstractVariationalDecoder,
decoder_output::NamedTuple,
var::Symbol;
reconstruction_loglikelihood::Function=decoder_loglikelihood,
position_logprior::Function=spherical_logprior,
momentum_logprior::Function=riemannian_logprior,
)
```

Compute the gradient of the Hamiltonian with respect to a given variable using the ForwardDiff.jl automatic differentiation library.

This function takes a point `x`

in the data space, a point `z`

in the latent space, a momentum `ρ`

, the inverse of the Riemannian metric tensor `G⁻¹`

, a `decoder`

of type `AbstractVariationalDecoder`

, a `decoder_output`

NamedTuple, and a variable `var`

(:z or :ρ), and computes the gradient of the Hamiltonian with respect to `var`

using ForwardDiff.jl.

The Hamiltonian is computed as follows:

Hₓ(z, ρ) = Uₓ(z) + κ(ρ),

Uₓ(z) = -log p(x|z) - log p(z),

κ(ρ) = 0.5 * log((2π)ᴰ det G(z)) + 0.5 * ρᵀ G⁻¹ ρ

where D is the dimension of the latent space, and G(z) is the metric tensor at the point `z`

.

**Arguments**

`x::AbstractArray`

: The point in the data space. This does not necessarily need to be a vector. Array inputs are supported. The last dimension is assumed to have each of the data points.`z::AbstractVector`

: The point in the latent space.`ρ::AbstractVector`

: The momentum.`G⁻¹::AbstractMatrix`

: The inverse of the Riemannian metric tensor.`logdetG::Union{<:Number,AbstractVector}`

: The log determinant of the Riemannian metric tensor.`decoder::AbstractVariationalDecoder`

: The decoder instance.`decoder_output::NamedTuple`

: The output of the decoder.`var::Symbol`

: The variable with respect to which the gradient is computed. Must be :z or :ρ.

**Optional Keyword Arguments**

`reconstruction_loglikelihood::Function`

: The function to compute the log-likelihood of the decoder reconstruction. Default is`decoder_loglikelihood`

. This function must take as input the decoder, the point`x`

in the data space, and the`decoder_output`

.`position_logprior::Function`

: The function to compute the log-prior of the latent space position. Default is`spherical_logprior`

. This function must take as input the point`z`

in the latent space.`momentum_logprior::Function`

: The function to compute the log-prior of the momentum. Default is`riemannian_logprior`

. This function must take as input the momentum`ρ`

and`G⁻¹`

.

**Returns**

`(z, ρ)`

with respect to variable `var`

.

**Note**

`ForwardDiff.jl`

is not composable with `Zygote.jl.`

Thus, for backpropagation using this function one should use `ReverseDiff.jl.`

```
∇hamiltonian_ForwardDiff(
x::AbstractArray,
z::AbstractMatrix,
ρ::AbstractMatrix,
G⁻¹::AbstractArray,
logdetG::Union{<:Number,AbstractVector},
decoder::AbstractVariationalDecoder,
decoder_output::NamedTuple,
var::Symbol;
reconstruction_loglikelihood::Function=decoder_loglikelihood,
position_logprior::Function=spherical_logprior,
momentum_logprior::Function=riemannian_logprior,
)
```

Compute the gradient of the Hamiltonian with respect to a given variable using the ForwardDiff.jl automatic differentiation library.

This function takes a point `x`

in the data space, a point `z`

in the latent space, a momentum `ρ`

, the inverse of the Riemannian metric tensor `G⁻¹`

, a `decoder`

of type `AbstractVariationalDecoder`

, a `decoder_output`

NamedTuple, and a variable `var`

(:z or :ρ), and computes the gradient of the Hamiltonian with respect to `var`

using ForwardDiff.jl.

The Hamiltonian is computed as follows:

Hₓ(z, ρ) = Uₓ(z) + κ(ρ),

Uₓ(z) = -log p(x|z) - log p(z),

κ(ρ) = 0.5 * log((2π)ᴰ det G(z)) + 0.5 * ρᵀ G⁻¹ ρ

where D is the dimension of the latent space, and G(z) is the metric tensor at the point `z`

.

The Jacobian is computed with respect to `var`

to compute derivatives for all columns at once. The relevant terms for each column's gradient are then extracted from the Jacobian.

**Arguments**

`x::AbstractArray`

: The point in the data space. This does not necessarily need to be a vector. Array inputs are supported. The last dimension is assumed to have each of the data points.`z::AbstractMatrix`

: The point in the latent space.`ρ::AbstractMatrix`

: The momentum.`G⁻¹::AbstractArray`

: The inverse of the Riemannian metric tensor.`logdetG::Union{<:Number,AbstractVector}`

: The log determinant of the Riemannian metric tensor.`decoder::AbstractVariationalDecoder`

: The decoder instance.`decoder_output::NamedTuple`

: The output of the decoder.`var::Symbol`

: The variable with respect to which the gradient is computed. Must be :z or :ρ.

**Optional Keyword Arguments**

`reconstruction_loglikelihood::Function`

: The function to compute the log-likelihood of the decoder reconstruction. Default is`decoder_loglikelihood`

. This function must take as input the decoder, the point`x`

in the data space, and the`decoder_output`

.`position_logprior::Function`

: The function to compute the log-prior of the latent space position. Default is`spherical_logprior`

. This function must take as input the point`z`

in the latent space.`momentum_logprior::Function`

: The function to compute the log-prior of the momentum. Default is`riemannian_logprior`

. This function must take as input the momentum`ρ`

and`G⁻¹`

.

**Returns**

A matrix representing the gradient of the Hamiltonian at the point `(z, ρ)`

with respect to variable `var`

.

**Note**

`ForwardDiff.jl`

is not composable with `Zygote.jl.`

Thus, for backpropagation using this function one should use `ReverseDiff.jl.`

```
∇hamiltonian_ForwardDiff(
x::AbstractArray,
z::AbstractVecOrMat,
ρ::AbstractVecOrMat,
rhvae::RHVAE,
var::Symbol;
reconstruction_loglikelihood::Function=decoder_loglikelihood,
position_logprior::Function=spherical_logprior,
momentum_logprior::Function=riemannian_logprior,
G_inv::Function=G_inv,
)
```

Compute the gradient of the Hamiltonian with respect to a given variable using the ForwardDiff.jl automatic differentiation library.

This function takes a point `x`

in the data space, a point `z`

in the latent space, a momentum `ρ`

, an instance of `RHVAE`

, and a variable `var`

(:z or :ρ), and computes the gradient of the Hamiltonian with respect to `var`

using ForwardDiff.jl.

The Hamiltonian is computed as follows:

Hₓ(z, ρ) = Uₓ(z) + κ(ρ),

Uₓ(z) = -log p(x|z) - log p(z),

κ(ρ) = 0.5 * log((2π)ᴰ det G(z)) + 0.5 * ρᵀ G⁻¹ ρ

where D is the dimension of the latent space, and G(z) is the metric tensor at the point `z`

.

**Arguments**

`x::AbstractArray`

: The point in the data space. This does not necessarily need to be a vector. Array inputs are supported. The last dimension is assumed to have each of the data points.`z::AbstractVecOrMat`

: The point in the latent space. If matrix, each column represents a point in the latent space.`ρ::AbstractVecOrMat`

: The momentum. If matrix, each column represents a momentum vector.`rhvae::RHVAE`

: An instance of the RHVAE model.`var::Symbol`

: The variable with respect to which the gradient is computed. Must be :z or :ρ.

**Optional Keyword Arguments**

`reconstruction_loglikelihood::Function`

: The function to compute the log-likelihood of the decoder reconstruction. Default is`decoder_loglikelihood`

. This function must take as input the decoder, the point`x`

in the data space, and the`decoder_output`

.`position_logprior::Function`

: The function to compute the log-prior of the latent space position. Default is`spherical_logprior`

. This function must take as input the point`z`

in the latent space.`momentum_logprior::Function`

: The function to compute the log-prior of the momentum. Default is`riemannian_logprior`

. This function must take as input the momentum`ρ`

and the inverse of the Riemannian metric tensor`G⁻¹`

.`G_inv::Function`

: The function to compute the inverse of the Riemannian metric tensor. Default is`G_inv`

. This function must take as input the point`z`

in the latent space and the`rhvae`

instance.

**Returns**

A matrix representing the gradient of the Hamiltonian at the point `(z, ρ)`

with respect to variable `var`

.

**Note**

`ForwardDiff.jl`

is not composable with `Zygote.jl.`

Thus, for backpropagation using this function one should use `ReverseDiff.jl.`

## Other Functions

`AutoEncoderToolkit.RHVAEs.update_metric`

— Function```
update_metric(
rhvae::RHVAE{<:VAE{<:AbstractGaussianEncoder,<:AbstractVariationalDecoder}}
)
```

Compute the `centroids_latent`

and `M`

field of a `RHVAE`

instance without modifying the instance. This method is used when needing to backpropagate through the RHVAE during training.

**Arguments**

`rhvae::RHVAE{<:VAE{<:AbstractGaussianEncoder,<:AbstractVariationalDecoder}}`

: The`RHVAE`

instance to be updated.

**Returns**

- NamedTuple with the following fields:
`centroids_latent::Matrix`

: A matrix where each column represents a centroid cᵢ in the inverse metric computation.`L::Array{<:Number, 3}`

: A 3D array where each slice represents a L_ψᵢ matrix.`M::Array{<:Number, 3}`

: A 3D array where each slice represents a L*ψᵢ L*ψᵢᵀ.

`AutoEncoderToolkit.RHVAEs.update_metric!`

— Function```
update_metric!(
rhvae::RHVAE{<:VAE{<:AbstractGaussianEncoder,<:AbstractVariationalDecoder}},
params::NamedTuple
)
```

Update the `centroids_latent`

and `M`

fields of a `RHVAE`

instance in place.

This function takes a `RHVAE`

instance and a named tuple `params`

containing the new values for `centroids_latent`

and `M`

. It updates the `centroids_latent`

, `L`

, and `M`

fields of the `RHVAE`

instance with the provided values.

**Arguments**

`rhvae::RHVAE{<:VAE{<:AbstractGaussianEncoder,<:AbstractVariationalDecoder}}`

: The`RHVAE`

instance to update.`params::NamedTuple`

: A named tuple containing the new values for`centroids_latent`

and`M`

. Must have the keys`:centroids_latent`

,`:L`

, and`:M`

.

**Returns**

Nothing. The `RHVAE`

instance is updated in place.

```
update_metric!(
rhvae::RHVAE{
<:VAE{<:AbstractGaussianEncoder,<:AbstractVariationalDecoder}
}
)
```

Update the `centroids_latent`

, and `M`

fields of a `RHVAE`

instance in place.

This function takes a `RHVAE`

instance as input and modifies its `centroids_latent`

and `M`

fields. The `centroids_latent`

field is updated by running the `centroids_data`

through the encoder of the underlying VAE and extracting the mean (µ) of the resulting Gaussian distribution. The `M`

field is updated by running each column of the `centroids_data`

through the `metric_chain`

and concatenating the results along the third dimension, then each slice is updated by multiplying each slice of `L`

by its transpose and concating the results along the third dimension.

**Arguments**

`rhvae::RHVAE{<:VAE{<:AbstractGaussianEncoder,<:AbstractVariationalDecoder}}`

: The`RHVAE`

instance to be updated.

**Notes**

This function modifies the `RHVAE`

instance in place, so it does not return anything. The changes are made directly to the `centroids_latent`

, `L`

, and `M`

fields of the input `RHVAE`

instance.

`AutoEncoderToolkit.RHVAEs.G_inv`

— Function```
G_inv(
z::AbstractVecOrMat,
centroids_latent::AbstractMatrix,
M::AbstractArray{<:Number,3},
T::Number,
λ::Number,
)
```

Compute the inverse of the metric tensor G for a given point in the latent space.

This function takes a point `z`

in the latent space, the `centroids_latent`

of the RHVAE instance, a 3D array `M`

representing the metric tensor, a temperature `T`

, and a regularization factor `λ`

, and computes the inverse of the metric tensor G at that point. The computation is based on the centroids and the temperature, as well as a regularization term. The inverse metric is computed as follows:

G⁻¹(z) = ∑ᵢ₌₁ⁿ L*ψᵢ L*ψᵢᵀ exp(-‖z - cᵢ‖₂² / T²) + λIₗ,

where L*ψᵢ is computed by the MetricChain, T is the temperature, λ is a regularization factor, and each column of `centroids*latent` are the cᵢ.

**Arguments**

`z::AbstractVecOrMat`

: The point in the latent space. If a matrix, each column represents a point in the latent space.`centroids_latent::AbstractMatrix`

: The centroids in the latent space.`M::AbstractArray{<:Number,3}`

: The 3D array containing the symmetric matrices used to compute the inverse metric tensor.`T::N`

: The temperature.`λ::N`

: The regularization factor.

**Returns**

A matrix or 3D array representing the inverse of the metric tensor G at the point `z`

. If a 3D array, each slice represents the inverse metric tensor at a different point in the latent space.

**Notes**

The computation involves the squared Euclidean distance between z and each centroid, the exponential of the negative of these distances divided by the square of the temperature, and a regularization term proportional to the identity matrix. The result is a matrix of the same size as the latent space.

**GPU support**

This function supports CPU and GPU arrays.

```
G_inv(
z::AbstractVecOrMat,
metric_param::Union{RHVAE,NamedTuple},
)
```

Compute the inverse of the metric tensor G for a given point in the latent space.

This function takes a `RHVAE`

instance and a point `z`

in the latent space, and computes the inverse of the metric tensor G at that point. The computation is based on the centroids and the temperature of the `RHVAE`

instance, as well as a regularization term. The inverse metric is computed as follows:

G⁻¹(z) = ∑ᵢ₌₁ⁿ L*ψᵢ L*ψᵢᵀ exp(-‖z - cᵢ‖₂² / T²) + λIₗ,

where L*ψᵢ is computed by the MetricChain, T is the temperature, λ is a regularization factor, and each column of `centroids*latent` are the cᵢ.

**Arguments**

`z::AbstractVecOrMat`

: The point in the latent space. If a matrix, each column represents a point in the latent space.`metric_param::Union{RHVAE,NamedTuple}`

: Either an`RHVAE`

instance or a named tuple containing the fields`centroids_latent`

,`M`

,`T`

, and`λ`

.

**Returns**

A matrix representing the inverse of the metric tensor G at the point `z`

.

**Notes**

The computation involves the squared Euclidean distance between z and each centroid of the RHVAE instance, the exponential of the negative of these distances divided by the square of the temperature, and a regularization term proportional to the identity matrix. The result is a matrix of the same size as the latent space.

`AutoEncoderToolkit.RHVAEs.metric_tensor`

— Function```
metric_tensor(
z::AbstractVecOrMat,
metric_param::Union{RHVAE,NamedTuple},
)
```

Compute the metric tensor G for a given point in the latent space. This function is a wrapper that determines the type of the input `z`

and calls the appropriate specialized function `_metric_tensor`

to perform the actual computation.

This function takes a `RHVAE`

instance or a named tuple containing the fields `centroids_latent`

, `M`

, `T`

, and `λ`

, and a point `z`

in the latent space, and computes the metric tensor G at that point. The computation is based on the inverse of the metric tensor G, which is computed by the `G_inv`

function.

**Arguments**

`z::AbstractVecOrMat`

: The point in the latent space. If a matrix, each column represents a point in the latent space.`metric_param::Union{RHVAE,NamedTuple}`

: Either an`RHVAE`

instance or a named tuple containing the fields`centroids_latent`

,`M`

,`T`

, and`λ`

.

**Returns**

A matrix representing the metric tensor G at the point `z`

.

**Notes**

The computation involves the inverse of the metric tensor G at the point z. The result is a matrix of the same size as the latent space.

**GPU Support**

This function supports CPU and GPU arrays.

`AutoEncoderToolkit.RHVAEs.riemannian_logprior`

— Function```
riemannian_logprior(
ρ::AbstractVector,
G⁻¹::AbstractMatrix,
logdetG::Number;
)
```

CPU AbstractVector version of the riemannian_logprior function.

```
riemannian_logprior(
ρ::AbstractVector,
G⁻¹::AbstractMatrix,
logdetG::Number,
)
```

CPU AbstractMatrix version of the riemannian_logprior function.

`AutoEncoderToolkit.RHVAEs.hamiltonian`

— Function```
hamiltonian(
x::AbstractArray,
z::AbstractVecOrMat,
ρ::AbstractVecOrMat,
G⁻¹::AbstractArray,
logdetG::Union{<:Number,<:AbstractVector},
decoder::AbstractVariationalDecoder,
decoder_output::NamedTuple;
decoder_loglikelihood::Function=decoder_loglikelihood,
position_logprior::Function=spherical_logprior,
momentum_logprior::Function=riemannian_logprior,
)
```

Compute the Hamiltonian for a given point in the latent space and a given momentum.

This function takes a point `x`

in the data space, a point `z`

in the latent space, a momentum `ρ`

, the inverse of the Riemannian metric tensor `G⁻¹`

, a `decoder`

of type `AbstractVariationalDecoder`

, and a `decoder_output`

NamedTuple, and computes the Hamiltonian. The computation is based on the log-likelihood of the decoder, the log-prior of the latent space, and the inverse of the metric tensor G at the point `z`

.

The Hamiltonian is computed as follows:

Hₓ(z, ρ) = Uₓ(z) + κ(ρ),

Uₓ(z) = -log p(x|z) - log p(z),

κ(ρ) = -log p(ρ),

where p(ρ) is the log-prior of the momentum.

**Arguments**

`x::AbstractArray`

: The point in the data space. This does not necessarily need to be a vector. Array inputs are supported, but the last dimension of the array should be of size 1.`z::AbstractVecOrMat`

: The point in the latent space.`ρ::AbstractVecOrMat`

: The momentum.`G⁻¹::AbstractArray`

: The inverse of the Riemannian metric tensor. This should be computed elsewhere and should correspond to the given`z`

value.`logdetG::Union{<:Number,AbstractVector}`

: The log determinant of the Riemannian metric tensor. This should be computed elsewhere and should correspond to the given`z`

value.`decoder::AbstractVariationalDecoder`

: The decoder instance. This is not used in the computation of the Hamiltonian, but is passed to the`decoder_loglikelihood`

function to know which method to use.`decoder_output::NamedTuple`

: The output of the decoder.

**Optional Keyword Arguments**

`reconstruction_loglikelihood::Function`

: The function to compute the log-likelihood of the decoder reconstruction. Default is`decoder_loglikelihood`

. This function must take as input the decoder, the point`x`

in the data space, and the`decoder_output`

.`position_logprior::Function`

: The function to compute the log-prior of the latent space position. Default is`spherical_logprior`

. This function must take as input the point`z`

in the latent space.`momentum_logprior::Function`

: The function to compute the log-prior of the momentum. Default is`riemannian_logprior`

. This function must take as input the momentum`ρ`

and the inverse of the Riemannian metric tensor`G⁻¹`

.

**Returns**

A scalar representing the Hamiltonian at the point `z`

with the momentum `ρ`

.

**Note**

The inverse of the Riemannian metric tensor `G⁻¹`

is assumed to be computed elsewhere. The user must ensure that the provided `G⁻¹`

corresponds to the given `z`

value.

```
hamiltonian(
x::AbstractArray,
z::AbstractVecOrMat,
ρ::AbstractVecOrMat,
rhvae::RHVAE;
reconstruction_loglikelihood::Function=decoder_loglikelihood,
position_logprior::Function=spherical_logprior,
momentum_logprior::Function=riemannian_logprior,
G_inv::Function=G_inv,
)
```

Compute the Hamiltonian for a given point in the latent space and a given momentum.

This function takes a point `x`

in the data space, a point `z`

in the latent space, a momentum `ρ`

, and an instance of `RHVAE`

. It computes the inverse of the Riemannian metric tensor `G⁻¹`

and the output of the decoder internally, and then computes the Hamiltonian. The computation is based on the log-likelihood of the decoder, the log-prior of the latent space, and the inverse of the metric tensor G at the point `z`

.

The Hamiltonian is computed as follows:

Hₓ(z, ρ) = Uₓ(z) + κ(ρ),

Uₓ(z) = -log p(x|z) - log p(z),

κ(ρ) = -log p(ρ),

where p(ρ) is the log-prior of the momentum.

**Arguments**

`x::AbstractArray`

: The point in the data space. This does not necessarily need to be a vector. Array inputs are supported, but the last dimension of the array should be of size 1.`z::AbstractVector`

: The point in the latent space.`ρ::AbstractVector`

: The momentum.`rhvae::RHVAE`

: An instance of the RHVAE model.

**Optional Keyword Arguments**

`reconstruction_loglikelihood::Function`

: The function to compute the log-likelihood of the decoder reconstruction. Default is`decoder_loglikelihood`

. This function must take as input the decoder, the point`x`

in the data space, and the`decoder_output`

.`position_logprior::Function`

: The function to compute the log-prior of the latent space position. Default is`spherical_logprior`

. This function must take as input the point`z`

in the latent space.`momentum_logprior::Function`

: The function to compute the log-prior of the momentum. Default is`riemannian_logprior`

. This function must take as input the momentum`ρ`

and the inverse of the Riemannian metric tensor`G⁻¹`

.`G_inv::Function`

: The function to compute the inverse of the Riemannian metric tensor. Default is`G_inv`

. This function must take as input the point`z`

in the latent space and the`rhvae`

instance.

**Returns**

A scalar representing the Hamiltonian at the point `z`

with the momentum `ρ`

.

**Note**

The inverse of the Riemannian metric tensor `G⁻¹`

, the log determinant of the metric tensor, and the output of the decoder are computed internally in this function. The user does not need to provide these as inputs.

`AutoEncoderToolkit.RHVAEs.∇hamiltonian`

— Function```
∇hamiltonian(
x::AbstractArray,
z::AbstractVecOrMat,
ρ::AbstractVecOrMat,
G⁻¹::AbstractArray,
logdetG::Union{<:Number,AbstractVector},
decoder::AbstractVariationalDecoder,
decoder_output::NamedTuple,
var::Symbol;
reconstruction_loglikelihood::Function=decoder_loglikelihood,
position_logprior::Function=spherical_logprior,
momentum_logprior::Function=riemannian_logprior,
adtype::Symbol=:TaylorDiff,
adkwargs::Union{NamedTuple,Dict}=Dict(),
)
```

Compute the gradient of the Hamiltonian with respect to a given variable using a specified automatic differentiation method.

This function takes a point `x`

in the data space, a point `z`

in the latent space, a momentum `ρ`

, the inverse of the Riemannian metric tensor `G⁻¹`

, a `decoder`

of type `AbstractVariationalDecoder`

, a `decoder_output`

NamedTuple, and a variable `var`

(:z or :ρ), and computes the gradient of the Hamiltonian with respect to `var`

using the specified automatic differentiation method. The computation is based on the log-likelihood of the decoder, the log-prior of the latent space, and `G⁻¹`

.

The Hamiltonian is computed as follows:

Hₓ(z, ρ) = Uₓ(z) + κ(ρ),

Uₓ(z) = -log p(x|z) - log p(z),

κ(ρ) = 0.5 * log((2π)ᴰ det G(z)) + 0.5 * ρᵀ G⁻¹ ρ

where D is the dimension of the latent space, and G(z) is the metric tensor at the point `z`

.

**Arguments**

`x::AbstractArray`

: The point in the data space. This does not necessarily need to be a vector. Array inputs are supported. The last dimension is assumed to have each of the data points.`z::AbstractVecOrMat`

: The point in the latent space. If matrix, each column represents a point in the latent space.`ρ::AbstractVecOrMat`

: The momentum. If matrix, each column represents a momentum vector.`G⁻¹::AbstractArray`

: The inverse of the Riemannian metric tensor. If 3D array, each slice along the third dimension represents the inverse of the metric tensor at the corresponding column of`z`

.`logdetG::Union{<:Number,AbstractVector}`

: The log determinant of the Riemannian metric tensor. If vector, each element represents the log determinant of the metric tensor at the corresponding column of`z`

.`decoder::AbstractVariationalDecoder`

: The decoder instance.`decoder_output::NamedTuple`

: The output of the decoder.`var::Symbol`

: The variable with respect to which the gradient is computed. Must be :z or :ρ.

**Optional Keyword Arguments**

`reconstruction_loglikelihood::Function`

: The function to compute the log-likelihood of the decoder reconstruction. Default is`decoder_loglikelihood`

. This function must take as input the decoder, the point`x`

in the data space, and the`decoder_output`

.`position_logprior::Function`

: The function to compute the log-prior of the latent space position. Default is`spherical_logprior`

. This function must take as input the point`z`

in the latent space.`momentum_logprior::Function`

: The function to compute the log-prior of the momentum. Default is`riemannian_logprior`

. This function must take as input the momentum`ρ`

and`G⁻¹`

.`adtype::Symbol`

=:finite`: The type of automatic differentiation method to use. Must be`

:finite`,`

:ForwardDiff`, or`

:TaylorDiff`. Default is`

:finite`.`adkwargs::Union{NamedTuple,Dict}=Dict()`

: Additional keyword arguments to pass to the automatic differentiation method.

**Returns**

`(z, ρ)`

with respect to variable `var`

.

```
∇hamiltonian(
x::AbstractArray,
z::AbstractVecOrMat,
ρ::AbstractVecOrMat,
rhvae::RHVAE,
var::Symbol;
reconstruction_loglikelihood::Function=decoder_loglikelihood,
position_logprior::Function=spherical_logprior,
momentum_logprior::Function=riemannian_logprior,
G_inv::Function=G_inv,
adtype::Symbol=:TaylorDiff,
adkwargs::Union{NamedTuple,Dict}=Dict(),
)
```

Compute the gradient of the Hamiltonian with respect to a given variable using a specified automatic differentiation method.

This function takes a point `x`

in the data space, a point `z`

in the latent space, a momentum `ρ`

, an instance of `RHVAE`

, and a variable `var`

(:z or :ρ), and computes the gradient of the Hamiltonian with respect to `var`

using the specified automatic differentiation method. The computation is based on the log-likelihood of the decoder, the log-prior of the latent space, and `G_inv`

.

**Arguments**

`x::AbstractArray`

: The point in the data space. This does not necessarily need to be a vector. Array inputs are supported. The last dimension is assumed to have each of the data points.`z::AbstractVecOrMat`

: The point in the latent space. If matrix, each column represents a point in the latent space.`ρ::AbstractVecOrMat`

: The momentum. If matrix, each column represents a momentum vector.`rhvae::RHVAE`

: An instance of the RHVAE model.`var::Symbol`

: The variable with respect to which the gradient is computed. Must be :z or :ρ.

**Optional Keyword Arguments**

`reconstruction_loglikelihood::Function`

: The function to compute the log-likelihood of the decoder reconstruction. Default is`decoder_loglikelihood`

. This function must take as input the decoder, the point`x`

in the data space, and the`decoder_output`

.`position_logprior::Function`

: The function to compute the log-prior of the latent space position. Default is`spherical_logprior`

. This function must take as input the point`z`

in the latent space.`momentum_logprior::Function`

: The function to compute the log-prior of the momentum. Default is`riemannian_logprior`

. This function must take as input the momentum`ρ`

and`G_inv`

.`G_inv::Function`

: The function to compute the inverse of the Riemannian metric tensor. Default is`G_inv`

.`adtype::Symbol`

=:finite`: The type of automatic differentiation method to use. Must be`

:finite`,`

:ForwardDiff`, or`

:TaylorDiff`. Default is`

:finite`.`adkwargs::Union{NamedTuple,Dict}=Dict()`

: Additional keyword arguments to pass to the automatic differentiation method.

**Returns**

`(z, ρ)`

with respect to variable `var`

.

`AutoEncoderToolkit.RHVAEs._leapfrog_first_step`

— Function```
_leapfrog_first_step(
x::AbstractArray,
z::AbstractVecOrMat,
ρ::AbstractVecOrMat,
G⁻¹::AbstractArray,
logdetG::Union{<:Number,AbstractVector},
decoder::AbstractVariationalDecoder,
decoder_output::NamedTuple;
ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
steps::Int=3,
∇H_kwargs::Union{NamedTuple,Dict}=(
reconstruction_loglikelihood=decoder_loglikelihood,
position_logprior=spherical_logprior,
momentum_logprior=riemannian_logprior,
),
)
```

Perform the first step of the generalized leapfrog integrator for Hamiltonian dynamics, defined as

ρ(t + ϵ/2) = ρ(t) - 0.5 * ϵ * ∇z_H(z(t), ρ(t + ϵ/2)).

This function is part of the generalized leapfrog integrator used in Hamiltonian dynamics. Unlike the standard leapfrog integrator, the generalized leapfrog integrator is implicit, which means it requires the use of fixed-point iterations to be solved.

The function takes a point `x`

in the data space, a point `z`

in the latent space, a momentum `ρ`

, the inverse of the Riemannian metric tensor `G⁻¹`

, a `decoder`

of type `AbstractVariationalDecoder`

, the output of the decoder `decoder_output`

, a step size `ϵ`

, and optionally the number of fixed-point iterations to perform (`steps`

), a function to compute the gradient of the Hamiltonian (`∇H`

), and a set of keyword arguments for `∇H`

(`∇H_kwargs`

).

The function performs the following update for `steps`

times:

ρ̃ = ρ̃ - 0.5 * ϵ * ∇hamiltonian(x, z, ρ̃, G⁻¹, decoder, decoder*output, :z; ∇H*kwargs...)

where `∇H`

is the gradient of the Hamiltonian with respect to the position variables `z`

. The result is returned as ρ̃.

**Arguments**

`x::AbstractArray`

: The point in the data space. This does not necessarily need to be a vector. Array inputs are supported. The last dimension is assumed to have each of the data points.`z::AbstractVecOrMat`

: The point in the latent space. If matrix, each column represents a point in the latent space.`ρ::AbstractVecOrMat`

: The momentum. If matrux, each column represents a momentum vector.`G⁻¹::AbstractArray`

: The inverse of the Riemannian metric tensor. If 3D array, each slice along the third dimension represents the inverse of the metric tensor at the corresponding column of`z`

.`logdetG::Union{<:Number,AbstractVector}`

: The log determinant of the Riemannian metric tensor. If vector, each element represents the log determinant of the metric tensor at the corresponding column of`z`

.`decoder::AbstractVariationalDecoder`

: The decoder instance.`decoder_output::NamedTuple`

: The output of the decoder.

**Optional Keyword Arguments**

`ϵ::Union{<:Number,<:AbstractVector}=0.01f0`

: The leapfrog step size. Default is 0.01f0.`steps::Int=3`

: The number of fixed-point iterations to perform. Default is 3.`∇H_kwargs::Union{NamedTuple,Dict}`

: The keyword arguments for`∇hamiltonian`

. Default is a tuple with`reconstruction_loglikelihood`

,`position_logprior`

,`momentum_logprior`

, and`G_inv`

.

**Returns**

A vector representing the updated momentum after performing the first step of the generalized leapfrog integrator.

```
_leapfrog_first_step(
x::AbstractArray,
z::AbstractVecOrMat,
ρ::AbstractVecOrMat,
rhvae::RHVAE;
ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
steps::Int=3,
∇H_kwargs::Union{NamedTuple,Dict}=(
reconstruction_loglikelihood=decoder_loglikelihood,
position_logprior=spherical_logprior,
momentum_logprior=riemannian_logprior,
),
G_inv::Function=G_inv,
)
```

Perform the first step of the generalized leapfrog integrator for Hamiltonian dynamics, defined as

ρ(t + ϵ/2) = ρ(t) - 0.5 * ϵ * ∇z_H(z(t), ρ(t + ϵ/2)).

This function is part of the generalized leapfrog integrator used in Hamiltonian dynamics. Unlike the standard leapfrog integrator, the generalized leapfrog integrator is implicit, which means it requires the use of fixed-point iterations to be solved.

The function takes a `RHVAE`

instance, a point `x`

in the data space, a point `z`

in the latent space, a momentum `ρ`

, a step size `ϵ`

, and optionally the number of fixed-point iterations to perform (`steps`

), a function to compute the gradient of the Hamiltonian (`∇H`

), and a set of keyword arguments for `∇H`

(`∇H_kwargs`

).

The function performs the following update for `steps`

times:

ρ̃ = ρ̃ - 0.5 * ϵ * ∇hamiltonian(rhvae, x, z, ρ̃, :z; ∇H_kwargs...)

where `∇H`

is the gradient of the Hamiltonian with respect to the position variables `z`

. The result is returned as ρ̃.

**Arguments**

`x::AbstractArray`

: The point in the data space. This does not necessarily need to be a vector. Array inputs are supported. The last dimension is assumed to have each of the data points.`z::AbstractVecOrMat`

: The point in the latent space. If matrix, each column represents a point in the latent space.`ρ::AbstractVecOrMat`

: The momentum. If matrux, each column represents a momentum vector.`rhvae::RHVAE`

: The`RHVAE`

instance.

**Optional Keyword Arguments**

`ϵ::Union{<:Number,<:AbstractVector}=0.01f0`

: The leapfrog step size. Default is 0.01f0.`steps::Int=3`

: The number of fixed-point iterations to perform. Default is 3.`∇H_kwargs::Union{NamedTuple,Dict}`

: The keyword arguments for`∇hamiltonian`

. Default is a tuple with`reconstruction_loglikelihood`

,`position_logprior`

, and`momentum_logprior`

.`G_inv::Function`

: The function to compute the inverse of the Riemannian metric tensor. Default is`G_inv`

.

**Returns**

A vector representing the updated momentum after performing the first step of the generalized leapfrog integrator.

`AutoEncoderToolkit.RHVAEs._leapfrog_second_step`

— Function```
_leapfrog_second_step(
x::AbstractArray,
z::AbstractVecOrMat,
ρ::AbstractVecOrMat,
G⁻¹::AbstractArray,
logdetG::Union{<:Number,AbstractVector},
decoder::AbstractVariationalDecoder,
decoder_output::NamedTuple;
ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
steps::Int=3,
∇H_kwargs::Union{NamedTuple,Dict}=(
reconstruction_loglikelihood=decoder_loglikelihood,
position_logprior=spherical_logprior,
momentum_logprior=riemannian_logprior,
),
)
```

Perform the second step of the generalized leapfrog integrator for Hamiltonian dynamics, defined as

z(t + ϵ) = z(t) + 0.5 * ϵ * [∇ρ*H(z(t), ρ(t+ϵ/2)) + ∇ρ*H(z(t + ϵ), ρ(t+ϵ/2))].

This function is part of the generalized leapfrog integrator used in Hamiltonian dynamics. Unlike the standard leapfrog integrator, the generalized leapfrog integrator is implicit, which means it requires the use of fixed-point iterations to be solved.

The function takes a point `x`

in the data space, a point `z`

in the latent space, a momentum `ρ`

, the inverse of the Riemannian metric tensor `G⁻¹`

, a `decoder`

of type `AbstractVariationalDecoder`

, the output of the decoder `decoder_output`

, a step size `ϵ`

, and optionally the number of fixed-point iterations to perform (`steps`

), a function to compute the gradient of the Hamiltonian (`∇H`

), and a set of keyword arguments for `∇H`

(`∇H_kwargs`

).

The function performs the following update for `steps`

times:

z̄ = z̄ + 0.5 * ϵ * ( ∇hamiltonian(x, z̄, ρ, G⁻¹, decoder, decoder*output, :ρ; ∇H*kwargs...) + ∇hamiltonian(x, z, ρ, G⁻¹, decoder, decoder*output, :ρ; ∇H*kwargs...) )

where `∇H`

is the gradient of the Hamiltonian with respect to the momentum variables `ρ`

. The result is returned as z̄.

**Arguments**

`x::AbstractArray`

: The point in the data space. This does not necessarily need to be a vector. Array inputs are supported. The last dimension is assumed to have each of the data points.`z::AbstractVecOrMat`

: The point in the latent space. If matrix, each column represents a point in the latent space.`ρ::AbstractVecOrMat`

: The momentum. If matrux, each column represents a momentum vector.`G⁻¹::AbstractArray`

: The inverse of the Riemannian metric tensor. If 3D array, each slice along the third dimension represents the inverse of the metric tensor at the corresponding column of`z`

.`logdetG::Union{<:Number,AbstractVector}`

: The log determinant of the Riemannian metric tensor. If vector, each element represents the log determinant of the metric tensor at the corresponding column of`z`

.`decoder::AbstractVariationalDecoder`

: The decoder instance.`decoder_output::NamedTuple`

: The output of the decoder.

**Optional Keyword Arguments**

`ϵ::Union{<:Number,<:AbstractVector}=0.01f0`

: The step size. Default is 0.01.`steps::Int=3`

: The number of fixed-point iterations to perform. Default is 3.`∇H_kwargs::Union{NamedTuple,Dict}`

: The keyword arguments for`∇hamiltonian`

. Default is a tuple with`reconstruction_loglikelihood`

,`position_logprior`

,`momentum_logprior`

.

**Returns**

A vector representing the updated position after performing the second step of the generalized leapfrog integrator.

```
_leapfrog_second_step(
x::AbstractArray,
z::AbstractVecOrMat,
ρ::AbstractVecOrMat,
rhvae::RHVAE;
ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
steps::Int=3,
∇H_kwargs::Union{NamedTuple,Dict}=(
reconstruction_loglikelihood=decoder_loglikelihood,
position_logprior=spherical_logprior,
momentum_logprior=riemannian_logprior,
),
G_inv::Function=G_inv,
)
```

Perform the second step of the generalized leapfrog integrator for Hamiltonian dynamics, defined as

z(t + ϵ) = z(t) + 0.5 * ϵ * [∇ρ*H(z(t), ρ(t+ϵ/2)) + ∇ρ*H(z(t + ϵ), ρ(t+ϵ/2))].

The function takes a `RHVAE`

instance, a point `x`

in the data space, a point `z`

in the latent space, a momentum `ρ`

, a step size `ϵ`

, and optionally the number of fixed-point iterations to perform (`steps`

), a function to compute the gradient of the Hamiltonian (`∇H`

), and a set of keyword arguments for `∇H`

(`∇H_kwargs`

).

The function performs the following update for `steps`

times:

z̄ = z̄ + 0.5 * ϵ * ( ∇hamiltonian(rhvae, x, z̄, ρ, :ρ; ∇H*kwargs...) + ∇hamiltonian(rhvae, x, z, ρ, :ρ; ∇H*kwargs...) )

where `∇H`

is the gradient of the Hamiltonian with respect to the momentum variables `ρ`

. The result is returned as z̄.

**Arguments**

`x::AbstractArray`

: The point in the data space. This does not necessarily need to be a vector. Array inputs are supported. The last dimension is assumed to have each of the data points.`z::AbstractVecOrMat`

: The point in the latent space. If matrix, each column represents a point in the latent space.`ρ::AbstractVecOrMat`

: The momentum. If matrux, each column represents a momentum vector.`rhvae::RHVAE`

: The`RHVAE`

instance.

**Optional Keyword Arguments**

`ϵ::Union{<:Number,<:AbstractVector}=0.01f0`

: The leapfrog step size. Default is 0.01f0.`steps::Int=3`

: The number of fixed-point iterations to perform. Default is 3. Typically, 3 iterations are sufficient.`∇H_kwargs::Union{NamedTuple,Dict}`

: The keyword arguments for`∇hamiltonian`

. Default is a tuple with`reconstruction_loglikelihood`

,`position_logprior`

, and`momentum_logprior`

.`G_inv::Function`

: The function to compute the inverse of the Riemannian metric tensor. Default is`G_inv`

.

**Returns**

A vector representing the updated position after performing the second step of the generalized leapfrog integrator.

`AutoEncoderToolkit.RHVAEs._leapfrog_third_step`

— Function```
_leapfrog_third_step(
x::AbstractArray,
z::AbstractVecOrMat,
ρ::AbstractVecOrMat,
G⁻¹::AbstractArray,
logdetG::Union{<:Number,AbstractVector},
decoder::AbstractVariationalDecoder,
decoder_output::NamedTuple;
ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
∇H_kwargs::Union{NamedTuple,Dict}=(
reconstruction_loglikelihood=decoder_loglikelihood,
position_logprior=spherical_logprior,
momentum_logprior=riemannian_logprior,
),
)
```

Perform the third step of the generalized leapfrog integrator for Hamiltonian dynamics, defined as

ρ(t + ϵ) = ρ(t + ϵ/2) - 0.5 * ϵ * ∇z_H(z(t + ϵ), ρ(t + ϵ/2)).

The function takes a point `x`

in the data space, a point `z`

in the latent space, a momentum `ρ`

, the inverse of the Riemannian metric tensor `G⁻¹`

, a `decoder`

of type `AbstractVariationalDecoder`

, the output of the decoder `decoder_output`

, a step size `ϵ`

, a function to compute the gradient of the Hamiltonian (`∇H`

), and a set of keyword arguments for `∇H`

(`∇H_kwargs`

).

The function performs the following update:

ρ̃ = ρ - 0.5 * ϵ * ∇hamiltonian( x, z, ρ, G⁻¹, decoder, decoder*output, :z; ∇H*kwargs... )

where `∇H`

is the gradient of the Hamiltonian with respect to the position variables `z`

. The result is returned as ρ̃.

**Arguments**

`x::AbstractArray`

: The point in the data space. This does not necessarily need to be a vector. Array inputs are supported. The last dimension is assumed to have each of the data points.`z::AbstractVecOrMat`

: The point in the latent space. If matrix, each column represents a point in the latent space.`ρ::AbstractVecOrMat`

: The momentum. If matrux, each column represents a momentum vector.`G⁻¹::AbstractArray`

: The inverse of the Riemannian metric tensor. If 3D array, each slice along the third dimension represents the inverse of the metric tensor at the corresponding column of`z`

.`logdetG::Union{<:Number,AbstractVector}`

: The log determinant of the Riemannian metric tensor. If vector, each element represents the log determinant of the metric tensor at the corresponding column of`z`

.`decoder::AbstractVariationalDecoder`

: The decoder instance.`decoder_output::NamedTuple`

: The output of the decoder.

**Optional Keyword Arguments**

`ϵ::Union{<:Number,<:AbstractVector}=0.01f0`

: The step size. Default is 0.01f0.`∇H_kwargs::Union{NamedTuple,Dict}`

: The keyword arguments for`∇hamiltonian`

. Default is a tuple with`reconstruction_loglikelihood`

,`position_logprior`

,`momentum_logprior`

.

**Returns**

A vector representing the updated momentum after performing the third step of the generalized leapfrog integrator.

```
_leapfrog_third_step(
x::AbstractArray,
z::AbstractVecOrMat,
ρ::AbstractVecOrMat,
rhvae::RHVAE;
ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
steps::Int=3,
∇H_kwargs::Union{NamedTuple,Dict}=(
reconstruction_loglikelihood=decoder_loglikelihood,
position_logprior=spherical_logprior,
momentum_logprior=riemannian_logprior,
),
G_inv::Function=G_inv,
)
```

Perform the third step of the generalized leapfrog integrator for Hamiltonian dynamics, defined as

ρ(t + ϵ) = ρ(t + ϵ/2) - 0.5 * ϵ * ∇z_H(z(t + ϵ), ρ(t + ϵ/2)).

The function takes a `RHVAE`

instance, a point `x`

in the data space, a point `z`

in the latent space, a momentum `ρ`

, a step size `ϵ`

, the number of fixed-point iterations to perform (`steps`

), a function to compute the gradient of the Hamiltonian (`∇H`

), and a set of keyword arguments for `∇H`

(`∇H_kwargs`

).

The function performs the following update:

ρ̃ = ρ - 0.5 * ϵ * ∇hamiltonian(rhvae, x, z, ρ, :z; ∇H_kwargs...)

`∇H`

is the gradient of the Hamiltonian with respect to the position variables `z`

. The result is returned as ρ̃.

**Arguments**

`x::AbstractArray`

: The point in the data space. This does not necessarily need to be a vector. Array inputs are supported. The last dimension is assumed to have each of the data points.`z::AbstractVecOrMat`

: The point in the latent space. If matrix, each column represents a point in the latent space.`ρ::AbstractVecOrMat`

: The momentum. If matrux, each column represents a momentum vector.`rhvae::RHVAE`

: The`RHVAE`

instance.

**Optional Keyword Arguments**

`ϵ::Union{<:Number,<:AbstractVector}`

: The leapfrog step size. Default is 0.01f0.`steps::Int=3`

: The number of fixed-point iterations to perform. Default is 3.`∇H_kwargs::Union{NamedTuple,Dict}`

: The keyword arguments for`∇hamiltonian`

. Default is a tuple with`reconstruction_loglikelihood`

,`position_logprior`

, and`momentum_logprior`

.`G_inv::Function`

: The function to compute the inverse of the Riemannian metric tensor. Default is`G_inv`

.

**Returns**

A vector representing the updated momentum after performing the third step of the generalized leapfrog integrator.

`AutoEncoderToolkit.RHVAEs.general_leapfrog_step`

— Function```
general_leapfrog_step(
x::AbstractArray,
z::AbstractVecOrMat,
ρ::AbstractVecOrMat,
G⁻¹::AbstractArray,
logdetG::Union{<:Number,AbstractVector},
decoder::AbstractVariationalDecoder,
decoder_output::NamedTuple,
metric_param::NamedTuple;
ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
steps::Int=3,
∇H_kwargs::Union{NamedTuple,Dict}=(
reconstruction_loglikelihood=decoder_loglikelihood,
position_logprior=spherical_logprior,
momentum_logprior=riemannian_logprior,
),
G_inv::Function=G_inv,
)
```

Perform a full step of the generalized leapfrog integrator for Hamiltonian dynamics.

The leapfrog integrator is a numerical integration scheme used to simulate Hamiltonian dynamics. It consists of three steps:

Half update of the momentum variable:

ρ(t + ϵ/2) = ρ(t) - 0.5 * ϵ * ∇z_H(z(t), ρ(t + ϵ/2)).

Full update of the position variable:

z(t + ϵ) = z(t) + 0.5 * ϵ * [∇ρ*H(z(t), ρ(t+ϵ/2)) + ∇ρ*H(z(t + ϵ), ρ(t+ϵ/2))].

Half update of the momentum variable:

ρ(t + ϵ) = ρ(t + ϵ/2) - 0.5 * ϵ * ∇z_H(z(t + ϵ), ρ(t + ϵ/2)).

This function performs these three steps in sequence, using the `_leapfrog_first_step`

, `_leapfrog_second_step`

and `_leapfrog_third_step`

helper functions.

**Arguments**

`x::AbstractArray`

: The point in the data space. This does not necessarily need to be a vector. Array inputs are supported. The last dimension is assumed to have each of the data points.`z::AbstractVecOrMat`

: The point in the latent space. If matrix, each column represents a point in the latent space.`ρ::AbstractVecOrMat`

: The momentum. If matrux, each column represents a momentum vector.`G⁻¹::AbstractArray`

: The inverse of the Riemannian metric tensor. If 3D array, each slice along the third dimension represents the inverse of the metric tensor at the corresponding column of`z`

.`logdetG::Union{<:Number,AbstractVector}`

: The log determinant of the Riemannian metric tensor. If vector, each element represents the log determinant of the metric tensor at the corresponding column of`z`

.`decoder::AbstractVariationalDecoder`

: The decoder instance.`decoder_output::NamedTuple`

: The output of the decoder.`metric_param::NamedTuple`

: The parameters for the metric tensor.

**Optional Keyword Arguments**

`ϵ::Union{<:Number,<:AbstractVector}=0.01f0`

: The step size. Default is 0.01.`steps::Int=3`

: The number of fixed-point iterations to perform. Default is 3. Typically, 3 iterations are sufficient.`∇H_kwargs::Union{NamedTuple,Dict}`

: The keyword arguments for`∇hamiltonian`

. Default is a tuple with`decoder_loglikelihood`

,`position_logprior`

,`momentum_logprior`

, and`G_inv`

.`G_inv::Function=G_inv`

: The function to compute the inverse of the Riemannian metric tensor.

**Returns**

A tuple `(z̄, ρ̄, Ḡ⁻¹, logdetḠ, decoder_update)`

representing the updated position, momentum, the inverse of the updated Riemannian metric tensor, the log of the determinant of the metric tensor and the updated decoder outputs after performing the full leapfrog step.

```
general_leapfrog_step(
x::AbstractArray,
z::AbstractVecOrMat,
ρ::AbstractVecOrMat,
rhvae::RHVAE;
ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
steps::Int=3,
∇H_kwargs::Union{NamedTuple,Dict}=(
reconstruction_loglikelihood=decoder_loglikelihood,
position_logprior=spherical_logprior,
momentum_logprior=riemannian_logprior,
G_inv=G_inv,
),
)
```

Perform a full step of the generalized leapfrog integrator for Hamiltonian dynamics.

The leapfrog integrator is a numerical integration scheme used to simulate Hamiltonian dynamics. It consists of three steps:

Half update of the momentum variable: ρ(t + ϵ/2) = ρ(t) - 0.5 * ϵ * ∇z_H(z(t), ρ(t + ϵ/2)).

Full update of the position variable: z(t + ϵ) = z(t) + 0.5 * ϵ * [∇ρ_H(z(t),

ρ(t+ϵ/2)) + ∇ρ_H(z(t + ϵ), ρ(t+ϵ/2))].

- Half update of the momentum variable: ρ(t + ϵ) = ρ(t + ϵ/2) - 0.5 * ϵ * ∇z_H(z(t + ϵ), ρ(t + ϵ/2)).

This function performs these three steps in sequence, using the `_leapfrog_first_step`

and `_leapfrog_second_step`

helper functions.

**Arguments**

`x::AbstractArray`

: The point in the data space. This does not necessarily need to be a vector. Array inputs are supported. The last dimension is assumed to have each of the data points.`z::AbstractVecOrMat`

: The point in the latent space. If matrix, each column represents a point in the latent space.`ρ::AbstractVecOrMat`

: The momentum. If matrux, each column represents a momentum vector.`rhvae::RHVAE`

: The`RHVAE`

instance.

**Optional Keyword Arguments**

`ϵ::Union{<:Number,<:AbstractVector}=0.01f0`

: The leapfrog step size. Default is 0.01f0.`steps::Int=3`

: The number of fixed-point iterations to perform. Default is 3. Typically, 3 iterations are sufficient.`∇H_kwargs::Union{NamedTuple,Dict}`

: The keyword arguments for`∇hamiltonian`

. Default is a tuple with`decoder_loglikelihood`

,`position_logprior`

, and`momentum_logprior`

`G_inv::Function`

: The function to compute the inverse of the Riemannian metric tensor. Default is`G_inv`

.A tuple

`(z̄, ρ̄, Ḡ⁻¹, logdetḠ, decoder_update)`

representing the updated position, momentum, the inverse of the updated Riemannian metric tensor, the log of the determinant of the metric tensor, and the updated decoder outputs after performing the full leapfrog step.

`AutoEncoderToolkit.RHVAEs.general_leapfrog_tempering_step`

— Function```
general_leapfrog_tempering_step(
x::AbstractArray,
zₒ::AbstractVecOrMat,
Gₒ⁻¹::AbstractArray,
logdetGₒ::Union{<:Number,AbstractVector},
decoder::AbstractVariationalDecoder,
decoder_output::NamedTuple,
metric_param::NamedTuple;
ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
K::Int=3,
βₒ::Number=0.3f0,
steps::Int=3,
∇H_kwargs::Union{NamedTuple,Dict}=(
reconstruction_loglikelihood=decoder_loglikelihood,
position_logprior=spherical_logprior,
momentum_logprior=riemannian_logprior,
G_inv=G_inv,
),
tempering_schedule::Function=quadratic_tempering,
)
```

Combines the leapfrog and tempering steps into a single function for the Riemannian Hamiltonian Variational Autoencoder (RHVAE).

**Arguments**

`x::AbstractArray`

: The data to be processed. If`Array`

, the last dimension must be of size 1.`zₒ::AbstractVector`

: The initial latent variable.`Gₒ⁻¹::AbstractArray`

: The initial inverse of the Riemannian metric tensor.`logdetGₒ::Union{<:Number,AbstractVector}`

: The log determinant of the initial Riemannian metric tensor. If vector, each element represents the log determinant of the metric tensor at the corresponding column of`zₒ`

.`decoder::AbstractVariationalDecoder`

: The decoder of the RHVAE model.`decoder_output::NamedTuple`

: The output of the decoder.`metric_param::NamedTuple`

: The parameters of the metric tensor.

**Optional Keyword Arguments**

`ϵ::Union{<:Number,<:AbstractVector}`

: The step size for the leapfrog steps in the HMC algorithm. This can be a scalar or an array. Default is 0.01f0.`K::Int`

: The number of leapfrog steps to perform in the Hamiltonian Monte Carlo (HMC) algorithm. Default is 3.`βₒ::Number`

: The initial inverse temperature for the tempering schedule. Default is 0.3f0.`steps::Int`

: The number of fixed-point iterations to perform. Default is 3.`∇H_kwargs::Union{NamedTuple,Dict}`

: Additional keyword arguments to be passed to the`∇hamiltonian`

function. Default is a NamedTuple with`reconstruction_loglikelihood`

,`position_logprior`

, and`momentum_logprior`

.`tempering_schedule::Function`

: The function to compute the inverse temperature at each step in the HMC algorithm. Defaults to`quadratic_tempering`

. This function must take three arguments: First,`βₒ`

, an initial inverse temperature, second,`k`

, the current step in the tempering schedule, and third,`K`

, the total number of steps in the tempering schedule.

**Returns**

- A
`NamedTuple`

with the following keys:`z_init`

: The initial latent variable.`ρ_init`

: The initial momentum variable.`Ginv_init`

: The initial inverse of the Riemannian metric tensor.`logdetG_init`

: The initial log determinant of the Riemannian metric tensor.`z_final`

: The final latent variable after`K`

leapfrog steps.`ρ_final`

: The final momentum variable after`K`

leapfrog steps.`Ginv_final`

: The final inverse of the Riemannian metric tensor after`K`

leapfrog steps.`logdetG_final`

: The final log determinant of the Riemannian metric tensor after`K`

leapfrog steps.

- The decoder output at the final latent variable is also returned. Note: This is not in the same named tuple as the other outputs, but as a separate output.

**Description**

The function first samples a random momentum variable `γₒ`

from a standard normal distribution and scales it by the inverse square root of the initial inverse temperature `βₒ`

to obtain the initial momentum variable `ρₒ`

. Then, it performs `K`

leapfrog steps, each followed by a tempering step, to generate a new sample from the latent space.

**Note**

Ensure the input data `x`

and the initial latent variable `zₒ`

match the expected input dimensionality for the RHVAE model.

```
general_leapfrog_tempering_step(
x::AbstractArray,
zₒ::AbstractVecOrMat,
rhvae::RHVAE;
ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
K::Int=3,
βₒ::Number=0.3f0,
steps::Int=3,
∇H_kwargs::Union{NamedTuple,Dict}=(
reconstruction_loglikelihood=decoder_loglikelihood,
position_logprior=spherical_logprior,
momentum_logprior=riemannian_logprior,
),
G_inv::Function=G_inv,
tempering_schedule::Function=quadratic_tempering,
)
```

Combines the leapfrog and tempering steps into a single function for the Riemannian Hamiltonian Variational Autoencoder (RHVAE).

**Arguments**

`x::AbstractArray`

: The data to be processed. If`Array`

, the last dimension must be of size 1.`zₒ::AbstractVecOrMat`

: The initial latent variable.

**Optional Keyword Arguments**

`ϵ::Union{<:Number,<:AbstractVector}`

: The step size for the leapfrog steps in the HMC algorithm. This can be a scalar or an array. Default is 0.01f0.`K::Int`

: The number of leapfrog steps to perform in the Hamiltonian Monte Carlo (HMC) algorithm. Default is 3.`βₒ::Number`

: The initial inverse temperature for the tempering schedule. Default is 0.3f0.`steps::Int`

: The number of fixed-point iterations to perform. Default is 3.`∇H_kwargs::Union{NamedTuple,Dict}`

: Additional keyword arguments to be passed to the`∇hamiltonian`

function. Default is a NamedTuple with`reconstruction_loglikelihood`

,`position_logprior`

, and`momentum_logprior`

.`tempering_schedule::Function`

: The function to compute the inverse temperature at each step in the HMC algorithm. Defaults to`quadratic_tempering`

. This function must take three arguments: First,`βₒ`

, an initial inverse temperature, second,`k`

, the current step in the tempering schedule, and third,`K`

, the total number of steps in the tempering schedule.

**Returns**

- A
`NamedTuple`

with the following keys:`z_init`

: The initial latent variable.`ρ_init`

: The initial momentum variable.`Ginv_init`

: The initial inverse of the Riemannian metric tensor.`z_final`

: The final latent variable after`K`

leapfrog steps.`ρ_final`

: The final momentum variable after`K`

leapfrog steps.`Ginv_final`

: The final inverse of the Riemannian metric tensor after`K`

leapfrog steps.

- The decoder output at the final latent variable is also returned. Note: This is not in the same named tuple as the other outputs, but as a separate output.

**Description**

The function first samples a random momentum variable `γₒ`

from a standard normal distribution and scales it by the inverse square root of the initial inverse temperature `βₒ`

to obtain the initial momentum variable `ρₒ`

. Then, it performs `K`

leapfrog steps, each followed by a tempering step, to generate a new sample from the latent space.

**Note**

Ensure the input data `x`

and the initial latent variable `zₒ`

match the expected input dimensionality for the RHVAE model.

`AutoEncoderToolkit.RHVAEs._log_p̄`

— Function```
_log_p̄(
x::AbstractArray,
rhvae::RHVAE{VAE{E,D}},
rhvae_outputs::NamedTuple;
reconstruction_loglikelihood::Function=decoder_loglikelihood,
position_logprior::Function=spherical_logprior,
momentum_logprior::Function=riemannian_logprior,
prefactor::AbstractArray=ones(Float32, 3),
)
```

This is an internal function used in `riemannian_hamiltonian_elbo`

to compute the numerator of the unbiased estimator of the marginal likelihood. The function computes the sum of the log likelihood of the data given the latent variables, the log prior of the latent variables, and the log prior of the momentum variables.

`log p̄ = log p(x | zₖ) + log p(zₖ) + log p(ρₖ(zₖ))`

**Arguments**

`x::AbstractArray`

: The input data. If`Array`

, the last dimension must contain each of the data points.`rhvae::RHVAE{<:VAE{<:AbstractGaussianEncoder,<:AbstractGaussianLogDecoder}}`

: The Riemannian Hamiltonian Variational Autoencoder (RHVAE) model.`rhvae_outputs::NamedTuple`

: The outputs of the RHVAE, including the final latent variables`zₖ`

and the final momentum variables`ρₖ`

.

**Optional Keyword Arguments**

`reconstruction_loglikelihood::Function`

: The function to compute the log likelihood of the data given the latent variables. Default is`decoder_loglikelihood`

.`position_logprior::Function`

: The function to compute the log prior of the latent variables. Default is`spherical_logprior`

.`momentum_logprior::Function`

: The function to compute the log prior of the momentum variables. Default is`riemannian_logprior`

.`prefactor::AbstractArray`

: A 3-element array to scale the log likelihood, log prior of the latent variables, and log prior of the momentum variables. Default is an array of ones.

**Returns**

`log_p̄::AbstractVector`

: The first term of the log of the unbiased estimator of the marginal likelihood for each data point.

**Note**

This is an internal function and should not be called directly. It is used as part of the `riemannian_hamiltonian_elbo`

function.

`AutoEncoderToolkit.RHVAEs._log_q̄`

— Function```
_log_q̄(
rhvae::RHVAE,
rhvae_outputs::NamedTuple,
βₒ::Number;
momentum_logprior::Function=riemannian_logprior,
prefactor::AbstractArray=ones(Float32, 3),
)
```

This is an internal function used in `riemannian_hamiltonian_elbo`

to compute the second term of the unbiased estimator of the marginal likelihood. The function computes the sum of the log posterior of the initial latent variables and the log prior of the initial momentum variables, minus a term that depends on the dimensionality of the latent space and the initial temperature.

` log q̄ = log q(zₒ) + log p(ρₒ) - d/2 log(βₒ)`

**Arguments**

`rhvae::RHVAE`

: The Riemannian Hamiltonian Variational Autoencoder (RHVAE) model.`rhvae_outputs::NamedTuple`

: The outputs of the RHVAE, including the initial latent variables`zₒ`

and the initial momentum variables`ρₒ`

.`βₒ::Number`

: The initial temperature for the tempering steps.

**Optional Keyword Arguments**

`momentum_logprior::Function`

: The function to compute the log prior of the momentum variables. Default is`riemannian_logprior`

.`prefactor::AbstractArray`

: A 3-element array to scale the log posterior of the initial latent variables, log prior of the initial momentum variables, and the tempering Jacobian term. Default is an array of ones.

**Returns**

`log_q̄::Vector`

: The second term of the log of the unbiased estimator of the marginal likelihood for each data point.

**Note**

This is an internal function and should not be called directly. It is used as part of the `riemannian_hamiltonian_elbo`

function.

`AutoEncoderToolkit.RHVAEs.riemannian_hamiltonian_elbo`

— Function```
riemannian_hamiltonian_elbo(
rhvae::RHVAE,
metric_param::NamedTuple,
x::AbstractArray;
ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
K::Int=3,
βₒ::Number=0.3f0,
steps::Int=3,
∇H_kwargs::Union{NamedTuple,Dict}=(
reconstruction_loglikelihood=decoder_loglikelihood,
position_logprior=spherical_logprior,
momentum_logprior=riemannian_logprior,
G_inv=G_inv,
),
tempering_schedule::Function=quadratic_tempering,
return_outputs::Bool=false,
logp_prefactor::AbstractArray=ones(Float32, 3),
logq_prefactor::AbstractArray=ones(Float32, 3),
)
```

Compute the Riemannian Hamiltonian Monte Carlo (RHMC) estimate of the evidence lower bound (ELBO) for a Riemannian Hamiltonian Variational Autoencoder (RHVAE).

This function takes as input an RHVAE, a NamedTuple of metric parameters, and a vector of input data `x`

. It performs `K`

RHMC steps with a leapfrog integrator and a tempering schedule to estimate the ELBO. The ELBO is computed as the difference between the `log p̄`

and `log q̄`

as

elbo = mean(log p̄ - log q̄),

**Arguments**

`rhvae::RHVAE`

: The RHVAE used to encode the input data and decode the latent space.`metric_param::NamedTuple`

: The parameters used to compute the metric tensor.`x::AbstractArray`

: The input data. If`Array`

, the last dimension must contain each of the data points.

**Optional Keyword Arguments**

`ϵ::Union{<:Number,<:AbstractVector}`

: The step size for the leapfrog integrator (default is 0.01).`K::Int`

: The number of RHMC steps (default is 3).`βₒ::Number`

: The initial inverse temperature (default is 0.3).`steps::Int`

: The number of leapfrog steps (default is 3).`∇H_kwargs::Union{NamedTuple,Dict}`

: Additional keyword arguments to be passed to the`∇hamiltonian`

function. Defaults to a NamedTuple with`:decoder_loglikelihood`

set to`decoder_loglikelihood`

,`:position_logprior`

set to`spherical_logprior`

, and`:momentum_logprior`

set to`riemannian_logprior`

.`G_inv::Function`

: The function to compute the inverse of the Riemannian metric tensor. Defaults to`G_inv`

.`tempering_schedule::Function`

: The tempering schedule function used in the RHMC (default is`quadratic_tempering`

).`return_outputs::Bool`

: Whether to return the outputs of the RHVAE. Defaults to`false`

. NOTE: This is necessary to avoid computing the forward pass twice when computing the loss function with regularization.`logp_prefactor::AbstractArray`

: A 3-element array to scale the log likelihood, log prior of the latent variables, and log prior of the momentum variables. Default is an array of ones.`logq_prefactor::AbstractArray`

: A 3-element array to scale the log posterior of the initial latent variables, log prior of the initial momentum variables, and the tempering Jacobian term. Default is an array of ones.

**Returns**

`elbo::Number`

: The RHMC estimate of the ELBO. If`return_outputs`

is`true`

, also returns the outputs of the RHVAE.

```
riemannian_hamiltonian_elbo(
rhvae::RHVAE,
x::AbstractVector;
K::Int=3,
ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
βₒ::Number=0.3f0,
steps::Int=3,
∇H_kwargs::Union{NamedTuple,Dict}=(
reconstruction_loglikelihood=decoder_loglikelihood,
position_logprior=spherical_logprior,
momentum_logprior=riemannian_logprior,
G_inv=G_inv,
),
tempering_schedule::Function=quadratic_tempering,
return_outputs::Bool=false,
logp_prefactor::AbstractArray=ones(Float32, 3),
logq_prefactor::AbstractArray=ones(Float32, 3),
)
```

Compute the Riemannian Hamiltonian Monte Carlo (RHMC) estimate of the evidence lower bound (ELBO) for a Riemannian Hamiltonian Variational Autoencoder (RHVAE).

This function takes as input an RHVAE, a NamedTuple of metric parameters, and a vector of input data `x`

. It performs `K`

RHMC steps with a leapfrog integrator and a tempering schedule to estimate the ELBO. The ELBO is computed as the difference between the `log p̄`

and `log q̄`

as

elbo = mean(log p̄ - log q̄)

**Arguments**

`rhvae::RHVAE`

: The RHVAE used to encode the input data and decode the latent space.`x::AbstractVector`

: The input data.

**Optional Keyword Arguments**

`∇H_kwargs::Union{NamedTuple,Dict}`

: Additional keyword arguments to be passed to the`∇hamiltonian`

function. Defaults to a NamedTuple with`:decoder_loglikelihood`

set to`decoder_loglikelihood`

,`:position_logprior`

set to`spherical_logprior`

,`:momentum_logprior`

set to`riemannian_logprior`

, and`:G_inv`

set to`G_inv`

.`K::Int`

: The number of RHMC steps (default is 3).`ϵ::Union{<:Number,<:AbstractVector}`

: The step size for the leapfrog integrator (default is 0.001).`βₒ::Number`

: The initial inverse temperature (default is 0.3).`steps::Int`

: The number of leapfrog steps (default is 3).`G_inv::Function`

: The function to compute the inverse of the Riemannian metric tensor (default is`G_inv`

).`tempering_schedule::Function`

: The tempering schedule function used in the RHMC (default is`quadratic_tempering`

).`return_outputs::Bool`

: Whether to return the outputs of the RHVAE. Defaults to`false`

. NOTE: This is necessary to avoid computing the forward pass twice when computing the loss function with regularization.`logp_prefactor::AbstractArray`

: A 3-element array to scale the log likelihood, log prior of the latent variables, and log prior of the momentum variables. Default is an array of ones.`logq_prefactor::AbstractArray`

: A 3-element array to scale the log posterior of the initial latent variables, log prior of the initial momentum variables, and the tempering Jacobian term. Default is an array of ones.

**Returns**

`elbo::Number`

: The RHMC estimate of the ELBO. If`return_outputs`

is`true`

, also returns the outputs of the RHVAE.

```
riemannian_hamiltonian_elbo(
rhvae::RHVAE,
metric_param::NamedTuple,
x_in::AbstractArray,
x_out::AbstractArray;
ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
K::Int=3,
βₒ::Number=0.3f0,
steps::Int=3,
∇H_kwargs::Union{NamedTuple,Dict}=(
reconstruction_loglikelihood=decoder_loglikelihood,
position_logprior=spherical_logprior,
momentum_logprior=riemannian_logprior,
G_inv=G_inv,
),
tempering_schedule::Function=quadratic_tempering,
return_outputs::Bool=false,
logp_prefactor::AbstractArray=ones(Float32, 3),
logq_prefactor::AbstractArray=ones(Float32, 3),
)
```

Compute the Riemannian Hamiltonian Monte Carlo (RHMC) estimate of the evidence lower bound (ELBO) for a Riemannian Hamiltonian Variational Autoencoder (RHVAE).

This function takes as input an RHVAE, a NamedTuple of metric parameters, and a vector of input data `x`

. It performs `K`

RHMC steps with a leapfrog integrator and a tempering schedule to estimate the ELBO. The ELBO is computed as the difference between the `log p̄`

and `log q̄`

as

elbo = mean(log p̄ - log q̄),

**Arguments**

`rhvae::RHVAE`

: The RHVAE used to encode the input data and decode the latent space.`metric_param::NamedTuple`

: The parameters used to compute the metric tensor.`x_in::AbstractArray`

: Input data to the RHVAE encoder. The last dimension is taken as having each of the samples in a batch.`x_out::AbstractArray`

: Target data to compute the reconstruction error. The last dimension is taken as having each of the samples in a batch.

**Optional Keyword Arguments**

`ϵ::Union{<:Number,<:AbstractVector}`

: The step size for the leapfrog integrator (default is 0.01).`K::Int`

: The number of RHMC steps (default is 3).`βₒ::Number`

: The initial inverse temperature (default is 0.3).`steps::Int`

: The number of leapfrog steps (default is 3).`∇H_kwargs::Union{NamedTuple,Dict}`

: Additional keyword arguments to be passed to the`∇hamiltonian`

function. Defaults to a NamedTuple with`:decoder_loglikelihood`

set to`decoder_loglikelihood`

,`:position_logprior`

set to`spherical_logprior`

, and`:momentum_logprior`

set to`riemannian_logprior`

.`G_inv::Function`

: The function to compute the inverse of the Riemannian metric tensor. Defaults to`G_inv`

.`tempering_schedule::Function`

: The tempering schedule function used in the RHMC (default is`quadratic_tempering`

).`return_outputs::Bool`

: Whether to return the outputs of the RHVAE. Defaults to`false`

. NOTE: This is necessary to avoid computing the forward pass twice when computing the loss function with regularization.`logp_prefactor::AbstractArray`

: A 3-element array to scale the log likelihood, log prior of the latent variables, and log prior of the momentum variables. Default is an array of ones.`logq_prefactor::AbstractArray`

: A 3-element array to scale the log posterior of the initial latent variables, log prior of the initial momentum variables, and the tempering Jacobian term. Default is an array of ones.

**Returns**

`elbo::Number`

: The RHMC estimate of the ELBO. If`return_outputs`

is`true`

, also returns the outputs of the RHVAE.

```
riemannian_hamiltonian_elbo(
rhvae::RHVAE,
x_in::AbstractArray,
x_out::AbstractArray;
K::Int=3,
ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
βₒ::Number=0.3f0,
steps::Int=3,
∇H_kwargs::Union{NamedTuple,Dict}=(
reconstruction_loglikelihood=decoder_loglikelihood,
position_logprior=spherical_logprior,
momentum_logprior=riemannian_logprior,
G_inv=G_inv,
),
tempering_schedule::Function=quadratic_tempering,
return_outputs::Bool=false,
logp_prefactor::AbstractArray=ones(Float32, 3),
logq_prefactor::AbstractArray=ones(Float32, 3),
)
```

`x`

. It performs `K`

RHMC steps with a leapfrog integrator and a tempering schedule to estimate the ELBO. The ELBO is computed as the difference between the `log p̄`

and `log q̄`

as

elbo = mean(log p̄ - log q̄).

**Arguments**

`rhvae::RHVAE`

: The RHVAE used to encode the input data and decode the latent space.`x_in::AbstractArray`

: Input data to the RHVAE encoder. The last dimension is taken as having each of the samples in a batch.`x_out::AbstractArray`

: Target data to compute the reconstruction error. The last dimension is taken as having each of the samples in a batch.

**Optional Keyword Arguments**

`∇H_kwargs::Union{NamedTuple,Dict}`

: Additional keyword arguments to be passed to the`∇hamiltonian`

function. Defaults to a NamedTuple with`:decoder_loglikelihood`

set to`decoder_loglikelihood`

,`:position_logprior`

set to`spherical_logprior`

,`:momentum_logprior`

set to`riemannian_logprior`

, and`:G_inv`

set to`G_inv`

.`K::Int`

: The number of RHMC steps (default is 3).`ϵ::Union{<:Number,<:AbstractVector}`

: The step size for the leapfrog integrator (default is 0.001).`βₒ::Number`

: The initial inverse temperature (default is 0.3).`steps::Int`

: The number of leapfrog steps (default is 3).`G_inv::Function`

: The function to compute the inverse of the Riemannian metric tensor (default is`G_inv`

).`tempering_schedule::Function`

: The tempering schedule function used in the RHMC (default is`quadratic_tempering`

).`return_outputs::Bool`

: Whether to return the outputs of the RHVAE. Defaults to`false`

. NOTE: This is necessary to avoid computing the forward pass twice when computing the loss function with regularization.`logp_prefactor::AbstractArray`

: A 3-element array to scale the log likelihood, log prior of the latent variables, and log prior of the momentum variables. Default is an array of ones.`logq_prefactor::AbstractArray`

: A 3-element array to scale the log posterior of the initial latent variables, log prior of the initial momentum variables, and the tempering Jacobian term. Default is an array of ones.

**Returns**

`elbo::Number`

: The RHMC estimate of the ELBO. If`return_outputs`

is`true`

, also returns the outputs of the RHVAE.

## Default initializations

`AutoEncoderToolkit.jl`

provides default initializations for both the metric tensor network and the RHVAE. Although less flexible than defining your own initial networks, these can serve as a good starting point for your experiments.

`AutoEncoderToolkit.RHVAEs.MetricChain`

— Method```
MetricChain(
n_input::Int,
n_latent::Int,
metric_neurons::Vector{<:Int},
metric_activation::Vector{<:Function},
output_activation::Function;
init::Function=Flux.glorot_uniform
) -> MetricChain
```

Construct a `MetricChain`

for computing the Riemannian metric tensor in the latent space.

**Arguments**

`n_input::Int`

: The number of input features.`n_latent::Int`

: The dimension of the latent space.`metric_neurons::Vector{<:Int}`

: The number of neurons in each hidden layer of the MLP.`metric_activation::Vector{<:Function}`

: The activation function for each hidden layer of the MLP.`output_activation::Function`

: The activation function for the output layer.`init::Function`

: The initialization function for the weights in the layers (default is`Flux.glorot_uniform`

).

**Returns**

`MetricChain`

: A`MetricChain`

object that includes the MLP, and two dense layers for computing the elements of a lower-triangular matrix used to compute the Riemannian metric tensor in latent space.

`AutoEncoderToolkit.RHVAEs.RHVAE`

— Method```
RHVAE(
vae::VAE,
metric_chain::MetricChain,
centroids_data::AbstractArray,
T::Number,
λ::Number
)
```

Construct a Riemannian Hamiltonian Variational Autoencoder (RHVAE) from a standard VAE and a metric chain.

**Arguments**

`vae::VAE`

: A standard Variational Autoencoder (VAE) model.`metric_chain::MetricChain`

: A chain of metrics to be used for the Riemannian Hamiltonian Monte Carlo (RHMC) sampler.`centroids_data::AbstractArray`

: An array of data centroids. Each column represents a centroid.`N`

is a subtype of`Number`

.`T::N`

: The temperature parameter for the inverse metric tensor.`N`

is a subtype of`Number`

.`λ::N`

: The regularization parameter for the inverse metric tensor.`N`

is a subtype of`Number`

.

**Returns**

- A new
`RHVAE`

object.

**Description**

The constructor initializes the latent centroids and the metric tensor `M`

to their default values. The latent centroids are initialized to a zero matrix of the same size as `centroids_data`

, and `M`

is initialized to a 3D array of identity matrices, one for each centroid.