关于transformer-xl中rel-shift实现的解读

背景

transformer-xl中有一步使用相对位置计算attention weight:

$\mathbf{A}_{i, j}^{\mathrm{rel}}=\underbrace{\mathbf{E}_{x_{i}}^{\top} \mathbf{W}_{q}^{\top} \mathbf{W}_{k, E} \mathbf{E}_{x_{j}}}_{(a)}+\underbrace{\mathbf{E}_{x_{i}}^{\top} \mathbf{W}_{q}^{\top} \mathbf{W}_{k, R} \mathbf{R}_{i-j}}_{(b)}+\underbrace{u^{\top} \mathbf{W}_{k, E} \mathbf{E}_{x_{j}}}_{(c)}+\underbrace{v^{\top} \mathbf{W}_{k, R} \mathbf{R}_{i-j}}_{(d)}$

由于相对位置要计算所有的query与key对,因此是平方的复杂度。而在论文的附录中提到可以通过简单的推导将复杂度降为线性。
简单地说,我们希望获得:
$\mathbf{B} = \left[ \begin{array}{cccccc}{q_{0}^{\top} \mathbf{W}_{k, R} \mathbf{R}_{M}} & {\cdots} & {q_{0}^{\top} \mathbf{W}_{k, R} \mathbf{R}_{0}} & {0} & {\cdots} & {0} \\ {q_{1}^{\top} \mathbf{W}_{k, R} \mathbf{R}_{M+1}} & {\cdots} & {q_{1}^{\top} \mathbf{W}_{k, R} \mathbf{R}_{1}} & {q_{1}^{\top} \mathbf{W}_{k, R} \mathbf{R}_{0}} & {\cdots} & {0} \\ {\vdots} & {\vdots} & {\vdots} & {\vdots} & {\ddots} & {\vdots} \\ {q_{L-1}^{\top} \mathbf{W}_{k, R} \mathbf{R}_{M+L-1}} & {\cdots} & {q_{L-1}^{\top} \mathbf{W}_{k, R} \mathbf{R}_{M+L-1}} & {q_{L-1}^{\top} \mathbf{W}_{k, R} \mathbf{R}_{L-1}} & {\cdots} & {q_{L-1}^{\top} \mathbf{W}_{k, R} \mathbf{R}_{0}}\end{array}\right] \\ = \left[ \begin{array}{cccccc}{q_{0}^{\top} \mathbf{Q}_{L-1}} & {\cdots} & {q_{0}^{\top} \mathbf{Q}_{M+L-1}} & {0} & {\cdots} & {0} \\ {q_{1}^{\top} \mathbf{Q}_{L-2}} & {\cdots} & {q_{1}^{\top} \mathbf{Q}_{M+L-2}} & {q_{1}^{\top} \mathbf{Q}_{M+L-1}} & {\cdots} & {0} \\ {\vdots} & {\vdots} & {\ddots} & {\vdots} & {\ddots} & {\vdots} \\ {q_{L-1}^{\top} \mathbf{Q}_{0}} & {\cdots} & {q_{L-1}^{\top} \mathbf{Q}_{M}} & {q_{L-1}^{\top} \mathbf{Q}_{M+1}} & {\cdots} & {q_{L-1}^{\top} \mathbf{Q}_{M+L-1}}\end{array}\right]$

其中:
$\mathbf{Q} :=\left[ \begin{array}{c}{\mathbf{R}_{M+L-1}^{\top}} \\ {\mathbf{R}_{M+L-2}^{\top}} \\ {\vdots} \\ {\mathbf{R}_{1}^{\top}} \\ {\mathbf{R}_{0}^{\top}}\end{array}\right] \mathbf{W}_{k, R}^{\top}=\left[ \begin{array}{c}{\left[\mathbf{W}_{k, R} \mathbf{R}_{M+L-1}\right]^{\top}} \\ {\vdots} \\ {\vdots} \\ {\left[\mathbf{W}_{k, R} \mathbf{R}_{1}\right]^{\top}} \\ {\left[\mathbf{W}_{k, R} \mathbf{R}_{0}\right]^{\top}}\end{array}\right] \in \mathbb{R}^{(M+L) \times d}$

而我们可以直接获得的是:
$\tilde{\mathbf{B}}=\mathbf{q} \mathbf{Q}^{\top}=\left[ \begin{array}{cccccc}{q_{0}^{\top} \mathbf{Q}_{0}} & {\cdots} & {q_{0}^{\top} \mathbf{Q}_{M}} & {q_{0}^{\top} \mathbf{Q}_{M+1}} & {\cdots} & {q_{0}^{\top} \mathbf{Q}_{M+L-1}} \\ {q_{1}^{\top} \mathbf{Q}_{0}} & {\cdots} & {q_{1}^{\top} \mathbf{Q}_{M}} & {q_{1}^{\top} \mathbf{Q}_{M+1}} & {\cdots} & {q_{1}^{\top} \mathbf{Q}_{M+L-1}} \\ {\vdots} & {\vdots} & {\ddots} & {\vdots} & {\ddots} & {\vdots} \\ {q_{L-1}^{\top} \mathbf{Q}_{0}} & {\cdots} & {q_{L-1}^{\top} \mathbf{Q}_{M}} & {q_{L-1}^{\top} \mathbf{Q}_{M+1}} & {\cdots} & {q_{L-1}^{\top} \mathbf{Q}_{M+L-1}}\end{array}\right]
$

$\tilde{\mathbf{B}}$与$\mathbf{B}$的区别在于$\mathbf{B}$是$\tilde{\mathbf{B}}$的left-shifted版本,其中第一行左移了L-1,后面每行依次递减左移个数,最后一行则不左移。

方法

抽象地看,我们要做的事情就是,给定一个矩阵,每行都进行左移,而移动的个数随行数递增而递减。

我目前想到的一种方法是使用gather,将想要的index提前定好,然后使用Pytorch的gather就能够实现。

而transformer-xl实现了另一种更好的方法:_rel_shift

1
2
3
4
5
6
7
8
9
10
11
def _rel_shift(self, x, zero_triu=False):
# x: q,k,bs,n_head
zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=1)

x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])

x = x_padded[1:].view_as(x)

return x

第一步是,将x的第一列填上padding,此时x.size()=q,k+1,bs,n_head,接下来将其重新reshape,则变成了x.size()=k+1,q,bs,n_head,最后将第一行去掉,变成x.size()=k,q,bs,n_head,再将其reshape回x原来的样子。

为什么这么做实现了我们想要的左移的功能?我们应该从一维的角度去理解。因为实际上在内存中所有元素都是按照一维去排列的。

原来的矩阵:

实际上就是有q个key按照一行去排列。

在做完padding之后,则:

实际上就是在每个key前面插入了0。

接下来view,实际上数据的先后顺序还是没有变(因为不是transpose):

实际上只是强行将该行切成一个一个q而已。

那么最后一个操作,将第一行丢掉,实际上就是要把原来的x的第一行强行左移q-1个(因为有padding)。那么为什么后面的行能够左移的个数依次减少?别忘了padding,第一行左移了q-1个,但第二个key前面也有一个padding,所以相当于将其向右推了一格;第三个又有一个padding,就在原来的基础上又推了一格,也即推了两格。因此最后达到了我们想要的目的。

实际上要理解该方法,需要牢牢把握数据存储的本质是一整行。

该方法没有数据的拷贝,全部都是view操作,因此更高效。

不得不佩服想到该方法的人的工程能力,同时也感谢戴宁带我理解该方法的本质,一开始我是死活不理解的。以后或许可以将该思想灵活应用到其他方面。