Torch load载入数据集的标准流程

本文详细介绍了如何使用PyTorch的DataLoader对图像和文本数据进行预处理,包括创建myDataset类,设置数据加载规则,以及配置batch_size和shuffle。通过实例演示了如何从原始文件加载数据并进行批处理操作。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

不管是图像识别还是自然语言处理任务,在训练时都需要将数据按照batch的方式载入。本文记录了用Torch load数据的流程

  •  第一步当然是先有数据的原始文件,图像处理就是一堆图片,NLP就是很多句子。这个原始文件保存成任何数据格式都没关系,因为之后会先处理成Torch.utils.data.Dataset的一个子类。
  •  转化数据为Torch.utils.data.Dataset,代码如下:
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler

class myDataset(Dataset):  # 这是一个Dataset子类
    def __init__(self):
        self.Data = np.asarray([[2, 2], [1, 4], [4, 1], [12, 4], [8, 5]])  
        self.Label = np.asarray([4, 2, 2, 4, 2]) 
    def __getitem__(self, index):
        x = torch.from_numpy(self.Data[index])
        y = (self.Label[index])
        return x, y  
    def __len__(self):
        return len(self.Data)

mydata = myDataset()

这个时候可以用下标或者len访问数据集的某个元素或者长度

  • 定义一个sampler,也就是数据被返回的规则。这个不一定需要,如果定义这个,在下一步的shuffle参数需要改为false

rand_sampler = RandomSampler(mydata)
  •  定义Dataloader 
data_loader = DataLoader(mydata,batch_size=2,shuffle=False,sampler=rand_sampler,num_workers=4)
  • 然后就可以用循环访问出一个batch一个batch的数据了。
for i,traindata in enumerate(data_loader):
    print('i:',i)
    Data,Label=traindata
    print('data:',Data)
    print('Label:',Label)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值