What is Batch Normalization in CNN?
Last Updated :
13 May, 2024
Batch Normalization is a technique used to improve the training and performance of neural networks, particularly CNNs. The article aims to provide an overview of batch normalization in CNNs along with the implementation in PyTorch and TensorFlow.
Overview of Batch Normalization
Batch normalization is a technique to improve the training of deep neural networks by stabilizing and accelerating the learning process. Introduced by Sergey Ioffe and Christian Szegedy in 2015, it addresses the issue known as "internal covariate shift" where the distribution of each layer's inputs changes during training, as the parameters of the previous layers change.
Need for Batch Normalization in CNN model
Batch Normalization in CNN addresses several challenges encountered during training. There are following reasons highlight the need for batch normalization in CNN:
- Addressing Internal Covariate Shift: Internal covariate shift occurs when the distribution of network activations changes as parameters are updated during training. Batch normalization addresses this by normalizing the activations in each layer, maintaining consistent mean and variance across inputs throughout training. This stabilizes training and speeds up convergence.
- Improving Gradient Flow: Batch normalization contributes to stabilizing the gradient flow during backpropagation by reducing the reliance of gradients on parameter scales. As a result, training becomes faster and more stable, enabling effective training of deeper networks without facing issues like vanishing or exploding gradients.
- Regularization Effect: During training, batch normalization introduces noise to the network activations, serving as a regularization technique. This noise aids in averting overfitting by injecting randomness and decreasing the network's sensitivity to minor fluctuations in the input data.
How Does Batch Normalization Work in CNN?
Batch normalization works in convolutional neural networks (CNNs) by normalizing the activations of each layer across mini-batch during training. The working is discussed below:
1. Normalization within Mini-Batch
In a CNN, each layer receives inputs from multiple channels (feature maps) and processes them through convolutional filters. Batch Normalization operates on each feature map separately, normalizing the activations across the mini-batch.
During training, batch normalization (BN) standardizes the activations of each layer by subtracting the mean and dividing by the standard deviation of each mini-batch.
- Mean Calculation: μ_B = \frac{1}{m}\sum_{i=1}^{m}{x_i}
- Variance Calculation: \sigma_{B}^{2} = \frac{1}{m} \sum_{i=1}^{m}{(x_i - \mu_B)^2}
- Normalization: \widehat{x_i} = \frac{x_i - \mu_B}{\sqrt{\sigma_{B}^{2} + \epsilon}}
2. Scaling and Shifting
After normalization, BN adjusts the normalized activations using learned scaling and shifting parameters. These parameters enable the network to adaptively scale and shift the activations, thereby maintaining the network's ability to represent complex patterns in the data.
- Scaling: \gamma \widehat{x_i}
- Shifting: z_i = y_i + \beta
3. Learnable Parameters
The parameters \gamma and \beta are learned during training through backpropagation. This allows the network to adaptively adjust the normalization and ensure that the activations are in the appropriate range for learning.
4. Applying Batch Normalization
Batch Normalization is typically applied after the convolutional and activation layers in a CNN, before passing the outputs to the next layer. It can also be applied before or after the activation function, depending on the network architecture.
5. Training and Inference
During training, Batch Normalization calculates the mean and variance of each mini-batch. During inference (testing), it uses the aggregated mean and variance calculated during training to normalize the activations. This ensures consistent normalization between training and inference.
Applying Batch Normalization in CNN model using TensorFlow
In this section, we have provided a pseudo code, to illustrate how can we apply batch normalization in CNN model using TensorFlow. For applying batch normalization layers after the convolutional layers and before the activation functions, we use 'tf.keras.layers.BatchNormalization()'.
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, BatchNormalization
# Build the CNN model
model = Sequential([
Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
BatchNormalization(), # Add batch normalization layer
MaxPooling2D((2, 2)),
Conv2D(64, (3, 3), activation='relu'),
BatchNormalization(), # Add batch normalization layer
MaxPooling2D((2, 2)),
Flatten(),
Dense(64, activation='relu'),
Dense(10, activation='softmax')
])
Applying Batch Normalization in CNN model using PyTorch
In PyTorch, we can easily apply batch normalization in a CNN model.
For applying BN in 1D Convolutional Neural Network model, we use 'nn.BatchNorm1d()'.
import torch
import torch.nn as nn
class CNN1D(nn.Module):
def __init__(self):
super(CNN1D, self).__init__()
self.conv1 = nn.Conv1d(3, 16, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm1d(16)
self.conv2 = nn.Conv1d(16, 32, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm1d(32)
self.fc = nn.Linear(32 * 28, 10) # Example fully connected layer
def forward(self, x):
x = torch.relu(self.bn1(self.conv1(x)))
x = torch.relu(self.bn2(self.conv2(x)))
x = x.view(-1, 32 * 28) # Reshape for fully connected layer
x = self.fc(x)
return x
# Instantiate the model
model = CNN1D()
For applying Batch Normalization in 2D Convolutional Neural Network model, we use 'nn.BatchNorm2d()'.
import torch
import torch.nn as nn
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(32)
self.fc = nn.Linear(32 * 28 * 28, 10) # Example fully connected layer
def forward(self, x):
x = torch.relu(self.bn1(self.conv1(x)))
x = torch.relu(self.bn2(self.conv2(x)))
x = x.view(-1, 32 * 28 * 28) # Reshape for fully connected layer
x = self.fc(x)
return x
# Instantiate the model
model = CNN()
For more detailed explanation regarding the implementation, refer to
Advantages of Batch Normalization in CNN
- Fast Convergence
- Improved generalization
- reduced sensitivity
- Higher learning rates
- Improvement in model accuracy
Conclusion
In conclusion, batch normalization stands as a pivotal technique in enhancing the training and performance of convolutional neural networks (CNNs). Its implementation addresses critical challenges such as internal covariate shift, thereby stabilizing training, accelerating convergence, and facilitating deeper network architectures.
Similar Reads
What is Batch Normalization In Deep Learning?
Batch Normalization is used to reduce the problem of internal covariate shift in neural networks. It works by normalizing the data within each mini-batch. This means it calculates the mean and variance of data in a batch and then adjusts the values so that they have similar range. After that it scal
4 min read
What is Group Normalization?
Group Normalization (GN) is a technique introduced by Yuxin Wu and Kaiming He in 2018. It addresses some of the limitations posed by Batch Normalization, especially when dealing with small batch sizes that are common in high-resolution images or video processing tasks. Unlike Batch Normalization, wh
4 min read
What is Layer Normalization?
Layer Normalization stabilizes and accelerates the training process in deep learning. In typical neural networks, activations of each layer can vary drastically which leads to issues like exploding or vanishing gradients which slow down training. Layer Normalization addresses this by normalizing the
5 min read
Batch Normalization Implementation in PyTorch
Batch Normalization (BN) is a critical technique in the training of neural networks, designed to address issues like vanishing or exploding gradients during training. In this tutorial, we will implement batch normalization using PyTorch framework. Table of Content What is Batch Normalization?How Bat
7 min read
How to Effectively Use Batch Normalization in LSTM?
Batch Normalization (BN) has revolutionized the training of deep neural networks by normalizing input data across batches, stabilizing the learning process, and allowing faster convergence. While BN is widely used in feedforward neural networks, its application to recurrent neural networks (RNNs) li
8 min read
Instance Normalization vs Batch Normalization
Instance normalization and batch normalization are techniques used to make machine learning models train better by normalizing data, but they work differently. Instance normalization normalizes each input individually focusing only on its own features. This is more like giving personalized feedback
5 min read
Data Normalization Machine Learning
Normalization is an essential step in the preprocessing of data for machine learning models, and it is a feature scaling technique. Normalization is especially crucial for data manipulation, scaling down, or up the range of data before it is utilized for subsequent stages in the fields of soft compu
9 min read
Normalization and Scaling
Normalization and Scaling are two fundamental preprocessing techniques when you perform data analysis and machine learning. They are useful when you want to rescale, standardize or normalize the features (values) through distribution and scaling of existing data that make your machine learning model
9 min read
What is Zero Mean and Unit Variance Normalization
Answer: Zero Mean and Unit Variance normalization rescale data to have a mean of zero and a standard deviation of one.Explanation:Mean Centering: The first step of Zero Mean normalization involves subtracting the mean value of each feature from all data points. This centers the data around zero, mea
2 min read
What is Standardization in Machine Learning
In Machine Learning we train our data to predict or classify things in such a manner that isn't hardcoded in the machine. So for the first, we have the Dataset or the input data to be pre-processed and manipulated for our desired outcomes. Any ML Model to be built follows the following procedure: Co
6 min read