# β-Variational Autoencoder

Variational Autoencoders, first introduced by Kingma and Welling in 2014, are a type of generative model that learns to encode high-dimensional data into a low-dimensional latent space. The main idea behind VAEs is to learn a probabilistic mapping (via variational inference) from the input data to the latent space, which allows for the generation of new data points by sampling from the latent space.

Their counterpart, the β-VAE, introduced by Higgins et al. in 2017, is a variant of the original VAE that includes a hyperparameter `β`

that controls the relative importance of the reconstruction loss and the KL divergence term in the loss function. By adjusting `β`

, the user can control the trade-off between the reconstruction quality and the disentanglement of the latent space.

In terms of implementation, the `VAE`

struct in `AutoEncoderToolkit.jl`

is a simple feedforward network composed of variational encoder and decoder parts. This means that the encoder has a log-posterior function and a KL divergence function associated with it, while the decoder has a log-likehood function associated with it.

## References

### VAE

Kingma, D. P. & Welling, M. Auto-Encoding Variational Bayes. Preprint at http://arxiv.org/abs/1312.6114 (2014).

### β-VAE

Higgins, I. et al. β-VAE: LEARNING BASIC VISUAL CONCEPTS WITH A CONSTRAINED VARIATIONAL FRAMEWORK. (2017).

`VAE`

struct

`AutoEncoderToolkit.VAEs.VAE`

— Type`struct VAE{E<:AbstractVariationalEncoder, D<:AbstractVariationalDecoder}`

Variational autoencoder (VAE) model defined for `Flux.jl`

**Fields**

`encoder::E`

: Neural network that encodes the input into the latent space.`E`

is a subtype of`AbstractVariationalEncoder`

.`decoder::D`

: Neural network that decodes the latent representation back to the original input space.`D`

is a subtype of`AbstractVariationalDecoder`

.

A VAE consists of an encoder and decoder network with a bottleneck latent space in between. The encoder compresses the input into a low-dimensional probabilistic representation q(z|x). The decoder tries to reconstruct the original input from a sampled point in the latent space p(x|z).

## Forward pass

`AutoEncoderToolkit.VAEs.VAE`

— Method` (vae::VAE)(x::AbstractArray; latent::Bool=false)`

Perform the forward pass of a Variational Autoencoder (VAE).

This function takes as input a VAE and a vector or matrix of input data `x`

. It first runs the input through the encoder to obtain the mean and log standard deviation of the latent variables. It then uses the reparameterization trick to sample from the latent distribution. Finally, it runs the latent sample through the decoder to obtain the output.

**Arguments**

`vae::VAE`

: The VAE used to encode the input data and decode the latent space.`x::AbstractArray`

: The input data. If array, the last dimension contains each of the samples in a batch.

**Optional Keyword Arguments**

`latent::Bool`

: Whether to return the latent variables along with the decoder output. If`true`

, the function returns a tuple containing the encoder outputs, the latent sample, and the decoder outputs. If`false`

, the function only returns the decoder outputs. Defaults to`false`

.

**Returns**

- If
`latent`

is`true`

, returns a tuple containing:`encoder`

: The outputs of the encoder.`z`

: The latent sample.`decoder`

: The outputs of the decoder.

- If
`latent`

is`false`

, returns the outputs of the decoder.

**Example**

```
# Define a VAE
vae = VAE(
encoder=Flux.Chain(Flux.Dense(784, 400, relu), Flux.Dense(400, 20)),
decoder=Flux.Chain(Flux.Dense(20, 400, relu), Flux.Dense(400, 784))
)
# Define input data
x = rand(Float32, 784)
# Perform the forward pass
outputs = vae(x, latent=true)
```

## Loss function

`AutoEncoderToolkit.VAEs.loss`

— Function```
loss(
vae::VAE,
x::AbstractArray;
β::Number=1.0f0,
reconstruction_loglikelihood::Function=decoder_loglikelihood,
kl_divergence::Function=encoder_kl,
reg_function::Union{Function,Nothing}=nothing,
reg_kwargs::NamedTuple=NamedTuple(),
reg_strength::Number=1.0f0
)
```

Computes the loss for the variational autoencoder (VAE).

The loss function combines the reconstruction loss with the Kullback-Leibler (KL) divergence, and possibly a regularization term, defined as:

loss = -⟨logπ(x|z)⟩ + β × Dₖₗ[qᵩ(z|x) || π(z)] + reg*strength × reg*term

Where:

- π(x|z) is a probabilistic decoder: π(x|z) = N(f(z), σ² I̲̲)) - f(z) is the function defining the mean of the decoder π(x|z) - qᵩ(z|x) is the approximated encoder: qᵩ(z|x) = N(g(x), h(x))
- g(x) and h(x) define the mean and covariance of the encoder respectively.

**Arguments**

`vae::VAE`

: A VAE model with encoder and decoder networks.`x::AbstractArray`

: Input data. The last dimension is taken as having each of the samples in a batch.

**Optional Keyword Arguments**

`β::Number=1.0f0`

: Weighting factor for the KL-divergence term, used for annealing.`reconstruction_loglikelihood::Function=decoder_loglikelihood`

: A function that computes the reconstruction log likelihood.`kl_divergence::Function=encoder_kl`

: A function that computes the Kullback-Leibler divergence between the encoder output and a standard normal.`reg_function::Union{Function, Nothing}=nothing`

: A function that computes the regularization term based on the VAE outputs. Should return a Float32. This function must take as input the VAE outputs and the keyword arguments provided in`reg_kwargs`

.`reg_kwargs::NamedTuple=NamedTuple()`

: Keyword arguments to pass to the regularization function.`reg_strength::Number=1.0f0`

: The strength of the regularization term.

**Returns**

`T`

: The computed average loss value for the input`x`

and its reconstructed counterparts, including possible regularization terms.

**Note**

- Ensure that the input data
`x`

matches the expected input dimensionality for the encoder in the VAE.

```
loss(
vae::VAE,
x_in::AbstractArray,
x_out::AbstractArray;
β::Number=1.0f0,
reconstruction_loglikelihood::Function=decoder_loglikelihood,
kl_divergence::Function=encoder_kl,
reg_function::Union{Function,Nothing}=nothing,
reg_kwargs::NamedTuple=NamedTuple(),
reg_strength::Number=1.0f0
)
```

Computes the loss for the variational autoencoder (VAE).

The loss function combines the reconstruction loss with the Kullback-Leibler (KL) divergence and possibly a regularization term, defined as:

loss = -⟨logπ(x*out|z)⟩ + β × Dₖₗ[qᵩ(z|x*in) || π(z)] + reg*strength × reg*term

Where:

- π(x
*out|z) is a probabilistic decoder: π(x*out|z) = N(f(z), σ² I̲̲)) - f(z) is

the function defining the mean of the decoder π(x*out|z) - qᵩ(z|x*in) is the approximated encoder: qᵩ(z|x*in) = N(g(x*in), h(x_in))

- g(x
*in) and h(x*in) define the mean and covariance of the encoder respectively.

**Arguments**

`vae::VAE`

: A VAE model with encoder and decoder networks.`x_in::AbstractArray`

: Input data to the VAE encoder. The last dimension is taken as having each of the samples in a batch.`x_out::AbstractArray`

: Target data to compute the reconstruction error. The last dimension is taken as having each of the samples in a batch.

**Optional Keyword Arguments**

`β::Number=1.0f0`

: Weighting factor for the KL-divergence term, used for annealing.`reconstruction_loglikelihood::Function=decoder_loglikelihood`

: A function that computes the reconstruction log likelihood.`kl_divergence::Function=encoder_kl`

: A function that computes the Kullback-Leibler divergence.`reg_function::Union{Function, Nothing}=nothing`

: A function that computes the regularization term based on the VAE outputs. Should return a Float32. This function must take as input the VAE outputs and the keyword arguments provided in`reg_kwargs`

.`reg_kwargs::NamedTuple=NamedTuple()`

: Keyword arguments to pass to the regularization function.`reg_strength::Number=1.0f0`

: The strength of the regularization term.

**Returns**

`T`

: The computed average loss value for the input`x_in`

and its reconstructed counterparts`x_out`

, including possible regularization terms.

**Note**

- Ensure that the input data
`x_in`

and`x_out`

match the expected input dimensionality for the encoder in the VAE.

The `loss`

function includes the `β`

optional argument that can turn a vanilla VAE into a β-VAE by changing the default value of `β`

from `1.0`

to any other value.

## Training

`AutoEncoderToolkit.VAEs.train!`

— Function`train!(vae, x, opt; loss_function, loss_kwargs, verbose, loss_return)`

Customized training function to update parameters of a variational autoencoder given a specified loss function.

**Arguments**

`vae::VAE`

: A struct containing the elements of a variational autoencoder.`x::AbstractArray`

: Data on which to evaluate the loss function. The last dimension is taken as having each of the samples in a batch.`opt::NamedTuple`

: State of the optimizer for updating parameters. Typically initialized using`Flux.Train.setup`

.

**Optional Keyword Arguments**

`loss_function::Function=loss`

: The loss function used for training. It should accept the VAE model, data`x`

, and keyword arguments in that order.`loss_kwargs::NamedTuple=NamedTuple()`

: Arguments for the loss function. These might include parameters like`σ`

, or`β`

, depending on the specific loss function in use.`verbose::Bool=false`

: If true, the loss value will be printed during training.`loss_return::Bool=false`

: If true, the loss value will be returned after training.

**Description**

Trains the VAE by:

- Computing the gradient of the loss w.r.t the VAE parameters.
- Updating the VAE parameters using the optimizer.

```
`train!(
vae, x_in, x_out, opt;
loss_function, loss_kwargs, verbose, loss_return
)`
```

Customized training function to update parameters of a variational autoencoder given a loss function.

**Arguments**

`vae::VAE`

: A struct containing the elements of a variational autoencoder.`x_in::AbstractArray`

: Input data for the loss function. Represents an individual sample. The last dimension is taken as having each of the samples in a batch.`x_out::AbstractArray`

: Target output data for the loss function. Represents the corresponding output for the`x_in`

sample. The last dimension is taken as having each of the samples in a batch.`opt::NamedTuple`

: State of the optimizer for updating parameters. Typically initialized using`Flux.Optimisers.update!`

.

**Optional Keyword Arguments**

`loss_function::Function=loss`

: The loss function used for training. It should accept the VAE model, data`x_in`

,`x_out`

, and keyword arguments in that order.`loss_kwargs::NamedTuple=NamedTuple()`

: Arguments for the loss function. These might include parameters like`σ`

, or`β`

, depending on the specific loss function in use.`verbose::Bool=false`

: Whether to print the loss value after each training step.`loss_return::Bool=false`

: Whether to return the loss value after each training step.

**Description**

Trains the VAE by:

- Computing the gradient of the loss w.r.t the VAE parameters.
- Updating the VAE parameters using the optimizer.

**Examples**

```
opt = Flux.setup(Optax.adam(1e-3), vae)
for (x_in, x_out) in dataloader
train!(vae, x_in, x_out, opt)
end
```