PyTorch教程:21个项目玩转PyTorch实战
上QQ阅读APP看书,第一时间看更新

2.1.1 CIFAR-10数据集简介

CIFAR-10数据集共有60 000张32×32的RGB彩色图片,分为10个类别,每个类别有6 000张图片。其中训练集图片为50 000张,测试集有10 000张图片。训练集和测试集的生成方法是,分别从每个类别中随机挑选1 000张图片加入测试集,其余图片便加入训练集。与MNIST手写字符数据集比较来看,CIFAR-10数据集是彩色图片,图片内容是真实世界的物体,噪声更大,物体的比例也不一样,所以在识别上比MNIST困难很多。CIFAR-10数据集样例如图2-1所示。

图2-1 CIFAR-10数据集样例

训练分类器的步骤如下:

(1)使用视觉工具包torchvision加载并且归一化CIFAR-10的训练和测试数据集;

(2)定义一个卷积神经网络;

(3)定义一个损失函数;

(4)在训练样本数据上训练神经网络;

(5)在测试样本数据上测试神经网络。