网站首页 > 技术文章 正文
前言
昨天我们学习了一些必备知识:PyTorch入门与实战——必备基础知识(下)01 PyTorch入门与实战——必备基础知识(上)01 。今天起我们进入模型的学习。如果将模型看作一辆汽车,那么它的开发过程就可以 看作是一套完整的生产流程,环环相扣、缺一不可。这些环节包括数据的读取、网络的设计、优化方法与损失函数的选择以及一些辅助的工具等。未来你将尝试构建自己的网络模型。
处理数据样本的代码可能会变得混乱且难以维护;理想情况下,我们希望我们的数据集代码与我们的模型训练代码分离,以获得更好的可读性和模块化。PyTorch 提供了两个数据库:torch.utils.data. 和torch.utils.data.Dataset 您使用预加载的数据集以及您自己的数据。 Dataset存储样本及其对应的标签,并DataLoader包裹一个可迭代对象Dataset,以便轻松访问样本。
PyTorch 域库提供了许多预加载的数据集(例如 FashionMNIST),这些数据集子类torch.utils.data.Dataset化并实现了特定于特定数据的功能。它们可用于对您的模型进行原型设计和基准测试。你可以在这里找到它们:图像数据集、 文本数据集和 音频数据集。
目录
- MNIST数据集
- 数据读入
- 迭代和可视化数据集
- 自定义数据集类
- 使用 DataLoaders 为训练准备数据
- 小结
MNIST数据集
下面是如何从 TorchVision 加载Fashion-MNIST数据集的示例。Fashion-MNIST 是 Zalando的文章中图像的数据集,由 60,000 个训练示例和 10,000 个测试示例组成。每个示例都包含 28×28 灰度图像和来自 10 个类别之一的相关标签。
我们使用以下参数加载FashionMNIST 数据集:
- root是存储训练/测试数据的路径,
- train指定训练或测试数据集,
- download=True如果数据不可用,则从 Internet 下载数据root。
- transform和target_transform分别指定特征和标签转换
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
数据被处理后的形式并不总是适合训练。所以,我们使用transforms对数据执行一些操作,使其适合于训练。所有TorchVision数据集都有两个参数,其中transform用于修改特征图,target_transform用于修改标签。torchvision.transforms模块提供了几种常用的转换,如下文的ToTensor()、Lambda。
- ToTensor将PIL图像或NumPy ndarray转换为浮点张量(FloatTensor)。并图像的像素值在[限制在[0,1]范围内。
- Lambda转换应用任何用户定义的lambda函数。例如:定义一个函数来将整数转换为一个独热编码张量。首先创建一个大小为class_num的零张量(数据集中标签的数量),并调用scatter_,它在标签y给定的索引上指定值为1。
target_transform = Lambda(lambda y: torch.zeros(
class_num, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
数据读入
训练开始的第一步,首先就是数据读取。PyTorch数据读入是通过Dataset+DataLoader的方式来得到数据迭代器。Dataset定义好数据的格式和数据变换形式,DataLoader用iterative的方式不断读入批次数据。下面我们分别来看下 Dataset 类与 DataLoader 类。
迭代和可视化数据集
我们可以Datasets像列表一样手动索引:training_data[index]. 我们matplotlib用来可视化训练数据中的一些样本。
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
自定义数据集类
自定义 Dataset 类必须实现三个函数:
我们可以定义自己的Dataset类来实现灵活的数据读取,定义的类需要继承PyTorch自身的Dataset类。主要包含三个函数:
- __init__: 用于向类中传入外部参数,同时定义样本集
- __getitem__: 索引数据集中的某一个数据,可以进行一定的变换,并将返回训练/验证所需的数据
- __len__: 用于返回数据集的样本数
例如:FashionMNIST 图像存储在一个目录img_dir中,它们的标签分别存储在一个 CSV 文件annotations_file中。在接下来的部分中,我们将分解每个函数中发生的事情。
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
'''返回数据集中的样本数'''
return len(self.img_labels)
def __getitem__(self, idx):
'''获取数据的方法,会和Dataloader连用''
# 获取图片路径,0表示CSV文件的第一列
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
# 读取图片
image = read_image(img_path)
# 获取图片对应的标签
label = self.img_labels.iloc[idx, 1]
# 如果使用时附加了transform参数,则对图片应用转换。本程序中transform=None表示没有附加参数
if self.transform:
image = self.transform(image)
# 同理,如果使用时附加了target_transform参数,则标签应用转换。本程序中target_transform=None表示没有附加参数
if self.target_transform:
label = self.target_transform(label)
# 返回图片和标签
return image, label
结合代码可以看到,我们定义了一个名字为 CustomImageDataset的数据集,在构造函数中,传入 Tensor 类型的数据与标签;在 __len__ 函数中,直接返回 Tensor 的大小;在 __getitem__ 函数中返回索引的数据与标签。
1、_init_
_ init _ 函数在实例化 Dataset 对象时运行一次。我们初始化包含图像、注释文件和两种转换的目录。
labels.csv 文件如下所示:
tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
2、_len_
_ len _ 函数返回我们数据集中的样本数。
def __len__(self):
return len(self.img_labels)
3、_getitem_
__ getitem __ 函数从给定索引处的数据集中加载并返回一个样本idx。基于索引,它识别图像在磁盘上的位置,使用 将其转换为张量read_image,从 csv 数据中检索相应的标签self.img_labels,调用它们的变换函数(如果适用),并返回张量图像和相应的标签一个元组。
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
使用 DataLoaders 为训练准备数据
检索我们数据集的Dataset特征并一次标记一个样本。在训练模型时,我们通常希望以“小批量”的形式传递样本,在每个回合(epoch) 重新打乱以减少模型过拟合,这需要DataLoader。
DataLoader通过一个简单的API为我们抽象了这种复杂的功能,且是可迭代的。
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
我们已将该数据集加载到 DataLoader中,并且可以根据需要遍历数据集。下面的每次迭代都会返回一批train_features(训练特征)和train_labels(训练标签)(分别包含batch_size=64特征和标签)。因为我们指定shuffle=True了 ,所以在我们遍历所有批次之后,数据被打乱(目的:为了更细粒度地控制数据加载顺序)。
用next()对DataLoader进行迭代:
# 显示一张图片和对应标签
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
小结
今天我们主要讲解了数据下载和数据处理,对数据处理有初步认识。主要介绍了如何下载MNIST数据集,如何把MNIST数据集划分为训练集和测试集。自定义数据集Dataset类三个主要函数__init__、 __getitem__、__len__。DataLoaders打乱数据、批次的划分。
思考题
在实际场景,往往需要自定义数据集类,请采用Dataset做一个简单的自定义小型数据集,并通过DataLoaders来加载。
欢迎在评论区交流与讨论
猜你喜欢
- 2024-10-14 程序员用PyTorch实现第一个神经网络前做好这9个准备,事半功倍
- 2024-10-14 PyTorch 分布式训练简明教程 pytorch分批训练
- 2024-10-14 PyTorch入门与实战——必备基础知识(下)01
- 2024-10-14 深度学习pytorch深度学习入门与简明实战教程2022年
- 2024-10-14 深度学习框架PyTorch-trick 集锦 深度学习框架pytorch:入门与实践 第2版
- 2024-10-14 AI | 图神经网络-Pytorch Biggraph简介及官方文档解读
- 2024-10-14 改动一行代码,PyTorch训练三倍提速,这些「高级技术」是关键
- 2024-10-14 利用pytorch CNN手写字母识别神经网络模型识别手写字母
- 2024-10-14 加快Python算法的四个方法(一)PyTorch
- 2024-10-14 PyTorch的Dataset 和TorchData API的比较
你 发表评论:
欢迎- 最近发表
- 标签列表
-
- oraclesql优化 (66)
- 类的加载机制 (75)
- feignclient (62)
- 一致性hash算法 (71)
- dockfile (66)
- 锁机制 (57)
- javaresponse (60)
- 查看hive版本 (59)
- phpworkerman (57)
- spark算子 (58)
- vue双向绑定的原理 (68)
- springbootget请求 (58)
- docker网络三种模式 (67)
- spring控制反转 (71)
- data:image/jpeg (69)
- base64 (69)
- java分页 (64)
- kibanadocker (60)
- qabstracttablemodel (62)
- java生成pdf文件 (69)
- deletelater (62)
- com.aspose.words (58)
- android.mk (62)
- qopengl (73)
- epoch_millis (61)
本文暂时没有评论,来添加一个吧(●'◡'●)