Splitting a dataset is an important step in training machine learning models. It helps to separate the data into different sets, typically training, and validation, so we can train our model on one set and validate its performance on another.
In this article, we are going to discuss the process of splitting a dataset using PyTorch, a popular framework for deep learning.
Introduction to Dataset Splitting
When we are working with a dataset, it is important not to use all of it just for training your model.
We should split it into different parts:
- Training Set: This is the portion of the dataset used to train our model.
- Validation Set: This set is used to evaluate our model's performance and adjust it accordingly.
- Test Set: Sometimes, a third set is used to test the model after training and validation are complete.
Splitting the data helps to avoid a common problem where the model learns too much from the training data and does really well on it but does not do well when faced with unseen (new) data.
Splitting Datasets in PyTorch: A Step-by-Step Guide with Random Split
PyTorch provides a simple function known as "random_split" to help us to split our dataset. This function divides our data into non-overlapping chunks based on the proportions we specify.
Step 1: Import Required Libraries
First, we need to import the necessary libraries for our task. We’ll use PyTorch, NumPy, and some tools from the sklearn library to generate sample data.
import pprint as pp
from sklearn import datasets
import numpy as np
import torch
from torch.utils.data import Dataset, random_split
Step 2: Generate Sample Data
Next, we will create some sample data that we can work with. We will use make_blobs from sklearn to generate a simple dataset.
# Define the number of samples
total_samples = 1800
# Generate sample data with 3 features and 2 centers
X_data, Y_data = datasets.make_blobs(n_samples=total_samples, n_features=3, centers=[(-2, 5), (3, -4)], random_state=42)
Here, we are generating a dataset with 1800 samples, each having 3 features, and split across 2 centers. This will give us some synthetic data to work with.
Step 3: Create a Custom Dataset Class
In PyTorch, it’s common to create a custom Dataset class to handle our data. This class will allow us to manage how data is loaded.
class CustomDataset(Dataset):
def __init__(self, x, y):
self.x = x
self.y = y
def __getitem__(self, index):
# Return a dictionary with 'features' and 'label' as keys
sample = {
'features': torch.tensor(self.x[index], dtype=torch.float32),
'label': torch.tensor(self.y[index], dtype=torch.long)
}
return sample
def __len__(self):
# Return the total number of samples
return len(self.x)
This class takes the input data (x) and labels (y), and returns them as a dictionary when accessed. The __len__ method returns the number of samples in the dataset.
Step 4: Create the Dataset Instance
Now, we will create an instance of our custom dataset and check its length.
# Create the dataset instance
dataset = CustomDataset(X_data, Y_data)
# Print the length of the dataset
print("Total number of samples in the dataset:", len(dataset))
This will print out the total number of samples in our dataset, which should be 1800.
Step 5: Split the Dataset
Finally, we can split the dataset into training and validation sets using random_split.
# Split the dataset into training (1200 samples) and validation (600 samples)
train_data, val_data = random_split(dataset, [1200, 600])
# Print the lengths of the train and validation sets
print("Number of training samples:", len(train_data))
print("Number of validation samples:", len(val_data))
Here, we’re splitting the dataset so that 1200 samples go to the training set and 600 to the validation set.
Example Code:
Below is example code for splitting a dataset using PyTorch:
Python
import pprint as pp
from sklearn import datasets
import numpy as np
import torch
from torch.utils.data import Dataset, random_split
# Generate Sample Data
total_samples = 1800
X_data, Y_data = datasets.make_blobs(n_samples=total_samples, n_features=3, centers=[(-2, 5), (3, -4)], random_state=42)
# Create a Custom Dataset Class
class CustomDataset(Dataset):
def __init__(self, x, y):
self.x = x
self.y = y
def __getitem__(self, index):
sample = {
'features': torch.tensor(self.x[index], dtype=torch.float32),
'label': torch.tensor(self.y[index], dtype=torch.long)
}
return sample
def __len__(self):
return len(self.x)
# Create the Dataset Instance
dataset = CustomDataset(X_data, Y_data)
print("Total number of samples in the dataset:", len(dataset))
# Split the Dataset
train_data, val_data = random_split(dataset, [1200, 600])
print("Number of training samples:", len(train_data))
print("Number of validation samples:", len(val_data))
Output:
Total number of samples in the dataset: 1800
Number of training samples: 1200
Number of validation samples: 600
How to Split CIFAR-10 Dataset for Training and Validation in PyTorch?
Splitting a dataset into training and validation sets is a crucial step in machine learning to ensure that a model is trained on one subset of data and evaluated on another, unseen subset. Now, we’ll walk through how to split the CIFAR-10 dataset using PyTorch.
Step 1: Import Required Libraries
First, we need to import the necessary libraries for data manipulation and model training. PyTorch provides tools for handling datasets and transformations, which we'll use in this example.
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
Step 2: Define Data Transformations
Data transformations are essential to preprocess the CIFAR-10 images before feeding them into the model. We will convert the images to tensors and normalize them.
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
Step 3: Load the CIFAR-10 Dataset
Download and load the CIFAR-10 dataset. The dataset will be transformed according to the transformations defined earlier.
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
Step 4: Define Split Ratios
Specify the proportions for training and validation splits. In this example, we'll use an 80-20 split.
train_ratio = 0.8
validation_ratio = 0.2
Step 5: Calculate Sizes for Each Split
Calculate the number of samples for each split based on the specified ratios.
dataset_size = len(dataset)
train_size = int(train_ratio * dataset_size)
validation_size = dataset_size - train_size
Step 6: Perform the Split
Use the random_split
function to divide the dataset into training and validation sets.
train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size])
Step 7: Create DataLoaders
DataLoaders are used to load the data in batches, which is useful for training and evaluating models.
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=64, shuffle=False)
Verify the Splits
Finally, print out the sizes of the training and validation datasets to verify the splits.
print(f'Total dataset size: {dataset_size}')
print(f'Training dataset size: {len(train_dataset)}')
print(f'Validation dataset size: {len(validation_dataset)}')
Complete Code
Python
import torch
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import datasets, transforms
# Step 1: Define the transformations
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Step 2: Load the dataset
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# Step 3: Define the split ratios
train_ratio = 0.8
validation_ratio = 0.2
# Step 4: Calculate the sizes for each split
dataset_size = len(dataset)
train_size = int(train_ratio * dataset_size)
validation_size = dataset_size - train_size
# Step 5: Perform the split
train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size])
# Step 6: Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=64, shuffle=False)
# Verify the splits
print(f'Total dataset size: {dataset_size}')
print(f'Training dataset size: {len(train_dataset)}')
print(f'Validation dataset size: {len(validation_dataset)}')
Output:
Total dataset size: 50000
Training dataset size: 40000
Validation dataset size: 10000
Conclusion
Splitting a dataset is a fundamental step in machine learning. We can easily do it using built-in functions of PyTorch. By following above steps, we can ensure that our model is trained and validated effectively, leading to better generalization and performance on new data.