Bayesian Neural Networks

Bayesian Neural Networks


Bayesian neural networks are different from plain neural networks in that weights and biases in Bayesian neural networks are interpreted in a probabilistic manner. Instead of finding a point estimation of weights and biases, in Bayesian neural networks, a prior distribution is assigned to the weights and biases, and a posterior distribution is obtained from the data. It relies on the Bayes formula

\[p(w|\mathcal{D}) = \frac{p(\mathcal{D}|w)p(w)}{p(\mathcal{D})}\]

Here $\mathcal{D}$ is the data, e.g., the input-output pairs of the neural network $\{(x_i, y_i)\}$, $w$ is the weights and biases of the neural network, and $p(w)$ is the prior distribution.

If we have a full posterior distribution $p(w|\mathcal{D})$, we can conduct predictive modeling using

\[p(y|x, \mathcal{D}) = \int p(y|x, w) p(w|\mathcal{D})d w\]

However, computing $p(w|\mathcal{D})$ is usually intractable since we need to compute the normalized factor $p(\mathcal{D}) = \int p(\mathcal{D}|w)p(w) dw$, which requires us to integrate over all possible $w$. Traditionally, Markov chain Monte Carlo (MCMC) has been used to sample from $p(w|\mathcal{D})$ without evaluating $p(\mathcal{D})$. However, MCMC can converge very slowly and requires a voluminous number of sampling, which can be quite expensive.

Variational Inference

In Bayesian neural networks, the idea is to approximate $p(w|\mathcal{D})$ using a parametrized family $p(w|\theta)$, where $\theta$ is the parameters. This method is called variational inference. We minimize the KL divergeence between the true posterior and the approximate posterial to find the optimal $\theta$

\[\text{KL}(p(w|\theta)||p(w|\mathcal{D})) = \text{KL}(p(w|\theta)||p(W)) - \mathbb{E}_{p(w|\theta)}\log p(\mathcal{D}|w) + \log p(\mathcal{D})\]

Evaluating $p(\mathcal{D})\geq 0$ is intractable, so we seek to minimize a lower bound of the KL divergence, which is known as variational free energy

\[F(\mathcal{D}, \theta) = \text{KL}(p(w|\theta)||p(w)) - \mathbb{E}_{p(w|\theta)}\log p(\mathcal{D}|w)\]

In practice, thee variational free energy is approximated by the discrete samples

\[F(\mathcal{D}, \theta) \approx \frac{1}{N}\sum_{i=1}^N \left[\log p(w_i|\theta)) - \log p(w_i) - \log p(\mathcal{D}|w_i)\right]\]

Parametric Family

In Baysian neural networks, the parametric family is usually chosen to be the Gaussian distribution. For the sake of automatic differentiation, we usually parametrize $w$ using

\[w = \mu + \sigma \otimes z\qquad z \sim \mathcal{N}(0, I) \tag{1}\]

Here $\theta = (\mu, \sigma)$. The prior distributions for $\mu$ and $\sigma$ are given as hyperparameters. For example, we can use a mixture of Gaussians as prior

\[\pi_1 \mathcal{N}(0, \sigma_1) + \pi_2 \mathcal{N}(0, \sigma_2)\]

The advantage of Equation 1 is that we can easily obtain the log probability $\log p(w|\theta)$.

Because $\sigma$ should always be positive, we can instead parametrize another parameter $\rho$ and transform $\rho$ to $\sigma$ using a softplus function

\[\sigma = \log (1+\exp(\rho))\]


Now let us consider a concrete example. The following example is adapted from this post.

Generating Training Data

We first generate some 1D training data

using ADCME
using PyPlot 
using ProgressMeter
using Statistics

function f(x, σ)
    ε = randn(size(x)...) * σ
    return 10 * sin.(2π*x) + ε

batch_size = 32
noise = 1.0

X = reshape(LinRange(-0.5, 0.5, batch_size)|>Array, :, 1)
y = f(X, noise)
y_true = f(X, 0.0)

scatter(X, y, marker="+", label="Training Data")
plot(X, y_true, label="Truth")

Construct Bayesian Neural Network

mutable struct VariationalLayer

function VariationalLayer(units; activation=relu, prior_σ1=1.5, prior_σ2=0.1,
    init_σ = sqrt(
        prior_π1 * prior_σ1^2 + (1-prior_π1)*prior_σ2^2
    VariationalLayer(units, activation, prior_σ1, prior_σ2, prior_π1, 1-prior_π1,
                        missing, missing, missing, missing, init_σ)

function kl_loss(vl, w, μ, σ)
    dist = ADCME.Normal(μ,σ)
    return sum(logpdf(dist, w)-logprior(vl, w))

function logprior(vl, w)
    dist1 = ADCME.Normal(constant(0.0), vl.prior_σ1)
    dist2 = ADCME.Normal(constant(0.0), vl.prior_σ2)
    log(vl.prior_π1*exp(logpdf(dist1, w)) + vl.prior_π2*exp(logpdf(dist2, w)))

function (vl::VariationalLayer)(x)
    x = constant(x)
    if ismissing(vl.bμ)
        vl.Wμ = get_variable(vl.init_σ*randn(size(x,2), vl.units))
        vl.Wρ = get_variable(zeros(size(x,2), vl.units))
        vl.bμ = get_variable(vl.init_σ*randn(1, vl.units))
        vl.bρ = get_variable(zeros(1, vl.units))
    Wσ = softplus(vl.Wρ)
    W = vl.Wμ + Wσ.*normal(size(vl.Wμ)...) 
    bσ = softplus(vl.bρ)
    b = vl.bμ + bσ.*normal(size(vl.bμ)...)
    loss = kl_loss(vl, W, vl.Wμ, Wσ) + kl_loss(vl, b, vl.bμ, bσ)
    out = vl.activation(x * W + b)
    return out, loss 

function neg_log_likelihood(y_obs, y_pred, σ)
    y_obs = constant(y_obs)
    dist = ADCME.Normal(y_pred, σ)
    sum(-logpdf(dist, y_obs))

ipt = placeholder(X)
x, loss1 = VariationalLayer(20, activation=relu)(ipt)
x, loss2 = VariationalLayer(20, activation=relu)(x)
x, loss3 = VariationalLayer(1, activation=x->x)(x)

loss_lf = neg_log_likelihood(y, x, noise)
loss = loss1 + loss2 + loss3 + loss_lf


We use an ADAM optimizer to optimize the loss function. In this case, quasi-Newton methods that are typically used for deterministic function optimization are not appropriate because the loss function essentially involves stochasticity.

Another caveat is that because the neural network may have many local minimum, we need to run the optimizer multiple times in order to obtain a good local minimum.

opt = AdamOptimizer(0.08).minimize(loss)
sess = Session(); init(sess)
@showprogress for i = 1:5000
    run(sess, opt)

X_test = reshape(LinRange(-1.5,1.5,32)|>Array, :, 1)
y_pred_list = []
@showprogress for i = 1:10000
    y_pred = run(sess, x, ipt=>X_test)
    push!(y_pred_list, y_pred)

y_preds = hcat(y_pred_list...)

y_mean = mean(y_preds, dims=2)[:]
y_std = std(y_preds, dims=2)[:]

plot(X_test, y_mean)
scatter(X[:], y[:], marker="+")
fill_between(X_test[:], y_mean-2y_std, y_mean+2y_std, alpha=0.5)