在训练 深度学习 模型之前,样本集的制作是非常重要的环节。在pytorch中,提供了一些接口和类,方便我们定义自己的数据集合,按照对应格式自定义数据集,才可以使用DataLoader加载数据,下面是自定义样本集的整个流程。
“三步走”的策略
Pytorch输入数据PipeLine一般遵循“三步走”的策略,一般pytorch 的数据加载到模型的操作顺序是这样的:
① 创建一个 Dataset 对象。必须实现__len__()、__getitem__()这两个方法,这里面会用到transform对数据集进行扩充。
② 创建一个 DataLoader 对象。它是对DataSet对象进行迭代的,一般不需要事先里面的其他方法了。
③ 循环遍历这个 DataLoader 对象。将img, label加载到模型中进行训练。
代码示例
1、基本格式
from torch.utils import data class MyDataset(data.Dataset): # 需要继承data.Dataset def __init__(self): pass def __len__(self): # 获取数据集的大小 pass def __getitem__(self, index): # 通过索引获取数据 # 如果是图片,可以返回PIL.image、numpy数组或者Tensor # 如果是PIL.image,需要使用transform转化成Tensor pass
2、初始化和获取数据集长度
def __init__(self, root): self.train_data = [] train_file = os.path.join(root, 'train/data.txt') with open(train_file) as f: for line in f: path, label = line.strip().split(' ') file_path = os.path.join(root, 'train', path) self.train_data.append((file_path, label)) def __len__(self): # 获取数据集的大小 return len(self.train_data)
创建文件 data.txt,并编辑文件路径和label,例如 cat/1.png 0,之后导入图片。
3、迭代读取文件def __getitem__(self, index): from PIL import Image file_path, label = self.train_data[index] # 读取图片,并转化为np数组 img = Image.open(file_path) return np.array(img), label
4、添加transform参数
dataset = MyDataset(file_path, transform=transforms.ToTensor()) loader = data.DataLoader(dataset, batch_size=10) for l in loader: print(l) # class MyDataset(): def __init__(self, root, transform): self.transform = transform def __getitem__(self, index): ...... if self.transform: img = self.transform(img) return img, label
参考文档
https://www.jianshu.com/p/2d9927a70594
https://pytorch.org/docs/stable/data.html
https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
本文为 陈华 原创,欢迎转载,但请注明出处:http://edu.ichenhua.cn/read/236