-
Notifications
You must be signed in to change notification settings - Fork 14
/
bots.py
269 lines (225 loc) · 11 KB
/
bots.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
"""Module with classes for questioner and answerer bots. A generic ChatBotAgent is defined,
which is extended separate by Questioner and Answerer bot classes.
Refer ParlAI docs on general semantics of a ParlAI Agent:
* http://parl.ai/static/docs/basic_tutorial.html#agents
* http://parl.ai/static/docs/agents.html#parlai.core.agents.Agent
"""
from __future__ import absolute_import, division
import math
import torch
from torch import nn
from torch.autograd import Variable
from torch.autograd import backward as autograd_backward
from parlai.core.agents import Agent
def xavier_init(module):
"""Xavier initializer for module parameters."""
for parameter in module.parameters():
if len(parameter.data.shape) == 1:
# 1D vector means bias
parameter.data.fill_(0)
else:
fan_in = parameter.data.size(0)
fan_out = parameter.data.size(1)
parameter.data.normal_(0, math.sqrt(2 / (fan_in + fan_out)))
return module
class ChatBotAgent(Agent, nn.Module):
"""Parent class for both, questioner and answerer bots. Extends a ParlAI Agent and PyTorch
module. This class is generic implementation of how a bot should look like / act and observe
in a generic ParlAI dialog world. It comprises of state tensor, actions list and
observation dict.
This parent class provides a ``listen_net`` embedding module to embed the text tokens. Also,
``speak_net`` module takes out a token to be given as action, based on state of agent. Both
questioner and answerer agents observe and act in a generic way.
Attributes
----------
opt : dict
Command-line opts passed into constructor from the world.
observation : dict
Observations dict exchanged during dialogs and on starting fresh episode. Has keys as
described in ParlAI docs ('text', 'image', 'episode_done', 'reward').
actions : list
List of action tensors by agent, acted in the current dialog episode.
h_state, c_state : torch.autograd.Variable
State of the agent.
eval_flag : boolean
Flag indicating whether agent in training or evaluation mode.
listen_net : nn.Embedding
speak_net : nn.Linear
softmax : nn.Softmax
"""
def __init__(self, opt, shared=None):
super(ChatBotAgent, self).__init__(opt, shared)
nn.Module.__init__(self)
self.id = 'ChatBotAgent'
self.observation = None
# standard initializations
self.h_state = torch.Tensor()
self.c_state = torch.Tensor()
self.eval_flag = False
self.actions = []
# modules (common)
self.listen_net = nn.Embedding(self.opt['in_vocab_size'], self.opt['embed_size'])
self.speak_net = nn.Linear(self.opt['hidden_size'], self.opt['out_vocab_size'])
self.softmax = nn.Softmax()
# xavier init of listen_net and speak_net
for module in {self.listen_net, self.speak_net}:
module = xavier_init(module)
def observe(self, observation):
"""Given an input token, interact for next round."""
self.observation = observation
if not observation.get('episode_done'):
# embed and pass through LSTM
token_embeds = self.listen_net(observation['text'])
# concat with image representation (valid for abot)
if 'image' in observation:
token_embeds = torch.cat((token_embeds, observation['image']), 1)
# remove all dimensions with size one
token_embeds = token_embeds.squeeze(1)
# update agent state using these tokens
self.h_state, self.c_state = self.rnn(token_embeds, (self.h_state, self.c_state))
else:
if observation.get('reward') is not None:
for action in self.actions:
action.reinforce(observation['reward'])
autograd_backward(self.actions, [None for _ in self.actions], retain_graph=True)
# clamp all gradients between (-5, 5)
for parameter in self.parameters():
parameter.grad.data.clamp_(min=-5, max=5)
def act(self):
"""Speak a token."""
# compute softmax and choose a token
out_distr = self.softmax(self.speak_net(self.h_state))
if self.eval_flag:
_, actions = out_distr.max(1)
actions = actions.unsqueeze(1)
else:
actions = out_distr.multinomial()
self.actions.append(actions)
return {'text': actions.squeeze(1), 'id': self.id}
def reset(self, batch_size=None, retain_actions=False):
"""Reset state and actions. ``opt.batch_size`` is not always used because batch_size
changes when complete data is passed (for validation)."""
if batch_size is None:
batch_size = self.opt['batch_size']
self.h_state = Variable(torch.zeros(batch_size, self.opt['hidden_size']))
self.c_state = Variable(torch.zeros(batch_size, self.opt['hidden_size']))
if self.opt.get('use_gpu'):
self.h_state, self.c_state = self.h_state.cuda(), self.c_state.cuda()
if not retain_actions:
self.actions = []
def train(self):
"""Switch to training mode."""
self.eval_flag = False
def eval(self):
"""Switch to evaluation mode."""
self.eval_flag = True
def forward(self):
"""Dummy forward pass."""
pass
class Questioner(ChatBotAgent):
"""Questioner bot - extending a ParlAI Agent as well as a PyTorch module. Answerer is modeled
as a combination of a speaker network, a listener LSTM, and a prediction network.
At the start of new episode of dialog, a task is observed by questioner bot, which is embedded
via listener LSTM. At each round, questioner observes the answer and acts by modelling the
probability of output utterances based on previous state. After observing the reply from
answerer, the listener LSTM updates the state by processing both tokens (question/answer) of
the dialog exchange. In the final round, the prediction LSTM is unrolled twice to produce
questioner's prediction based on the final state and assigned task.
Attributes
----------
rnn : nn.LSTMCell
Listener LSTM module. Embedding module before listener is provided by base class.
predict_rnn, predict_net : nn.LSTMCell, nn.Linear
Collectively form the prediction network module.
task_offset : int
Offset in terms of one-hot encoding, task vectors come after width equal to question and
answer vocabulary.
listen_offset : int
Offset due to listening response of answer bot. Answer token one-hot vectors would
require width equal to answer vocabulary - next question vectors would be after that.
"""
def __init__(self, opt, shared=None):
opt['in_vocab_size'] = opt['q_out_vocab'] + opt['a_out_vocab'] + opt['task_vocab']
opt['out_vocab_size'] = opt['q_out_vocab']
super(Questioner, self).__init__(opt, shared)
self.id = 'QBot'
# always condition on task
self.rnn = nn.LSTMCell(self.opt['embed_size'], self.opt['hidden_size'])
# additional prediction network, start token included
num_preds = sum([len(ii) for ii in self.opt['props'].values()])
# network for predicting
self.predict_rnn = nn.LSTMCell(self.opt['embed_size'], self.opt['hidden_size'])
self.predict_net = nn.Linear(self.opt['hidden_size'], num_preds)
# xavier init of rnn, predict_rnn, predict_net
for module in {self.rnn, self.predict_rnn, self.predict_net}:
module = xavier_init(module)
# setting offset
self.task_offset = opt['q_out_vocab'] + opt['a_out_vocab']
self.listen_offset = opt['a_out_vocab']
def predict(self, tasks, num_tokens):
"""Return an answer from the task."""
guess_tokens = []
guess_distr = []
for _ in range(num_tokens):
# explicit task dependence
task_embeds = self.listen_net(tasks)
task_embeds = task_embeds.squeeze()
# unroll twice, compute softmax and choose a token
self.h_state, self.c_state = self.predict_rnn(task_embeds,
(self.h_state, self.c_state))
out_distr = self.softmax(self.predict_net(self.h_state))
# if evaluating
if self.eval_flag:
_, actions = out_distr.max(1)
else:
actions = out_distr.multinomial()
# record actions
self.actions.append(actions)
# record the guess and distribution
guess_tokens.append(actions)
guess_distr.append(out_distr)
# return prediction
return guess_tokens, guess_distr
class Answerer(ChatBotAgent):
"""Answerer bot - extending a ParlAI Agent as well as a PyTorch module. Answerer is modeled
as a combination of a speaker network, a listener LSTM, and an image encoder.
While observing, it embeds the received question tokens and concatenates them with image
embeds, using them to update its state by listener LSTM. Answerer bot acts by choosing a
token based on softmax probabilities obtained after passing the state through speak net. The
image encoder embeds each one-hot attribute vector via a linear layer and concatenates all
three encodings to obtain a unified image instance representation.
Attributes
----------
rnn : nn.LSTMCell
Listener LSTM module. Embedding module before listener is provided by base class.
img_net : nn.Embedding
Image Instance Encoder module.
listen_offset : int
Offset due to listening response of question bot. Question token one-hot vectors would
require width equal to question vocabulary - answer vectors would be after that.
"""
def __init__(self, opt, shared=None):
opt['in_vocab_size'] = opt['q_out_vocab'] + opt['a_out_vocab']
opt['out_vocab_size'] = opt['a_out_vocab']
super(Answerer, self).__init__(opt, shared)
self.id = 'ABot'
# number of attribute values
num_attrs = sum([len(ii) for ii in self.opt['props'].values()])
# number of unique attributes
num_unique_attrs = len(self.opt['props'])
# rnn input size
rnn_input_size = num_unique_attrs * self.opt['img_feat_size'] + self.opt['embed_size']
self.img_net = nn.Embedding(num_attrs, self.opt['img_feat_size'])
self.rnn = nn.LSTMCell(rnn_input_size, self.opt['hidden_size'])
# xavier init of img_net and rnn
for module in {self.img_net, self.rnn}:
module = xavier_init(module)
# set offset
self.listen_offset = opt['q_out_vocab']
def embed_image(self, image):
"""Embed the image attributes color, shape and style into vectors of length 20 each, and
concatenate them to make a feature vector representing the image.
"""
embeds = self.img_net(image)
features = torch.cat(embeds.transpose(0, 1), 1)
return features