2025-08-27
算法
00

目录

标准FFN
GLU门控线性单元
SwishGLU
SwishFFN 模块
总结

在当今的AI领域,Transformer架构无疑是王者,从自然语言处理到计算机视觉,其身影无处不在。然而,模型的性能提升往往来自于对细节的不断打磨。今天带来的是一篇由Google大脑的传奇人物Noam Shazeer发表的论文——《GLU Variants Improve Transformer》。

这篇论文提出了一个看似微小却影响深远的改动:将Transformer中标准的前馈网络(Feed-Forward Network, FFN)中的ReLU或GELU激活函数,替换为门控线性单元(Gated Linear Unit, GLU)的变体。这一改动被证明可以显著提升模型性能,并已被众多先进的大语言模型(如LLaMA、PaLM)所采纳。

标准FFN

首先来看一个标准的FFN定义:

F ⁣F ⁣N(x)=max(0,xW1+b1)W2+b2\mathit{F \! F \! N} ( x ) = \mathit{max}( 0 , x W_{1} + b_{1} ) W_{2}+b_{2}

这里的激活函数通常是ReLU,后来GELU和Swish也因其更平滑的特性而被广泛使用。这个结构通常采用“沙漏”的倒置瓶颈设计,即第一个线性层会将输入的维度dmodeld_{model}扩大4倍(到dffd_{ff}),经过激活函数后,第二个线性层再将其投影回原始维度dmodeld_{model}

一个标准的FFN代码实现:

python
class FeedForwardReLU(nn.Module): def __init__(self, d_model, d_ff, dropout=0.1): """ d_model: 输入和输出维度 (e.g., 512) d_ff: 中间隐藏层维度 (e.g., 2048) """ super().__init__() self.linear_1 = nn.Linear(d_model, d_ff) self.dropout = nn.Dropout(dropout) self.linear_2 = nn.Linear(d_ff, d_model) def forward(self, x): x = self.linear_1(x) x = F.relu(x) x = self.dropout(x) x = self.linear_2(x) return x

从代码中可以看出,一个FFN实际上是由两个nn.Linear 和 一个激活函数组成的。

GLU门控线性单元

门控线性单元(Gated Linear Unit, GLU)最早由Yann Dauphin等人在2017年的论文《Language Modeling with Gated Convolutional Networks》中提出。其核心思想是引入一个“门控”机制,动态地控制信息流。

原始GLU的计算公式为:

GLU(x)=(xW+b)σ(xV+c)G L U ( x )=( x W+b ) \otimes\sigma( x V+c )

其中σ(x)\sigma(x)表示sigmod函数。 \otimes表示逐点相乘。

SwishGLU

SwiGLU 不是一个简单的激活函数,而是一个“门控线性单元”的变体。定义如下:

F ⁣FNSwiGLU(x)=(Swish1(xW)xV)W2F \! F N_{S w i G L U} ( x )=( S w i s h_{1} ( x W ) \otimes x V ) W_{2}

注意:这里去掉了bias

实现代码:

python
# 1. 定义 SwiGLU 模块 class SwiGLU(nn.Module): """ SwiGLU Gated Linear Unit """ def __init__(self, in_features, hidden_features=None, out_features=None): super().__init__() hidden_features = hidden_features or in_features # 门控投影 self.w_gate = nn.Linear(in_features, hidden_features, bias=False) # 数据投影 self.w_proj = nn.Linear(in_features, hidden_features, bias=False) def forward(self, x): return F.silu(self.w_proj(x)) * self.w_gate(x)

SwishFFN 模块

这个就论文中提到的FFN的变体之一,也是这几个变体中表现效果最好的。

实现代码:

class FeedForwardSwiGLU(nn.Module): def __init__(self, d_model, d_ff, dropout=0.1): """ d_model: 输入维度 (e.g., 512) d_ff: 中间隐藏层维度 (e.g., 2048) """ super().__init__() # SwiGLU 模块替换了原来的 linear_1 + relu # 它的输入是 d_model,输出(门控后)是 d_ff self.swiglu_block = SwiGLU(in_features=d_model, hidden_features=d_ff) # 第二个线性层保持不变,它的输入维度需要匹配 SwiGLU 的输出维度 self.linear_2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) def forward(self, x): x = self.swiglu_block(x) # 直接调用 SwiGLU 模块 x = self.dropout(x) x = self.linear_2(x) return x

总结

这篇论文没有什么公式推理,单纯是因为作者不断的实验得出的效果,在论文结尾,作者表示:

We offer no explanation as to why these architectures seem to work; we attribute their success, as all else, to divine benevolence.

效果为啥好,他也不知道,归功于神的旨意🤷。

我在测试中发现,对于浅层网络来说,看不太出优化,可能更适合深度网络的模块。

本文作者:James

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!