How to handle overfitting in PyTorch models using Early Stopping
Last Updated :
10 Dec, 2024
Overfitting is a challenge in machine learning, where a model performs well on training data but poorly on unseen data, due to learning excessive noise or details from the training dataset.
In the context of deep learning with PyTorch, one effective method to combat overfitting is implementing early stopping. This article explains how early stopping works, demonstrates how to implement it in PyTorch, and explores its benefits and considerations.
What is Early Stopping?
Early stopping is a regularization technique used to avoid overfitting during the training process. It involves stopping the training phase if the model's performance on a validation set does not improve for a specified number of consecutive epochs, called the "patience" period. This ensures the model does not learn the noise and specific details of the training data, thereby enhancing its generalization capabilities.
Benefits of Early Stopping
- Prevents Overfitting: By halting training at the right time, early stopping ensures the model does not overfit.
- Saves Time and Resources: It reduces unnecessary training time and computational resources by stopping the training early.
- Optimizes Model Performance: Helps in selecting the version of the model that performs best on unseen data.
Steps needed to Implement Early Stopping in PyTorch
In this section, we are going to walk through the process of creating, training, and evaluating a simple neural network using PyTorch, focusing on the implementation of early stopping to prevent overfitting.
Step 1: Import Libraries
First, we import the necessary libraries:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
Step 2: Define the Neural Network Architecture
Next, we define a simple neural network class using PyTorch's nn.Module
. The neural network has:
- fc1, fc2, fc3: Fully connected layers with ReLU activations.
- forward method: Defines the forward pass of the network
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = torch.flatten(x, 1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
Step 3: Implement Early Stopping
We implement an EarlyStopping
class to halt training if the validation loss stops improving. Here the parameters are:
- patience: Number of epochs to wait before stopping if no improvement.
- delta: Minimum change in the monitored quantity to qualify as an improvement.
- best_score, best_model_state: Track the best validation score and model state.
- call method: Updates the early stopping logic.
class EarlyStopping:
def __init__(self, patience=5, delta=0):
self.patience = patience
self.delta = delta
self.best_score = None
self.early_stop = False
self.counter = 0
self.best_model_state = None
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.best_model_state = model.state_dict()
elif score < self.best_score + self.delta:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.best_model_state = model.state_dict()
self.counter = 0
def load_best_model(self, model):
model.load_state_dict(self.best_model_state)
Step 4: Load the Data
We load and transform the MNIST dataset.
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
Step 5: Initialize the Model, Loss Function, and Optimizer
We set up the model, criterion, and optimizer.
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
Step 6: Train the Model with Early Stopping
We train the model, incorporating early stopping.
Here,
- Train loop: Train the model, update weights, and calculate training loss.
- Validation loop: Evaluate the model on validation data and calculate validation loss.
- Early stopping check: Apply early stopping logic after each epoch.
# Training loop
num_epochs = 100
for epoch in range(num_epochs):
model.train()
train_loss = 0
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item() * data.size(0)
train_loss /= len(train_loader.dataset)
# Validation step (using validation set, not test set)
model.eval()
val_loss = 0
with torch.no_grad():
for data, target in val_loader: # Changed from test_loader to val_loader
output = model(data)
loss = criterion(output, target)
val_loss += loss.item() * data.size(0)
val_loss /= len(val_loader.dataset)
print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
early_stopping(val_loss, model)
if early_stopping.early_stop:
print("Early stopping")
break
early_stopping.load_best_model(model)
Step 7: Evaluate the Model
Finally, we evaluate the model's accuracy on the test dataset. The evaluation loop computes the accuracy by comparing predicted labels with true labels.
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
outputs = model(data)
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')
Building and Training a Simple Neural Network with Early Stopping in PyTorch
Python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import numpy as np
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = torch.flatten(x, 1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
class EarlyStopping:
def __init__(self, patience=5, delta=0):
self.patience = patience
self.delta = delta
self.best_score = None
self.early_stop = False
self.counter = 0
self.best_model_state = None
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.best_model_state = model.state_dict()
elif score < self.best_score + self.delta:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.best_model_state = model.state_dict()
self.counter = 0
def load_best_model(self, model):
model.load_state_dict(self.best_model_state)
# Data loading
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# Split the training dataset into training and validation sets
train_size = int(0.8 * len(train_dataset)) # 80% for training
val_size = len(train_dataset) - train_size # 20% for validation
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# Model, loss function, and optimizer
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Early stopping
early_stopping = EarlyStopping(patience=5, delta=0.01)
# Training loop
num_epochs = 100
for epoch in range(num_epochs):
model.train()
train_loss = 0
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item() * data.size(0)
train_loss /= len(train_loader.dataset)
# Validation step (using validation set, not test set)
model.eval()
val_loss = 0
with torch.no_grad():
for data, target in val_loader:
output = model(data)
loss = criterion(output, target)
val_loss += loss.item() * data.size(0)
val_loss /= len(val_loader.dataset)
print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
early_stopping(val_loss, model)
if early_stopping.early_stop:
print("Early stopping")
break
# Load the best model
early_stopping.load_best_model(model)
# Final evaluation on the test set
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
outputs = model(data)
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')
Output:
Epoch 1, Train Loss: 0.4373, Val Loss: 0.2750
Epoch 2, Train Loss: 0.2244, Val Loss: 0.1835
Epoch 3, Train Loss: 0.1617, Val Loss: 0.1441
.
.
.
Epoch 14, Train Loss: 0.0445, Val Loss: 0.1036
Epoch 15, Train Loss: 0.0398, Val Loss: 0.1205
Epoch 16, Train Loss: 0.0388, Val Loss: 0.0934
Early stopping
Accuracy of the model on the test images: 97.35%
Conclusion
In this tutorial, we demonstrated how to build, train, and evaluate a simple neural network using PyTorch, with a focus on implementing early stopping to prevent overfitting. This approach helps achieve better generalization by halting training when the validation performance stops improving.
Similar Reads
How to handle overfitting in computer vision models?
Overfitting is a common problem in machine learning, especially in computer vision tasks where models can easily memorize training data instead of learning to generalize from it. Handling overfitting is crucial to ensure that the model performs well on unseen data. In this article, we are going to e
7 min read
Using Early Stopping to Reduce Overfitting in Neural Networks
Overfitting is a common challenge in training neural networks. It occurs when a model learns to memorize the training data rather than generalize patterns from it, leading to poor performance on unseen data. While various regularization techniques like dropout and weight decay can help combat overfi
7 min read
How to handle overfitting in TensorFlow models?
Overfitting occurs when a machine learning model learns to perform well on the training data but fails to generalize to new, unseen data. In TensorFlow models, overfitting typically manifests as high accuracy on the training dataset but lower accuracy on the validation or test datasets. This phenome
10 min read
Identifying Overfitting in Machine Learning Models Using Scikit-Learn
Overfitting is a critical issue in machine learning that can significantly impact the performance of models when applied to new, unseen data. Identifying overfitting in machine learning models is crucial to ensuring their performance generalizes well to unseen data. In this article, we'll explore ho
7 min read
How to Split a Dataset Using PyTorch
Splitting a dataset is an important step in training machine learning models. It helps to separate the data into different sets, typically training, and validation, so we can train our model on one set and validate its performance on another. In this article, we are going to discuss the process of s
6 min read
Create Model using Custom Module in Pytorch
Custom module in Pytorch A custom module in PyTorch is a user-defined module that is built using the PyTorch library's built-in neural network module, torch.nn.Module. It's a way of creating new modules by combining and extending the functionality provided by existing PyTorch modules. The torch.nn.M
8 min read
How to implement Genetic Algorithm using PyTorch
The optimization algorithms are capable of solving complex problems and genetic algorithm is one of the optimization algorithm. Genetic Algorithm can be easily integrate with PyTorch to address a wide array of optimization tasks. We will understand how to implement Genetic Algorithm using PyTorch. G
8 min read
Monitoring Model Training in PyTorch with Callbacks and Logging
Monitoring model training is crucial for understanding the performance and behavior of your machine learning models. PyTorch provides several mechanisms to facilitate this, including the use of callbacks and logging. This article will guide you through the process of using these tools effectively. T
7 min read
How to Print the Model Summary in PyTorch
Printing a model summary is a crucial step in understanding the architecture of a neural network. In frameworks like Keras, this is straightforward with the model.summary() method. However, in PyTorch, achieving a similar output requires a bit more work. This article will guide you through the proce
6 min read
Overfitting in Decision Tree Models
In machine learning, decision trees are a popular tool for making predictions. However, a common problem encountered when using these models is overfitting. Here, we explore overfitting in decision trees and ways to handle this challenge. Why Does Overfitting Occur in Decision Trees?Overfitting in d
7 min read