计算机系统应用教程网站

网站首页 > 技术文章 正文

第5章 torchvision库与加载内置图片数据集

btikc 2024-10-01 08:27:24 技术文章 11 ℃ 0 评论

4.1 torchvision库

torchvision库是PyTorch框架中用来处理图像和视频的一个辅助库,属于PyTorch项目的一部分。在本书第1章PyTorch安装过程中,我们已一并安装了此库。PyTorch通过torchvision库提供了一些常用的数据集、模型、转换函数等。torchvision库提供的内置数据集可用于测试、学习和创建基准模型,本章我们使用torchvision加载内置的数据集进行分类模型的演示。

为了统一数据加载和处理代码,PyTorch提供了两个类用以处理数据加载,它们分别是torch.utils.data.Dataset类和torch.utils.data.DataLoader类,通过这两个类可使我们的数据集加载和预处理代码与模型训练代码脱钩,从而获得更好的代码模块化和代码可读性。关于如何创建自定义的Dataset类将在下一章介绍,本章我们需要知道,torchvision加载的内置图片数据集均继承自torch.utils.data.Dataset类,因此我们可直接使用加载的内置数据集创建DataLoader。

4.2 加载内置图片数据集

PyTorch内置图片数据集均在torchvision.datasets模块下,包含Caltech、CelebA、CIFAR、Cityscapes、COCO、Fashion-MNIST、ImageNet、MNIST等很多著名的数据集。其中MNIST数据集是手写数字数据集,这是一个很适合入门学习使用的小型计算机视觉数据集,它包含0~9的手写数字图片,也包含每一张图片对应的标签。我们以此数据集为例,来学习如何加载使用内置图片数据集。加载内置图片数据集的代码如下。

import torchvision # 导入 torchvision 库

from torchvision.transforms import ToTensor


train_ds = torchvision.datasets.MNIST('data/',

train=True,

transform=ToTensor(),

download=True

)


test_ds = torchvision.datasets.MNIST('data/',

train=False,

transform=ToTensor(),

download=True

)


上面代码中我们首先导入了torchvision库,并从torchvision.transforms模块下导入ToTensor这个类。torchvision.transforms模块包含了转换函数,使用它可以很方便地对加载的图像做各种变换,具体的转换将在迁移学习一章中做详细的演示。在这里我们用到了ToTensor这个类,该类的主要作用有以下三点。

(1)将输入转换为张量。

(2)将读取图片的格式规范为 (channel, height,width),这与读者以前经常遇到的图片格式可能有些区别,PyTorch中的图片格式一般是通道数(channel)在前,然后是高度(height)和宽度(width)。

(3)将图片像素的取值范围归一化,规范到0~1。

上述加载代码中,我们通过torchvision.datasets.MNIST加载MNIST数据集,这个方法的第一个参数 data/ 代表下载数据集存放的位置,这里我们放在了当前程序目录下的data文件夹中,train参数表示是否是训练数据,若为True,则加载训练集,若为False,表示加载测试数据集;使用transform参数表示对加载数据的预处理,参数值为ToTensor(),最后的参数download=True表示我们将下载此数据集,这样一旦下载完成后,下一次执行此代码时,将优先从本地文件夹直接加载。如果读者的计算机不能连接互联网,也可以直接将文件复制到下载文件夹data中,这样就能从本地直接加载数据了。

现在我们得到了两个数据集(dataset),分别是训练数据集和测试数据集,PyTorch还为我们提供了 torch.utils.data.DataLoader类用以对dataset做进一步的处理,DataLoader接收dataset,并执行复杂的操作,如小批次处理、多线程、随机打乱等,以便从数据集中获取数据。它接收来自用户的dataset实例,并使用采样器策略将数据采样为小批次。DataLoader主要用来完成四个目的。

(1)使用shuffle参数对数据做乱序。一般我们需要对训练数据集进行乱序的操作。因为原始的数据在样本均衡的情况下可能是按照某种顺序进行排列的,如前半部分为某一类别的数据,后半部分为另一类别的数据。但经过打乱顺序之后,数据的排列就会拥有一定的随机性,在顺序读取的情况下,读取一次得到的样本为任何一种类型数据的可能性相同。这样可避免模型反复依次序学习数据的特征或者学习到的只是数据的次序特征。

(2)将数据采样为小批次,可用batch_size参数指定批次大小。我们在线性回归一节中,同时对输入和标签迭代送入模型进行训练,这样单个样本训练有一个很大的缺点,就是损失和梯度会受到单个样本的影响,如果样本分布不均匀,或者有错误标注样本,则会引起梯度的巨大震荡,从而导致模型训练效果很差。为了解决此问题,可考虑使用批量数据训练,也叫作批量梯度下降算法,通过遍历全部数据集算一次损失函数,然后计算损失对各个参数的梯度,更新参数。这种训练方式每更新一次参数都要把数据集里的所有样本都看一遍,不仅计算开销大,而且计算速度慢。为了克服上述两种方法的缺点,一般采用的是一种折中手段进行损失函数计算:即把数据分为若干个小的批次(batch),按批次来更新参数,这样在一个批次中的一组数据共同决定了本次梯度的方向,大大降低了参数更新时的梯度方差,下降起来更加稳定,减少了随机性。与单样本训练相比,小批次训练可利用矩阵操作进行有效的梯度计算,计算量也不是很大,对计算机内存的要求也不高。

(3)可以充分利用多个子进程,加速数据预处理。num_workers参数可以指定子进程数量。

(4)可通过 collate_fn 参数传递批次数据的处理函数,实现在DataLoader中对批次数据做转换处理。关于这个转换,我们将在文本分类入门实例中演示。

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)

test_dl = torch.utils.data.DataLoader(test_ds, batch_size=46)


上述代码中我们分别创建了train数据和test数据的DataLoader,并设置它们的批次大小为64,对于train数据设置了shuffle为True,对于test数据,由于仅仅作为测试,没必要做shuffle。

Dataloader是可迭代对象,我们看一下它返回的数据的形状,方便大家对Dataloader和MNIST数据集有一个直观的印象,代码如下。

imgs, labels = next(iter(train_dl)) # 创建生成器,并用next返回一个批次的数据

print(imgs.shape) # 输出 torch.Size([64, 1, 28, 28])

print(labels.shape) # 输出 torch.Size([64])

代码中我们使用iter方法将DataLoader对象创建为生成器,并使用next方法返回了一个批次的图像(imgs)和对应的一个批次的标签(labels),imgs.shape为 torch.Size([64, 1, 28, 28]), 如何理解这个shape?显然这里的64是批次, 我们可以认为这代表 64 张shape为(1, 28, 28)的图片,其中 1 为channel, 28和28分别表示高和宽;既然这里有64张图片,对应的也应该有64个标签,也就是labels.shape所显示的 torch.Size([64]) 。

下面通过绘图来看一下MNIST数据集中的这些图片是什么样子的。使用Matplotlib库绘图,绘制imgs中的前10张图片,对于Matplotlib绘图不熟悉的同学可参考学习日月光华在网易云课堂的数据绘图课程或者数据分析课程进行学习。

plt.figure(figsize=(10, 1)) # 创建画布

for i, img in enumerate(imgs[:10]):

npimg = img.numpy() # 将tensor转为ndarray

npimg = np.squeeze(npimg) # 图片shape由(1, 28, 28) 转为 (28, 28)

plt.subplot(1, 10, i+1) # 初始化子图,三个参数表示1行10列的第i+1个子图

plt.imshow(npimg) # 在子图中绘制单张图片

plt.axis('off') # 关闭显示子图坐标

绘图结果如图4-1所示。

接下来可以打印对应的标签,观察图片与标签是否是对应的,代码如下。

print(labels[:10]) # 输出 tensor([5, 0, 0, 4, 0, 1, 3, 0, 8, 0])


很明显,图片与标签是对应的,这便是我们本章要使用的MNIST内置手写数字数据集。

作者:日月光华 咨询微信: guanghua2025

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表