问题的提出
论文:Distilling the Knowledge in a Neural Network(author:Hinton神)
问题的提出:模型的训练与部署的不一致性。训练过程中使用的模型复杂,存在着推理速度慢,对部署资源的要求较高等问题。所以说,模型压缩(保证性能前提下减少模型参数量)是一个很重要的问题,而知识蒸馏(Knowledge Distillation)就是一种压缩模型的方法。顾名思义,知识蒸馏是将已经训练好的模型中包含的知识(Knowledge)蒸馏(distlll)到另一个模型里面。
具体介绍
思想介绍
知识蒸馏这种方法一般用于分类问题上,或者是其他本质问题上是分类问题的问题。 知识蒸馏使用一种teacher-studuent网络。teacher为知识的输出者,student为知识的接受者。 teacher一般是比较复杂的模型,对于输入X,teacher能够输出Y,Y经过softmax函数得到对应类别的概率。student模型一般是参数量较小,结构简单的模型。对于输入X,student也能够输出Y,Y经过softmax函数得到对应类别的概率。 那应该怎么将teacher的知识传递给student呢? 机器学习的目的其实是要训练出在某个问题上泛化能力强的模型。而经过训练的teacher模型,已经具备较强的泛化能力。最直接的方法就是用teacher输出的结果,用于指导student的训练,让student的表现逐渐地向teacher靠拢。
下面,将讲一下为什么这种方法可以让student学习到teacher的泛化能力。 首先,要从one-hot编码开始讲起。one-hot 编码其实属于一种hard target(概率值非零即1),对应的soft target其实就是每个类别都有一定的概率值,并非那么强硬非0即1。如下图1所示。 图1. hard target 和soft target对比
而使用hard target有一定的缺陷。假设误差函数为交叉熵(如图2所示)。若以one-hot进行编码,其实对误差有贡献的只有值为1的那个类别,其他类别对误差的贡献是0,而深度学习在训练时其实是在通过用梯度下降法来最小化误差。为了最小化误差,那么模型预测的softmax结果会倾向于将那个正确类别的值向1逼近,那么其他类别的概率值会逼近于0(softmax 的特性)。如果训练数据不充分,经过反复训练,这就会导致模型在训练集过于自信,即过拟合现象。
图2. 交叉熵误差函数
针对这一问题,label smooth方法被提出来了。 方法很简单,如图3所示: 图3. label smooth
对hard target 进行一定的扰动,防止模型在训练时过度地自信,可以作为一种正则化方法。不过label smooth单纯地添加随机噪音,因此对模型的提升有限,甚至有欠拟合的风险。 现在开始聊teacher的输出。如果直接使用softmax层的输出值作为soft target,其实还是有正确类别的值向1逼近,那么其他类别的概率值会逼近于0这个问题。因此,知识蒸馏中,加入了温度T这个超参数。softmax函数变成如下图4所示:
图4.加入T的softmax
T其实反映的是训练过程中对负标签的关注程度。T越大,softmax输出的概率分布其实会更加的平滑,负标签的关注程度会被相对的放大,即在误差函数中负标签的占比被放大。 下面讲一下我对soft target的理解 soft target能考虑到类间相似性。我感觉这本质其实也是引入一种有益的归纳偏差。 经过训练的模型,soft target中,错误标签中与真实标签相似的类别的的概率会尽可能相近。比如一张图像标签是猫,那么喂入teacher后,输出的概率中,与猫相似的狮子,老虎的概率也要变大。感觉label smooth感觉是超低配版的知识,因为它是错误标签设置相同类别概率,没有考虑到这种类间的相似性。而且soft target中,包含的信息量比hard target更多(soft target负标签带有着大量的信息),所以说,用蒸馏过的知识来对student进行训练,相比于用hard target,拥有更好的泛化能力。
具体实现
第一步是训练teacher,然后在温度T下对teacher的知识进行蒸馏,用于训练student。如下图5(图源:https://intellabs.github.io/distiller/knowledge_distillation.html) 图5.知识蒸馏过程 下面讲一下student的训练误差。训练误差由两部分构成:一部分是教师输出和 student输出,二者进行温度t的蒸馏后得到soft label,进行误差计算;另一部分是student的输出在t=1蒸馏后的soft label和ground truth进行误差计算。将二者加权求和即为总体误差。然后,在实际用于推理时,要把T设为1。 第二部分误差的作用可以如下理解:老师传授的知识其实也是会有一定的错误的,同时使用ground truth(标准答案)其实可以降低被老师偶尔的错误带偏的可能。第二部分通常设置占比较小,这是经验结论。
欢迎同学们来学术沙龙。
Reference
【经典简读】知识蒸馏(Knowledge Distillation) 经典之作 标签平滑 - Label Smoothing概述(转载)