Datatset\TensorDataset\DataLoader
class torch.utils.data.Dataset
表示Dataset的抽象类。所有其他数据集都应该进行子类化。所有子类应该override__len__和__getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。dataset是用来做打包和预处理(比如输入资料路径自动读取)。
当对图片进行处理时,如通过定义一个transforms来随机旋转训练图片,将图片格式变成tensor。
import numpy as np
import torch
from torch.utils.data import TensorDataset, Dataset
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
train_data = datasets.MNIST(root="./data/mnist", train=True, transform=transforms.ToTensor(), download=True)
test_data = datasets.MNIST(root="./data/mnist", train=False, transform=transforms.ToTensor(), download=True)
batch_size = 128
train_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ToTensor(), ]
)
test_transform = transforms.Compose(
[transforms.ToPILImage(),
transforms.ToTensor(), ]
) # 测试集不需要翻转或旋转图片
# 继承Dataset
class ImgDataset(Dataset):
def __init__(self, x, y=None, transform=None):
self.x = x
self.y = y