Open In App

How to handle overfitting in PyTorch models using Early Stopping

Last Updated : 10 Dec, 2024
Comments
Improve
Suggest changes
Like Article
Like
Report

Overfitting is a challenge in machine learning, where a model performs well on training data but poorly on unseen data, due to learning excessive noise or details from the training dataset.

In the context of deep learning with PyTorch, one effective method to combat overfitting is implementing early stopping. This article explains how early stopping works, demonstrates how to implement it in PyTorch, and explores its benefits and considerations.

What is Early Stopping?

Early stopping is a regularization technique used to avoid overfitting during the training process. It involves stopping the training phase if the model's performance on a validation set does not improve for a specified number of consecutive epochs, called the "patience" period. This ensures the model does not learn the noise and specific details of the training data, thereby enhancing its generalization capabilities.

Benefits of Early Stopping

  1. Prevents Overfitting: By halting training at the right time, early stopping ensures the model does not overfit.
  2. Saves Time and Resources: It reduces unnecessary training time and computational resources by stopping the training early.
  3. Optimizes Model Performance: Helps in selecting the version of the model that performs best on unseen data.


Steps needed to Implement Early Stopping in PyTorch

In this section, we are going to walk through the process of creating, training, and evaluating a simple neural network using PyTorch, focusing on the implementation of early stopping to prevent overfitting.

Step 1: Import Libraries

First, we import the necessary libraries:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np

Step 2: Define the Neural Network Architecture

Next, we define a simple neural network class using PyTorch's nn.Module. The neural network has:

  • fc1, fc2, fc3: Fully connected layers with ReLU activations.
  • forward method: Defines the forward pass of the network
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)

def forward(self, x):
x = torch.flatten(x, 1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x

Step 3: Implement Early Stopping

We implement an EarlyStopping class to halt training if the validation loss stops improving. Here the parameters are:

  • patience: Number of epochs to wait before stopping if no improvement.
  • delta: Minimum change in the monitored quantity to qualify as an improvement.
  • best_score, best_model_state: Track the best validation score and model state.
  • call method: Updates the early stopping logic.
class EarlyStopping:
def __init__(self, patience=5, delta=0):
self.patience = patience
self.delta = delta
self.best_score = None
self.early_stop = False
self.counter = 0
self.best_model_state = None

def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.best_model_state = model.state_dict()
elif score < self.best_score + self.delta:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.best_model_state = model.state_dict()
self.counter = 0

def load_best_model(self, model):
model.load_state_dict(self.best_model_state)

Step 4: Load the Data

We load and transform the MNIST dataset.

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Step 5: Initialize the Model, Loss Function, and Optimizer

We set up the model, criterion, and optimizer.

model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

Step 6: Train the Model with Early Stopping

We train the model, incorporating early stopping.

Here,

  • Train loop: Train the model, update weights, and calculate training loss.
  • Validation loop: Evaluate the model on validation data and calculate validation loss.
  • Early stopping check: Apply early stopping logic after each epoch.
# Training loop
num_epochs = 100
for epoch in range(num_epochs):
model.train()
train_loss = 0
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item() * data.size(0)

train_loss /= len(train_loader.dataset)

# Validation step (using validation set, not test set)
model.eval()
val_loss = 0
with torch.no_grad():
for data, target in val_loader: # Changed from test_loader to val_loader
output = model(data)
loss = criterion(output, target)
val_loss += loss.item() * data.size(0)

val_loss /= len(val_loader.dataset)

print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

early_stopping(val_loss, model)
if early_stopping.early_stop:
print("Early stopping")
break

early_stopping.load_best_model(model)

Step 7: Evaluate the Model

Finally, we evaluate the model's accuracy on the test dataset. The evaluation loop computes the accuracy by comparing predicted labels with true labels.

model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
outputs = model(data)
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()

print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')

Building and Training a Simple Neural Network with Early Stopping in PyTorch

Python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import numpy as np

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class EarlyStopping:
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0
        self.best_model_state = None

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.best_model_state = model.state_dict()
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.best_model_state = model.state_dict()
            self.counter = 0

    def load_best_model(self, model):
        model.load_state_dict(self.best_model_state)

# Data loading
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Split the training dataset into training and validation sets
train_size = int(0.8 * len(train_dataset))  # 80% for training
val_size = len(train_dataset) - train_size  # 20% for validation
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Model, loss function, and optimizer
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Early stopping
early_stopping = EarlyStopping(patience=5, delta=0.01)

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * data.size(0)

    train_loss /= len(train_loader.dataset)

    # Validation step (using validation set, not test set)
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for data, target in val_loader:
            output = model(data)
            loss = criterion(output, target)
            val_loss += loss.item() * data.size(0)

    val_loss /= len(val_loader.dataset)

    print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping")
        break

# Load the best model
early_stopping.load_best_model(model)

# Final evaluation on the test set
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data, target in test_loader:
        outputs = model(data)
        _, predicted = torch.max(outputs.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')

Output:

Epoch 1, Train Loss: 0.4373, Val Loss: 0.2750
Epoch 2, Train Loss: 0.2244, Val Loss: 0.1835
Epoch 3, Train Loss: 0.1617, Val Loss: 0.1441
.
.
.
Epoch 14, Train Loss: 0.0445, Val Loss: 0.1036
Epoch 15, Train Loss: 0.0398, Val Loss: 0.1205
Epoch 16, Train Loss: 0.0388, Val Loss: 0.0934
Early stopping
Accuracy of the model on the test images: 97.35%

Conclusion

In this tutorial, we demonstrated how to build, train, and evaluate a simple neural network using PyTorch, with a focus on implementing early stopping to prevent overfitting. This approach helps achieve better generalization by halting training when the validation performance stops improving.


Next Article

Similar Reads