网站首页 > 技术文章 正文
作者 | 小马
来源 | FightingCV
一、残差结构
Residual net(残差网络)
将靠前若干层的某一层数据输出直接跳过多层引入到后面数据层的输入部分。
意味着后面的特征层的内容会有一部分由其前面的某一层线性贡献。
深度残差网络的设计是为了克服由于网络深度加深而产生的学习效率变低与准确率无法有效提升的问题。
二、ResNet50模型基本构成
ResNet50有两个基本的块,分别名为Conv Block和Identity Block
Conv Block输入和输出的维度(通道数和size)是不一样的,所以不能连续串联,它的作用是改变网络的维度;
Identity Block输入维度和输出维度(通道数和size)相同,可以串联,用于加深网络的。
Conv Block结构
Identity Block的结构
三、总体的网络结构
四、代码复现
1.导库
import torch
from torch import nn
2.写Block类
'''
Block的各个plane值:
inplane:输出block的之前的通道数
midplane:在block中间处理的时候的通道数(这个值是输出维度的1/4)
midplane*self.extention:输出的维度
'''
class Bottleneck(nn.Module):
#每个stage中维度拓展的倍数
extention=4
#定义初始化的网络和参数
def __init__(self,inplane,midplane,stride,downsample=None):
super(Bottleneck,self).__init__()
self.conv1=nn.Conv2d(inplane,midplane,kernel_size=1,stride=stride,bias=False)
self.bn1=nn.BatchNorm2d(midplane)
self.conv2=nn.Conv2d(midplane,midplane,kernel_size=3,stride=1,padding=1,bias=False)
self.bn2=nn.BatchNorm2d(midplane)
self.conv3=nn.Conv2d(midplane,midplane*self.extention,kernel_size=1,stride=1,bias=False)
self.bn3=nn.BatchNorm2d(midplane*self.extention)
self.relu=nn.ReLU(inplace=False)
self.downsample=downsample
self.stride=stride
def forward(self,x):
#参差数据
residual=x
#卷积操作
out=self.relu(self.bn1(self.conv1(x)))
out=self.relu(self.bn2(self.conv2(out)))
out=self.relu(self.bn3(self.conv3(out)))
#是否直连(如果时Identity block就是直连;如果是Conv Block就需要对参差边进行卷积,改变通道数和size)
if(self.downsample!=None):
residual=self.downsample(x)
#将参差部分和卷积部分相加
out+=residual
out=self.relu(out)
return out
3.写Resnet结构
class ResNet(nn.Module):
#初始化网络结构和参数
def __init__(self,block,layers,num_classes=1000):
#self.inplane为当前的fm的通道数
self.inplane=64
super(ResNet,self).__init__()
#参数
self.block=block
self.layers=layers
#stem的网络层
self.conv1=nn.Conv2d(3,self.inplane,kernel_size=7,stride=2,padding=3,bias=False)
self.bn1=nn.BatchNorm2d(self.inplane)
self.relu=nn.ReLU()
self.maxpool=nn.MaxPool2d(kernel_size=3,padding=1,stride=2)
#64,128,256,512是指扩大4倍之前的维度,即Identity Block的中间维度
self.stage1=self.make_layer(self.block,64,self.layers[0],stride=1)
self.stage2=self.make_layer(self.block,128,self.layers[1],stride=2)
self.stage3=self.make_layer(self.block,256,self.layers[2],stride=2)
self.stage4=self.make_layer(self.block,512,self.layers[3],stride=2)
#后续的网络
self.avgpool=nn.AvgPool2d(7)
self.fc = nn.Linear(512 * block.extention, num_classes)
def forward(self,x):
#stem部分:conv+bn+relu+maxpool
out=self.conv1(x)
out=self.bn1(out)
out=self.relu(out)
out=self.maxpool(out)
#block
out=self.stage1(out)
out=self.stage2(out)
out=self.stage3(out)
out=self.stage4(out)
#分类
out=self.avgpool(out)
out = torch.flatten(out, 1)
out=self.fc(out)
return out
def make_layer(self,block,midplane,block_num,stride=1):
'''
block:block模块
midplane:每个模块中间运算的维度,一般等于输出维度/4
block_num:重复次数
stride:Conv Block的步长
'''
block_list=[]
#先计算要不要加downsample模块
downsample=None
if(stride!=1or self.inplane!=midplane*block.extention):
downsample=nn.Sequential(
nn.Conv2d(self.inplane,midplane*block.extention,stride=stride,kernel_size=1,bias=False),
nn.BatchNorm2d(midplane*block.extention)
)
#Conv Block
conv_block=block(self.inplane,midplane,stride=stride,downsample=downsample)
block_list.append(conv_block)
self.inplane=midplane*block.extention
#Identity Block
for i in range(1,block_num):
block_list.append(block(self.inplane,midplane,stride=1))
return nn.Sequential(*block_list)
4.调用
resnet = ResNet(Bottleneck, [3, 4, 6, 3])
x=torch.randn(1,3,224,224)
x=resnet(x)
print(x.shape)
复现之后resnet的图
链接: https://pan.baidu.com/s/1pazTBDtVMb68tECRR1sM4Q 密码: ps9c
复现之后resnet的代码
链接: https://pan.baidu.com/s/1SKfgCnx_excnc-AzvfK8IQ 密码: fi1w
零、复现参考图:
参考
https://blog.csdn.net/weixin_44791964/article/details/102790260
猜你喜欢
- 2024-10-20 MindSpore网络实战系列:使用ResNet-50实现图像分类任务
- 2024-10-20 芯语 | 进行两阶段人体姿态估计的研究
- 2024-10-20 CNN网络结构总结[一] cnn网络层数
- 2024-10-20 COCO2018 Keypoint冠军算法解读 coco2018+keypoint冠军算法解读最新
- 2024-10-20 网络退化问题:ResNet 通过残差连接建立高速网络,实现恒等映射
- 2024-10-20 「网络结构比较」 基于Resnet两个注意力:BAM与CBAM
- 2024-10-20 MindSpore实现ResNet50详解(附单机+集群代码)
你 发表评论:
欢迎- 最近发表
- 标签列表
-
- oraclesql优化 (66)
- 类的加载机制 (75)
- feignclient (62)
- 一致性hash算法 (71)
- dockfile (66)
- 锁机制 (57)
- javaresponse (60)
- 查看hive版本 (59)
- phpworkerman (57)
- spark算子 (58)
- vue双向绑定的原理 (68)
- springbootget请求 (58)
- docker网络三种模式 (67)
- spring控制反转 (71)
- data:image/jpeg (69)
- base64 (69)
- java分页 (64)
- kibanadocker (60)
- qabstracttablemodel (62)
- java生成pdf文件 (69)
- deletelater (62)
- com.aspose.words (58)
- android.mk (62)
- qopengl (73)
- epoch_millis (61)
本文暂时没有评论,来添加一个吧(●'◡'●)