网站首页 > 技术文章 正文
文章: Residual Attention: A Simple but Effective Method for Multi-Label Recognition[1], ICCV2021
下面说一下我对这篇文章的浅陋之见, 如有错误, 请多包涵指正.
文章的核心方法
如下图所示为其处理流程:
图中 X 为CNN骨干网络提取得到的feature, 其大小为 d*h*w , 为1个batch数据. 一般 d*h*w=2048*7*7 .
从图中可以看到, 有2个分支, 一个是 average pooling, 一个是 spatial pooling, 最后二者加权融合得到 residual attention .
Spatial pooling
其过程为:
这里有个 1*1 的卷积操作FC , 其大小为 C*d*1*1 , C 为类别数, 如果直接使用矩阵乘法计算, FC(X) 后的大小为 C*h*w .
但文章中的公式是将其展开为对每个空间点单独计算, 其中 $\pmb{m_i}$? 为 FC 第i 个类别的参数, 其大小为 d*1*1, 计算得到的 $s^i_j$??? 为第 i 个类别在第 j 个位置的概率, $\pmb{a^i}$???? 为第 i 个类别的特征, 其大小为 d*1 .
如果, $\pmb{m_i}$ 和 $\pmb{a^i}$ 计算就可以得到第 i 个类别的概率. 这样就可以用到每个空间点的特征, 有利于不同目标不同类别物体的分类识别.
公式中有个温度参数 T 用来控制 $s^i_j$???? 的大小, 当 T 趋于无穷时, spatial pooling 就变成了 max pooling
Average pooling
其过程为:
上式其实就是一般分类模型的做法, 全局均值池化.
Residual Attention
如下所示, 将上述2个过程进行加权融合:
其中, $\pmb{f^i}$ 大小为 d*1, $\pmb{m_i}^T \pmb{f^i}$ 为第 i 个类别的概率.
至于为什么叫 Residual Attention , 文章中的说法是:
the max pooling among different spatial regions for every class, is in fact a class-specific attention operation, which can be further viewed as a residual component of the class-agnostic global average pooling.
我的理解是, 公式5形式有点像 residual 形式.
文章实验结果
多标签
如下表所示为作者对多个数据集的测试, 除了ImageNet 为单标签外, 其它都为多标签. 可以看到多标签提升还是不错的.
热力图
由于利用到了不同位置空间点的信息, 获得的 heatmap 会更加准确, 文章中给出了一张结果, 如下:
我觉得这里有个遗憾的是, 文中没有进行对比.
个人理解
关于原理
根据流程图, 结合文中作者给出的核心代码, 其基本原理就是 average pooling + max pooling.
上述代码中: y_avg 大小为 C*1, 为 average pooling ; y_max 大小为 C*1, 为 max pooling .
下面是上述代码的一个例子, y_raw 的大小为 1*3*9 , B=1, C=3, H3H, W=3:
可以看到, y_avg 刚好为 average pooling , y_max 刚好为 max pooling .
关于公式
公式中的温度参数 T 用于调整参数大小, 而给出的核心代码中, 只有T趋于无穷的情况(等价于max pooling), 对于多个 Head 的情况, T=2,3,4,5 等, 代码中是如何体现出来的?
关于效果
对于 multi-label , 使用了 spatial pooling 和 multi-head 来提高效果, 从实验结果来看, 确实有效果, 但对于单标签情况, max pooling 应该改善不大, 从实验结果上看也确实可以看到, 单标签数据集上, 最高提升了0.02个百分点.
测试代码
测试代码如下, 可以参考这里[2].
import torch
from torch import nn
class ResidualAttention(nn.Module):
def __init__(self, channel=512, num_class=1000, la=0.2):
super().__init__()
self.la = la
self.fc = nn.Conv2d(in_channels=channel, out_channels=num_class, kernel_size=1, stride=1, bias=False)
def forward(self, x):
y_raw = self.fc(x).flatten(2) # b, num_class, h*w
y_avg = torch.mean(y_raw, dim=2) # b, num_class
y_max = torch.max(y_raw, dim=2)[0] # b, num_class
score = y_avg + self.la * y_max
return score
if __name__ == '__main__':
channel = 4
num_class = 3
batchsize = 1
input = torch.randn(batchsize, channel, 3, 3)
resatt = ResidualAttention(channel=channel, num_class=num_class, la=0.2)
output = resatt(input)
print(output.shape)
引用链接
[1] Residual Attention: A Simple but Effective Method for Multi-Label Recognition: http://arxiv.org/abs/2108.02456
[2] 这里: https://github.com/xmu-xiaoma666/External-Attention-pytorch#23-Residual-Attention-Usage
猜你喜欢
- 2024-10-12 「AAAI oral」阿里北大提出新attention建模框架
- 2024-10-12 CVPR 2020 | 港中文、上交大、商汤联合提出两种轨迹预测新方法
- 2024-10-12 东北石油大学研究者提出电能质量扰动识别的新方法
- 2024-10-12 「独家解读」谷歌会议app背景模糊和替换技术解析
- 2024-10-12 如何在深度学习模型内部做特征选择?
- 2024-10-12 深度时空网络、记忆网络与特征表达学习在 CTR 预估中的应用
- 2024-10-12 揭秘 BERT 火爆背后的功臣——Attention
- 2024-10-12 MViT:性能杠杠的多尺度ViT | ICCV 2021
- 2024-10-12 CTR预估系列(5)–阿里Deep Interest Network理论
- 2024-10-12 MeMViT:记忆增强的多尺度ViT做长时视频识别
你 发表评论:
欢迎- 最近发表
- 标签列表
-
- oraclesql优化 (66)
- 类的加载机制 (75)
- feignclient (62)
- 一致性hash算法 (71)
- dockfile (66)
- 锁机制 (57)
- javaresponse (60)
- 查看hive版本 (59)
- phpworkerman (57)
- spark算子 (58)
- vue双向绑定的原理 (68)
- springbootget请求 (58)
- docker网络三种模式 (67)
- spring控制反转 (71)
- data:image/jpeg (69)
- base64 (69)
- java分页 (64)
- kibanadocker (60)
- qabstracttablemodel (62)
- java生成pdf文件 (69)
- deletelater (62)
- com.aspose.words (58)
- android.mk (62)
- qopengl (73)
- epoch_millis (61)
本文暂时没有评论,来添加一个吧(●'◡'●)