代码片段记录17

[Pytorch restart写法]

1
2
3
if args.restart:
with open(os.path.join(args.restart_dir, 'model.pt'), 'rb') as f:
model = torch.load(f)

[Pytorch获得模型参数量]

1
args.n_all_param = sum([p.nelement() for p in model.parameters()])

[Pytorch将数据保存为二进制方便快速读入]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# transformer-xl样例

def get_lm_corpus(datadir, dataset):
fn = os.path.join(datadir, 'cache.pt')
if os.path.exists(fn):
print('Loading cached dataset...')
corpus = torch.load(fn)
else:
print('Producing dataset {}...'.format(dataset))
kwargs = {}
if dataset in ['wt103', 'wt2']:
kwargs['special'] = ['<eos>']
kwargs['lower_case'] = False
elif dataset == 'ptb':
kwargs['special'] = ['<eos>']
kwargs['lower_case'] = True
elif dataset == 'lm1b':
kwargs['special'] = []
kwargs['lower_case'] = False
kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt')
elif dataset in ['enwik8', 'text8']:
pass

corpus = Corpus(datadir, dataset, **kwargs)
torch.save(corpus, fn)

return corpus

[Pytorch自带API实现inverse sqrt的lr schedule]

1
2
3
4
5
6
7
8
9
10
11
12
# from transformer-xl

# originally used for Transformer (in Attention is all you need)
def lr_lambda(step):
# return a multiplier instead of a learning rate
if step == 0 and args.warmup_step == 0:
return 1.
else:
return 1. / (step ** 0.5) if step > args.warmup_step \
else step / (args.warmup_step ** 1.5)

scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)