代码片段记录1

1️⃣ get_batch

注意到shuffle的标准做法

1
2
3
4
5
6
7
8
9
10
11
12
def get_batch(self,data,batch_size=32,is_shuffle):
N=len(data) #获得数据的长度
if is_shuffle is True:
r=random.Random()
r.seed()
r.shuffle(data) #如果is_shuffle为真则打乱
#开始获得batch,使用[ for in ]
batch=[data[k:k+batch_size] for k in range(0,N,batch_size)]
if N%batch_size!=0: #处理不整除问题,如果有显式要求丢掉则不需要处理,这里默认处理
remainder=N-N%batch_size #剩下的部分
batch.append(data[temp:N])
return batch

2️⃣使用gensim将GloVe读入

实际上这份代码有点问题,在使用过程中,发现glove文件需要放在gensim的文件夹下才能被读到(7.20 updated,应该使用绝对地址),并不好。

教程地址:gensim: scripts.glove2word2vec – Convert glove format to word2vec

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#1. 使用gensim读入word2vec

model = gensim.models.KeyedVectors.load_word2vec_format(
fname='GoogleNews-vectors-negative300-SLIM.bin', binary=True)
words = model.vocab #获得词表
vector= model[word] #word是words里面的元素

#2. 使用gensim读入glove


from gensim.models import KeyedVectors
from gensim.test.utils import datapath, get_tmpfile
from gensim.scripts.glove2word2vec import glove2word2vec
glove_file=datapath('glove.txt') #最好使用绝对地址
tmp_file=get_tmpfile('word2vec.txt')
glove2word2vec(glove_file,tmp_file)
model=KeyedVectors.load_word2vec_format(tmp_file)
#接下来使用的方法是一样的


3️⃣data_split方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def data_split(seed=1, proportion=0.7): 
data = list(iter_corpus())
ids = list(range(len(data)))

N = int(len(ids) * proportion) # number of training data

rng = random.Random(seed)
rng.shuffle(ids)
test_ids = set(ids[N:])
train_data = []
test_data = []

for x in data:
if x[1] in test_ids: # x[1]: sentence id
test_data.append(x)
else:
train_data.append(x)


return train_data, test_data

4️⃣对string预处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def clean_str(string):
string = re.sub(r"[^A-Za-z0-9()!?\'\`]", "", string)
string = re.sub(r"\'s", " \'s", string)
string = re.sub(r"\'m", " \'m", string)
string = re.sub(r"\'ve", " \'ve", string)
string = re.sub(r"n\'t", " n\'t", string)
string = re.sub(r"\'re", " \'re", string)
string = re.sub(r"\'d", " \'d", string)
string = re.sub(r"\'ll", " \'ll", string)
string = re.sub(r",", " , ", string)
string = re.sub(r"!", " ! ", string)
string = re.sub(r"\(", " \( ", string)
string = re.sub(r"\)", " \) ", string)
string = re.sub(r"\?", " \? ", string)
string = re.sub(r"\s{2,}", " ", string)
string = re.sub(r"\@.*?[\s\n]", "", string)
string = re.sub(r"https*://.+[\s]", "", string)
return string.strip().lower()

5️⃣collate_fn(batch)

重写collate_fn组建mini-batch,在NLP中常用,句子的不等长性

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def collate_fn(batch):  # rewrite collate_fn to form a mini-batch
lengths = np.array([len(data['sentence']) for data in batch])
sorted_index = np.argsort(-lengths)
lengths = lengths[sorted_index] # descend order

max_length = lengths[0]
batch_size = len(batch)
sentence_tensor = torch.LongTensor(batch_size, int(max_length)).zero_()

for i, index in enumerate(sorted_index):
sentence_tensor[i][:lengths[i]] = torch.LongTensor(batch[index]['sentence'][:max_length])

sentiments = torch.autograd.Variable(torch.LongTensor([batch[i]['sentiment'] for i in sorted_index]))
if config.use_cuda:
packed_sequences = torch.nn.utils.rnn.pack_padded_sequence(Variable(sentence_tensor.t()).cuda(), lengths) #remember to transpose
sentiments = sentiments.cuda()
else:
packed_sequences = torch.nn.utils.rnn.pack_padded_sequence(Variable(sentence_tensor.t()),lengths) # remember to transpose
return {'sentence': packed_sequences, 'sentiment': sentiments}

## 重写collate_fn(batch)以用于dataloader使用,使用方法如下:

train_dataloader=DataLoader(train_data,batch_size=32,shuffle=True,collate_fn=collate_fn)

## 其中,train_dataloader可循环遍历​​。
for data in train_dataloader:
...


6️⃣使用yield获得数据的generator

yield的用法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def get_dataset(txt_file):     # return generator
with open(txt_file,'r') as f:
for line in f:
if len(line.strip())==0:
continue
sentence=list(line.strip())+['<eos>']
yield sentence

#在使用的时候:
dataset=get_dataset(txt_file)
for d in dataset:
pass

#如果需要还可以改成list形式
dataset=list(get_dataset(txt_file))


7️⃣动态创建RNN实例

根据rnn_type动态创建对象实例,使用了getattr

1
2
3
# rnn in ['GRU','LSTM','RNN']

self.rnn = getattr(nn, self.rnn_type)(self.embedding_dim, self.hidden_dim, self.num_layers, dropout=self.dropout)