PyTorch vs PyTorch Lightning
Last Updated :
21 Mar, 2024
The PyTorch research team at Facebook AI Research (FAIR) introduced PyTorch Lightning to address these challenges and provide a more organized and standardized approach. In this article, we will see the major differences between PyTorch Lightning and Pytorch.
Pytorch
PyTorch is widely used for deep learning and artificial intelligence research and applications. PyTorch provides a dynamic computational graph, allowing for more flexibility and ease of use compared to static computational graph frameworks.
Pytorch Lightning: Advanced Framework of Pytorch
PyTorch Lightning is a lightweight PyTorch wrapper that provides a high-level interface for training PyTorch models. It is designed to simplify and standardize the training loop, making it easier to write cleaner, more modular code for deep learning projects. PyTorch Lightning introduces a set of abstractions and conventions that remove boilerplate code and allow researchers and practitioners to focus more on the model architecture and experiment configurations.
Pytorch vs Pytorch Lightning
PyTorch and PyTorch Lightning are both frameworks for building and training neural network models, but they differ in terms of abstraction, structure, and ease of use. Here are some key differences between PyTorch and PyTorch Lightning:
Features
| PyTorch
| PyTorch Lightning
|
---|
Training Loop
| User-defined, training loop, validation loop, and testing loop explicitly, including handling aspects like moving data to the GPU, computing gradients, and updating model parameters.
| Users define hooks and callbacks to customize behavior without directly modifying the training loop.
|
---|
Model Setup
| User define the model, loss function, optimizer, and other components explicitly.
| Standardized with dedicated methods
|
---|
Abstraction Level | Lower-level, requires more manual coding | Higher-level, hides boilerplate code |
---|
GPU and Distributed Training
| Requires manual efforts for explicitly moving models and data to GPU, manage distributed training, and handle multi-GPU scenarios.
| Automatic based on user configuration. Users can specify the number of GPUs.
|
---|
Logging and Experiment Tracking
| Logging metrics and tracking experiments require manual implementation using tools like TensorBoard or custom loggers.
| built-in support for various logging frameworks (TensorBoard, CSV, etc.) and experiment tracking platforms (e.g., WandB, Comet)
|
---|
Debugging and Profiling
| Manual instrumentation
| Provides hooks for common debugging and profiling tools.
|
---|
Integration with other Libraries
| User-managed integrations
| Built-in integrations
|
---|
Best Practices
| Independent discovery
| Benefits from a growing community and ecosystem, allowing users to leverage pre-built components and best practices.
|
---|
Standardized Interface
| Code structure may vary.
| Enforces a consistent structure through LightningModule.
|
---|
Module System
| Left to developer discretion.
| Promotes a modular system with LightningModules.
|
---|
Checkpointing
| Users must implement checkpointing logic.
| Built-in support for simplified model checkpointing.
|
---|
Implementation: From Pytorch to Pytorch Lightning
Let's illustrate the difference in code between a basic PyTorch script and its equivalent using PyTorch Lightning. Consider a simple training script for a neural network in both PyTorch and PyTorch Lightning.
Let's compare the training and validation loops for a simple 3-layer neural network on the MNIST dataset using both PyTorch and PyTorch Lightning. The key ingredients include the model, dataset (MNIST), optimizer, and loss function.
PyTorch
Importing necessary libraries and modules
The code starts by importing the necessary libraries and modules for building and training the neural network. These include torch, torch.nn, torch.optim, and torchvision.
Python3
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
2. Defining the neural network model
Next, the code defines a simple neural network model using PyTorch's nn.Module class. The model consists of three fully connected layers (fc1, fc2, and fc3) with 256, 128, and 10 neurons, respectively. The output layer has 10 neurons, corresponding to the 10 classes of digits in the MNIST dataset.
The forward method defines the forward pass of the neural network, where the input is passed through each layer and transformed using the ReLU activation function.
Python3
# Define the model
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 10)
def forward(self, x):
x = torch.flatten(x, 1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
3. Loading the dataset
The code then loads the MNIST dataset using torchvision.datasets.MNIST. The dataset is preprocessed using the transforms.Compose method, which applies a series of transformations to the data. In this case, the data is converted to a tensor and normalized to have a mean of 0.1307 and a standard deviation of 0.3081.
The train_dataset and test_dataset objects are created and loaded into train_loader and test_loader, which are PyTorch DataLoader objects that handle batching and shuffling of the data.
Python3
# Load the dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
4. Initializing the model, optimizer, and loss function
The neural network model is initialized, and the stochastic gradient descent (SGD) optimizer and cross-entropy loss function are defined.
Python
# Initialize the model
model = NeuralNetwork()
# Define the optimizer and loss function
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
5. Training and validating the model
The code defines two functions, train and validate, which handle the training and validation of the neural network.
- The train function takes in the model, training data loader, optimizer, and loss function, and trains the model on the data in batches. The gradients are accumulated and the model weights are updated using the SGD optimizer.
- The validate function takes in the model, test data loader, and loss function, and evaluates the model on the test data. The test loss is computed and the accuracy of the model is calculated as the percentage of correct predictions.
Finally, the code trains and validates the neural network for 10 epochs using the train and validate functions.
Python3
# Training loop
def train(model, train_loader, optimizer, criterion):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# Validation loop
def validate(model, test_loader, criterion):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
print('Validation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
test_loss, correct, len(test_loader.dataset), accuracy))
# Train and validate the model
for epoch in range(10):
train(model, train_loader, optimizer, criterion)
validate(model, test_loader, criterion)
Output:
Validation set: Average loss: 0.0052, Accuracy: 9050/10000 (90.50%)
Validation set: Average loss: 0.0041, Accuracy: 9267/10000 (92.67%)
Validation set: Average loss: 0.0034, Accuracy: 9369/10000 (93.69%)
Validation set: Average loss: 0.0030, Accuracy: 9445/10000 (94.45%)
Validation set: Average loss: 0.0026, Accuracy: 9506/10000 (95.06%)
Validation set: Average loss: 0.0024, Accuracy: 9555/10000 (95.55%)
Validation set: Average loss: 0.0021, Accuracy: 9597/10000 (95.97%)
Validation set: Average loss: 0.0020, Accuracy: 9641/10000 (96.41%)
Validation set: Average loss: 0.0018, Accuracy: 9663/10000 (96.63%)
Validation set: Average loss: 0.0017, Accuracy: 9680/10000 (96.80%)
PyTorch Lightning
1. First, the necessary imports are made, including PyTorch, PyTorch Lightning, the MNIST dataset, and the Adam optimizer.
Python3
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.optim import Adam
import pytorch_lightning as pl
2. Defining the neural network model
Next, the MyModel class is defined, which inherits from pl.LightningModule. This class defines the neural network architecture, the forward pass, the training step, the validation step, and the configuration of the optimizer.
- In the __init__ method, the neural network architecture is defined using PyTorch's nn.Sequential module. It consists of three fully connected layers with ReLU activation functions, and a final softmax layer for outputting probabilities.
- The forward method takes in an input tensor x, reshapes it to have the correct number of dimensions, and passes it through the neural network using the self.model module.
Python3
class MyModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 10),
nn.Softmax(dim=1)
)
def forward(self, x):
x = x.view(x.size(0), -1) # Reshape the input
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
self.log('val_loss', loss)
# Calculate accuracy
correct = (y_hat.argmax(1) == y).sum().item()
total = y.size(0)
self.log('accuracy', correct / total, on_step=False, on_epoch=True, prog_bar=True)
def configure_optimizers(self):
return Adam(self.parameters(), lr=0.001)
3. Loading the dataset
The MNIST dataset is loaded using MNIST class from torchvision.datasets. The training set and validation set are split into separate DataLoader objects for training and validation.
Python3
# Load the MNIST dataset
train_dataset = MNIST(root='.', train=True, transform=ToTensor(), download=True)
val_dataset = MNIST(root='.', train=False, transform=ToTensor())
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)
4. Initializing the model
Python3
# Initialize the model
model = MyModel()
5. Train the model
The trainer object is created using pl.Trainer. The max_epochs argument is set to 10, which means that the model will be trained for 10 epochs. The accelerator argument is set to "cpu" if a GPU is not available, otherwise it is set to "cpu".
Python3
# Initialize the trainer
trainer = pl.Trainer(max_epochs=10, accelerator = "cpu" if torch.cuda.is_available() else "cpu")
6. Model fitting
Finally, the model is trained using the trainer.fit method. It takes in the model object, and the train_loader and val_loader objects as arguments.
Python3
# Train the model
trainer.fit(model, train_loader, val_loader)
7. Validating the model
Python3
trainer.validate(model, val_loader)
Output:
[{'val_loss': 1.4881926774978638, 'accuracy': 0.9729999899864197}]
Code Difference Takeaways
Feature | PyTorch | PyTorch Lightning |
---|
Inheritance | Inherits from nn.Module | Inherits from pl.LightningModule |
---|
Architecture Definition | Uses separate class or custom definition | Uses nn.Sequential within MyModel class |
---|
Code Structure | Separate functions for training and validation | Organized within MyModel class with dedicated methods |
---|
Training Loop | Explicitly written for loop | Abstracted, handled by trainer.fit |
---|
Optimizer and Scheduler | Defined and configured within training loop | Defined in configure_optimizers method |
---|
Logging Metrics | Manual implementation with external libraries | Simplified using self.log within LightningModule |
---|
Conclusion
PyTorch Lightning serves as a powerful tool for researchers and practitioners in the deep learning community, offering a standardized and organized framework for building and training models. By abstracting away common boilerplate code, automating training processes, and providing a modular structure, PyTorch Lightning simplifies the development workflow and enhances collaboration.
Similar Reads
PyTorch Lightning with TensorBoard
Pytorch-Lightning is a popular deep learning framework. It basically works with PyTorch models to simplify the training and testing of the models. This library is useful for distributed training as one can train the model seamlessly without much complex codes. Now to get the metrics in an user inter
5 min read
How to Install PyTorch Lightning
PyTorch Lightning is a powerful and flexible framework designed to streamline the process of building complex deep learning models using PyTorch. By organizing PyTorch code, it allows researchers and engineers to focus more on research and less on boilerplate code. This article will guide you throug
2 min read
PyTorch Lightning 1.5.10 Overview
PyTorch Lightning is a lightweight wrapper around PyTorch that aims to simplify the process of building and training machine learning models. It abstracts much of the boilerplate code, allowing researchers and developers to focus more on the model architecture and less on the engineering details. Th
5 min read
Tensors in Pytorch
A Pytorch Tensor is basically the same as a NumPy array. This means it does not know anything about deep learning or computational graphs or gradients and is just a generic n-dimensional array to be used for arbitrary numeric computation. However, the biggest difference between a NumPy array and a P
6 min read
Training Neural Networks using Pytorch Lightning
Introduction: PyTorch Lightning is a library that provides a high-level interface for PyTorch. Problem with PyTorch is that every time you start a project you have to rewrite those training and testing loop. PyTorch Lightning fixes the problem by not only reducing boilerplate code but also providing
7 min read
PyTorch-Lightning Conda Setup Guide
PyTorch-Lightning is a popular deep learning framework and is more simple version of PyTorch. It is easy to use as one does not need to define the training loops and the testing loops. We can perform distributed training easily without making the code complex. Some other features include more focus
7 min read
Python - PyTorch log() method
PyTorch torch.log() method gives a new tensor having the natural logarithm of the elements of input tensor. Syntax: torch.log(input, out=None) Arguments input: This is input tensor. out: The output tensor. Return: It returns a Tensor. Let's see this concept with the help of few examples: Example 1:
1 min read
What are Torch Scripts in PyTorch?
TorchScript is a powerful feature in PyTorch that allows developers to create serializable and optimizable models from PyTorch code. It serves as an intermediate representation of a PyTorch model that can be run in high-performance environments, such as C++, without the need for a Python runtime. Th
5 min read
Saving and Loading Weights in PyTorch Lightning
In Machine learning models, it is important to save and load weights efficiently. This helps us preserve the state of our model during training, so we can resume later without starting from scratch. In this article, we are going to discuss how to save and load weights in PyTorch Lightning. PyTorch L
8 min read
Reinforcement Learning using PyTorch
Reinforcement learning using PyTorch enables dynamic adjustment of agent strategies, crucial for navigating complex environments and maximizing rewards. The article aims to demonstrate how PyTorch enables the iterative improvement of RL agents by balancing exploration and exploitation to maximize re
7 min read