forked from gknappattack/DRAGN-Town-Quests
-
Notifications
You must be signed in to change notification settings - Fork 1
/
ngram.py
125 lines (108 loc) · 3.69 KB
/
ngram.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import string
import random
import time
from typing import List
# ideally we would use some smart text tokenizer, but for simplicity use this one
def tokenize(text: str) -> List[str]:
"""
:param text: Takes input sentence
:return: tokenized sentence
"""
for punct in string.punctuation:
text = text.replace(punct, ' '+punct+' ')
t = text.split()
return t
def get_ngrams(n: int, tokens: list) -> list:
"""
:param n: n-gram size
:param tokens: tokenized sentence
:return: list of ngrams
ngrams of tuple form: ((previous wordS!), target word)
"""
# tokens.append('<END>')
tokens = (n-1)*['<START>']+tokens
l = [(tuple([tokens[i-p-1] for p in reversed(range(n-1))]), tokens[i]) for i in range(n-1, len(tokens))]
return l
class NgramModel(object):
def __init__(self, n):
self.n = n
# dictionary that keeps list of candidate words given context
self.context = {}
# keeps track of how many times ngram has appeared in the text before
self.ngram_counter = {}
def update(self, sentence: str) -> None:
"""
Updates Language Model
:param sentence: input text
"""
n = self.n
ngrams = get_ngrams(n, tokenize(sentence))
for ngram in ngrams:
if ngram in self.ngram_counter:
self.ngram_counter[ngram] += 1.0
else:
self.ngram_counter[ngram] = 1.0
prev_words, target_word = ngram
if prev_words in self.context:
self.context[prev_words].append(target_word)
else:
self.context[prev_words] = [target_word]
def prob(self, context, token):
"""
Calculates probability of a candidate token to be generated given a context
:return: conditional probability
"""
try:
count_of_token = self.ngram_counter[(context, token)]
count_of_context = float(len(self.context[context]))
result = count_of_token / count_of_context
except KeyError:
result = 0.0
return result
def random_token(self, context):
"""
Given a context we "semi-randomly" select the next word to append in a sequence
:param context:
:return:
"""
r = random.random()
map_to_probs = {}
token_of_interest = self.context[context]
for token in token_of_interest:
map_to_probs[token] = self.prob(context, token)
summ = 0
for token in sorted(map_to_probs):
summ += map_to_probs[token]
if summ > r:
return token
def generate_text(self, token_count: int, input=None):
"""
:param token_count: number of words to be produced
:return: generated text
"""
result = []
n = self.n
context_queue = (n - 1) * ['<START>']
#print('cq: ', context_queue)
if input != None:
input_token = tokenize(input)
i = 0
for token in input_token:
result.append(token)
#context_queue[i] = token
i += 1
if len(input_token) == 1:
context_queue[-1] = input_token[0]
else:
context_queue = input_token[-(n-1):]
#print('cq: ', context_queue)
for _ in range(token_count):
obj = self.random_token(tuple(context_queue))
result.append(obj)
if n > 1:
context_queue.pop(0)
if obj == '.':
context_queue = (n - 1) * ['<START>']
else:
context_queue.append(obj)
return ' '.join(result)