关于Pytorch中index_copy_及其思考

前几日因为in-place操作的问题,debug了好几天,最终才发现问题。

1
2
output,_=pad_packed_sequence(output,batch_first=True)
output=output.index_copy(0,torch.tensor(sorted_index),output)

因为Pytorch中pack_sequence需要将batch按长度排列,我在过完GRU后需要将其顺序还原,在这边sorted_index即是记录原来index映射。

然而我在写的时候,参考的是官方的example:

1
2
3
4
5
6
7
8
9
>>> x = torch.zeros(5, 3)
>>> t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float)
>>> index = torch.tensor([0, 4, 2])
>>> x.index_copy_(0, index, t)
tensor([[ 1., 2., 3.],
[ 0., 0., 0.],
[ 7., 8., 9.],
[ 0., 0., 0.],
[ 4., 5., 6.]])

因此我也不假思索地写:

1
2
output,_=pad_packed_sequence(output,batch_first=True)
output=output.index_copy_(0,torch.tensor(sorted_index),output)

就因为多了一个_,导致逻辑和我想象中的不一样。

一个简单的例子展示为什么这么是错的:

1
2
3
4
5
6
7
8
9
10
11
import torch

x=torch.Tensor([21,42,45,59])

print(x) # tensor([21., 42., 45., 59.])

index=torch.tensor([1,2,0,3])

x=x.index_copy_(0,index,x)

print(x) # tensor([21., 21., 21., 59.])

由于是in-place操作,第一步,将index=0的数值(也即21)复制到index=1的地方,此时变成[21,21,45,59];接着将index=1的数值复制到index=2的位置上,注意到之前已经是in-place操作,因此此时取的不是想象中的42,而是已经被替换的21。后面的也是如此。

正确的做法只需要去掉in-place即可。


已经好几次遇到in-place的问题了,在每次做in-place操作时,都要警惕。应尽可能避免in-place操作。实际上Pytorch官方也不建议使用in-place操作。