-
Notifications
You must be signed in to change notification settings - Fork 13
/
utils.py
59 lines (48 loc) · 2.14 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import re
from collections import Counter
def preprocess(text):
# Replace punctuation with tokens so we can use them in our model
text = text.lower()
text = text.replace('.', ' <PERIOD> ')
text = text.replace(',', ' <COMMA> ')
text = text.replace('"', ' <QUOTATION_MARK> ')
text = text.replace(';', ' <SEMICOLON> ')
text = text.replace('!', ' <EXCLAMATION_MARK> ')
text = text.replace('?', ' <QUESTION_MARK> ')
text = text.replace('(', ' <LEFT_PAREN> ')
text = text.replace(')', ' <RIGHT_PAREN> ')
text = text.replace('--', ' <HYPHENS> ')
text = text.replace('?', ' <QUESTION_MARK> ')
# text = text.replace('\n', ' <NEW_LINE> ')
text = text.replace(':', ' <COLON> ')
words = text.split()
# Remove all words with 5 or fewer occurences
word_counts = Counter(words)
trimmed_words = [word for word in words if word_counts[word] > 5]
return trimmed_words
def get_batches(int_text, batch_size, seq_length):
"""
Return batches of input and target
:param int_text: Text with the words replaced by their ids
:param batch_size: The size of batch
:param seq_length: The length of sequence
:return: A list where each item is a tuple of (batch of input, batch of target).
"""
n_batches = int(len(int_text) / (batch_size * seq_length))
# Drop the last few characters to make only full batches
xdata = np.array(int_text[: n_batches * batch_size * seq_length])
ydata = np.array(int_text[1: n_batches * batch_size * seq_length + 1])
x_batches = np.split(xdata.reshape(batch_size, -1), n_batches, 1)
y_batches = np.split(ydata.reshape(batch_size, -1), n_batches, 1)
return list(zip(x_batches, y_batches))
def create_lookup_tables(words):
"""
Create lookup tables for vocabulary
:param words: Input list of words
:return: A tuple of dicts. The first dict....
"""
word_counts = Counter(words)
sorted_vocab = sorted(word_counts, key=word_counts.get, reverse=True)
int_to_vocab = {ii: word for ii, word in enumerate(sorted_vocab)}
vocab_to_int = {word: ii for ii, word in int_to_vocab.items()}
return vocab_to_int, int_to_vocab