-
Notifications
You must be signed in to change notification settings - Fork 2
/
hf_tokenizer.py
68 lines (51 loc) · 2.68 KB
/
hf_tokenizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import BpeTrainer
from tokenizers.processors import TemplateProcessing
class LM_Tokenizer:
def __init__(self, config):
self.tokenizer = Tokenizer(BPE())
self.tokenizer.pre_tokenizer = Whitespace()
self.trainer = BpeTrainer(special_tokens=['<s>', '</s>', '<unk>', '<pad>', '<mask>'])
self.config = config
def train_tokenizer(self, data_files=None, binary_iterator=None, str_iter=None):
if data_files is not None:
self.tokenizer.train(self.trainer, data_files)
else:
str_iter = str_iter if str_iter is not None else self.make_str_iter(binary_iterator)
self.tokenizer.train_from_iterator(trainer=self.trainer, iterator=str_iter)
self.set_up_tokenizer()
def make_str_iter(self, binary_iterator):
def str_iter():
for batch in binary_iterator:
yield self.decode_to_str(batch)
return str_iter()
def set_up_tokenizer(self):
self.tokenizer.enable_padding(pad_id=self.tokenizer.token_to_id('<pad>'),
length=self.config['max_length'])
self.tokenizer.enable_truncation(self.config['max_length']-1)
self.tokenizer.post_processor = TemplateProcessing(single = "<s>:1 $A:1 </s>:1",
pair = "<s>:1 $A:1 </s>:1 </s>:2 $B:2 </s>:2",
special_tokens=[('<s>',1), ('</s>',2)])
def decode_to_str(self, batch_text) :
"""
Converts bytes string data to text. And truncates to max_len.
"""
max_len = self.config['max_length']
return [ ' '.join(text.decode('utf-8').split()[:max_len] if isinstance(text, bytes)
else text.split()[:max_len])
for text in batch_text ]
def batch_encode_plus(self, batch1, batch2=None):
"""
Two batches correspond to sequences of different type/language.
"""
if batch2 is None :
return self.tokenizer.encode_batch( self.decode_to_str(batch1) )
else :
lis = [ (seq1,seq2) for seq1, seq2 in zip( self.decode_to_str(batch1), self.decode_to_str(batch2) ) ]
return self.tokenizer.encode_batch(lis)
def get_token_ids(self, token_encoding):
return [elem.ids for elem in token_encoding]
def get_lang_ids(self, token_encoding):
return[elem.type_ids for elem in token_encoding]