-
Notifications
You must be signed in to change notification settings - Fork 12
/
convert2idx.py
71 lines (56 loc) · 2.34 KB
/
convert2idx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
'''
Convert article's text to word indexes.
'''
import h5py
import numpy as np
import nltk
import utils
import os
import parameters as prm
import time
from nltk.tokenize import wordpunct_tokenize
def compute_idx(pages_path_in, pages_path_out, vocab):
f = h5py.File(pages_path_in, 'r')
if prm.att_doc and prm.att_segment_type == 'sentence':
nltk.download('punkt')
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
os.remove(pages_path_out) if os.path.exists(pages_path_out) else None
# Save to HDF5
fout = h5py.File(pages_path_out,'a')
if prm.att_doc:
shape = (f['text'].shape[0],prm.max_segs_doc,prm.max_words)
else:
shape=(f['text'].shape[0],prm.max_words)
idxs = fout.create_dataset('idx', shape=shape, dtype=np.int32)
mask = fout.create_dataset('mask', shape=(f['text'].shape[0],), dtype=np.float32)
i = 0
for text in f['text']:
st = time.time()
if prm.att_doc:
if prm.att_segment_type.lower() == 'section' or prm.att_segment_type.lower() == 'subsection':
segs = ['']
for line in text.split('\n'):
if prm.att_segment_type == 'section':
line = line.replace('===', '')
if line.strip().startswith('==') and line.strip().endswith('=='):
segs.append('')
segs[-1] += line.lower() + '\n'
elif prm.att_segment_type.lower() == 'sentence':
segs = tokenizer.tokenize(text.lower().decode('ascii', 'ignore'))
elif prm.att_segment_type.lower() == 'word':
segs = wordpunct_tokenize(text.decode('ascii', 'ignore'))
else:
raise ValueError('Not a valid value for the attention segment type (att_segment_type) parameter. Valid options are "section", "subsection", "sentence", or "word".')
segs = segs[:prm.max_segs_doc]
idxs_, _ = utils.text2idx2(segs, vocab, prm.max_words)
idxs[i,:len(idxs_),:] = idxs_
mask[i] = len(idxs_)
else:
idx, _ = utils.text2idx2([text.lower()], vocab, prm.max_words)
idxs[i,:] = idx[0]
i += 1
#if i > 3000:
# break
print 'processing article', i, 'time', time.time()-st
f.close()
fout.close()