在当今的AI领域,Transformer架构无疑是王者,从自然语言处理到计算机视觉,其身影无处不在。然而,模型的性能提升往往来自于对细节的不断打磨。今天带来的是一篇由Google大脑的传奇人物Noam Shazeer发表的论文——《GLU Variants Improve Transformer》。
这篇论文提出了一个看似微小却影响深远的改动:将Transformer中标准的前馈网络(Feed-Forward Network, FFN)中的ReLU或GELU激活函数,替换为门控线性单元(Gated Linear Unit, GLU)的变体。这一改动被证明可以显著提升模型性能,并已被众多先进的大语言模型(如LLaMA、PaLM)所采纳。
首先来看一个标准的FFN定义:
这里的激活函数通常是ReLU,后来GELU和Swish也因其更平滑的特性而被广泛使用。这个结构通常采用“沙漏”的倒置瓶颈设计,即第一个线性层会将输入的维度扩大4倍(到),经过激活函数后,第二个线性层再将其投影回原始维度
一个标准的FFN代码实现:
pythonclass FeedForwardReLU(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
"""
d_model: 输入和输出维度 (e.g., 512)
d_ff: 中间隐藏层维度 (e.g., 2048)
"""
super().__init__()
self.linear_1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.linear_2 = nn.Linear(d_ff, d_model)
def forward(self, x):
x = self.linear_1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.linear_2(x)
return x
从代码中可以看出,一个FFN实际上是由两个nn.Linear 和 一个激活函数组成的。
门控线性单元(Gated Linear Unit, GLU)最早由Yann Dauphin等人在2017年的论文《Language Modeling with Gated Convolutional Networks》中提出。其核心思想是引入一个“门控”机制,动态地控制信息流。
原始GLU的计算公式为:
其中表示sigmod函数。 表示逐点相乘。
SwiGLU 不是一个简单的激活函数,而是一个“门控线性单元”的变体。定义如下:
注意:这里去掉了bias
实现代码:
python# 1. 定义 SwiGLU 模块
class SwiGLU(nn.Module):
""" SwiGLU Gated Linear Unit """
def __init__(self, in_features, hidden_features=None, out_features=None):
super().__init__()
hidden_features = hidden_features or in_features
# 门控投影
self.w_gate = nn.Linear(in_features, hidden_features, bias=False)
# 数据投影
self.w_proj = nn.Linear(in_features, hidden_features, bias=False)
def forward(self, x):
return F.silu(self.w_proj(x)) * self.w_gate(x)
这个就论文中提到的FFN的变体之一,也是这几个变体中表现效果最好的。
实现代码:
class FeedForwardSwiGLU(nn.Module): def __init__(self, d_model, d_ff, dropout=0.1): """ d_model: 输入维度 (e.g., 512) d_ff: 中间隐藏层维度 (e.g., 2048) """ super().__init__() # SwiGLU 模块替换了原来的 linear_1 + relu # 它的输入是 d_model,输出(门控后)是 d_ff self.swiglu_block = SwiGLU(in_features=d_model, hidden_features=d_ff) # 第二个线性层保持不变,它的输入维度需要匹配 SwiGLU 的输出维度 self.linear_2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) def forward(self, x): x = self.swiglu_block(x) # 直接调用 SwiGLU 模块 x = self.dropout(x) x = self.linear_2(x) return x
这篇论文没有什么公式推理,单纯是因为作者不断的实验得出的效果,在论文结尾,作者表示:
We offer no explanation as to why these architectures seem to work; we attribute their success, as all else, to divine benevolence.
效果为啥好,他也不知道,归功于神的旨意🤷。
我在测试中发现,对于浅层网络来说,看不太出优化,可能更适合深度网络的模块。
本文作者:James
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!