-
Notifications
You must be signed in to change notification settings - Fork 0
/
postprocess_wikiText.py
129 lines (111 loc) · 4.74 KB
/
postprocess_wikiText.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
Script to post-process WikiText files created with `create_wikitext.py`.
Creates additional files where words not in the training data are replaced
with <UNK> and numbers are modified with a regex.
"""
import re
import argparse
from collections import Counter
from pathlib import Path
number_split_re = re.compile(r'([,.])')
UNK = '<unk>'
number_match_re = re.compile(r'^([0-9]+[,.]?)+$')
def replace_number(token):
"""Replaces a number and returns a list of one or multiple tokens."""
if number_match_re.match(token):
return number_split_re.sub(r' @\1@ ', token)
return token
def build_vocab(file_path, cutoff=3):
counter = Counter()
with open(file_path, 'r', encoding='utf-8') as f:
for i, line in enumerate(f):
tokens = line.strip().split(' ') + ['<eos>']
counter.update(tokens)
vocab = {}
in_vocab_count = 0
OOV_count = 0
for token, count in counter.most_common():
if count >= cutoff:
vocab[token] = count
in_vocab_count += count
else:
OOV_count += count
print('OOV ratio: %.4f.' % (OOV_count / (in_vocab_count + OOV_count)))
return vocab
def limit_vocab(unk_path, vocab):
"""
https://gist.github.com/Smerity/94af5902aa9498817c92d1e71eb2f87b#file-limit_vocab-py
:param unk_path:
:param vocab:
:return:
"""
temp_file_path = unk_path.with_name(unk_path.name + '.temp')
total_num_tokens = 0
print(f'Limiting vocab in {unk_path}. Writing to {unk_path}.')
with open(unk_path, 'r', encoding='utf-8') as f_in, open(temp_file_path, 'w', encoding='utf-8') as f_out:
for line in f_in:
tokens = [x for x in line.strip().split(' ') if x]
tokens = [token if token in vocab else UNK for token in tokens]
# Ensures there's a space between tokens, including the last word,
# newline, and the first word of the next line
tokens = tokens + ['\n']
total_num_tokens += len(tokens)
tokens = [''] + tokens
line = ' '.join(tokens)
f_out.write(line)
print(f'{unk_path.name}. # of tokens: {total_num_tokens}')
temp_file_path.replace(unk_path)
def replace_numbers(file_path, unk_path):
"""
Replace numbers as in Smerity's script:
https://gist.github.com/Smerity/94af5902aa9498817c92d1e71eb2f87b#file-post_process-py
:param file_path:
:param unk_path:
:return:
"""
print(f'Replacing numbers in {file_path}. Writing to {unk_path}.')
with open(file_path, 'r', encoding='utf-8') as f_in, open(unk_path, 'w', encoding='utf-8') as f_out:
for line in f_in:
raw_tokens = line.strip().split(' ')
tokens = []
for token in raw_tokens:
tokens.append(replace_number(token))
# Starting each line with a blank line is required
# Some systems replace \n with <eos> and assume, like in PTB, everything is space separated
tokens = [''] + tokens + ['\n']
line = ' '.join(tokens)
f_out.write(line)
def main(args):
input_path = Path(args.input)
assert input_path.exists(), f'Error: {input_path} does not exist.'
sml_wiki = input_path / f'{args.lang}-2'
lrg_wiki = input_path / f'{args.lang}-100'
assert sml_wiki.exists(), f'Error: {sml_wiki} does not exist.'
assert lrg_wiki.exists(), f'Error: {lrg_wiki} does not exist.'
splits = ['train', 'valid', 'test']
for wiki in [sml_wiki, lrg_wiki]:
for split in splits:
# replace numbers with placeholders
file_path = wiki / f'{args.lang}.wiki.{split}.tokens'
unk_path = wiki / f'{args.lang}.wiki.{split}.tokens.unk'
replace_numbers(file_path, unk_path)
sml_wiki_train = sml_wiki / f'{args.lang}.wiki.train.tokens'
lrg_wiki_train = lrg_wiki / f'{args.lang}.wiki.train.tokens'
sml_vocab = build_vocab(sml_wiki_train)
print(f'{args.lang}-2 vocab size: {len(sml_vocab)}')
lrg_vocab = build_vocab(lrg_wiki_train)
print(f'{args.lang}-100 vocab size: {len(lrg_vocab)}')
# replace words not in the vocab with <unk>
for wiki, vocab in zip([sml_wiki, lrg_wiki], [sml_vocab, lrg_vocab]):
for split in splits:
unk_path = wiki / f'{args.lang}.wiki.{split}.tokens.unk'
limit_vocab(unk_path, vocab)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', required=True,
help='the directory of the wikitext files')
parser.add_argument('-l', '--lang', required=True,
help='the iso code of the language of the Wikipedia '
'documents, e.g. en, fr, de, etc.')
args = parser.parse_args()
main(args)