CIFAR-10 数据集,是为数不多的可以用笔记本跑的 深度学习 数据集,总共有6w张彩色图片,图片大小为32x32,分为10类,其中5w张训练集,1w张测试集,本文主要介绍在 Pytorch 中使用该数据集的方法。

关联文档

https://pytorch.org/vision/stable/datasets.html

https://www.cs.toronto.edu/~kriz/cifar.html

代码示例

1、下载数据集

from torchvision import datasets

# 下载数据集到datas目录
file_path = './datas'
# 文件存在后,就不需要重复下载了
cifar10 = datasets.CIFAR10(file_path, download=True)

2、使用 matplotlib 展示图片

# 读取并显示一个样本
from matplotlib import pyplot as plt

img, label = cifar10[0]
plt.imshow(img)
plt.show()

3、数据转为 Tensor 格式

from torch.utils.data import DataLoader

# 读取数据,并转化为Tensor,默认PIL.image
cifar10 = datasets.CIFAR10(file_path, transform=transforms.ToTensor())

# 加载数据,并打乱顺序
data = DataLoader(cifar10, batch_size=10, shuffle=True)

for img, label in data:
    print(img.size())  # 图像数据
    print(label.item())  # 图像类别

4、常用transform方法

train_trans = transforms.Compose([
    transforms.CenterCrop(),  # 中心裁剪
    transforms.Grayscale(),  # 转灰度图
    transforms.Resize((32, 32)),  # 缩放
    transforms.RandomCrop(32, padding=4),  # 随机裁剪
    transforms.ToTensor(),  # 图片转张量,同时归一化0-255,范围0-1
])

data = datasets.CIFAR10(file_path, transform=train_trans)

本文为 陈华 原创,欢迎转载,但请注明出处:http://edu.ichenhua.cn/read/235