关于fp16混合精度的一点感受和debug经历

严格来说,这应该是关于在fairseq使用混合精度的一点感受。

前两天因为需求需要在fairseq上跑一个transformer large,自己的卡(XP和1080ti和2080ti)又不算很支持(按理说XP和2080ti都有tensor core,应该是可以支持混合精度加速的,但驱动死活装不上)又不大行,恰巧正好借到几张V100,感受了一下,发现fp16确实真的相当快。训练70个epoch只需要1天时间,而使用fp32则需要两天多。

但是baseline跑完了,我需要在transformer的基础上加点东西,具体来说就是在attention上随机mask掉一些tensor,然后重新归一化。

1
2
3
# attn_weight is fp16
attn_weight = attn_weight*mask # attn: bs*n_head,q_len,k_len
attn_weight = attn_weight/(attn_weight.sum(dim=-1,keepdim=True)+1e-10)

只要加了这一个,跑过几个step就一定会出现overflow。

| WARNING: overflow detected, setting loss scale to 0.01
Minimum loss scale reached (0.0001). Your loss is probably exploding. Try lowering the learning rate, using gradient clipping or increasing the batch size.

第一时间搜索issue,发现还是有不少人遇到这个情况的,而也有官方给出建议

这里的解决方案其实就是增加训练的稳定性,包括对overflow的容忍度提升;增加batch或者减小lr使得模型的loss不爆掉。

我用上了这几个方法后虽然爆掉的时间点推后了,但还是爆了。

这个时候旁边的人说,你既然用了fp16,为啥要加1e-10啊,1e-10是不是已经超过了fp16的表达范围了?我查了一下,还真是。

那我就改成1e-6吧。

好像又将模型爆掉的时间点推后了一点点,可是还是爆了。

我尝试关掉这个归一化,诶,好像就稳定了。看来问题就锁定在归一化这条语句上。

我想,是不是fp16有啥特殊写法啊,那我查一下吧。网上的写法多是使用apex的,看了一下,fairseq都是自己写的混合精度训练。没办法,我只好参考fairseq自己是怎么写兼容混合精度训练的代码了。

我翻了一下,无意中在multihead类看到这句:

1
2
3
4
5
6
7

attn_weights = attn_weights.float().masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf'),
).type_as(attn_weights) # FP16 support: cast to float and back

attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

看到注释,诶,原来这里是要先升精度计算完softmax再降精度啊。是不是因为我的就是因为没做这个操作才不work的?

那我改:

1
2
3
4
attn_weight_float=attn_weight.float()
attn_weight_float = attn_weight_float / (attn_weight_float.sum(dim=-1, keepdim=True))

return attn_weight_float.type_as(attn_weight)

我全程计算用fp32,在return的时候再降到fp16。这回没问题了吧。

有。跑了几个回合后最终还是爆掉了。

我想了一下(我对浮点数这块不会算),有没有可能是float在除以另一个float的时候导致尾数很多,而在最后被cast到fp16的时候就出现不稳定的情况,比如有些很小的数干脆就被cast成0了,导致数值的不稳定?

恰好这天郭博来了,就顺口问了一下。他告诉我,fp16确实要特殊写法,比如升降精度这块就得自己把握,但具体要啥时候升啥时候降?不知道。出了问题只能一条一条debug,一点一点尝试,靠经验。

好吧,这至少说明我所想的升降精度还是没错的。

我想了一下,那干脆这样,在做mask的时候,我用fp16,走到归一化这关,我就全程用fp16好了,这样sum出来的也是fp16,不用经历cast的过程。

1
2
3
4
5
attn_weight_float = attn_weight.float()
attn_weight_float = attn_weight_float * mask
attn_weight = attn_weight_float.type_as(attn_weight)
scale = attn_weight.sum(dim=-1, keepdim=True) # fp16
attn_weight = attn_weight / scale

诶,这样好像就没问题了!跑了一晚上,第二天还是没出错。看来被我蒙对了。

而虽然找到了对的方法,但我还想优化。既然进来的时候weight就已经fp16了,为啥我在得到mask之后不直接把mask转换成fp16,然后和attn_weight相乘,fp16之间的相乘应该更快才是。

所以改进版就应该是:

1
2
3
attn_weight = attn_weight * mask.type_as(attn_weight)
scale = attn_weight.sum(dim=-1, keepdim=True) # fp16
attn_weight = attn_weight / scale

emmm,很不巧,我感觉这个方法应该没任何问题的,但还是爆了。或许是矩阵乘法这边又出了啥问题吧。但此时我已经心累了,那就还是按照之前找到的方法来吧。

记录一下debug艰辛过程的思路:

从30日改到1日,横跨一年,终于debug完成。


说了那么多,总结一下:

  1. 当没有改动直接用fairseq跑,出现overflow的问题,建议看看上面提到的issue的解决方案,很多时候loss爆掉就是lr太大了,这个一般可以通过减小lr和增大batch来解决。
  2. 如果是自己加东西或者写模型,要非常注意加减乘除这些操作,要注意范围,比如不要加1e-10这种明显超过表示范围的。在写之前想一下有没有可能性会导致cast的过程中出现数值不稳定;同时要记得升降精度这个操作。关于混合精度计算,有几个operator是需要升降精度的,这个在NVIDIA官方文档中有。
  3. 如果出现loss爆掉,应该逐行排查,其实就是逐行用torch.isnan()以及print出来,看看是哪里导致的nan,使得loss爆掉,然后解决方案就是尝试各种的升降精度,使得逻辑上没啥问题。当然这个还是玄学,似乎有一套指导方案,但我不知道。
  4. 在训练过程中看log的时候,有overflow的WARNING是很正常的,因为fairseq代码会根据overflow的情况自动调整loss scale,除非scale到很小的数值,否则可以静观其变。因为如果在训练过程中长时间没有overflow的风险,模型会自动调高scale。

最后的最后,fp16真的香,如果显卡支持尽量尝试fp16。