Save and Load Models in PyTorch
Last Updated :
24 Apr, 2025
It often happens that we need to use the already-trained models to perform some operations in our development environment. In this case, would you create the model again and again? Or, you will save the model somewhere else and load it as per the requirement. You would definitely choose the second option. So in this article, we will see how to implement the concept of saving and loading the models using PyTorch.
What is PyTorch?
PyTorch is an open-source Machine Learning Library that works on the dynamic computation graph. In the static computation approach, the models are predefined before the execution. But in dynamic computation which PyTorch follows, the structure of the graph in the Neural Network can change during the execution based on the input data. Hence, It allows to creation and training the Neural Networks to extract hidden patterns from the data.
You might think what a Neural Network is. So in simple words, a Neural Network is a collection of layers containing Nodes. These layers are interconnected with each other in which one Node processes the data and passes it to the other Node. Hence, the entire Neural Network learns and extracts the insights from the data.
Stepwise Guide to Save and Load Models in PyTorch
Now, we will see how to create a Model using the PyTorch.
Creating Model in PyTorch
To save and load the model, we will first create a Deep-Learning Model for the image classification. This model will classify the images of the handwritten digits from the MNIST Dataset. The below code implements the Convolutional Neural Network for image classification. The data is loaded and transformed into PyTorch Sensors, which are like containers to store the data.
The following code shows the creation of the PyTorch Model.
Importing Necessary Libraries
Python3
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
Data Transformation
The given code defines a transformation pipeline using torchvision.transforms.Compose
for preprocessing image data before feeding it into a PyTorch model.
transforms.ToTensor()
: Converts the input image (assumed to be in PIL Image format) to a PyTorch tensor. It converts the image data type to torch.FloatTensor
and scales the pixel values to the range [0.0, 1.0].
transforms.Normalize((0.5,), (0.5,))
: Normalizes the tensor image with mean and standard deviation. The provided mean and standard deviation values (0.5,)
and (0.5,)
respectively are used to normalize each channel of the input tensor. This transformation normalizes the tensor values to be in the range [-1.0, 1.0].
Python3
# Define transformation to apply to the data
data_transform = transforms.Compose([
transforms.ToTensor(), # Convert images to PyTorch tensors
transforms.Normalize((0.5,), (0.5,)) # Normalize the pixel values to range [-1, 1]
])
# Download MNIST dataset and apply the transformation
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=data_transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=data_transform, download=True)
# Define data loaders to load the data in batches during training and testing
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
Defining neural network architecture
- Class Definition: The code defines a class
SimpleCNN
that inherits from nn.Module
, which is the base class for all neural network modules in PyTorch. This class represents a simple convolutional neural network (CNN). - Initialization: In the
__init__
method, the code defines the layers of the CNN. It includes two convolutional layers (conv1_layer
and conv2_layer
) with specified kernel sizes and padding, and two fully connected layers (fc1_layer
and fc2_layer
) with specified input and output sizes. - Forward Pass: The
forward
method defines the forward pass of the network. It applies a ReLU activation function after each convolutional layer and uses max pooling with a kernel size of 2 and stride of 2 to downsample the feature maps. The output of the second convolutional layer is flattened before being passed to the fully connected layers. - View Operation: The
view
operation reshapes the output of the second convolutional layer to be compatible with the input size of the first fully connected layer. The -1
argument in view
indicates that the size of that dimension should be inferred based on the other dimensions. - Model Instance: Finally, an instance of the
SimpleCNN
class is created and assigned to the variable cnn_model
. This instance represents the actual neural network that can be trained and used for inference.
Python3
# Here we are adding convolution layer and fully connected layers in neural network
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1_layer = nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.conv2_layer = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.fc1_layer = nn.Linear(32 * 7 * 7, 128)
self.fc2_layer = nn.Linear(128, 10)
# Adding ReLU Activation function Max Pooling Layer
def forward(self, inputs):
new_input = torch.relu(self.conv1_layer(inputs))
new_input = torch.max_pool2d(new_input, kernel_size=2, stride=2)
new_input = torch.relu(self.conv2_layer(new_input))
new_input = torch.max_pool2d(new_input, kernel_size=2, stride=2)
new_input = new_input.view(-1, 32 * 7 * 7)
new_input = torch.relu(self.fc1_layer(new_input))
new_input = self.fc2_layer(new_input)
return new_input
# Creating Model Instance
cnn_model = SimpleCNN()
Loss Function and Optimizer
- Loss Function:
nn.CrossEntropyLoss()
is used as the loss function. This loss function is commonly used for classification problems with multiple classes. It calculates the cross-entropy loss between the predicted class probabilities and the actual class labels. - Optimizer:
optim.Adam
is used as the optimizer. Adam is a popular optimization algorithm that computes adaptive learning rates for each parameter. It is well-suited for training deep neural networks. The optimizer is initialized with the parameters of the cnn_model
and a learning rate of 0.001.
Python3
# Define loss function and optimizer
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn_model.parameters(), lr=0.001)
Training the model
The code implements the following steps:
- Outer Loop (Epochs): The code iterates over 5 epochs using a
for
loop. An epoch is a single pass through the entire dataset. - Inner Loop (Batches): Within each epoch, the code iterates over batches of data using
train_loader
, which presumably contains batches of input data (inputs
) and their corresponding labels (labels
). - Zero Gradients: Before the backward pass (
loss.backward()
), optimizer.zero_grad()
is called to zero out the gradients of the model parameters. This is necessary because PyTorch accumulates gradients by default. - Forward and Backward Pass:
outputs = cnn_model(inputs)
performs the forward pass, where the model processes the input data to generate predictions (outputs
).loss = loss_func(outputs, labels)
calculates the loss between the predicted outputs and the actual labels.loss.backward()
computes the gradients of the loss with respect to the model parameters, enabling backpropagation.optimizer.step()
updates the model parameters based on the computed gradients, using the optimization algorithm (Adam in this case) to adjust the weights.
- Loss Calculation: Within the inner loop,
running_loss
accumulates the total loss across batches. At the end of each epoch, the average loss per batch is printed to monitor the training progress.
Python3
# Train model
for epoch in range(5): # Train for 5 epochs
running_loss = 0.0
for inputs, labels in train_loader:
optimizer.zero_grad() # Zero the gradients
outputs = cnn_model(inputs) # Forward pass
loss = loss_func(outputs, labels) # Calculate the loss
loss.backward() # Backward pass
optimizer.step() # Update weights
running_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")
Output:
Epoch 1, Loss: 0.22154594235159933
Epoch 2, Loss: 0.05766747533348697
Epoch 3, Loss: 0.04144403319505514
Epoch 4, Loss: 0.029859573355312946
Epoch 5, Loss: 0.024109310584392515
Testing The Model
Python
# Test model
correct_predictions = 0
total_samples = 0
with torch.no_grad():
for inputs, labels in test_loader:
outputs = cnn_model(inputs)
_, predicted_labels = torch.max(outputs.data, 1)
total_samples += labels.size(0)
correct_predictions += (predicted_labels == labels).sum().item()
print(f"Accuracy of test set: {100 * correct_predictions / total_samples}%")
Output:
Accuracy of test set: 99.16%
Saving and Loading Model
Method 1: Using torch.save() and torch.load()
The following code shows method to save and load the model using the built-in function provided by the torch module. The torch.save() method directly saves model object into the file and the torch.load() loads the model back into the memory.
Python
# Save the model
torch.save(cnn_model.state_dict(), 'cnn_model.pth')
# Load the model
loaded_model = SimpleCNN()
loaded_model.load_state_dict(torch.load('cnn_model.pth'))
# Set the model to evaluation mode
loaded_model.eval()
Output:
SimpleCNN(
(conv1_layer): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2_layer): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fc1_layer): Linear(in_features=1568, out_features=128, bias=True)
(fc2_layer): Linear(in_features=128, out_features=10, bias=True)
)
Method 2: Using model.state_dict()
Now, let us see another way to save and load the model using the state_dict() method. This method stores the parameters of the created model. When the model is loaded, a new model with the same architecture is created. Then, the parameters of the new model are replaced with the stored parameters. Since only parameters are stored, this method is memory efficient. The following code snippet illustrates this method.
Python
# Saving the model
torch.save(cnn_model.state_dict(), 'cnn_model.pth')
# Loading the model
loaded_model = SimpleCNN()
loaded_model.load_state_dict(torch.load('cnn_model.pth'))
print(loaded_model)
Output:
SimpleCNN(
(conv1_layer): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2_layer): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fc1_layer): Linear(in_features=1568, out_features=128, bias=True)
(fc2_layer): Linear(in_features=128, out_features=10, bias=True)
)
Method 3: Saving and Loading using the Checkpoints
The checkpoints method saves the model by creating a dictionary that contains all the necessary information like model state_dict, optimizer state_dict, current epoch, loss, etc. And, to load the model, the checkpoint file is loaded to retrieve the information. This method is demonstrated as shown below:
Python
# Saving the model
checkpoint = {
'epoch': epoch,
'model_state_dict': cnn_model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
# you may add other information to add
}
torch.save(checkpoint, 'checkpoint.pth')
# Loading the model
checkpoint = torch.load('checkpoint.pth')
cnn_model = SimpleCNN()
cnn_model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
print(cnn_model)
Output:
SimpleCNN(
(conv1_layer): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2_layer): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fc1_layer): Linear(in_features=1568, out_features=128, bias=True)
(fc2_layer): Linear(in_features=128, out_features=10, bias=True)
)
Conclusion
There are various methods to save and load Models created using PyTorch Library. It has the torch.save() and torch.load() method to save and load the model object. On the other hand, the model.state_dict() provides the memory-efficient approach to save and load the models. In addition to this, if you want to store all the relevant information about the model in a dictionary, you can use the checkpoint file to store the model object and retrieve it from the memory. Hence, these various methods allow us to manage the models, and transfer the parameters and other information. All we need to understand is the memory constraints, information beyond just model parameters, and use-case scenarios so that we can select the right method.
Similar Reads
Save and load models in Tensorflow
Training machine learning or deep learning model is time-consuming and shutting down the notebook causes all the weights and activations to disappear as the memory is flushed. Hence, we save models for reusability, collaboration, and continuation of training. Saving the model allows us to avoid leng
4 min read
How to deploy PyTorch models on Vertex AI
PyTorch is a freely available machine learning library that can be imported and used inside the code for performing machine learning operations based on requirements. The front-end api is written in Python and the tensor operations are implemented using C++. It is developed by Facebook's AI Research
12 min read
Save and Load Models using TensorFlow in Json?
If you are looking to explore Machine Learning with TensorFlow, you are at the right place. This comprehensive article explains how to save and load the models in TensorFlow along with its brief overview. If you read this article till the end, you will not need to look for further guides on how to s
6 min read
Saving and Loading Weights in PyTorch Lightning
In Machine learning models, it is important to save and load weights efficiently. This helps us preserve the state of our model during training, so we can resume later without starting from scratch. In this article, we are going to discuss how to save and load weights in PyTorch Lightning. PyTorch L
8 min read
What Does model.train() Do in PyTorch?
A crucial aspect of training a model in PyTorch involves setting the model to the correct mode, either training or evaluation. This article delves into the purpose and functionality of the model.train() method in PyTorch, explaining its significance in the training process and how it interacts with
4 min read
Vector Operations in Pytorch
In this article, we are going to discuss vector operations in PyTorch. Vectors are a one-dimensional tensor, which is used to manipulate the data. Vector operations are of different types such as mathematical operation, dot product, and linspace. PyTorch is an optimized tensor library majorly used f
4 min read
Change view of Tensor in PyTorch
In this article, we will learn how to change the shape of tensors using the PyTorch view function. We will also look at the multiple ways in which we can change the shape of the tensors. Also, we can use the view function to convert lower-dimensional matrices to higher dimensions. What is the necess
3 min read
Create Model using Custom Module in Pytorch
Custom module in Pytorch A custom module in PyTorch is a user-defined module that is built using the PyTorch library's built-in neural network module, torch.nn.Module. It's a way of creating new modules by combining and extending the functionality provided by existing PyTorch modules. The torch.nn.M
8 min read
Saving and Loading XGBoost Models
XGBoost is a powerful and widely-used gradient boosting library that has become a staple in machine learning. Its ability to handle large datasets and provide accurate results makes it a popular choice among data scientists. However, one crucial aspect of working with XGBoost models is saving and lo
7 min read
Train a Deep Learning Model With Pytorch
Neural Network is a type of machine learning model inspired by the structure and function of human brain. It consists of layers of interconnected nodes called neurons which process and transmit information. Neural networks are particularly well-suited for tasks such as image and speech recognition,
6 min read