在上一篇文章中:44.人工智能——深度学习飞桨框架自定义数据集,定义好了数据集,本文就根据数据集定义模型,实现时间和天气多分类预测。
#定义模型
#这里以resnet50为分类模型
from paddle.vision.models import resnet50
import numpy as np
class PWModel(paddle.nn.Layer):
def __init__(self):
super(PWModel, self).__init__()
#定义骨干网resnet50,预训练权重为True
backbone = resnet50(pretrained=True)
backbone.fc=paddle.nn.Identity() #骨干网的全连接层保持一致性
self.backbone = backbone
#有两个全连接层
#时间分类
self.fc1 = paddle.nn.Linear(in_features=2048, out_features=4)
#天气分类
self.fc2 = paddle.nn.Linear(in_features=2048, out_features=3)
#前向计算
def forward(self, x):
x = self.backbone(x)
#同时完成时间和天气分类
period = self.fc1(x)
weather = self.fc2(x)
return period, weather
#查看模型结构
model=paddle.Model(WeatherModel())
model.summary((1,3,256,256))
显示结果:
-------------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
===============================================================================
Conv2D-180 [[1, 3, 256, 256]] [1, 64, 128, 128] 9,408
BatchNorm2D-180 [[1, 64, 128, 128]] [1, 64, 128, 128] 256
ReLU-61 [[1, 64, 128, 128]] [1, 64, 128, 128] 0
MaxPool2D-5 [[1, 64, 128, 128]] [1, 64, 64, 64] 0
Conv2D-182 [[1, 64, 64, 64]] [1, 64, 64, 64] 4,096
BatchNorm2D-182 [[1, 64, 64, 64]] [1, 64, 64, 64] 256
ReLU-62 [[1, 256, 64, 64]] [1, 256, 64, 64] 0
Conv2D-183 [[1, 64, 64, 64]] [1, 64, 64, 64] 36,864
BatchNorm2D-183 [[1, 64, 64, 64]] [1, 64, 64, 64] 256
Conv2D-184 [[1, 64, 64, 64]] [1, 256, 64, 64] 16,384
BatchNorm2D-184 [[1, 256, 64, 64]] [1, 256, 64, 64] 1,024
Conv2D-181 [[1, 64, 64, 64]] [1, 256, 64, 64] 16,384
BatchNorm2D-181 [[1, 256, 64, 64]] [1, 256, 64, 64] 1,024
BottleneckBlock-49 [[1, 64, 64, 64]] [1, 256, 64, 64] 0
Conv2D-185 [[1, 256, 64, 64]] [1, 64, 64, 64] 16,384
BatchNorm2D-185 [[1, 64, 64, 64]] [1, 64, 64, 64] 256
ReLU-63 [[1, 256, 64, 64]] [1, 256, 64, 64] 0
Conv2D-186 [[1, 64, 64, 64]] [1, 64, 64, 64] 36,864
BatchNorm2D-186 [[1, 64, 64, 64]] [1, 64, 64, 64] 256
Conv2D-187 [[1, 64, 64, 64]] [1, 256, 64, 64] 16,384
BatchNorm2D-187 [[1, 256, 64, 64]] [1, 256, 64, 64] 1,024
BottleneckBlock-50 [[1, 256, 64, 64]] [1, 256, 64, 64] 0
Conv2D-188 [[1, 256, 64, 64]] [1, 64, 64, 64] 16,384
BatchNorm2D-188 [[1, 64, 64, 64]] [1, 64, 64, 64] 256
ReLU-64 [[1, 256, 64, 64]] [1, 256, 64, 64] 0
Conv2D-189 [[1, 64, 64, 64]] [1, 64, 64, 64] 36,864
BatchNorm2D-189 [[1, 64, 64, 64]] [1, 64, 64, 64] 256
Conv2D-190 [[1, 64, 64, 64]] [1, 256, 64, 64] 16,384
BatchNorm2D-190 [[1, 256, 64, 64]] [1, 256, 64, 64] 1,024
BottleneckBlock-51 [[1, 256, 64, 64]] [1, 256, 64, 64] 0
Conv2D-192 [[1, 256, 64, 64]] [1, 128, 64, 64] 32,768
BatchNorm2D-192 [[1, 128, 64, 64]] [1, 128, 64, 64] 512
ReLU-65 [[1, 512, 32, 32]] [1, 512, 32, 32] 0
Conv2D-193 [[1, 128, 64, 64]] [1, 128, 32, 32] 147,456
BatchNorm2D-193 [[1, 128, 32, 32]] [1, 128, 32, 32] 512
Conv2D-194 [[1, 128, 32, 32]] [1, 512, 32, 32] 65,536
BatchNorm2D-194 [[1, 512, 32, 32]] [1, 512, 32, 32] 2,048
Conv2D-191 [[1, 256, 64, 64]] [1, 512, 32, 32] 131,072
BatchNorm2D-191 [[1, 512, 32, 32]] [1, 512, 32, 32] 2,048
BottleneckBlock-52 [[1, 256, 64, 64]] [1, 512, 32, 32] 0
Conv2D-195 [[1, 512, 32, 32]] [1, 128, 32, 32] 65,536
BatchNorm2D-195 [[1, 128, 32, 32]] [1, 128, 32, 32] 512
ReLU-66 [[1, 512, 32, 32]] [1, 512, 32, 32] 0
Conv2D-196 [[1, 128, 32, 32]] [1, 128, 32, 32] 147,456
BatchNorm2D-196 [[1, 128, 32, 32]] [1, 128, 32, 32] 512
Conv2D-197 [[1, 128, 32, 32]] [1, 512, 32, 32] 65,536
BatchNorm2D-197 [[1, 512, 32, 32]] [1, 512, 32, 32] 2,048
BottleneckBlock-53 [[1, 512, 32, 32]] [1, 512, 32, 32] 0
Conv2D-198 [[1, 512, 32, 32]] [1, 128, 32, 32] 65,536
BatchNorm2D-198 [[1, 128, 32, 32]] [1, 128, 32, 32] 512
ReLU-67 [[1, 512, 32, 32]] [1, 512, 32, 32] 0
Conv2D-199 [[1, 128, 32, 32]] [1, 128, 32, 32] 147,456
BatchNorm2D-199 [[1, 128, 32, 32]] [1, 128, 32, 32] 512
Conv2D-200 [[1, 128, 32, 32]] [1, 512, 32, 32] 65,536
BatchNorm2D-200 [[1, 512, 32, 32]] [1, 512, 32, 32] 2,048
BottleneckBlock-54 [[1, 512, 32, 32]] [1, 512, 32, 32] 0
Conv2D-201 [[1, 512, 32, 32]] [1, 128, 32, 32] 65,536
BatchNorm2D-201 [[1, 128, 32, 32]] [1, 128, 32, 32] 512
ReLU-68 [[1, 512, 32, 32]] [1, 512, 32, 32] 0
Conv2D-202 [[1, 128, 32, 32]] [1, 128, 32, 32] 147,456
BatchNorm2D-202 [[1, 128, 32, 32]] [1, 128, 32, 32] 512
Conv2D-203 [[1, 128, 32, 32]] [1, 512, 32, 32] 65,536
BatchNorm2D-203 [[1, 512, 32, 32]] [1, 512, 32, 32] 2,048
BottleneckBlock-55 [[1, 512, 32, 32]] [1, 512, 32, 32] 0
Conv2D-205 [[1, 512, 32, 32]] [1, 256, 32, 32] 131,072
BatchNorm2D-205 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
ReLU-69 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 0
Conv2D-206 [[1, 256, 32, 32]] [1, 256, 16, 16] 589,824
BatchNorm2D-206 [[1, 256, 16, 16]] [1, 256, 16, 16] 1,024
Conv2D-207 [[1, 256, 16, 16]] [1, 1024, 16, 16] 262,144
BatchNorm2D-207 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 4,096
Conv2D-204 [[1, 512, 32, 32]] [1, 1024, 16, 16] 524,288
BatchNorm2D-204 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 4,096
BottleneckBlock-56 [[1, 512, 32, 32]] [1, 1024, 16, 16] 0
Conv2D-208 [[1, 1024, 16, 16]] [1, 256, 16, 16] 262,144
BatchNorm2D-208 [[1, 256, 16, 16]] [1, 256, 16, 16] 1,024
ReLU-70 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 0
Conv2D-209 [[1, 256, 16, 16]] [1, 256, 16, 16] 589,824
BatchNorm2D-209 [[1, 256, 16, 16]] [1, 256, 16, 16] 1,024
Conv2D-210 [[1, 256, 16, 16]] [1, 1024, 16, 16] 262,144
BatchNorm2D-210 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 4,096
BottleneckBlock-57 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 0
Conv2D-211 [[1, 1024, 16, 16]] [1, 256, 16, 16] 262,144
BatchNorm2D-211 [[1, 256, 16, 16]] [1, 256, 16, 16] 1,024
ReLU-71 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 0
Conv2D-212 [[1, 256, 16, 16]] [1, 256, 16, 16] 589,824
BatchNorm2D-212 [[1, 256, 16, 16]] [1, 256, 16, 16] 1,024
Conv2D-213 [[1, 256, 16, 16]] [1, 1024, 16, 16] 262,144
BatchNorm2D-213 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 4,096
BottleneckBlock-58 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 0
Conv2D-214 [[1, 1024, 16, 16]] [1, 256, 16, 16] 262,144
BatchNorm2D-214 [[1, 256, 16, 16]] [1, 256, 16, 16] 1,024
ReLU-72 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 0
Conv2D-215 [[1, 256, 16, 16]] [1, 256, 16, 16] 589,824
BatchNorm2D-215 [[1, 256, 16, 16]] [1, 256, 16, 16] 1,024
Conv2D-216 [[1, 256, 16, 16]] [1, 1024, 16, 16] 262,144
BatchNorm2D-216 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 4,096
BottleneckBlock-59 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 0
Conv2D-217 [[1, 1024, 16, 16]] [1, 256, 16, 16] 262,144
BatchNorm2D-217 [[1, 256, 16, 16]] [1, 256, 16, 16] 1,024
ReLU-73 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 0
Conv2D-218 [[1, 256, 16, 16]] [1, 256, 16, 16] 589,824
BatchNorm2D-218 [[1, 256, 16, 16]] [1, 256, 16, 16] 1,024
Conv2D-219 [[1, 256, 16, 16]] [1, 1024, 16, 16] 262,144
BatchNorm2D-219 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 4,096
BottleneckBlock-60 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 0
Conv2D-220 [[1, 1024, 16, 16]] [1, 256, 16, 16] 262,144
BatchNorm2D-220 [[1, 256, 16, 16]] [1, 256, 16, 16] 1,024
ReLU-74 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 0
Conv2D-221 [[1, 256, 16, 16]] [1, 256, 16, 16] 589,824
BatchNorm2D-221 [[1, 256, 16, 16]] [1, 256, 16, 16] 1,024
Conv2D-222 [[1, 256, 16, 16]] [1, 1024, 16, 16] 262,144
BatchNorm2D-222 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 4,096
BottleneckBlock-61 [[1, 1024, 16, 16]] [1, 1024, 16, 16] 0
Conv2D-224 [[1, 1024, 16, 16]] [1, 512, 16, 16] 524,288
BatchNorm2D-224 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,048
ReLU-75 [[1, 2048, 8, 8]] [1, 2048, 8, 8] 0
Conv2D-225 [[1, 512, 16, 16]] [1, 512, 8, 8] 2,359,296
BatchNorm2D-225 [[1, 512, 8, 8]] [1, 512, 8, 8] 2,048
Conv2D-226 [[1, 512, 8, 8]] [1, 2048, 8, 8] 1,048,576
BatchNorm2D-226 [[1, 2048, 8, 8]] [1, 2048, 8, 8] 8,192
Conv2D-223 [[1, 1024, 16, 16]] [1, 2048, 8, 8] 2,097,152
BatchNorm2D-223 [[1, 2048, 8, 8]] [1, 2048, 8, 8] 8,192
BottleneckBlock-62 [[1, 1024, 16, 16]] [1, 2048, 8, 8] 0
Conv2D-227 [[1, 2048, 8, 8]] [1, 512, 8, 8] 1,048,576
BatchNorm2D-227 [[1, 512, 8, 8]] [1, 512, 8, 8] 2,048
ReLU-76 [[1, 2048, 8, 8]] [1, 2048, 8, 8] 0
Conv2D-228 [[1, 512, 8, 8]] [1, 512, 8, 8] 2,359,296
BatchNorm2D-228 [[1, 512, 8, 8]] [1, 512, 8, 8] 2,048
Conv2D-229 [[1, 512, 8, 8]] [1, 2048, 8, 8] 1,048,576
BatchNorm2D-229 [[1, 2048, 8, 8]] [1, 2048, 8, 8] 8,192
BottleneckBlock-63 [[1, 2048, 8, 8]] [1, 2048, 8, 8] 0
Conv2D-230 [[1, 2048, 8, 8]] [1, 512, 8, 8] 1,048,576
BatchNorm2D-230 [[1, 512, 8, 8]] [1, 512, 8, 8] 2,048
ReLU-77 [[1, 2048, 8, 8]] [1, 2048, 8, 8] 0
Conv2D-231 [[1, 512, 8, 8]] [1, 512, 8, 8] 2,359,296
BatchNorm2D-231 [[1, 512, 8, 8]] [1, 512, 8, 8] 2,048
Conv2D-232 [[1, 512, 8, 8]] [1, 2048, 8, 8] 1,048,576
BatchNorm2D-232 [[1, 2048, 8, 8]] [1, 2048, 8, 8] 8,192
BottleneckBlock-64 [[1, 2048, 8, 8]] [1, 2048, 8, 8] 0
AdaptiveAvgPool2D-5 [[1, 2048, 8, 8]] [1, 2048, 1, 1] 0
Identity-5 [[1, 2048]] [1, 2048] 0
ResNet-5 [[1, 3, 256, 256]] [1, 2048] 0
Linear-14 [[1, 2048]] [1, 4] 8,196
Linear-15 [[1, 2048]] [1, 3] 6,147
===============================================================================
Total params: 23,575,495
Trainable params: 23,469,255
Non-trainable params: 106,240
-------------------------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 341.55
Params size (MB): 89.93
Estimated Total Size (MB): 432.23
-------------------------------------------------------------------------------
{'total_params': 23575495, 'trainable_params': 23469255}
可以看到在Linear-14和Linear-15两个全连接层输出两个分类。
本文暂时没有评论,来添加一个吧(●'◡'●)