pytorch中TensorDataset与DataLoader的使用

文章介绍了在PyTorch中处理数据的基本组件,包括Dataset的抽象类,用于定义数据集的大小和获取单个样本,TensorDataset用于包装数据和目标张量,以及DataLoader用于批量加载和预处理数据。示例展示了如何使用transform对图像进行旋转、翻转等操作,并创建自定义的ImgDataset类。DataLoader还支持多进程加载和数据打乱。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Datatset\TensorDataset\DataLoader

class torch.utils.data.Dataset

表示Dataset的抽象类。所有其他数据集都应该进行子类化。所有子类应该override__len__和__getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。dataset是用来做打包和预处理(比如输入资料路径自动读取)。

当对图片进行处理时,如通过定义一个transforms来随机旋转训练图片,将图片格式变成tensor。

import numpy as np
import torch
from torch.utils.data import TensorDataset, Dataset
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

train_data = datasets.MNIST(root="./data/mnist", train=True, transform=transforms.ToTensor(), download=True)
test_data = datasets.MNIST(root="./data/mnist", train=False, transform=transforms.ToTensor(), download=True)

batch_size = 128
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(), ]
)
test_transform = transforms.Compose(
    [transforms.ToPILImage(),
     transforms.ToTensor(), ]
)  # 测试集不需要翻转或旋转图片

# 继承Dataset
class ImgDataset(Dataset):
    def __init__(self, x, y=None, transform=None):
        self.x = x
        self.y = y
        
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值