How to load Fashion MNIST dataset using PyTorch?
Last Updated :
14 May, 2024
In machine learning, datasets are essential because they serve as benchmarks for comparing and assessing the performance of different algorithms. Fashion MNIST is one such dataset that replaces the standard MNIST dataset of handwritten digits with a more difficult format. The article explores the Fashion MNIST dataset, including its characteristics, uses, and how can we load it using PyTorch.
What is Fashion MNIST?
Fashion-MNIST is a dataset developed by Zalando Research as a modern alternative to the original MNIST dataset. It comprises 70,000 grayscale images categorized into 10 fashion-related items. Each image is 28x28 pixels, providing a uniform format for machine learning model input. The dataset is divided into a training set of 60,000 images and a test set of 10,000 images.
The ten categories in Fashion MNIST are:
- T-shirt/top
- Trouser
- Pullover
- Dress
- Coat
- Sandal
- Shirt
- Sneaker
- Bag
- Ankle boot
Characteristics of Fashion MNIST Dataset
Here are the key characteristics of the Fashion-MNIST dataset in bullet points:
- Images are preprocessed and normalized, with pixel values ranging from 0 to 255.
- Fashion-MNIST introduces real-world complexity with variations in lighting, pose, and background clutter.
- The dataset exhibits class imbalance, with some categories having more images than others.
Load Fashion MNIST dataset in PyTorch
The 'torchvision.datasets.FashionMNIST()' function is used to load the FashionMNIST dataset in PyTorch.
torchvision.datasets.FashionMNIST(root: Union[str, Path], train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)
Breakdown of Parameters:
- root: Location where the dataset is stored or will be stored. Can be a directory path or a Path object.
- train: Indicates whether to load the training set (`True`) or the test set (`False`). Default is training set.
- transform: Transformation to apply to the samples (images). Can be a single transformation or a composition of multiple transformations. Default is no transformation.
- target_transform: Transformation to apply to the targets (labels). Used to preprocess the labels. Default is no transformation.
- download: Specifies whether to download the dataset if it's not already downloaded. Default is `False`, meaning it won't download if dataset exists locally.
Loading Fashion MNIST dataset using PyTorch in Python
In the following code, we have loaded the Fashion MNIST dataset using PyTorch and displayed 4x4 grid of images with their labels.
For loading the Fashion MNIST dataset, we have followed these steps:
- Imported necessary libraries including PyTorch, torchvision, and matplotlib.
- Defined a transformation using
transforms.ToTensor()
to convert the images into PyTorch tensors. - Load the FashionMNIST dataset using
torchvision.datasets.FashionMNIST()
. It downloads the dataset if it's not already downloaded and applies the defined transformation. - Created 4x4 grid of subplots and set the title of each subplot with corresponding labels and turn off the axis.
Python
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# Define the transformation
transform = transforms.ToTensor()
# Load the dataset
train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
# Create a subplot with 4x4 grid
fig, axs = plt.subplots(4, 4, figsize=(8, 8))
# Loop through each subplot and plot an image
for i in range(4):
for j in range(4):
image, label = train_dataset[i * 4 + j] # Get image and label
image_numpy = image.numpy().squeeze() # Convert image tensor to numpy array
axs[i, j].imshow(image_numpy, cmap='gray') # Plot the image
axs[i, j].axis('off') # Turn off axis
axs[i, j].set_title(f"Label: {label}") # Set title with label
plt.tight_layout() # Adjust layout
plt.show() # Show plot
Output: