网站首页 > 技术文章 正文
之前的几篇介绍完PyTorch基本并且常用的API后,大家肯定对PyTorch对Tensor的操作有了一个基本的理解,当然我们介绍的API并不是很完整,也相对比较基础。还有很多复杂的API和一些比较不好理解的API没有讲解。不过我们后面遇到一例case也可以单独和大家一起学习分享吃透相关的问题。
今天开始我们就要开始学习PyTorch我们日常编程的基本套路,这个套路基本上是不怎么变换的,换的也是一些细节,换的也是对应场景的模型,当然本次学习大家最好都能够手动coding一下代码,主要是留下一些印象。
流程基本介绍
回归正题,正常来说PyTorch编程的常见套路如下:
一般来说,整体流程如下解释:
1.首先需要获取当前需要处理的数据,不管是图片,文字,声音,表格,视频等都需要通过PyTorch提供的一些API把这些数据转换成Tensor
2.我们需要选择一个和当前场景比较适合的深度学习模型,这边一般需要遵循如下的一些条件和规律:
- 首先这个模型是能解决你的问题的,没有万能模型。
- 我们也要选择对应的损失函数和优化器
- 训练阶段一般都是几个循环
3.有了处理好的数据,也有了正确的模型(最好走一个sample跑一下模型)接下来我们就应该直接跑一下训练一下模型的参数,通过训练让模型学习到这些数据的特性,简而言之,找到相关规律,将这些规律能够很好地在模型参数里面进行体现。
4.训练好了模型接下来一步就是评估模型是否符合预期,就是用测试数据来评估模型,是否真正地学习到这些数据潜在的规律。
5.在验证完模型后,如果模型符合预期,就需要把训练好的模型参数保存下来,持久化好参数,方便后续就可以直接加载使用这些训练好的模型参数。
基本上所有的PyTorch的编程范式和流程都是这样,在实际使用深度学习解决生活中实际问题的场景中,不同的地方可能就是一些细节,例如模型的复杂度,将实际的数据处理成Tensor的复杂度,训练过程中一些技巧。这边就是需要大家快速了解一下,有一个大概的印象。等后续看到代码的时候,大家就能够快速了解其中的逻辑。至少在hello world中这块还是很简单的。
接下来和大家使用jupyter notebook或者google的colab快速实现一下,废话不用多说,只需要“show me code”就ok,任何语言都是一样的,首先先导入包,如下图所示,后续导入torch和torch.nn包大家可以默认养成习惯,几乎是必须的包,这边还导入了matplotlib包,主要是用来方便用户进行数据可视化,方便用户理解,还记得之前PyTorch入门教材里面写的,可视化和coding是学习深度学习框架PyTorch中不二的法门。
数据部分处理
导入好相关的包,接下来我们就开始准备相关的数据了。刚才也说了,我们可以将我们生活各类的数据转换成Tensor,因为我们这边是入门,所以我们就先造一些数值类数据,方便大家理解,
这边是使用y=ax+b的模式来造数据的,这边我们是知道a和b的值的,这样我们就可以造出很多成对的(x,y)了,这里我们造出100对(x,y)送入到模型,看看模型能不能够通过我们送入的(x,y)键值对能不能够学习出a和b的值。
这边我们假设a等于0.7,b等于0.3。然后我们使用我们之前学到的torch.arange的api,从0开始,步长是0.02,到1为止(不包含1,左闭右开),这边大家顺便理解一下unsqueeze的作用。
好了,我们现在已经有了mock的数据了,并且我们也知道我们这次的模型最后需要得到的“答案”了,接下来我们要对数据进行简单的划分,我们需要将我们自造的这些数据划分为“训练集”,“验证集”和“测试集”了,相信如果对机器学习有过了解的话,这块应该是不陌生的。
- 训练集:这个集合里面的数据就是我们需要传送给模型,让模型进行学习,通过训练集里面的数据可以让模型找到其中的规律,当然这部分的数据的量级是比较大的,大概是80%
- 验证集:这部分数据是不给模型进行学习的,也就是说这部分的数据就像单元测试,模型无法提前获取到这个数据集的答案的,需要模型根据输出,根据平时在训练集所获得到的经验来计算相关的输出,当然这个验证集的数据不是必须的,就像素质教育一样,是没有单元测试一样。
- 测试集:这是需要模型根据测试集的输入传递给模型,让模型输出相关的结果,并且让这个结果和测试集正确的答案做对比,看看模型经过学习的准确率是否符合预期,这是评估模型好坏的一个重要数据集。
了解到数据集的理论知识后,接下来就用python的一些基本操作来进行数据集的划分。因为这是我们的比较简单的测试项目,就没有划分验证集,只有训练集和测试集,我们一共有50个数据对,按照4:1的原则,我们训练集是40个数据对,10个测试数据对。如下所示:
好了,我们现在就是用之前介绍的用“可视化”来方便我们理解,输入和输出之间的关系,这块就涉及到matplotlib里面的基本知识了,我们这里主要是学习PyTorch的,这边就不过多介绍了,我们就简单地把代码敲一遍,有一个大概印象就可以。
我们简单运行一下:
通过可视化,我们能够很清晰地发现是成线性关系的,说明在一些案例中,进行可视化确实可以方便我们理解,不过后面的课程我在学习的过程中,模型的复杂程度超过了可视化的范畴,更多地还是需要自己多看,来回看和来回理解的。
模型搭建
接下来就是正式的环节,需要搭建我们的模型了,我们之前说过选择一个合适的模型来解决我们的问题是至关重要的,但是在这个Hello world关卡里,我们是知道“答案”的,所以我们知道可以使用线性回归模型是能够解决我们的问题的。所以反而对于我们新手而言,更加重要的就是使用PyTorch来写出我们的线性回归模型。
我们来稍微解释一下这段代码,毕竟这是入门级的模型,里面还有的信息相对比较少也比较好理解:
- 我们构建了一个类,且这个类继承于PyTorch的nn.Module
- 在init初始化函数里面,我们定义了weights和bias这2个可学习的参数,是使用我们之前学习过torch.randn初始化的一个参数,类型是torch.float,并且requires_grad是true
- forward方法也是必须要写的,这个方法就是记录前向过程的函数,入参就是一个tensor,最终就是返回这个模型使用weights* x + bias 这个的前向结果作为这个模型的输出。
详细解释
一般来说构建一个网络模型PyTorch提供了四个比较基础的模块方便我们来构建各色各样的网络模型来满足各个不同的复杂场景,不管是后面的LLM模型还是Stable Diffusion模型,我们都是使用PyTorch提供的这四个模块进行构建的,这四个基础模型是
这些模块包含的东西确实非常多,看起来也比较复杂,几乎涵盖了PyTorch搭建模型的所有东西,但是仔细梳理一下你也会发现并不是很难。
torch.nn的模块里面包含了构建复杂模型的基础模块。
torch.nn.Parameter这个里面存储了模型里面的各项tensor参数,我们通过设置requires_grad设置为True的方式来确保可以用梯度下降的方式来更新模型的参数。
torch.nn.Module 这个是基类,是所有子模型的父类,也就是说如果我们使用PyTorch来构建,我们都需要继承torch.nn.Module,并且实现forward方法。
torch.optim 这个包内包含了各项各种各样的优化算法,这个里面optim里面是说在梯度下降的时候如何更好地提高梯度下降的效率和降低相关的loss。
forward函数,这个是所有nn.Module必须需要实现的类,定义了我们在具体模型前向传播的逻辑。
好的,到此为止,我们已经快速讲解了nn.Module模型信息了,我们需要手敲相关的模型,加深我们的理解,下一个小节,我们就开始讲解模型的推理和训练代码coding的流程了,大家周末愉快,本周降温,大家记得保暖。
猜你喜欢
- 2024-10-14 程序员用PyTorch实现第一个神经网络前做好这9个准备,事半功倍
- 2024-10-14 PyTorch 分布式训练简明教程 pytorch分批训练
- 2024-10-14 PyTorch入门与实战——必备基础知识(下)01
- 2024-10-14 深度学习pytorch深度学习入门与简明实战教程2022年
- 2024-10-14 深度学习框架PyTorch-trick 集锦 深度学习框架pytorch:入门与实践 第2版
- 2024-10-14 AI | 图神经网络-Pytorch Biggraph简介及官方文档解读
- 2024-10-14 PyTorch入门与实战——数据处理与数据加载02
- 2024-10-14 改动一行代码,PyTorch训练三倍提速,这些「高级技术」是关键
- 2024-10-14 利用pytorch CNN手写字母识别神经网络模型识别手写字母
- 2024-10-14 加快Python算法的四个方法(一)PyTorch
你 发表评论:
欢迎- 最近发表
- 标签列表
-
- oraclesql优化 (66)
- 类的加载机制 (75)
- feignclient (62)
- 一致性hash算法 (71)
- dockfile (66)
- 锁机制 (57)
- javaresponse (60)
- 查看hive版本 (59)
- phpworkerman (57)
- spark算子 (58)
- vue双向绑定的原理 (68)
- springbootget请求 (58)
- docker网络三种模式 (67)
- spring控制反转 (71)
- data:image/jpeg (69)
- base64 (69)
- java分页 (64)
- kibanadocker (60)
- qabstracttablemodel (62)
- java生成pdf文件 (69)
- deletelater (62)
- com.aspose.words (58)
- android.mk (62)
- qopengl (73)
- epoch_millis (61)
本文暂时没有评论,来添加一个吧(●'◡'●)