Adam(Adaptive Moment Estimation)是一种用于训练深度学习模型的优化算法。它在 2014 年由 Diederik P. Kingma 和 Jimmy Ba 提出,并迅速成为深度学习领域最流行和最常用的优化器之一。
Intuition 理解
举一个例子:
想象一个球从山坡上滚下来。在重力的作用下,它会越滚越快,并且由于惯性(动量),它会保持当前的前进方向。即使遇到一些小的颠簸(梯度方向的轻微改变),它也能凭借惯性冲过去,而不是轻易改变方向。
Momentum 算法就是借鉴了这一思想。它不仅仅考虑当前参数的梯度方向,还引入了一个“动量”项,这个动量是过去所有梯度方向的指数加权移动平均值。
这样做的好处是:
存在的问题:
因为要为每一个梯度都维持额外的参数状态,会导致显存随着模型急剧增大。尤其是大模型训练,优化器维护的状态会占据大量的显存。
移动平均(Exponentially Moving Average, EMA)
移动平均,又叫指数平滑。可以用来做一些简单的预测,其核心公式如下:
yt=βyt−1+(1−β)xt,y0=0
将公式看作一个预测系统, y是系统预测输出,x是观察输入,t表示时间step步数。β∈[0~1]参数表示对上一次系统输入的信任度,β越大,越信任上一次的输出yt−1,反之越信任本次的观测xt。
将公式递推展开:
yt=βyt−1+(1−β)xt=(1−β)(βt−1x1+…+β1xt−1+β0xt)
可以观察到移动平均的另一个特性,β0>βi>βt−1 [i∈(0,t−1)]: 该系统更重视离t更近的观测值,更信任最近的观测数据,时间越久,重要性呈现指数衰减。
Adam算法原理
核心思想: Adam 结合了两种经典优化算法的优点:
动量(Momentum): 借鉴了物理学中动量的概念,通过累积过去的梯度来加速收敛,并减少在梯度方向变化剧烈时的震荡。
RMSprop(Root Mean Square Propagation): 为每个参数独立地调整学习率。它通过使用梯度的平方的移动平均值来缩放学习率,使得在梯度较大的方向上学习率减小,在梯度较小的方向上学习率增大。
Adam 通过计算梯度的一阶矩估计(First Moment Estimate,即动量项) 和二阶矩估计(Second Moment Estimate,即 RMSprop 项),并对它们进行偏差校正,从而为每个参数设计出独立的、自适应的学习率。
Adam 算法的详细步骤
假设我们要优化的参数是 θ,损失函数是 J(θ)。算法的执行过程如下:
1. 初始化参数
- 学习率(Learning Rate):α (通常建议值为 0.001)
- 一阶矩估计的指数衰减率:β1 (通常建议值为 0.9)
- 二阶矩估计的指数衰减率:β2 (通常建议值为 0.999)
- 数值稳定性常数:ϵ (通常建议值为 10−8)
- 参数:θ0 (随机初始化)
- 一阶矩向量:m0=0
- 二阶矩向量:v0=0
- 时间步(Time step):t=0
2. 训练循环
在每一次迭代(时间步 t)中,执行以下计算:
步骤一:计算当前梯度
首先,计算损失函数 J(θ) 关于当前参数 θt−1 的梯度:
gt=∇θJ(θt−1)
步骤二:更新有偏一阶矩估计(动量项)
计算梯度 gt 的指数移动平均值,这被称为一阶矩估计 mt。它保留了过去梯度的方向信息。
mt=β1⋅mt−1+(1−β1)⋅gt
这里的 mt 是对梯度均值(一阶矩)的估计。
步骤三:更新有偏二阶矩估计(RMSprop 项)
计算梯度平方 gt2 的指数移动平均值,这被称为二阶矩估计 vt。它保留了过去梯度的大小信息。注意,gt2 是逐元素(element-wise)的平方。
vt=β2⋅vt−1+(1−β2)⋅gt2
这里的 vt 是对梯度未中心化方差(二阶矩)的估计。
步骤四:计算偏差校正后的一阶矩估计
由于 m0 和 v0 初始化为 0,在训练初期,mt 和 vt 会偏向于 0。为了修正这种偏差,Adam 引入了偏差校正(Bias Correction)。
m^t=1−β1tmt
当 t 很小时,(1−β1t) 也很小,这会放大 mt 的值,从而缓解其偏向 0 的问题。随着 t 的增大,(1−β1t) 趋近于 1,偏差校正的作用会逐渐消失。
步骤五:计算偏差校正后的二阶矩估计
同样地,对二阶矩估计 vt 进行偏差校正:
v^t=1−β2tvt
步骤六:更新模型参数
最后,使用经过校正的一阶和二阶矩估计来更新参数 θ。
θt=θt−1−α⋅v^t+ϵm^t
- α:全局学习率,控制整体的更新步长。
- m^t:提供了梯度的方向和大小,类似于动量。
- v^t:作为分母,为每个参数提供了自适应的学习率。
- 如果某个参数的梯度历史值一直很大(v^t 很大),那么其实际学习率会变小,从而减缓更新,防止步子迈得太大。
- 如果某个参数的梯度历史值一直很小(v^t 很小),那么其实际学习率会变大,从而加速更新。
- ϵ:一个非常小的数,用于防止分母为零,保证数值稳定性。设为1e-7或1e-8即可。