Training data efficient image transformers & distillation through attention
总体介绍
在ImageNet分类上,ViT是通过使用大数据集进行训练,再进行fine-tune才超过CNNs,单纯使用ImageNet数据集效果并不是非常理想。(ViT原文中说明:由于没有像卷积神经网络天然的拥有平移等效性和局部性这样的归纳偏差,需要使用大量数据进行训练才能超越CNNs) DeiT此篇工作,贡献如下: ①不引入额外大数据集,只使用ImageNet进行训练,只使用ImageNet进行训练(指标图如下所示),消耗的资源更少,取得比ViT更好的效果。 ②为Transformer引入了一种新的知识蒸馏(Knowledge Distillation)方法(我将在另一篇博客对KD进行梳理) 图1.DeiT与其他的模型指标对比,横坐标为每秒处理图片的数量,Ours没带⚗ 表示跟ViT相同,但是使用了各种数据增强和正则化;带⚗表示是使用知识蒸馏的DeiT。
关键介绍
我们知道:ViT的操作是拿class token 过模型得到的embedding来和one-hot标签计算交叉熵误差,DeiT其实就是多引入了一个跟class token完全相同的distillation token,与class token 同时过网络,得到输出后,只不过计算误差时是和teacher的label,DeiT如图2。(原始的KD方法是拿teacher 的soft label,这里文中提出使用一种hard label,先将teacher的soft label 使用argmax,然后再使用label smooth。为什么这么做?因为这里作者做了消融,发现后者指标比前者好,见图3)。再将两部分误差各设0.5权重作为总体误差。 然后作者发现了一个有趣的现象,他把两个token过网络的embedding计算余弦相似度(二者开始是0.6,最后一层是0.93),发现二者逐渐相似但不同,然后做了消融实验,把distillation token换成class token的误差计算方式,发现余弦相似度(0.999)而且模型性能没有提高。这表明distillation token其实给模型带来了一些新的东西。 另外,由于ViT并没有像CNNs有归纳偏差,其实需要大量的训练数据。在训练过程中,使用了Auto-Augment , Rand-Augment 和 random erasing等方法来提指标,然后用Mixup和Cutmix等方法来做regularization,这些也都是key ingredients。(吐槽一下:看篇DeiT我还得看这么多QAQ,真的看不动了;炼丹还有这么多手法,忒难了啊) 图2. DeiT结构图 图3. teacher标签的设计方法和embedding设计策略的指标对比,可以发现,使用作者提出的hard label 效果比通常在知识蒸馏使用的的soft label 效果好,然后使用class+distillation的效果会比单纯只使用一种效果好。
我相信,现在有一些问题你没弄明白。
①既然是知识蒸馏,那这个teacher怎么选取? 文中进行了消融实验,对比了不同teacher下,DeiT的结果,如图4。 图4. 不同teacher下,DeiT的结果
可以发现CNNs作为teacher的效果比Transformer好,文中认为可能是DeiT学习到了一些卷积的一些归纳偏差。接着,作者拿teacher、student的决策差异度进行比对,如下图5。 ConvNet做teacher的DeiT(第二行最后三列,可以看到只使用distillation;class和distillation同时使用的模型,在与convnet决策差异度上都比只使用class 的要小,这说明前者其实是学习到ConvNet的一些决策方法的,但是没办法直接说明是学习到了一些归纳偏置)) 图5. 多组模型的决策差异性对比
②有两个token,那最后模型怎么输出? 作者选取的是两个输出的sotfmax取平均,再argmax。
总结
该篇工作的重点其实是在ViT中引入了知识蒸馏的思想,不使用额外的大数据集,加速模型的训练,而且取得了比ViT更好的效果,同时也说明了data augmentation和regularization等训练tricks的重要性。