点赞收藏关注!
如需要转载,请注明出处!
torch的模型加载有两种方式:
Datasets & DataLoaders
torch本身可以提供两数据加载函数:
torch.utils.data.DataLoader()和torch.utils.data.Dataset()
其中torch.utils.data 是PyTorch提供的一个模块,用于处理和加载数据。该模块提供了一系列工具类和函数,用于创建、操作和批量加载数据集。
加载函数后可以实现数据集代码与模型训练代码分离,以获得更好的可读性和模块化
Dataset定义了抽象的数据集类,用户可以通过继承该类来构建自己的数据集。制作自己的数据集必须要实现三个函数:
- init()函数在实例化Dataset对象时运行一次
- len()函数返回数据集中样本的数量
- getitem()函数的作用是:从给定索引 index,从数据集中加载并返回一个样本并将其转换为张量。
import torch
from torch.utils.data import Dataset
class CreateDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, index):
# 根据索引获取样本
return self.data[index]
def __len__(self):
# 返回数据集大小
return len(self.data)
# 创建数据集对象
data = [[255,255,255],[255,245,235],[225,226,227]]
dataset = CreateDataset(data)
# 根据索引获取样本
sample = dataset[1]
print(sample)
# [255,245,235]
数据处理模块其他的功能:
- TensorDataset: 继承自 Da