现代 LLM 的核心架构设计其一:RMSNorm

现代 LLM 的核心架构设计其一:RMSNorm 1. 归一化的发展历程#我们在很早之前就展开过归一化的内容归一化它的基本逻辑是通过调整数据尺度来加速模型训练帮助优化从而让模型更快的收敛这里再简单复述一下一个深层网络在训练时每层的输入分布都会随着前一层参数的变化而不断变化。这意味着后一层需要持续适应前一层输出的分布漂移训练也就变得更困难甚至可能导致梯度现象。当时的研究者将这一问题归因于内部协变量偏移Internal Covariate ShiftICS并尝试通过归一化稳定中间层分布。最初的方案是 15 年的 Batch Normalization对每一层的激活值做归一化减去 batch 均值除以 batch 标准差。这个操作让网络可以使用更大的学习率、更少的 epoch 就收敛也因此迅速成了 CNN 的标配。但 Batch Normalization 的局限在于依赖 batch 维度的统计量。在序列长度变化、 batch size 不一定大的 NLP 任务中其效果会大打折扣。而且在 Transformer 这种结构中BN 的 “竖着统计”和与自注意力的 “横着注意” 在同时应用时也存在逻辑上的不协调。于是 16 年的 Layer Normalization 换了一个思路不沿 batch 维度归一化而是沿特征维度归一化。这种方式针对单个样本计算它所有特征的均值和方差然后做归一化。既不依赖 batch size也不受序列长度影响天然适合序列模型。此后LayerNorm 就成了 Transformer 的标准配置一直到今天的少部分大模型也仍在使用。因此在进入 RMSNorm 前我们再回顾一下 LN 本身的逻辑。2. LayerNorm 有什么问题#我们在之前的 Transformer Block 里也介绍过 LN先看一个标准的 LayerNorm 公式LayerNorm()−2⊙其中1∑121∑1(−)2 和 是可学习的参数维度与 相同。 是防止除零的极小常数。总结来看LN 做了三件事居中centering减去均值 让分布中心归零。缩放scaling除以标准差 把方差归一化到 1 。恢复affine transformation用 和 恢复模型需要的尺度和偏移。前两步是固定的归一化操作第三步是把自由度还给模型。因为强制把所有激活值变成零均值单位方差不一定是最优的模型可能希望某些层保留特定的分布特征。看这个结构感觉每一步都很合理。但科研往往打破常规19 年的论文 Root Mean Square Layer Normalization 提出了问题这三步里真的每一步都是必要的吗特别是第一步居中。以此我们来展开 RMSNorm 的思路3. RMSNorm 的内容#RMSNorm 提出了一个很直接的实验把 LayerNorm 的居中部分去掉只保留缩放看看效果如何。去掉之后的公式长这样RMSNorm()1∑12⊙对比一下操作LayerNormRMSNorm计算均值1∑无减去均值−无除以标准差1∑(−)21∑2可学习偏置无可学习增益很明显RMSNorm 简化了两处内容取消中心化直接使用原始输入的均方根来做归一化的分母而不是用标准差。去掉偏置参数 这是因为不做中心化后的分布中心不再是零就不需要专门用一个可学习参数去专门调整了。4. 为什么可以去掉中心化#值得一提的是这一问题的答案并不在 RMSNorm 论文本身而是在后续大模型实践中逐渐清晰的其关键在于我们之前提到的Pre-Norm 结构。我们在 Transformer Block那里展开过 Post-Norm 和 Pre-Norm 两种结构Post-Norm归一化放在子层之后LayerNorm(Sublayer())Pre-Norm归一化放在子层之前Sublayer(LayerNorm())已知 Post-Norm 是原版 Transformer 的设计但深层训练不稳定。于是 Pre-Norm 把归一化放在残差分支的入口处保证了主干的信号流通更加顺畅训练更稳定从而成为事实标准。后续大量实践表明Transformer 的训练稳定性主要依赖于尺度控制而非严格的零均值约束。所以 Pre-Norm 结构改善了梯度传播而 RMSNorm 保留了最关键的尺度归一化功能。两者结合后即使不再执行中心化操作模型依然能够稳定训练并保持性能。但毕竟是黑箱还有一种说法是Pre-Norm 结构和残差的组合可以特征值保持较小的均值让中心化不再那么必要。总之在主流认知中的共同点都有Pre-Norm的重要性但验证其效果后剩下的其实也就是怎么讲故事的问题了。我们来重点看看其效果5. RMSNorm 的表现#首先RMSNorm 论文在多个任务上做了对比。核心结论是去掉均值中心化之后效果与 LayerNorm 基本持平甚至在某些设定下略有提升。但更重要的是计算效率。RMSNorm 省去了两个操作前向不需要计算均值 均方根的计算也比方差计算少一次减法。反向不需要对 求梯度。原论文在多种网络结构上进行了实验在保持性能接近的情况下整体运行时间可减少约 7%64%。对于 Transformer 模型加速通常在 7%15% 左右。在单层上这点差异微乎其微但在 70B、上百层的模型中累积起来就是可观的训练加速。但同样要说明的是今天的大模型层数更深使用的技术更广Norm 层的占比变小实际上的收益已经不会像上面一样那么可观了。从 LLaMA 开始RMSNorm 几乎成了所有现代大模型的默认选择模型系列归一化方式LLaMA 1/2/3RMSNormMistral / MixtralRMSNormQwen 系列RMSNormGemmaRMSNormDeepSeekRMSNormGPT-NeoX / PythiaLayerNorm原始 TransformerLayerNorm