计算机系统应用教程网站

网站首页 > 技术文章 正文

【AI 和机器学习】PyTorch BASIC 基础知识(节2):数据集数据加载器

btikc 2024-12-16 11:13:35 技术文章 34 ℃ 0 评论

【AI 和机器学习】PyTorch BASIC 基础知识:节2

—— 数据集和数据加载器(Datasets & DataLoaders)

前言

—— 哪个更适合初学者?

想要学习并掌握 AI,最直接的办法就是自己动手进行实操。有一些流行的来源可供练习 AI 技能,例如:

  • Kaggle:一个托管各种数据集和机器学习竞赛的平台。
  • UCI 机器学习存储库:用于机器学习研究的数据集集合。
  • TensorFlow 教程:TensorFlow 团队提供的教程和示例。
  • PyTorch 教程:PyTorch 团队提供的教程和示例。

其中PyTorch和TensorFlow的AI教程资源非常丰富。但对于初学者来说,哪个更合适,可能还得取决于您的特定目标(研究与生产)以及您的偏好等:

  • PyTorch 因其简单、易读和易于调试而通常被认为更适合初学者。PyTorch 的动态特性使新手可以学习概念而不会被复杂的语法所困扰。
  • TensorFlow 随着 TensorFlow 2.x 和 Keras 的推出变得更加适合初学者,但它仍可能对初学者构成挑战。

本文先选择PyTorch来和大家一起学习,学习它的一些基础内容。其中所有素材均取自其教程。对于每一节内容,我们都将先给出摘要,然后把译文稍作整理后附在后面,供参考。

目录

【续前文】

本节摘要

本节讨论了 PyTorch 中数据集和数据加载器的使用,强调了将数据集处理与模型训练分离,以便提高可读性和模块化的重要性。PyTorch 提供了两个关键组件:“torch.utils.data.Dataset”用于存储样本及其标签,以及“torch.utils.data.DataLoader”用于方便访问这些样本。

本节解释了如何使用 TorchVision 加载 Fashion-MNIST 数据集(包含 60000 张训练图像和 10000 张测试灰度图像的集合),详细说明了 “root”、“train” 和transformations等参数。本文包含加载数据集和使用 Matplotlib 可视化样本的代码示例。

此外,还介绍了通过实现“__init__”、“__len__”和“__getitem__”方法来创建自定义数据集类,允许用户从指定目录加载图像,并从 CSV 文件加载其标签。

最后,本节重点介绍了 DataLoader 在模型训练期间的批处理、随机排序和加速数据检索功能,从而可以在机器学习流中更加轻松地有效地管理数据集。


本节正文

用于处理数据样本的代码,可能会变得混乱而且难以维护;理想情况下,我们希望数据集代码与我们的模型训练代码解耦,以此提高可读性和模块化。PyTorch 提供了两个数据原语: torch.utils.data.DataLoader 和 torch.utils.data.Dataset,它们允许您使用预加载的数据集以及您自己的数据。Dataset 存储样本及其相应的标签,DataLoader 在 Dataset 周围包装了一个可迭代对象,以便于访问样本。

PyTorch 域库提供了许多预加载的数据集(例如 FashionMNIST),这些数据集子类化 torch.utils.data.Dataset 并实现特定于特定数据的功能。它们可以用于对模型进行原型设计和基准测试。您可以在此处找到它们:图像数据集、文本数据集和音频数据集

加载数据集

以下是如何从 TorchVision 加载 Fashion-MNIST 数据集的示例。Fashion-MNIST 是 Zalando 的文章图像数据集,由 60,000 个训练示例和 10,000 个测试示例组成。每个示例都包含一个 28×28 灰度图像和一个来自 10 个类之一的关联标签。

我们使用以下参数加载 FashionMNIST 数据集:

  • root 存储训练/测试数据的路径
  • train 指定训练或测试数据集
  • download=True 如果数据在 root 中不可用,则从 Internet 下载数据。
  • transform 和 target_transform 指定特征和标签转换

输出:

迭代和可视化数据集

我们可以像列表一样手动索引 DataDataset:training_data[index]。我们使用 matplotlib 来可视化训练数据中的一些样本。

(踝靴, 衬衫, 包, 踝靴, 长裤, 凉鞋, 外套, 凉鞋, 套头衫)


为您的文件创建自定义数据集

自定义 Dataset 类必须实现三个函数:__init__、__len__ 和 __getitem__。来看一下这个实现;FashionMNIST 图像存储在目录img_dir中,其标签单独存储在 CSV 文件annotations_file中。

在接下来的部分中,我们将分解每个函数中发生的情况。


__init__

__init__ 函数在实例化 Dataset 对象时运行一次。我们初始化包含图像、annotations 文件和两个transform的目录(下一节将更详细地介绍)。

labels.csv 文件如下所示:

__len__

__len__ 函数返回数据集中的样本数。

例:

__getitem__

__getitem__ 函数加载并返回给定索引 idx 处的数据集中的样本。根据索引,它识别图像在磁盘上的位置,使用 read_image 将其转换为张量,从 self.img_labels 中的 csv 数据中检索相应的标签,对其调用transform函数(如果适用),并在元组中返回张量图像和相应的标签。


准备用于DataLoader 进行训练的数据

Dataset 检索我们数据集的特征,并一次标记一个样本。在训练模型时,我们通常希望以 “小批量” 的形式传递样本,在每个 epoch 重新洗牌数据以减少模型过度拟合,并使用 Python 的multiprocessing来加快数据检索速度。

DataLoader 是一个可迭代对象,它通过一个简单的 API 为我们抽象了这种复杂性。

通过DataLoader 迭代

我们已将该数据集加载到 DataLoader 中,并可以根据需要迭代数据集。下面的每次迭代都会返回一批 train_features 和 train_labels(分别包含 batch_size=64 个特征和标签)。因为我们指定了 shuffle=True,所以在我们迭代所有批次后,数据会被随机排序(要对数据加载顺序进行更精细的控制,请查看 Samplers)。

输出:

延伸阅读:torch.utils.data API 接口

【未完待续】

农历甲辰十月廿五

2024.11.25

【部分图片来源网络,侵删】

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表