Load a Computer Vision Dataset in PyTorch
Last Updated :
24 Apr, 2025
Computer vision is a subset of Artificial Intelligence that gives the ability to the computer to understand images. In Deep Learning, Convolution Neural Network is used to process the image. For building the good we need a lot of images to process.
There are several ways to load a computer vision dataset in PyTorch, depending on the format of the dataset and the specific requirements of your project.
One popular method is to use the built-in PyTorch dataset classes, such as torchvision.datasets.'It provides a convenient way to load and preprocess common computer vision datasets, such as CIFAR-10 and ImageNet. For example, to load the CIFAR-10 dataset, you can use the following code:
Python3
# Import the necessary library
import torchvision.datasets as datasets
# Download the cifar Dataset
cifar10_train = datasets.CIFAR10(root="./data", train=True, download=True)
cifar10_test = datasets.CIFAR10(root="./data", train=False, download=True)
Output:
CIFAR-10
The code above will download the CIFAR-10 dataset and save it in the './data' directory.
Another method is using the 'torch.utils.data.DataLoader class to load the data. This is more useful when the data is in your local machine and you would like to have the power of data augmentation and the ability to shuffle the data and also have the ability to specify the batch size. it has the advantages of customizing data loading order, batching, single or multi-process data loading, etc.
Here we can use transform.Compose function from torchvision to rotate, flip, normalize and convert it into tensor form from the image.
Python3
# Import the necessary library
from torchvision import transforms
from torch.utils.data import DataLoader
# Image Transformation
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize([0.35, 0.35, 0.406], [0.30, 0.34, 0.35])
])
# Load the dataset with transformation
cifar10_train = datasets.CIFAR10(root="./data", train=True, download=False, transform=transform)
cifar10_test = datasets.CIFAR10(root="./data", train=False, download=False, transform=transform)
# Make the batch of size 16
train_loader = DataLoader(cifar10_train, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(cifar10_test, batch_size=32, shuffle=False, num_workers=2)
View the train and test data
Python3
#Train Dataset
print(train_loader.dataset)
#Test Dataset
print(test_loader.dataset)
Output:
Dataset CIFAR10
Number of datapoints: 50000
Root location: ./data
Split: Train
StandardTransform
Transform: Compose(
RandomHorizontalFlip(p=0.5)
RandomRotation(degrees=[-10.0, 10.0], interpolation=nearest, expand=False, fill=0)
ToTensor()
Normalize(mean=[0.35, 0.35, 0.406], std=[0.3, 0.34, 0.35])
)
Dataset CIFAR10
Number of datapoints: 10000
Root location: ./data
Split: Test
StandardTransform
Transform: Compose(
RandomHorizontalFlip(p=0.5)
RandomRotation(degrees=[-10.0, 10.0], interpolation=nearest, expand=False, fill=0)
ToTensor()
Normalize(mean=[0.35, 0.35, 0.406], std=[0.3, 0.34, 0.35])
)
Plot the image:
Python3
# Iteration
inputs, Class = next(iter(train_loader))
#Define the class names
class_name ={0:'airplane',
1:'automobile',
2:'bird',
3:'cat',
4:'deer',
5:'dog',
6:'frog',
7:'horse',
8:'ship',
9:'truck'
}
#Plot the figure
plt.figure(figsize=(30,16), dpi=1000)
for i in range(32):
plt.subplot(4,8,i+1)
plt.imshow(inputs[i].numpy().transpose((1, 2, 0)))
plt.axis('off')
plt.title(class_name[int(Class[i])])
plt.show()
Output:
CIFAR-10
The other libraries like 'albumentations' , can be used to load the dataset and preprocess the data. It all depends on the format of your data and what you are trying to achieve
You might also want to check the version of PyTorch you're using, as well as the format of the dataset you're trying to load. Some datasets might be in a custom format and you might need to write your own code to load it correctly.
Similar Reads
Dataset for Computer Vision
Computer Vision is an area in the field of Artificial Intelligence that enables machines to interpret and understand visual information. As in case of any other AI application, Computer vision also requires huge amount of data to give accurate results. These datasets provide all the necessary traini
11 min read
How to load CIFAR10 Dataset in Pytorch?
The CIFAR-10 dataset is a popular resource for training machine learning models, especially in the field of image recognition. It consists of 60,000 32x32 color images in 10 different classes, with 6,000 images per class. The dataset is divided into 50,000 training images and 10,000 testing images.
3 min read
Computer Vision Datasets
Computer vision has rapidly evolved, impacting sectors from healthcare to automotive and from retail to security. In this article, we delve into the significance of computer vision datasets, explore prominent datasets, and discuss their contributions in shaping the future of AI. These datasets, incl
6 min read
Computing the Mean and Std of a Dataset in Pytorch
PyTorch provides various inbuilt mathematical utilities to monitor the descriptive statistics of a dataset at hand one of them being mean and standard deviation. Mean, denoted by, is one of the Measures of central tendencies which is calculated by finding the average of the given dataset. Standard D
3 min read
Computer Vision with PyTorch
PyTorch is a powerful framework applicable to various computer vision tasks. The article aims to enumerate the features and functionalities within the context of computer vision that empower developers to build neural networks and train models. It also demonstrates how PyTorch framework can be utili
6 min read
Top Computer Vision Models
Computer Vision has affected diverse fields due to the release of resourceful models. Some of these are the image classification models of CNNs such as AlexNet and ResNet; object detection models include R-CNN variants, while medical image segmentation uses U-Nets. YOLO and SSD models are perfect fo
10 min read
Installing a CPU-Only Version of PyTorch
PyTorch is a popular open-source machine learning library that provides a flexible platform for developing deep learning models. While PyTorch is well-known for its GPU support, there are many scenarios where a CPU-only version is preferable, especially for users with limited hardware resources or t
3 min read
How to Compute Gradients in PyTorch
PyTorch is a leading deep-learning library that offers flexibility and a dynamic computing environment, making it a preferred tool for researchers and developers. One of its most praised features is the ease of computing gradients automatically, which is crucial for training neural networks.In this
5 min read
How to Split a Dataset Using PyTorch
Splitting a dataset is an important step in training machine learning models. It helps to separate the data into different sets, typically training, and validation, so we can train our model on one set and validate its performance on another. In this article, we are going to discuss the process of s
6 min read
PyTorch Functional Transforms for Computer Vision
In this post, we will discuss ten PyTorch Functional Transforms most used in computer vision and image processing using PyTorch. PyTorch provides the torchvision library to perform different types of computer vision-related tasks. The functional transforms can be accessed from the torchvision.transf
6 min read