网站首页 > 技术文章 正文
上节课我们学习了多层感知机在回归任务上的应用,这节课我们进一步学习一下多层感知机的分类任务。
任务
使用MLP+Softmax神经网络模型实现手写数字集MNIST分类
MNIST 数据集介绍
MNIST 数据集来自美国国家标准与技术研究所,由来自 250 个不同人手写的数字构成,其中 50% 是高中学生,50% 来自人口普查局的工作人员。
MNIST 数据集包含了四个部分:
Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)
部分训练图像展示如下:
一张图片包含28*28个像素,我们把这一个数组展开成一个向量,长度是28*28=784.
如果把数据用矩阵表示,可以把MNIST训练数据变成一个形状为[60000,784]的矩阵,第一个维度数字用来索引图片,第二个维度数字用来索引每张图片中的像素点,图片里的每个像素的强度值介于0-1之间。
程序过程
1.数据集准备
我们使用torchvision.datasets.MNIST自动下载数据集到本地
相关参数解释:
root='dataset 下载数据,并且存放在dataset文件夹中
train=True 用于指定在数据集下载完成后需要载入哪部分数据,如果设置为True,则说明载入的是该数据集的训练集部分; train=False 如果设置为False,则说明载入的是该数据集的测试集部分。
transform=transforms.ToTensor() 数据的标准化等操作都在transforms中,此处是转换
download=True True为自动网络下载,False为使用本地已经下载好的数据集
然后使用torch.utils.data.DataLoader加载器加载数据集
?
相关参数解释:
batch_size 一个批次可以认为是一个包,每个包中含有batch_size张图片
shuffle=True 指数据是否打乱顺序
num_workers 非并行加载就填0
2.网络构建
建立一个四层感知机网络: 一个输入层,两个全连接的隐藏线性层,一个输出层
因为图片是28*28的,需要全部展开,最终我们要输出数字,一共10个数字。
10个数字实际上是10个类别,输出是概率分布,最后选取概率最大的作为预测值输出
?
3.训练
该部分代码过长,只展示部分关键代码
训练效果图如下:
训练300个epoch,测试集acc在0.98,训练集acc在0.993.
训练集和测试集的精度曲线图:
训练集和测试集的损失曲线图:
4.测试
测试部分代码
Pytorch预测效果(从测试集选择一张图预测)
总结,MNIST数据集作为机器学习中"hello world"级别的存在,被业界普遍认为是入门最佳数据集!有条件的快来动手试试吧!
猜你喜欢
- 2024-09-30 效果超过SGD和Adam,谷歌大脑的「神经网络优化器搜索」自动找到更好的训练优化器
- 2024-09-30 AlphaGo:技术上如何运作? alphago主要使用的技术
- 2024-09-30 综述论文:机器学习中的模型评价、模型选择与算法选择
- 2024-09-30 Colab超火的Keras/TPU深度学习实战,会点Python就能看懂的课程
- 2024-09-30 结合符号与连接,斯坦福神经状态机冲刺视觉推理新SOTA
- 2024-09-30 端到端对话模型新突破!Facebook发布大规模个性化对话数据库
- 2024-09-30 神经网络基础篇二 神经网络的基础
- 2024-09-30 吴恩达深度学习笔记(55)-Softmax 回归(Softmax regression)
- 2024-09-30 激活函数:ReLU和Softmax 激活函数relu和sigmoid
- 2024-09-30 火爆全网,只有4页!ICLR爆款论文「你只需要Patch」到底香不香?
你 发表评论:
欢迎- 最近发表
- 标签列表
-
- oraclesql优化 (66)
- 类的加载机制 (75)
- feignclient (62)
- 一致性hash算法 (71)
- dockfile (66)
- 锁机制 (57)
- javaresponse (60)
- 查看hive版本 (59)
- phpworkerman (57)
- spark算子 (58)
- vue双向绑定的原理 (68)
- springbootget请求 (58)
- docker网络三种模式 (67)
- spring控制反转 (71)
- data:image/jpeg (69)
- base64 (69)
- java分页 (64)
- kibanadocker (60)
- qabstracttablemodel (62)
- java生成pdf文件 (69)
- deletelater (62)
- com.aspose.words (58)
- android.mk (62)
- qopengl (73)
- epoch_millis (61)
本文暂时没有评论,来添加一个吧(●'◡'●)