网站首页 > 技术文章 正文
选自Github
作者:szagoruyko
机器之心编译
参与:赵华龙、吴攀
本项目是论文《要更加注重注意力:通过注意迁移技术提升卷积神经网络的性能(Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer)》PyTorch 实现。点击文末「阅读原文」可查阅原论文。
项目地址:https://github.com/szagoruyko/attention-transfer
这篇论文已经提交给了 ICLR 2017 会议,正在 review 状态:https://openreview.net/forum?id=Sks9_ajex
到目前为止该代码库里的内容包括:
CIFAR-10 实验的基于激活技术的 AT 代码
ImageNet 实验的代码(ResNet-18-ResNet-34 student-teacher)
即将上线:
基于梯度的 AT
场景和基于 CUB 激活的 AT 代码
预训练的基于激活的 AT ResNet-18
代码使用 PyTorch。原始的实验是用 torch-autograd 做的,我们目前已经验证了 CIFAR-10 实验结果能够完全在 PyTorch 中复现,而且目前正在针对 ImageNet 做类似的工作(由于超参数的原因,PyTorch 的结果有一点点变差)
引用:
@article{Zagoruyko2016AT, author = {Sergey Zagoruyko and Nikos Komodakis}, title = {Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer}, url = {https://arxiv.org/abs/1612.03928}, year = {2016}}
要求
先安装 PyTorch,再安装 torchnet:
git clone https://github.com/pytorch/tntcd tntpython setup.py install
安装 OpenCV 以及 Python 支持包,以及带有 OpenCV 变换的 torchvision:
git clone https://github.com/szagoruyko/visioncd vision; git checkout opencvpython setup.py install
最后,安装其他的 Python 包:
pip install -r requirements.txt
实验
CIFAR-10
这一节讲述如何得到本文中第一个表里的那些结果。
首先,训练老师:
python cifar.py --save logs/resnet_40_1_teacher --depth 40 --width 1python cifar.py --save logs/resnet_16_2_teacher --depth 16 --width 2python cifar.py --save logs/resnet_40_2_teacher --depth 40 --width 2
用基于激活的 AT 来训练:
python cifar.py --save logs/at_16_1_16_2 --teacher_id resnet_16_2_teacher --beta 1e+3
用 KD 来训练:
python cifar.py --save logs/kd_16_1_16_2 --teacher_id resnet_16_2_teacher --alpha 0.9
我们下一步计划增加带有 beta 衰退的 AT+KD 来得到最优的知识转换结果。
ImageNet
预训练模型
我们提供带有基于激活 AT 的 ResNet-18 预训练模型:
从头开始训练
下载 ResNet-34 的预训练权值(functional-zoo 里有更多介绍):
wget https://s3.amazonaws.com/pytorch/h5models/resnet-34-export.hkl
根据 fb.resnet.torch 准备数据,然后进行训练(比如使用 2 个 GPU):
python imagenet.py --imagenetpath ~/ILSVRC2012 --depth 18 --width 1 \ --teacher_params resnet-34-export.hkl --gpu_id 0,1 --ngpu 2 \ --beta 1e+3
猜你喜欢
- 2024-10-15 你认为CNN的归纳偏差,Transformer它没有吗?
- 2024-10-15 Genetic CNN: 经典NAS算法,遗传算法的标准套用 | ICCV 2017
- 2024-10-15 CB Loss:基于有效样本的类别不平衡损失
- 2024-10-15 NVIDIA Jetson Nano 2GB 系列文章(49):智能避撞之现场演示
- 2024-10-15 针对不平衡问题建模的有趣Loss 不平衡指派问题matlab
- 2024-10-15 目标检测RCNN系列总结 目标检测nms
- 2024-10-15 谷歌开源GPipe:单个加速器处理参数3.18亿,速度提升25倍
- 2024-10-15 TensorFlow 模型优化工具包:模型大小减半,精度几乎不变
- 2024-10-15 资源受限场景下的深度学习图像分类:MSDNet多尺度密集网络
- 2024-10-15 ILSVR发展简介 ilsvrc创立者
你 发表评论:
欢迎- 最近发表
- 标签列表
-
- oraclesql优化 (66)
- 类的加载机制 (75)
- feignclient (62)
- 一致性hash算法 (71)
- dockfile (66)
- 锁机制 (57)
- javaresponse (60)
- 查看hive版本 (59)
- phpworkerman (57)
- spark算子 (58)
- vue双向绑定的原理 (68)
- springbootget请求 (58)
- docker网络三种模式 (67)
- spring控制反转 (71)
- data:image/jpeg (69)
- base64 (69)
- java分页 (64)
- kibanadocker (60)
- qabstracttablemodel (62)
- java生成pdf文件 (69)
- deletelater (62)
- com.aspose.words (58)
- android.mk (62)
- qopengl (73)
- epoch_millis (61)
本文暂时没有评论,来添加一个吧(●'◡'●)