代码片段记录9

1️⃣[collate_fn]

将不等长句子组合成batch。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def collate_fn(insts):
''' Pad the instance to the max seq length in batch '''

max_len = max(len(inst) for inst in insts)

batch_seq = np.array([
inst + [Constants.PAD] * (max_len - len(inst))
for inst in insts])

batch_pos = np.array([
[pos_i + 1 if w_i != Constants.PAD else 0
for pos_i, w_i in enumerate(inst)] for inst in batch_seq]) # 位置信息

batch_seq = torch.LongTensor(batch_seq)
batch_pos = torch.LongTensor(batch_pos)

return batch_seq, batch_pos