计算机系统应用教程网站

网站首页 > 技术文章 正文

101.人工智能——构建残差网络ResNet18网络模型

btikc 2024-10-18 04:43:21 技术文章 9 ℃ 0 评论

残差网络(Residual Network,ResNet)是He Kaiming等人提出的,以解决网络层数加深之后模型效果没有提升的问题。基本设计思想是引入了残差块,2015年ImageNet比赛的冠军,TOP5错误率3.6%,甚至超出人眼识别的准确度。

在神经网络模型中给非线性层增加直连边的方式来缓解梯度消失问题,从而使训练深度神经网络变得更加容易。

在残差网络中,最基本的单位为残差单元。一个残差网络通常有很多个残差单元堆叠而成。

模型构建

实现一个算子ResBlock来构建残差单元,其中定义了use_residual参数,用于控制是否使用残差连接。


残差单元包裹的非线性层的输入和输出形状大小应该一致。如果一个卷积层的输入特征图和输出特征图的通道数不一致,则其输出与输入特征图无法直接相加。为了解决上述问题,我们可以使用1×1大小的卷积将输入特征图的通道数映射为与级联卷积输出特征图的一致通道数。

1×1卷积:与标准卷积完全一样,唯一的特殊点在于卷积核的尺寸是1×1,也就是不去考虑输入数据局部信息之间的关系,而把关注点放在不同通道间。通过使用1×1卷积,可以起到如下作用:

  • 实现信息的跨通道交互与整合。考虑到卷积运算的输入输出都是3个维度(宽、高、多通道),所以1×1卷积实际上就是对每个像素点,在不同的通道上进行线性组合,从而整合不同通道的信息;
  • 对卷积核通道数进行降维和升维,减少参数量。经过1×1卷积后的输出保留了输入数据的原有平面结构,通过调控通道数,从而完成升维或降维的作用;
  • 利用1×1卷积后的非线性激活函数,在保持特征图尺寸不变的前提下,大幅增加非线性。
class ResBlock(nn.Layer):
    def __init__(self, num_channels, num_filters, stride=1, use_1x1conv=False, use_residual=True):
        """
        残差单元
        输入:
            - num_channels:输入通道数
            - num_filters:输出通道数
            - stride:残差单元的步长
            - use_1x1conv:当残差单元包裹的非线性层输入和输出通道数不一致时,需要用1*1卷积调整通道数后再进行相加运算
            - use_residual:用于控制是否使用残差连接
        """
        super(ResBlock, self).__init__()
        self.stride = stride
        self.use_1x1conv = use_1x1conv
        self.use_residual = use_residual
        #级联的等宽卷积
        self.conv1 = nn.Conv2D(num_channels, num_filters, kernel_size=3, padding=1, stride=self.stride)
        self.conv2 = nn.Conv2D(num_filters, num_filters, kernel_size=3, padding=1)
        #如果conv2的输出和此残差块的输入数据形状一致,则use_1x1conv=False
        #否则use_1x1conv=True,添加1个1x1的卷积作用在输入数据上,使其形状变成跟conv2一致
        if self.use_1x1conv:
            self.short = nn.Conv2D(num_channels, num_filters, kernel_size=1, stride=self.stride)
        #每个卷积后会接一个批量归一化层。
        self.bn1 = nn.BatchNorm2D(num_filters)
        self.bn2 = nn.BatchNorm2D(num_filters)

    def forward(self, inputs):
        y = F.relu(self.bn1(self.conv1(inputs)))
        y = self.bn2(self.conv2(y))
        if self.use_residual:
            #如果use_1x1conv=True,需要对inputs进行卷积,将形状调整成跟conv2输出一致
            if self.use_1x1conv:
                short = self.short(inputs)
            #否则直接将inputs和conv2的输出相加
            else:
                short = inputs
            y = paddle.add(x=short, y=y)
        out = F.relu(y)
        return out

残差网络的整体结构

残差网络就是将很多个残差单元串联起来构成的一个非常深的网络。

其中为了便于理解,可以将ResNet18网络划分为6个模块:

  • 第一模块:包含了一个步长为2,大小为7×7的卷积层,卷积层的输出通道数为64,卷积层的输出经过批量归一化、ReLU激活函数的处理后,接了一个步长为2的3×3的最大汇聚层;
  • 第二模块:包含了两个残差单元,经过运算后,输出通道数为64,特征图的尺寸保持不变;
  • 第三模块:包含了两个残差单元,经过运算后,输出通道数为128,特征图的尺寸缩小一半;
  • 第四模块:包含了两个残差单元,经过运算后,输出通道数为256,特征图的尺寸缩小一半;
  • 第五模块:包含了两个残差单元,经过运算后,输出通道数为512,特征图的尺寸缩小一半;
  • 第六模块:包含了一个全局平均汇聚层,将特征图变为1×1的大小,最终经过全连接层计算出最后的输出。

ResNet18模型的代码实现如下:

def make_first_block(in_channels):
    #模块一:7*7卷积、批归一化、汇聚
    b1 = nn.Sequential(nn.Conv2D(in_channels, 64, kernel_size=7, stride=2, padding=3),
                    nn.BatchNorm2D(64), nn.ReLU(),
                    nn.MaxPool2D(kernel_size=3, stride=2, padding=1))
    return b1
def resnet_block(input_channels, num_channels, num_residuals,
                 first_block=False, use_residual=True):
    blk = []
    #循环生成模块二到模块五
    for i in range(num_residuals):
        #创建模块三、四、五的第一个block,stride=2,需要使用1*1卷积调整通道数
        if i == 0 and not first_block:
            blk.append(ResBlock(input_channels, num_channels,
                                stride=2, use_1x1conv=True, 
                                use_residual=use_residual))
        #创建其他block
        else:
            blk.append(ResBlock(num_channels, num_channels,
                                use_residual=use_residual))
    return blk
def make_blocks(use_residual):
    #创建模块二到模块五
    b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True, 
                                use_residual=use_residual))
    b3 = nn.Sequential(*resnet_block(64, 128, 2, 
                                use_residual=use_residual))
    b4 = nn.Sequential(*resnet_block(128, 256, 2, 
                                use_residual=use_residual))
    b5 = nn.Sequential(*resnet_block(256, 512, 2, 
                                use_residual=use_residual))
    return b2, b3, b4, b5
#定义完整网络
class Model_ResNet18(nn.Layer):
    def __init__(self, in_channels=3, num_classes=10, use_residual=True):
        super(Model_ResNet18,self).__init__()
        b1 = make_first_block(in_channels)
        b2, b3, b4, b5 = make_blocks(use_residual)
        #封装模块一到模块6
        self.net = nn.Sequential(b1, b2, b3, b4, b5,
                            #模块六:汇聚层、全连接层
                            nn.AdaptiveAvgPool2D(1),
                            nn.Flatten(), nn.Linear(512, num_classes))

    def forward(self, x):
        return self.net(x)

这里同样可以使用paddle.summary统计模型的参数量。

model = Model_ResNet18(in_channels=1, num_classes=10, use_residual=True)
params_info = paddle.summary(model, (1, 1, 32, 32))
print(params_info)
-------------------------------------------------------------------------------
   Layer (type)         Input Shape          Output Shape         Param #    
===============================================================================
     Conv2D-22        [[1, 1, 32, 32]]     [1, 64, 16, 16]         3,200     
   BatchNorm2D-1     [[1, 64, 16, 16]]     [1, 64, 16, 16]          256      
      ReLU-1         [[1, 64, 16, 16]]     [1, 64, 16, 16]           0       
    MaxPool2D-5      [[1, 64, 16, 16]]      [1, 64, 8, 8]            0       
     Conv2D-23        [[1, 64, 8, 8]]       [1, 64, 8, 8]         36,928     
   BatchNorm2D-2      [[1, 64, 8, 8]]       [1, 64, 8, 8]           256      
     Conv2D-24        [[1, 64, 8, 8]]       [1, 64, 8, 8]         36,928     
   BatchNorm2D-3      [[1, 64, 8, 8]]       [1, 64, 8, 8]           256      
    ResBlock-1        [[1, 64, 8, 8]]       [1, 64, 8, 8]            0       
     Conv2D-25        [[1, 64, 8, 8]]       [1, 64, 8, 8]         36,928     
   BatchNorm2D-4      [[1, 64, 8, 8]]       [1, 64, 8, 8]           256      
     Conv2D-26        [[1, 64, 8, 8]]       [1, 64, 8, 8]         36,928     
   BatchNorm2D-5      [[1, 64, 8, 8]]       [1, 64, 8, 8]           256      
    ResBlock-2        [[1, 64, 8, 8]]       [1, 64, 8, 8]            0       
     Conv2D-27        [[1, 64, 8, 8]]       [1, 128, 4, 4]        73,856     
   BatchNorm2D-6      [[1, 128, 4, 4]]      [1, 128, 4, 4]          512      
     Conv2D-28        [[1, 128, 4, 4]]      [1, 128, 4, 4]        147,584    
   BatchNorm2D-7      [[1, 128, 4, 4]]      [1, 128, 4, 4]          512      
     Conv2D-29        [[1, 64, 8, 8]]       [1, 128, 4, 4]         8,320     
    ResBlock-3        [[1, 64, 8, 8]]       [1, 128, 4, 4]           0       
     Conv2D-30        [[1, 128, 4, 4]]      [1, 128, 4, 4]        147,584    
   BatchNorm2D-8      [[1, 128, 4, 4]]      [1, 128, 4, 4]          512      
     Conv2D-31        [[1, 128, 4, 4]]      [1, 128, 4, 4]        147,584    
   BatchNorm2D-9      [[1, 128, 4, 4]]      [1, 128, 4, 4]          512      
    ResBlock-4        [[1, 128, 4, 4]]      [1, 128, 4, 4]           0       
     Conv2D-32        [[1, 128, 4, 4]]      [1, 256, 2, 2]        295,168    
  BatchNorm2D-10      [[1, 256, 2, 2]]      [1, 256, 2, 2]         1,024     
     Conv2D-33        [[1, 256, 2, 2]]      [1, 256, 2, 2]        590,080    
  BatchNorm2D-11      [[1, 256, 2, 2]]      [1, 256, 2, 2]         1,024     
     Conv2D-34        [[1, 128, 4, 4]]      [1, 256, 2, 2]        33,024     
    ResBlock-5        [[1, 128, 4, 4]]      [1, 256, 2, 2]           0       
     Conv2D-35        [[1, 256, 2, 2]]      [1, 256, 2, 2]        590,080    
  BatchNorm2D-12      [[1, 256, 2, 2]]      [1, 256, 2, 2]         1,024     
     Conv2D-36        [[1, 256, 2, 2]]      [1, 256, 2, 2]        590,080    
  BatchNorm2D-13      [[1, 256, 2, 2]]      [1, 256, 2, 2]         1,024     
    ResBlock-6        [[1, 256, 2, 2]]      [1, 256, 2, 2]           0       
     Conv2D-37        [[1, 256, 2, 2]]      [1, 512, 1, 1]       1,180,160   
  BatchNorm2D-14      [[1, 512, 1, 1]]      [1, 512, 1, 1]         2,048     
     Conv2D-38        [[1, 512, 1, 1]]      [1, 512, 1, 1]       2,359,808   
  BatchNorm2D-15      [[1, 512, 1, 1]]      [1, 512, 1, 1]         2,048     
     Conv2D-39        [[1, 256, 2, 2]]      [1, 512, 1, 1]        131,584    
    ResBlock-7        [[1, 256, 2, 2]]      [1, 512, 1, 1]           0       
     Conv2D-40        [[1, 512, 1, 1]]      [1, 512, 1, 1]       2,359,808   
  BatchNorm2D-16      [[1, 512, 1, 1]]      [1, 512, 1, 1]         2,048     
     Conv2D-41        [[1, 512, 1, 1]]      [1, 512, 1, 1]       2,359,808   
  BatchNorm2D-17      [[1, 512, 1, 1]]      [1, 512, 1, 1]         2,048     
    ResBlock-8        [[1, 512, 1, 1]]      [1, 512, 1, 1]           0       
AdaptiveAvgPool2D-1   [[1, 512, 1, 1]]      [1, 512, 1, 1]           0       
     Flatten-1        [[1, 512, 1, 1]]         [1, 512]              0       
     Linear-11           [[1, 512]]            [1, 10]             5,130     
===============================================================================
Total params: 11,186,186
Trainable params: 11,170,570
Non-trainable params: 15,616
-------------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 1.03
Params size (MB): 42.67
Estimated Total Size (MB): 43.70
-------------------------------------------------------------------------------

{'total_params': 11186186, 'trainable_params': 11170570}

Tags:

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表