今天带来的论文是2019年由Ilya Loshchilov等人发表的《Decoupled Weight Decay Regularization》,在这篇论文中提出了AdamW算法,也是目前transformer模型中主要使用的优化器算法。这篇论文之前被拒两次,但是最终还是极大影响了后来llm优化器的选择。这篇论文纠正了过去的一种正则化思想:在Adam中的梯度加入权重衰减系数等价于对损失函数做L2正则化。 实际上二者在SGD算法中是等价的,但是Adam算法中并不等价。
首先我们来看L2正则化公式:
对于一个模型的损失函数 L(w)(如均方误差或交叉熵),L2 正则化后的损失函数为:
Lreg(θ)=L(θ)+2λi∑θi2
计算其梯度:
∇Lreg(θ)=∇Lt(θ)+λθ
但是将所有权重相加比较麻烦,所以选择直接在每个参数的梯度上进行衰减【正常来讲二者是等价的】:
θt+1=θt−α∇Lreg
简化后如下:
θt+1=(1−λ)θt−α∇Lt(θt),
根据梯度下降法,二者确实等价,不用每次这样:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss = criterion(output, target) + l2_lambda * l2_penalty
AdamW算法原理
先参考原始的Adam算法,算法的具体原来参考https://www.jamesblog.top/post/48
Adam 算法的更新规则是(忽略偏置修正项):
θt+1=θt−ηvt+ϵmt
其中:
- mt=β1mt−1+(1−β1)gt
- vt=β2vt−1+(1−β2)gt2
- gt=∇θL(θt)
如果我们直接在梯度上进行修改,添加权重衰减系数【传统错误的方式】:
gtwd=∇L(θt)+λθt
原来的gt梯度变成了添加衰减系数的梯度gtwd, 由mt和vt公式可知,动量和速度都受到衰减系数的影响。这是高度耦合的:
θt+1=θt−ηvt+ϵmt,mt=β1mt−1+(1−β1)gtwd
现在我们来分析耦合带来的影响:
- Adam的更新步长由 vt+ϵη决定。这个分母项 vt本意是用来根据梯度的历史平方大小来调整每个参数的学习率的:梯度大的参数,学习率变小;梯度小的参数,学习率变大。
- 当我们将权重衰减系数λ加入梯度后,如果某个权重θi的值很大(这正是我们希望通过正则化来惩罚的),那么λθi这一项也会很大。
- 对应的vt会被放大,vt在分母,导致等效学习率会变小。本来给较大权重施加的“惩罚”效果反而被削弱了,导致正则化效果变差。
L2正则化的目的是将大权重“拉回”到零。但在耦合的Adam中,这个“拉回”的力 λθi同时又不成比例地增加了vt,减小了整体的更新步长,从而削弱了包括正则化项在内的整个梯度更新的效果。
简单来说,权重越大,它受到的正则化惩罚本应越强,但Adam的自适应机制反而会因为这个大权重而降低对它的更新力度,使得正则化效果大打折扣。正则化项在“自我削弱”。
这篇论文通过解耦来优化这种削弱,提出算法AdamW,AdamW 提出的关键思想:将权重衰减从梯度更新中解耦出来,直接在参数更新时进行衰减:
θt+1=θt−ηvt+ϵmt−ηλθt
- ηvt+ϵmt 这里等效梯度用Adam原始算法的等效梯度,计算mt和vt的时候采用gt,而不是gwd. 这样mt和vt不受权重衰减系数的影响,实现了解耦。
- ηλθt 是独立的权重衰减
这样权重衰减与自适应梯度无关,能更稳定地控制参数大小。