计算机系统应用教程网站

网站首页 > 技术文章 正文

PyTorch入门与实战——数据处理与数据加载02

btikc 2024-10-14 08:49:44 技术文章 4 ℃ 0 评论

前言

昨天我们学习了一些必备知识:PyTorch入门与实战——必备基础知识(下)01 PyTorch入门与实战——必备基础知识(上)01 。今天起我们进入模型的学习。如果将模型看作一辆汽车,那么它的开发过程就可以 看作是一套完整的生产流程,环环相扣、缺一不可。这些环节包括数据的读取、网络的设计、优化方法与损失函数的选择以及一些辅助的工具等。未来你将尝试构建自己的网络模型。

处理数据样本的代码可能会变得混乱且难以维护;理想情况下,我们希望我们的数据集代码与我们的模型训练代码分离,以获得更好的可读性和模块化。PyTorch 提供了两个数据库:torch.utils.data. 和torch.utils.data.Dataset 您使用预加载的数据集以及您自己的数据。 Dataset存储样本及其对应的标签,并DataLoader包裹一个可迭代对象Dataset,以便轻松访问样本。

PyTorch 域库提供了许多预加载的数据集(例如 FashionMNIST),这些数据集子类torch.utils.data.Dataset化并实现了特定于特定数据的功能。它们可用于对您的模型进行原型设计和基准测试。你可以在这里找到它们:图像数据集、 文本数据集和 音频数据集。

目录

  1. MNIST数据集
  2. 数据读入
  3. 迭代和可视化数据集
  4. 自定义数据集类
  5. 使用 DataLoaders 为训练准备数据
  6. 小结

MNIST数据集

下面是如何从 TorchVision 加载Fashion-MNIST数据集的示例。Fashion-MNIST 是 Zalando的文章中图像的数据集,由 60,000 个训练示例和 10,000 个测试示例组成。每个示例都包含 28×28 灰度图像和来自 10 个类别之一的相关标签。

我们使用以下参数加载FashionMNIST 数据集:

  • root是存储训练/测试数据的路径,
  • train指定训练或测试数据集,
  • download=True如果数据不可用,则从 Internet 下载数据root。
  • transform和target_transform分别指定特征和标签转换

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

数据被处理后的形式并不总是适合训练。所以,我们使用transforms对数据执行一些操作,使其适合于训练。所有TorchVision数据集都有两个参数,其中transform用于修改特征图,target_transform用于修改标签。torchvision.transforms模块提供了几种常用的转换,如下文的ToTensor()、Lambda。

  • ToTensor将PIL图像或NumPy ndarray转换为浮点张量(FloatTensor)。并图像的像素值在[限制在[0,1]范围内。
  • Lambda转换应用任何用户定义的lambda函数。例如:定义一个函数来将整数转换为一个独热编码张量。首先创建一个大小为class_num的零张量(数据集中标签的数量),并调用scatter_,它在标签y给定的索引上指定值为1。
target_transform = Lambda(lambda y: torch.zeros(
    class_num, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

数据读入

训练开始的第一步,首先就是数据读取。PyTorch数据读入是通过Dataset+DataLoader的方式来得到数据迭代器。Dataset定义好数据的格式和数据变换形式,DataLoader用iterative的方式不断读入批次数据下面我们分别来看下 Dataset 类与 DataLoader 类。

迭代和可视化数据集

我们可以Datasets像列表一样手动索引:training_data[index]. 我们matplotlib用来可视化训练数据中的一些样本。


labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

自定义数据集类

自定义 Dataset 类必须实现三个函数

我们可以定义自己的Dataset类来实现灵活的数据读取,定义的类需要继承PyTorch自身的Dataset类。主要包含三个函数:

  • __init__: 用于向类中传入外部参数,同时定义样本集
  • __getitem__: 索引数据集中的某一个数据,可以进行一定的变换,并将返回训练/验证所需的数据
  • __len__: 用于返回数据集的样本数

例如:FashionMNIST 图像存储在一个目录img_dir中,它们的标签分别存储在一个 CSV 文件annotations_file中。在接下来的部分中,我们将分解每个函数中发生的事情。


import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
       '''返回数据集中的样本数'''
        return len(self.img_labels)

    def __getitem__(self, idx):
      '''获取数据的方法,会和Dataloader连用''
        # 获取图片路径,0表示CSV文件的第一列
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        # 读取图片
        image = read_image(img_path)
        # 获取图片对应的标签
        label = self.img_labels.iloc[idx, 1]
       # 如果使用时附加了transform参数,则对图片应用转换。本程序中transform=None表示没有附加参数
        if self.transform:
            image = self.transform(image)
       # 同理,如果使用时附加了target_transform参数,则标签应用转换。本程序中target_transform=None表示没有附加参数
        if self.target_transform:
            label = self.target_transform(label)
				# 返回图片和标签
        return image, label


结合代码可以看到,我们定义了一个名字为 CustomImageDataset的数据集,在构造函数中,传入 Tensor 类型的数据与标签;在 __len__ 函数中,直接返回 Tensor 的大小;在 __getitem__ 函数中返回索引的数据与标签。

1、_init_

_ init _ 函数在实例化 Dataset 对象时运行一次。我们初始化包含图像、注释文件和两种转换的目录。

labels.csv 文件如下所示:
tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9

def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file)
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform

2、_len_

_ len _ 函数返回我们数据集中的样本数。



def __len__(self):
    return len(self.img_labels)

3、_getitem_

__ getitem __ 函数从给定索引处的数据集中加载并返回一个样本idx。基于索引,它识别图像在磁盘上的位置,使用 将其转换为张量read_image,从 csv 数据中检索相应的标签self.img_labels,调用它们的变换函数(如果适用),并返回张量图像和相应的标签一个元组。


def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    return image, label

使用 DataLoaders 为训练准备数据

检索我们数据集的Dataset特征并一次标记一个样本。在训练模型时,我们通常希望以“小批量”的形式传递样本,在每个回合(epoch) 重新打乱以减少模型过拟合,这需要DataLoader。

DataLoader通过一个简单的API为我们抽象了这种复杂的功能,且是可迭代的


from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

我们已将该数据集加载到 DataLoader中,并且可以根据需要遍历数据集。下面的每次迭代都会返回一批train_features(训练特征)和train_labels(训练标签)(分别包含batch_size=64特征和标签)。因为我们指定shuffle=True了 ,所以在我们遍历所有批次之后,数据被打乱(目的:为了更细粒度地控制数据加载顺序)。

用next()对DataLoader进行迭代:


# 显示一张图片和对应标签
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

小结

今天我们主要讲解了数据下载和数据处理,对数据处理有初步认识。主要介绍了如何下载MNIST数据集,如何把MNIST数据集划分为训练集和测试集。自定义数据集Dataset类三个主要函数__init__、 __getitem__、__len__。DataLoaders打乱数据、批次的划分。

思考题

在实际场景,往往需要自定义数据集类,请采用Dataset做一个简单的自定义小型数据集,并通过DataLoaders来加载。

欢迎在评论区交流与讨论

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表