计算机系统应用教程网站

网站首页 > 技术文章 正文

利用LSTM思想来做CNN剪枝,北大提出Gate Decorator

btikc 2024-09-06 17:58:10 技术文章 13 ℃ 0 评论

选自arXiv

作者:Zhonghui You等

机器之心编译

参与:思源、一鸣

利用LSTM基本思想门控机制进行剪枝?让模型自己决定哪些卷积核可以扔。

还记得在理解 LSTM 的时候,我们会发现,它用一种门控机制记住重要的信息而遗忘不重要的信息。在此之后,很多机器学习方法都受到了门控机制的影响,包括 Highway Network 和 GRU 等等。北大的研究者同样也是,它们将门控机制加入到 CNN 剪枝中,让模型自己决定哪些滤波器不太重要,那么它们就可以删除了。

其实对滤波器进行剪枝是一种最为有效的、用于加速和压缩卷积神经网络的方法。在这篇论文中,来自北大的研究者提出了一种全局滤波器剪枝的算法,名为「门装饰器(gate decorator)」。这一算法可以通过将输出和通道方向的尺度因子(门)相乘,进而改变标准的 CNN 模块。当这种尺度因子被设0的时候,就如同移除了对应的滤波器。

研究人员使用了泰勒展开,用于估计因设定了尺度因子为 0 时对损失函数造成的影响,并用这种估计值来给全局滤波器的重要性进行打分排序。接着,研究者移除哪些不重要的滤波器。在剪枝后,研究人员将所有的尺度因子合并到原始的模块中,因此不需要引入特别的运算或架构。此外,为了提升剪枝的准确率,研究者还提出了一种迭代式的剪枝架构—— Tick-Tock。

图 1:滤波器剪枝图示。第 i 个层有4个滤波器(通道)。如果移除其中一个,对应的特征映射就会消失,而输入 i+1 层的通道也会变为3。

扩展实验说明了研究者提出的方法的效果。例如,研究人员在 ResNet-56 上达到了剪枝比例最好的 SOTA,减少了 70% 的每秒浮点运算次数,但没有带来明显的准确率降低。

在 ImageNet 上训练的 ResNet-50 上,研究者减少了 40% 的每秒浮点运算次数,且在 top-1 准确率上超过了基线模型 0.31%。在研究中使用了多种数据,包括 CIFAR-10、CIFAR-100、CUB-200、ImageNet ILSVRC-12 和 PASCAL VOC 2011。

本文的主要贡献包括两个部分:第一部分是「门装饰器」算法,用于解决 GFIR 问题。第二部分是 Tick-Tock 剪枝框架,用于提升剪枝准确率。

具体而言,研究者展示了如何将门装饰器用于批归一化操作,并将这种方法命名为门批归一化(GBN)。给定预训练模型,研究者在剪枝前将归一化模块转换成门批归一化。剪枝结束后,他们将门批归一化还原为批归一化。通过这样的方法,不需要给模型引入特殊的运算或架构。

  • 论文地址:https://arxiv.org/abs/1909.08174
  • 实现地址:https://github.com/youzhonghui/gate-decorator-pruning

门控剪枝到底怎么做

那么到底怎样使用门控机制解决全局滤波器重要性排序呢?研究者表示他们会先将 Gate Decorator 应用到批归一化机制中,然后使用一种名为 Tick-Tock 的迭代剪枝框架来获得更好的剪枝准确率,最后再采用分组剪枝(Group Pruning)技术解决待条件的剪枝问题,例如剪枝带残差连接的网络。

上面简要展示了叙述了门控剪枝三步走,后面会做一个简单的介绍,当然更详细的内容可查阅原论文。

门控批归一化

研究者将 Gate Decorator应用到批归一化中,并将该模块称之为门控批归一化(GBN),门控批归一化如下方程7所示,它和标准批归一化的不同之处在于 φ arrow的门控选择。其中 φ arrow 是 φ 的一个向量,c 是 Z_in 的通道数。

如果 φ arrow 中的元素是零,那么就表示它对应的通道被裁减了。此外,对于不使用BN 的网络,我们也可以直接将 Gate Decorator 应用到卷积运算中,从而达到门控剪枝的效果。

Tick-Tock 剪枝框架

研究者还引进了一种迭代式的剪枝框架,从而提升剪枝准确率,他们将该框架称为Tick-Tok。其中 Tick 阶段会在训练数据的子集上执行,卷积核会被设定为不可更新状态。而 Tock 阶段使用全部训练数据,并将稀疏约束 φ 添加到损失函数中。

图2:Tick-Tock剪枝框架图示。

其中 Tick 阶段主要希望能实现以下三个目标:加速剪枝过程;计算每一个滤波器的重要性分数 Θ;降低前面剪枝引起的内部协变量迁移问题。

在 Tick 阶段中,研究者会在训练数据的子集中训练一个 Epoch,我们仅允许门控 φ 和最终的线性层能更新,这样能大大降低小数据集上的过拟合风险。通过训练后,模型会根据重要性分数 Θ 排序所有的滤波器,并将不那么重要的滤波器移除。

在 Tock 阶段前,Tick 阶段能重复 T 次。Tock 阶段会微调网络以降低总体误差,这些误差可能是由于一处滤波器造成的。此外,Tock 阶段和一般的微调过程有两大不同:微调比 Tock 要训练更多的 Epoch;微调并不会给损失函数加上稀疏性约束。

分组剪枝:解决带约束的剪枝问题

ResNet 和其变体包含残差连接,也就是在两个残差块产生的特征图上执行元素级的加法。如果单独修剪每个层的滤波器,可能会导致残差连接中特征图对不齐。这可以视为一种带约束的剪枝问题,我们希望剪枝是在对齐特征图的条件下完成的。

为了解决无法对齐的问题,作者们提出了分组剪枝:将通过纯残差方式连接的 GBN 分配给同一组。纯残差连接是指在侧分支上没有卷积层的一种方式,如图3所示。

图3:组剪枝展示。同样颜色的GBN属于同一组。

每一组可以视为一个 Virtual GBN,它的所有组成卷积共享了相同的剪枝模式。并且在分组中,滤波器的重要性分数就是成员卷积分数的和。

实验设置和数据集

数据集

研究者使用了多种数据集,包括 CIFAR-10,CIFAR-100,CUB-200, ImageNet ILSVRC-12和 PASCAL VOC 2011。CIFAR-10 数据集包括了50K的训练数据和10K的测试数据。CIFAR-100和CIFAR-10相同,但有100个类别,每个类别有600张图片。CUB-200包括了将近6000张训练图片和5700张测试图片,涵盖了200种鸟类。ImageNet ILSVRC-12有128万训练图像和50K的测试图像,覆盖1000个类别。研究者还使用了PASCAL VOC 2011分割数据集和其扩展数据集SBD,它有20个类别,共8498张训练样本图片和2857张测试样本图片。

被剪枝的模型

研究者使用了三种网络架构进行剪枝:VGGNet、ResNet和FCN。所有的网络都使用SGD进行训练,权重衰减和动量超参数分别设定为10-4和0.9。

研究者使用了多种训练数据和不同的批大小对这些网络进行了训练,同时加入了一些数据增强的方法。

在剪枝阶段,研究者在每个Tick阶段剪去ResNet0.2%的滤波器,在VGG和FCN上减去1%的滤波器。在每10个Tick操作后进行一次Tock操作。

剪枝效果

表1:在 ResNet-56上,使用CIFAR-10训练的模型剪枝后的表现。基线准确率为93.1%。

表 2:在ResNet-50上,使用ImageNe训练的模型剪枝后的表现。P.Top-1、P.Top-5 分别表示 top-1和 top-5剪枝后的模型在验证集上的单中心裁剪准确率。[Top-1] ↓ 和 [Top-5] ↓分别表示剪枝后模型准确率和基线模型相比的下降情况。Global 表示这一剪枝方法是否是全局滤波器剪枝算法。

图4:VGG-16-M在CUB-200数据集上的剪枝效果。

下图5的基线模型是VGG-16-M,他在CIFAR-100上的测试准确率为73.19%。其中「shrunk」版表示将所有卷积层的通道数减半,因此将FLOPs降低到了基线模型的1/4,从头训练后它的测试准确率会降低1.98%。「pruned」版表示采用Tick-Tock框架进行剪枝的结果,它的测试准确率会降低1.3%。

如果我们从头训练「pruned」版模型,那么它的准确率能达到71.02%,相当于降低了2.17%。不过重要的是,「pruned」版模型的参数量只有「shrunk」版模型的1/3。

图5:两种网络的效果和通道数对比,它们有相同的FLOPs。

Tags:

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表