计算机系统应用教程网站

网站首页 > 技术文章 正文

利用pytorch CNN手写字母识别神经网络模型识别手写字母

btikc 2024-10-14 08:49:35 技术文章 3 ℃ 0 评论

上期文章我们分享了使用EMNIST数据集训练第一个pytorch CNN手写字母识别神经网络

哪里使用pytorch训练了第一个手写字母的神经网络,并保存了预训练模型,本期我们使用上期的模型进行手写字母的识别

搭建神经网络

根据上期文章的分享,我们搭建一个手写字母识别的神经网络

import torch
import torch.nn as nn
from PIL import Image  # 导入图片处理工具
import PIL.ImageOps
import numpy as np
from torchvision import transforms
import cv2
import matplotlib.pyplot as plt
# 定义神经网络
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(  # input shape (1, 28, 28)
            nn.Conv2d(
                in_channels=1,  # 输入通道数
                out_channels=16,  # 输出通道数
                kernel_size=5,   # 卷积核大小
                stride=1,  #卷积步数
                padding=2,  # 如果想要 con2d 出来的图片长宽没有变化, 
                            # padding=(kernel_size-1)/2 当 stride=1
            ),  # output shape (16, 28, 28)
            nn.ReLU(),  # activation
            nn.MaxPool2d(kernel_size=2),  # 在 2x2 空间里向下采样, output shape (16, 14, 14)
        )
        self.conv2 = nn.Sequential(  # input shape (16, 14, 14)
            nn.Conv2d(16, 32, 5, 1, 2),  # output shape (32, 14, 14)
            nn.ReLU(),  # activation
            nn.MaxPool2d(2),  # output shape (32, 7, 7)
        )
        self.out = nn.Linear(32 * 7 * 7, 37)  # 全连接层,A/Z,a/z一共37个类

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)  # 展平多维的卷积图成 (batch_size, 32 * 7 * 7)
        output = self.out(x)
        return output

我们手写字母的识别神经网络主要使用了EMNIST数据库

EMNIST 主要分为以下 6 类:

By_Class : 共 814255 张,62 类,与 NIST 相比重新划分类训练集与测试机的图片数

By_Merge: 共 814255 张,47 类, 与 NIST 相比重新划分类训练集与测试机的图片数

Balanced : 共 131600 张,47 类, 每一类都包含了相同的数据,每一类训练集 2400 张,测试集 400 张

Digits :共 28000 张,10 类,每一类包含相同数量数据,每一类训练集 24000 张,测试集 4000 张

Letters : 共 103600 张,37 类,每一类包含相同数据,每一类训练集 2400 张,测试集 400 张

MNIST : 共 70000 张,10 类,每一类包含相同数量数据(注:这里虽然数目和分类都一样,但是图片的处理方式不一样,EMNIST 的 MNIST 子集数字占的比重更大)

这里为什么后面的分类不是26+26?其主要原因是一些大小写字母比较类似的字母就合并了,比如C等等

这就是为什么神经网络的全连接层self.out = nn.Linear(32 * 7 * 7, 37)一共有37类,我们上期代码中的训练是按照letters类进行训练的,神经网络的搭建过程,这里不再一一介绍了,可以参考往期文章进行学习。

使用MNIST数据集训练第一个pytorch CNN手写数字识别神经网络

pytorch利用CNN卷积神经网络来识别手写数字

上期代码
# EMNIST 手写字母 训练集
train_data = torchvision.datasets.EMNIST(
    root='./data',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download = DOWNLOAD_MNIST,
    split = 'letters' 
)
# EMNIST 手写字母 测试集
test_data = torchvision.datasets.EMNIST(
    root='./data',
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=False,
    split = 'letters'     
)

加载图片,预处理

神经网络搭建完成后,我们需要加载图片,并进行图片的一些预处理操作


file_name = '55.png'  # 导入自己的图片
img = Image.open(file_name)
img = img.convert('L')

img = PIL.ImageOps.invert(img)
img = img.transpose(Image.FLIP_LEFT_RIGHT)
img = img.rotate(90)

plt.imshow(img)
plt.show()

train_transform = transforms.Compose([
       transforms.Grayscale(),
         transforms.Resize((28, 28)),
         transforms.ToTensor(),
 ])

img = train_transform(img)
img = torch.unsqueeze(img, dim=0)
#torch.unsqueeze()这个函数主要是对数据维度进行扩充。
需要通过dim指定位置,给指定位置加上维数为1的维度。
通过往期文章对数据库的可视化,可以得知EMNIST数据库是黑底白字,
但是平时我们自己的照片一般是白底黑字,这里,我们使用
img = img.convert('L')
img = PIL.ImageOps.invert(img)
对图片进行颜色的翻转
且我们知道EMNIST数据库左右翻转图片后,又进行了图片的逆时针旋转90度
这里我们使用PIL库提供的函数进行图片的处理操作
img = img.transpose(Image.FLIP_LEFT_RIGHT)
img = img.rotate(90)
Grayscale:将图像转换为灰度
ToTensor:转换为张量

通过train_transform(img)函数我们对输入的图片进行预处理操作,转换为pytorch可以识别的神经网络数据

加载模型

#加载模型
model = CNN()
model.load_state_dict(torch.load('./model/Eminist.pth',map_location='cpu'))
model.eval()

def get_mapping(num, with_type='letters'):
    """
    根据 mapping,由传入的 num 计算 UTF8 字符。
    """
    if with_type == 'byclass':
        if num <= 9:
            return chr(num + 48)  # 数字
        elif num <= 35:
            return chr(num + 55)  # 大写字母
        else:
            return chr(num + 61)  # 小写字母
    elif with_type == 'letters':
        return chr(num + 64) + " / " + chr(num + 96)  # 大写/小写字母
    elif with_type == 'digits':
        return chr(num + 96)
    else:
        return num

model.load_state_dict(torch.load)函数加载上期神经网络训练完成的模型

get_mapping函数:

由于神经网络识别完成后,反馈给程序的是字母的UTF-8编码,我们通过查表来找到对应的字母

字符编码表(UTF-8)

神经网络识别

with torch.no_grad():
    y = model(img)
    print(y)
    output = torch.squeeze(y)
    print(output)
    predict = torch.softmax(output, dim=0)
    print(predict)
    predict_cla = torch.argmax(predict).numpy()
    print(predict_cla)
print(get_mapping(predict_cla), predict[predict_cla].numpy())

运行以上代码我们可以看到神经网络输出每个字母的识别精度,我们使用torch.argmax(predict).numpy()函数选择其中精度最大字母,并利用get_mapping函数查表选择出神经网络识别出来的字母

tensor([8.1084e-10, 6.5350e-04, 8.5815e-01, 1.2294e-05, 4.6187e-03, 9.3248e-04,
        9.5208e-06, 2.7102e-02, 3.7893e-04, 3.2245e-05, 3.4475e-05, 8.7205e-06,
        3.5875e-07, 2.1584e-06, 1.9850e-06, 5.4030e-02, 4.6604e-03, 4.5797e-02,
        3.4711e-04, 3.0402e-03, 3.9365e-05, 1.6103e-06, 5.6906e-06, 1.3600e-07,
        6.6020e-07, 2.8310e-06, 1.3650e-04, 1.1401e-09, 7.8941e-10, 7.8516e-10,
        8.4411e-10, 1.2178e-09, 1.3640e-09, 8.4337e-10, 1.1152e-09, 1.0610e-09,
        1.0475e-09])
2
B / b 0.8581509

我们可以看到神经网络可以成功地识别出我们手写的字母B

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表