Open In App

How to load Fashion MNIST dataset using PyTorch?

Last Updated : 14 May, 2024
Summarize
Comments
Improve
Suggest changes
Share
Like Article
Like
Report

In machine learning, datasets are essential because they serve as benchmarks for comparing and assessing the performance of different algorithms. Fashion MNIST is one such dataset that replaces the standard MNIST dataset of handwritten digits with a more difficult format. The article explores the Fashion MNIST dataset, including its characteristics, uses, and how can we load it using PyTorch.


What is Fashion MNIST?

Fashion-MNIST is a dataset developed by Zalando Research as a modern alternative to the original MNIST dataset. It comprises 70,000 grayscale images categorized into 10 fashion-related items. Each image is 28x28 pixels, providing a uniform format for machine learning model input. The dataset is divided into a training set of 60,000 images and a test set of 10,000 images.

The ten categories in Fashion MNIST are:

  1. T-shirt/top
  2. Trouser
  3. Pullover
  4. Dress
  5. Coat
  6. Sandal
  7. Shirt
  8. Sneaker
  9. Bag
  10. Ankle boot

Characteristics of Fashion MNIST Dataset

Here are the key characteristics of the Fashion-MNIST dataset in bullet points:

  • Images are preprocessed and normalized, with pixel values ranging from 0 to 255.
  • Fashion-MNIST introduces real-world complexity with variations in lighting, pose, and background clutter.
  • The dataset exhibits class imbalance, with some categories having more images than others.

Load Fashion MNIST dataset in PyTorch

The 'torchvision.datasets.FashionMNIST()' function is used to load the FashionMNIST dataset in PyTorch.

torchvision.datasets.FashionMNIST(root: Union[str, Path], train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)

Breakdown of Parameters:

  1. root: Location where the dataset is stored or will be stored. Can be a directory path or a Path object.
  2. train: Indicates whether to load the training set (`True`) or the test set (`False`). Default is training set.
  3. transform: Transformation to apply to the samples (images). Can be a single transformation or a composition of multiple transformations. Default is no transformation.
  4. target_transform: Transformation to apply to the targets (labels). Used to preprocess the labels. Default is no transformation.
  5. download: Specifies whether to download the dataset if it's not already downloaded. Default is `False`, meaning it won't download if dataset exists locally.

Loading Fashion MNIST dataset using PyTorch in Python

In the following code, we have loaded the Fashion MNIST dataset using PyTorch and displayed 4x4 grid of images with their labels.

For loading the Fashion MNIST dataset, we have followed these steps:

  1. Imported necessary libraries including PyTorch, torchvision, and matplotlib.
  2. Defined a transformation using transforms.ToTensor() to convert the images into PyTorch tensors.
  3. Load the FashionMNIST dataset using torchvision.datasets.FashionMNIST(). It downloads the dataset if it's not already downloaded and applies the defined transformation.
  4. Created 4x4 grid of subplots and set the title of each subplot with corresponding labels and turn off the axis.
Python
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Define the transformation
transform = transforms.ToTensor()

# Load the dataset
train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)

# Create a subplot with 4x4 grid
fig, axs = plt.subplots(4, 4, figsize=(8, 8))

# Loop through each subplot and plot an image
for i in range(4):
    for j in range(4):
        image, label = train_dataset[i * 4 + j]  # Get image and label
        image_numpy = image.numpy().squeeze()    # Convert image tensor to numpy array
        axs[i, j].imshow(image_numpy, cmap='gray')  # Plot the image
        axs[i, j].axis('off')  # Turn off axis
        axs[i, j].set_title(f"Label: {label}")  # Set title with label

plt.tight_layout()  # Adjust layout
plt.show()  # Show plot

Output:

download-(8)



Next Article

Similar Reads