代码片段记录10

1️⃣[get_sinusoid_encoding_table]

Transformer绝对位置。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
def cal_angle(position, hid_idx):
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)

def get_posi_angle_vec(position):
return [cal_angle(position, hid_j) for hid_j in range(d_hid)]

sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])

sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])
sinusoid_table[:, 0::2] = np.cos(sinusoid_table[:, 0::2])

if padding_idx is not None:
sinusoid_table[padding_idx] = 0.

return torch.FloatTensor(sinusoid_table) # n_position,embed_dim