Variational Inference in Bayesian Neural Networks
Last Updated :
30 Jun, 2025
Bayesian Neural Networks (BNNs) extend traditional neural networks by treating weights as probability distributions rather than fixed values. This approach quantifies uncertainty and avoids overfitting. Variational Inference (VI) provides a scalable method to approximate the intractable posterior distribution of these weights.
- Traditional Neural Networks: Each weight has a single fixed value (point estimate).
- Bayesian Neural Networks: Each weight is treated as a probability distribution, representing uncertainty about its true value.
Why Use Variational Inference?
- Challenge: Computing the exact "posterior" distribution over all weights (what we believe about weights after seeing the data) is mathematically intractable for neural networks.
- Solution: Variational Inference (VI) approximates this complex posterior with a simpler, easy-to-handle distribution, usually a Gaussian.
How Does Variational Inference Work in BNNs?
Choose a Simple Distribution: Pick a family of distributions (e.g., diagonal Gaussian) to approximate the true posterior over weights. Each weight now has a mean and standard deviation, not just a single value.
Optimization Objective: Instead of maximizing likelihood (as in standard neural nets), VI maximizes a new objective that balances two things:
- Fit to Data: How well the network explains the observed data (like usual training).
- Closeness to Prior: How close the chosen distribution is to a prior belief about weights (regularization).
Gradient-Based Training: VI uses gradient descent, just like regular neural networks, but updates both the means and standard deviations of the weight distributions.
Prediction: At test time, predictions are made by averaging over several samples of weights from the learned distribution, capturing model uncertainty.
Key Points
- Posterior Consistency: Under certain conditions, the variational approximation will concentrate around the true solution as data increases.
- Trade-off: VI must balance fitting the data and staying close to the prior, especially important in large (overparameterized) networks.
- Choice of Approximation: Simpler distributions (like independent Gaussians) are easier to train but may not capture all uncertainty; more complex ones (like normalizing flows) can be more accurate but harder to optimize.
Practical Implementation of Variational Inference in BNNs
Main Formula (ELBO)
\text{ELBO} = \mathbb{E}_{q(\theta)} \left[ \log p(\text{data} \mid \theta) \right] - \mathrm{KL}\left( q(\theta) \,\|\, p(\theta) \right)
where
- q(\theta) : The variational (approximate) posterior distribution over the network weights \theta (what we’re learning).
- p(\text{data}|\theta) : The likelihood how likely the observed data is, given weights \theta (model fit).
- p(\theta) : The prior distribution over weights (our initial belief, e.g., a standard normal distribution).
- \mathbb{E}_{q(\theta)}[\log p(\text{data}|\theta)]: The expected log-likelihood encourages the model to fit the data.
- \text{KL}(q(\theta) | p(\theta)): The Kullback-Leibler divergence regularizes q(\theta) to stay close to p(\theta).
Practical Training Steps
1. Choose Priors and Variational Family: Set p(\theta) (e.g., \mathcal{N}(0, 1) for each weight).
Choose q(\theta) (e.g., a Gaussian with learnable mean and variance per weight).
2. Sample Weights: For each mini-batch, sample weights $\theta$ from q(\theta).
3. Compute Expected Log-Likelihood: \mathbb{E}_{q(\theta)} \left[ \log p(\text{data} \mid \theta) \right] \approx \frac{1}{S} \sum_{s=1}^{S} \log p(\text{data} \mid \theta^{(s)})
- S = number of samples,
- \theta^{(s)} = s-th sample from q(\theta).
4. Compute KL Divergence: \mathrm{KL}\left(q(\theta)\,\|\,p(\theta)\right)
For Gaussians, this has a closed-form expression.
5. Optimize ELBO: Use stochastic gradient descent (SGD/Adam) to maximize ELBO (or equivalently, minimize \text{ELBO}).
Advantages
- Uncertainty Quantification: BNNs can say how confident they are in their predictions useful for safety tasks or when data is scarce.
- Regularization: The prior acts as a built-in regularizer, helping prevent overfitting.
- Scalability: VI allows Bayesian ideas to be used in deep learning at scale, since it works with standard training tools and hardware.