2025-08-20
研究论文
0

目录

论文核心思想
概念定义
蒸馏过程
蒸馏损失
学生损失
论文关于蒸馏的数学推导部分
论文的缺点和局限

Geoffrey Hinton、Oriol Vinyals 和 Jeff Dean 于 2015 年发表的论文《Distilling the Knowledge in a Neural Network》 是深度学习领域一篇具有里程碑意义的著作,它开创了一种强大的模型压缩和知识迁移范式——知识蒸馏(Knowledge Distillation)。该技术旨在将一个大型、复杂的“教师”神经网络模型的知识,迁移到一个更小、更轻便的“学生”模型中,使得学生模型在保持较低计算复杂度的同时,能够取得与教师模型相近甚至更好的性能。最关键的是:一个小规模的神经网络通过知识蒸馏的方式训练上限是可以突破直接使用数据集“硬标签”训练的上限。

论文核心思想

训练一个小型网络(学生模型),使其模仿一个大型网络(教师模型)的行为,而不仅仅是学习训练数据中的真实标签。

将臃肿模型(教师)学到的“知识”迁移到一个紧凑、高效的模型(学生)中,从而在保持高性能的同时,实现快速推理。

这里的关键在于,“模仿行为”比“学习标签”能提供更多的信息。一个硬标签(Hard Label)只告诉你“这张图片是猫”,而教师模型的输出(一个概率分布)会告诉你“这张图片90%是猫,5%是狗,1%是老虎...”,这种更丰富的信息被称为“暗知识”(Dark Knowledge),它揭示了模型对类别之间相似性的理解。

概念定义

教师模型 (Teacher Model): 通常是一个或一组预训练好的、性能强大的大模型。它可以是一个单一的深层网络,也可以是多个模型的集成。它的知识是待“蒸馏”的源泉。

学生模型 (Student Model): 一个规模较小、结构更简单的轻量级网络。它将是最终被部署用于推理的模型。

直接使用教师模型的标准 softmax 输出作为软目标,效果可能并不好。因为对于一个训练得很好的模型,它对正确类别的预测概率会非常接近 1,而其他错误类别的概率会非常接近 0。这样的“软目标”其实已经很“硬”了,暗知识被极大地压缩了。

为了解决这个问题,Hinton 引入了 “温度”(Temperature, T) 的概念,对标准的 softmax 函数进行了修改:

qi=exp(zi/T)jexp(zj/T)q_{i}=\cfrac{\operatorname{e x p} ( z_{i} / T )} {\sum_{j} \operatorname{e x p} ( z_{j} / T )}

其中i表示类别,T表示温度,温度越高,概率分布越平滑,T=1就是标准大的softmax函数。

蒸馏过程

学生模型的训练目标是双重的,因此其损失函数由两部分组成: 1.蒸馏损失 2.学生损失

总的损失便是在这个基础上添加一个权重系数α\alpha 即:

L=αLsoft+(1α)LhardL=\alpha\cdot L_{s o f t}+\left( 1-\alpha\right) \cdot L_{h a r d}

简单实现:https://colab.research.google.com/drive/1ig8BS1RniVHhcYf32ObiA4LFGqzSpjYs?usp=sharing

蒸馏损失

目的是让学生模型的“软预测”尽可能地接近教师模型的“软目标”。计算方式:

  1. 用一个较高的温度T计算教师模型的软目标 pTp^T
  2. 再用相同温度计算学生模型的软目标qTq^T
  3. 计算两个软目标的交叉损失熵得到蒸馏损失
Lsoft=ipiTlog(qiT)L_{s o f t}=-\sum_{i} p_{i}^{T} \operatorname{l o g} ( q_{i}^{T} )

学生损失

学生模型自身也要能准确预测真实标签。这部分确保学生不会被一个可能犯错的老师“带偏”,同时也能从真实数据中学习。

用标准的softmax计算学生的预测分布 q1q^{1}% 【:表示温度为1】,然后计算与真实标签的交叉损失熵:

Lhard=icilog(qi1)L_{h a r d}=-\sum_{i} c_{i} \operatorname{l o g} ( q_{i}^{1} )

论文关于蒸馏的数学推导部分

这一部分内容实际上是最能体现论文理论基础的部分,理论上证明了Caruana等人早期工作中使用的“匹配 logits”方法,实际上是 Hinton 提出的知识蒸馏方法在高温极限下的一个特例。这为知识蒸馏提供了更坚实的理论基础,并将其与先前的工作联系起来。

这部分一共3个公式:

Czi=1T(qipi)=1T(ezi/Tjezj/Tevi/Tjevj/T)\frac{\partial C} {\partial z_{i}}=\frac{1} {T} \left( q_{i}-p_{i} \right)=\frac{1} {T} \left( \frac{e^{z_{i} / T}} {\sum_{j} e^{z_{j} / T}}-\frac{e^{v_{i} / T}} {\sum_{j} e^{v_{j} / T}} \right)

第一个公式是推导出蒸馏损失的梯度,具体推导可以参考 交叉熵损失函数梯度推导

Czi1T(1+zi/TN+jzj/T1+vi/TN+jvj/T)\frac{\partial C} {\partial z_{i}} \approx\frac{1} {T} \left( \frac{1+z_{i} / T} {N+\sum_{j} z_{j} / T}-\frac{1+v_{i} / T} {N+\sum_{j} v_{j} / T} \right)

第二个公式,利用了T>T-> \infin的时候的近似,利用ex=1+xe^x = 1 + x 泰勒展开,忽略高阶无穷小。

Czi1NT2(zivi)\frac{\partial C} {\partial z_{i}} \approx\frac{1} {N T^{2}} \left( z_{i}-v_{i} \right)

第三个公式,基于两个模型在softmax之前,对logit进行norm操作,即均值为0(即:jzj=jvj=0\sum_{j} z_{j}=\sum_{j} v_{j}=0),然后推导出来的。

我们对比Caruana等人的蒸馏损失梯度计算:

LMSE=12i(zivi)2L_{M S E}=\frac{1} {2} \sum_{i} ( z_{i}-v_{i} )^{2}
LMSEzi=zivi\cfrac{\partial L_{M S E}} {\partial z_{i}}=z_{i}-v_{i}

经过对比发现 :Hinton等人提出的蒸馏损失梯度只是多了一个常数缩放因子(1NT2\frac{1} {N T^{2}})不影响优化的方向(只影响学习率的有效大小),因此这两种优化目标是等价的。

总而言之,这段推导精辟地论证了,知识蒸馏是一种更通用、更强大的知识转移框架,而直接匹配 logits 只是这个框架在线性近似下的一种特殊情况。 它告诉我们,通过调节温度 T,我们可以在“硬标签”的精确性和“logits 匹配”的线性关系之间进行权衡,找到一个最佳的平衡点来传递教师模型中最有价值的知识。

最后为了平衡软标签损失梯度和硬标签损失梯度,除了利用α\alpha权重系数参数调整,还需要将软标签的梯度乘以T2T^2,保证二者梯度在一个数量级,如果不乘以 T2T^2,在温度 T 较高时,软标签产生的梯度会变得非常小,导致模型几乎只从硬标签学习,知识蒸馏就失效了,至于为什么是T2T^2,仔细推导第三个公式就能明白。

论文的缺点和局限

超参数敏感: 温度 T 和损失函数的权重等超参数的选择对蒸馏效果有较大影响,需要经验性地调整。

本文作者:James

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!