每周论文39

本周论文:

  1. Understanding the Difficulty of Training Transformers

[Understanding the Difficulty of Training Transformers]

这篇论文很有意思并且很有启发性。本文的核心观点是:训练深层Transformer的难点不在于梯度消失(gradient vanishing),而是在训练初期对residual branch(非线性那部分)的依赖过重,导致小的扰动会带来很大的变化,从而引起模型的不稳定。作者提出Admin,一种两段式的初始化方法来解决此问题。同时,作者认为post-norm比pre-norm潜力更大,并将Aadmin应用于post-norm的Transformer上,并取得明显的效果。

背景

Post-norm:对于每层layer而言,先做非线性(即multihead和ffn),与残差相加后再做normalization;
Pre-norm:先对输入进行normalization,再做非线性,最后与残差相加。

导致深层Transformer无法训动的原因

过去解决训练深层Transformer困难的一大方法就是使用pre-norm。作者先对比pre-norm与post-norm在不同超参下的结果。

可以看到,pre-norm在所有的超参下都能够收敛,而post-norm则在15次中出现了7次不收敛的情况。这说明pre-norm比post-norm更鲁棒。

但是,要注意到,当post-norm收敛时,基本上比相同超参下的pre-norm要好。这说明了post-norm的潜力比pre-norm强。

接着,作者以主流的梯度消失的观点做了一些统计,具体来说,统计了pre-norm和post-norm的encoder和decoder的每层的梯度norm的相对值 $\frac{\left|\Delta \mathbf{x}_{i}^{(\cdot)}\right|_{2}}{\max _{j}\left|\Delta \mathbf{x}_{j}^{(\cdot)}\right|_{2}}$。

上图可以得到几个结论:

  • pre-norm的encoder和decoder都没有梯度消失的问题;post-norm的encoder没有梯度消失问题,而decoder则有问题。
  • (上图的下半部分要从右往左看)当decoder的梯度经过encoder-attention时,会有明显的减少。这说明了梯度消失是由encoder-attention(即cross-attention造成的)。

既然发现了post-norm的decoder有梯度消失的问题,那是否可以仅解决decoder的问题,也即可以将post-norm的decoder换成pre-norm的decoder,构造一个混合模型。

上图可以看到,仅仅解决梯度消失的问题是不够的。因此存在其他原因导致模型训练不稳定。

放大效应(Amplification effect)是造成模型训练不稳定的原因

引入一些notation:
记$\widehat{\mathbf{a}}_{i}=\frac{\mathbf{a}_{i}}{\sqrt{\operatorname{Var} \mathbf{a}_{i}}}$ 是第i层的residual branch(residual branch这里指的是ffn和multihead)的normalized输出;记$\widehat{\mathbf{x}}_{i}=\frac{\mathbf{x}_{i}}{\sqrt{\operatorname{Var} \mathbf{x}_{i}}}$是第i层的总的normalized输出。我们可以将$\widehat{\mathbf{x}}$看做是所有过去$\widehat{\mathbf{a}}$的加权平均(很好理解,因为每个$\widehat{\mathbf{a}}$都输入到下一层了)。那么我们就有:

$\beta_{i, j} $代表了$\widehat{\mathbf{x}}$对第j层的依赖程度。

我们可以将pre-norm和post-norm换个角度看,如图所示。

可以看到,在pre-norm中,每个residual output只被norm一次;而post-norm则不止一次,并且,越远的层被norm的次数将会越多,因此对于层i的输出来说,它对近的层的依赖更强。也即,我们可以将norm看做是对residual branch的输出进行一次赋权。对于post-norm而言,他对更近的层的residual branch更倚重。

作者不仅从公式中推导出这一结论,还对$\beta$做了一些统计。

可以发现:pre-norm从初始化开始到训练结束,每一层对底层的依赖都比较平均,虽然在训练过程中也会逐渐加强对近的层的依赖,但相对没那么严重。而post-norm从训练开始对近层的依赖就很严重,而随着训练过程的进行,变得越来越严重。

因此,我们可以认为,pre-norm实质上是一种正则化方法

  • 对于pre-norm而言,模型被限制过重依赖近层的输出,但因此也限制了模型的表示空间和表示能力。从公式上我们也可以看出,$\beta_{i, i}=\frac{\sqrt{\operatorname{Var}\left[\mathbf{a}_{i}\right]}}{\sqrt{\operatorname{Var}\left[\mathbf{x}_{i-1}^{(p \cdot)}+\mathbf{a}_{i}\right]}}$代表了当前层输出对自己的residual branch输出的依赖程度,由于$\operatorname{Var}\left[\mathbf{x}_{i}^{(p \cdot)}\right]$会随着层数加深而逐渐变大(很好理解,因为过了那么多非线性以及各种相加的操作,层i相比底层的输出的方差更大),则$\beta_{i, i}$在越高层则会变得越小,也即pre-norm限制了对自身residual branch输出的依赖程度。
  • 由于pre-norm限制了模型不能太过于依赖近的branch,当层数非常深时,第i层和第i+1层的输出很可能没太大区别(因为内容被前面层填满了)。这也是为什么pre-norm模型再加深反而性能会下降的原因。同时,这也解释了,为什么post-norm的潜力大于pre-norm。

作者从这开始,从理论和统计都证明了,对于pre-norm而言,输入/参数扰动造成的影响是O(logn)级别的,而post-norm则是O(n)级别的。其中n表示模型深度。(具体证明见论文)

从图中可以看到,pre-norm和post-norm面对扰动造成的输出变化是基本符合推论的。

作者同时还证明了,随着训练的进行,这种扰动造成的输出偏移会逐渐变小。(证明见论文)

根据这一结论以及上述结论,作者提出一种两段式的初始化方法,其目的是为了使post-norm模型在训练初期稳定,而在训练达到稳定后,再使用普通post-norm使其能够释放潜能。

方法

  • 添加一个新的参数$\omega$,原公式变为$\mathbf{x}_{i}=f_{\mathrm{LN}}\left(\mathbf{x}_{i-1} \cdot \omega_{i}+f_{i}\left(\mathbf{x}_{i-1}\right)\right)$
  • 两段式训练:Profiling-Initialization
    • Profiling:先初始化$\omega=1$,也即退化成一个普通的post-norm。该阶段不训练,只记录输出的方差。
    • Initialization:在记录后,设$\boldsymbol{\omega}_{i}=\sqrt{\sum_{j<i} \operatorname{Var}\left[f_{j}\left(\mathbf{x}_{j-1}\right)\right]}$,然后从头开始训练。
    • 当模型训练稳定后,去掉$\omega$,以释放post-norm的全部潜能。

实验

以上表格说明,相对post-norm,Admin能够获得稳定的提升。

同时,作者还统计了使用Admin后,$\beta$的变化。

可以看到,使用Admin后,模型在一开始没那么依赖近的branch,因此能够稳步训练;而在训练过程中则逐渐加深对近的branch的依赖。

结论与思考

过去传统的梯度消失问题并不是真正带来训练深层Transformer困难的原因,模型输出对扰动过于敏感可能才是真正的问题。实际上过去的工作虽然声称是解决了梯度消失的问题,但他们真正解决的是扰动敏感的问题(我会稍后写一篇新的博客来讨论过去的工作是怎么不小心/隐式地解决了这个问题的),只不过梯度消失是扰动敏感这一问题的表现形式之一罢了,他们解决了扰动敏感,实验或统计表现出来的则是梯度消失被解决了(所以这里面倒是一个表象与本质之间的关系问题)。

pre-norm虽好,但带来了限制,影响了Transformer的模型表示能力。而我们一直认为不好的post-norm相反则有更强的能力,只不过我们以前一直没有训动而已。

pre-norm与post-norm之外是否能够再找到一个新的变体?作者虽然解决/缓解了训练深层Transformer的问题,但这种方法仍然不够简洁。我们是否能够end2end地解决这个问题?或者我们是否能提出一个新的模型改动使得这个问题在新的架构下不存在?

作者能够打破过去主流的梯度消失的观点,并且提出了自己的观点。论文循序渐进,逻辑严谨,行文如行云流水,不仅有理论推导,还有定量统计,最后再辅以实验验证,使人信服。确实是不可多得的好论文,刷新了我的认知。