网站首页 > 技术文章 正文
上期文章我们分享了使用EMNIST数据集训练第一个pytorch CNN手写字母识别神经网络
哪里使用pytorch训练了第一个手写字母的神经网络,并保存了预训练模型,本期我们使用上期的模型进行手写字母的识别
搭建神经网络
根据上期文章的分享,我们搭建一个手写字母识别的神经网络
import torch
import torch.nn as nn
from PIL import Image # 导入图片处理工具
import PIL.ImageOps
import numpy as np
from torchvision import transforms
import cv2
import matplotlib.pyplot as plt
# 定义神经网络
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential( # input shape (1, 28, 28)
nn.Conv2d(
in_channels=1, # 输入通道数
out_channels=16, # 输出通道数
kernel_size=5, # 卷积核大小
stride=1, #卷积步数
padding=2, # 如果想要 con2d 出来的图片长宽没有变化,
# padding=(kernel_size-1)/2 当 stride=1
), # output shape (16, 28, 28)
nn.ReLU(), # activation
nn.MaxPool2d(kernel_size=2), # 在 2x2 空间里向下采样, output shape (16, 14, 14)
)
self.conv2 = nn.Sequential( # input shape (16, 14, 14)
nn.Conv2d(16, 32, 5, 1, 2), # output shape (32, 14, 14)
nn.ReLU(), # activation
nn.MaxPool2d(2), # output shape (32, 7, 7)
)
self.out = nn.Linear(32 * 7 * 7, 37) # 全连接层,A/Z,a/z一共37个类
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) # 展平多维的卷积图成 (batch_size, 32 * 7 * 7)
output = self.out(x)
return output
我们手写字母的识别神经网络主要使用了EMNIST数据库
EMNIST 主要分为以下 6 类:
By_Class : 共 814255 张,62 类,与 NIST 相比重新划分类训练集与测试机的图片数
By_Merge: 共 814255 张,47 类, 与 NIST 相比重新划分类训练集与测试机的图片数
Balanced : 共 131600 张,47 类, 每一类都包含了相同的数据,每一类训练集 2400 张,测试集 400 张
Digits :共 28000 张,10 类,每一类包含相同数量数据,每一类训练集 24000 张,测试集 4000 张
Letters : 共 103600 张,37 类,每一类包含相同数据,每一类训练集 2400 张,测试集 400 张
MNIST : 共 70000 张,10 类,每一类包含相同数量数据(注:这里虽然数目和分类都一样,但是图片的处理方式不一样,EMNIST 的 MNIST 子集数字占的比重更大)
这里为什么后面的分类不是26+26?其主要原因是一些大小写字母比较类似的字母就合并了,比如C等等
这就是为什么神经网络的全连接层self.out = nn.Linear(32 * 7 * 7, 37)一共有37类,我们上期代码中的训练是按照letters类进行训练的,神经网络的搭建过程,这里不再一一介绍了,可以参考往期文章进行学习。
使用MNIST数据集训练第一个pytorch CNN手写数字识别神经网络
上期代码
# EMNIST 手写字母 训练集
train_data = torchvision.datasets.EMNIST(
root='./data',
train=True,
transform=torchvision.transforms.ToTensor(),
download = DOWNLOAD_MNIST,
split = 'letters'
)
# EMNIST 手写字母 测试集
test_data = torchvision.datasets.EMNIST(
root='./data',
train=False,
transform=torchvision.transforms.ToTensor(),
download=False,
split = 'letters'
)
加载图片,预处理
神经网络搭建完成后,我们需要加载图片,并进行图片的一些预处理操作
file_name = '55.png' # 导入自己的图片
img = Image.open(file_name)
img = img.convert('L')
img = PIL.ImageOps.invert(img)
img = img.transpose(Image.FLIP_LEFT_RIGHT)
img = img.rotate(90)
plt.imshow(img)
plt.show()
train_transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((28, 28)),
transforms.ToTensor(),
])
img = train_transform(img)
img = torch.unsqueeze(img, dim=0)
#torch.unsqueeze()这个函数主要是对数据维度进行扩充。
需要通过dim指定位置,给指定位置加上维数为1的维度。
通过往期文章对数据库的可视化,可以得知EMNIST数据库是黑底白字,
但是平时我们自己的照片一般是白底黑字,这里,我们使用
img = img.convert('L')
img = PIL.ImageOps.invert(img)
对图片进行颜色的翻转
且我们知道EMNIST数据库左右翻转图片后,又进行了图片的逆时针旋转90度
这里我们使用PIL库提供的函数进行图片的处理操作
img = img.transpose(Image.FLIP_LEFT_RIGHT)
img = img.rotate(90)
Grayscale:将图像转换为灰度
ToTensor:转换为张量
通过train_transform(img)函数我们对输入的图片进行预处理操作,转换为pytorch可以识别的神经网络数据
加载模型
#加载模型
model = CNN()
model.load_state_dict(torch.load('./model/Eminist.pth',map_location='cpu'))
model.eval()
def get_mapping(num, with_type='letters'):
"""
根据 mapping,由传入的 num 计算 UTF8 字符。
"""
if with_type == 'byclass':
if num <= 9:
return chr(num + 48) # 数字
elif num <= 35:
return chr(num + 55) # 大写字母
else:
return chr(num + 61) # 小写字母
elif with_type == 'letters':
return chr(num + 64) + " / " + chr(num + 96) # 大写/小写字母
elif with_type == 'digits':
return chr(num + 96)
else:
return num
model.load_state_dict(torch.load)函数加载上期神经网络训练完成的模型
get_mapping函数:
由于神经网络识别完成后,反馈给程序的是字母的UTF-8编码,我们通过查表来找到对应的字母
字符编码表(UTF-8)
神经网络识别
with torch.no_grad():
y = model(img)
print(y)
output = torch.squeeze(y)
print(output)
predict = torch.softmax(output, dim=0)
print(predict)
predict_cla = torch.argmax(predict).numpy()
print(predict_cla)
print(get_mapping(predict_cla), predict[predict_cla].numpy())
运行以上代码我们可以看到神经网络输出每个字母的识别精度,我们使用torch.argmax(predict).numpy()函数选择其中精度最大字母,并利用get_mapping函数查表选择出神经网络识别出来的字母
tensor([8.1084e-10, 6.5350e-04, 8.5815e-01, 1.2294e-05, 4.6187e-03, 9.3248e-04,
9.5208e-06, 2.7102e-02, 3.7893e-04, 3.2245e-05, 3.4475e-05, 8.7205e-06,
3.5875e-07, 2.1584e-06, 1.9850e-06, 5.4030e-02, 4.6604e-03, 4.5797e-02,
3.4711e-04, 3.0402e-03, 3.9365e-05, 1.6103e-06, 5.6906e-06, 1.3600e-07,
6.6020e-07, 2.8310e-06, 1.3650e-04, 1.1401e-09, 7.8941e-10, 7.8516e-10,
8.4411e-10, 1.2178e-09, 1.3640e-09, 8.4337e-10, 1.1152e-09, 1.0610e-09,
1.0475e-09])
2
B / b 0.8581509
我们可以看到神经网络可以成功地识别出我们手写的字母B
猜你喜欢
- 2024-10-14 程序员用PyTorch实现第一个神经网络前做好这9个准备,事半功倍
- 2024-10-14 PyTorch 分布式训练简明教程 pytorch分批训练
- 2024-10-14 PyTorch入门与实战——必备基础知识(下)01
- 2024-10-14 深度学习pytorch深度学习入门与简明实战教程2022年
- 2024-10-14 深度学习框架PyTorch-trick 集锦 深度学习框架pytorch:入门与实践 第2版
- 2024-10-14 AI | 图神经网络-Pytorch Biggraph简介及官方文档解读
- 2024-10-14 PyTorch入门与实战——数据处理与数据加载02
- 2024-10-14 改动一行代码,PyTorch训练三倍提速,这些「高级技术」是关键
- 2024-10-14 加快Python算法的四个方法(一)PyTorch
- 2024-10-14 PyTorch的Dataset 和TorchData API的比较
你 发表评论:
欢迎- 最近发表
- 标签列表
-
- oraclesql优化 (66)
- 类的加载机制 (75)
- feignclient (62)
- 一致性hash算法 (71)
- dockfile (66)
- 锁机制 (57)
- javaresponse (60)
- 查看hive版本 (59)
- phpworkerman (57)
- spark算子 (58)
- vue双向绑定的原理 (68)
- springbootget请求 (58)
- docker网络三种模式 (67)
- spring控制反转 (71)
- data:image/jpeg (69)
- base64 (69)
- java分页 (64)
- kibanadocker (60)
- qabstracttablemodel (62)
- java生成pdf文件 (69)
- deletelater (62)
- com.aspose.words (58)
- android.mk (62)
- qopengl (73)
- epoch_millis (61)
本文暂时没有评论,来添加一个吧(●'◡'●)