From 25c24c9708389f68aa0ad53b88602cfdf5070723 Mon Sep 17 00:00:00 2001 From: Jason Weston Date: Mon, 26 Jun 2017 14:30:21 -0400 Subject: [PATCH 01/10] small --- parlai/agents/remote_agent/remote_agent.py | 4 ++-- parlai/agents/repeat_label/repeat_label.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/parlai/agents/remote_agent/remote_agent.py b/parlai/agents/remote_agent/remote_agent.py index 6dbeafeac21..399417c166b 100644 --- a/parlai/agents/remote_agent/remote_agent.py +++ b/parlai/agents/remote_agent/remote_agent.py @@ -11,7 +11,7 @@ import zmq -class RemoteAgent(Agent): +class RemoteAgentAgent(Agent): """Agent which connects over ZMQ to a paired agent. The other agent is launched using the command line options set via `add_cmdline_args`.""" @@ -105,7 +105,7 @@ def shutdown(self): self.process.kill() -class ParsedRemoteAgent(RemoteAgent): +class ParsedRemoteAgent(RemoteAgentAgent): """Same as the regular remote agent, except that this agent converts all text into vectors using its dictionary before sending them. """ diff --git a/parlai/agents/repeat_label/repeat_label.py b/parlai/agents/repeat_label/repeat_label.py index 7becab203fe..1559c168eee 100644 --- a/parlai/agents/repeat_label/repeat_label.py +++ b/parlai/agents/repeat_label/repeat_label.py @@ -33,6 +33,8 @@ def __init__(self, opt, shared=None): def act(self): obs = self.observation + if obs is None: + return { 'text': "Nothing to repeat yet." } reply = {} reply['id'] = self.getID() if ('labels' in obs and obs['labels'] is not None From 38d4324b568cf50231cc8068ec159c1a0fbb17ad Mon Sep 17 00:00:00 2001 From: Jason Weston Date: Mon, 26 Jun 2017 16:05:07 -0400 Subject: [PATCH 02/10] exec world --- parlai/agents/remote_agent/remote_agent.py | 4 +- parlai/core/worlds.py | 115 ++++++++++++++++++--- 2 files changed, 103 insertions(+), 16 deletions(-) diff --git a/parlai/agents/remote_agent/remote_agent.py b/parlai/agents/remote_agent/remote_agent.py index 399417c166b..6dbeafeac21 100644 --- a/parlai/agents/remote_agent/remote_agent.py +++ b/parlai/agents/remote_agent/remote_agent.py @@ -11,7 +11,7 @@ import zmq -class RemoteAgentAgent(Agent): +class RemoteAgent(Agent): """Agent which connects over ZMQ to a paired agent. The other agent is launched using the command line options set via `add_cmdline_args`.""" @@ -105,7 +105,7 @@ def shutdown(self): self.process.kill() -class ParsedRemoteAgent(RemoteAgentAgent): +class ParsedRemoteAgent(RemoteAgent): """Same as the regular remote agent, except that this agent converts all text into vectors using its dictionary before sending them. """ diff --git a/parlai/core/worlds.py b/parlai/core/worlds.py index 49625bfc9bd..717173bc347 100644 --- a/parlai/core/worlds.py +++ b/parlai/core/worlds.py @@ -268,7 +268,7 @@ def shutdown(self): class MultiAgentDialogWorld(World): """Basic world where each agent gets a turn in a round-robin fashion, - recieving as input the actions of all other agents since that agent last + receiving as input the actions of all other agents since that agent last acted. """ def __init__(self, opt, agents=None, shared=None): @@ -315,6 +315,87 @@ def shutdown(self): a.shutdown() +class ExecutableWorld(World): + """A world where messages from agents can be interpreted as _actions_ in the + world which result in changes in the environment (are executed). Hence a grounded + simulation can be implemented rather than just dialogue between agents. + """ + def __init__(self, opt, agents=None, shared=None): + super().__init__(opt) + if shared: + # Create agents based on shared data. + self.agents = create_agents_from_shared(shared['agents']) + else: + # Add passed in agents directly. + self.agents = agents + self.acts = [None] * len(agents) + super().__init__(opt, agents, shared) + self.init_world() + + def init_world(self): + """An executable world class should implement this function, otherwise + the world still works, but actions do not do anything. + """ + pass + + def gen_observe(self, agent): + """ Generate an observation of the world for an agent + given the current state of the (executable) world. + This is differentiated from a message from another dialogue agent, + and hence has the id 'world' in the message. + """ + msg = {} + msg['text'] = '' # By default the world does nothing. + msg['id'] = 'world' + agent.observe(validate(msg)) + return msg + + def execute(self, agent, act): + # Execute action from agent. We also send an update to all other agents + # that can observe the change. + if 'text' in act: + valid = self.g.parse_exec(agent.id, act['text']) + if not valid: + agent.observe({'id':'world', 'text':'invalid action'}) + for index, agent in enumerate(self.agents): + acts[index] = agent.act() + for other_agent in self.agents: + if other_agent != agent: + other_agent.observe(validate(acts[index])) + + + + def parley(self): + """For each agent: observe, act, execute action in world + """ + acts, agents = self.acts, self.agents + for index, agent in enumerate(agents): + acts[index] = agent.act() + # execute action in environment + self.execute(agent, acts[index]) + + def epoch_done(self): + done = False + for a in self.agents: + if a.epoch_done(): + done = True + return done + + def episode_done(self): + done = False + for a in self.agents: + if a.episode_done(): + done = True + return done + + def report(self): + return self.agents[0].report() + + def shutdown(self): + for a in self.agents: + a.shutdown() + + class MultiWorld(World): """Container for a set of worlds where each world gets a turn in a round-robin fashion. The same user_agents are placed in each, @@ -481,16 +562,6 @@ def __next__(self): if self.epoch_done(): raise StopIteration() - def batch_observe(self, index, batch_actions): - batch_observations = [] - for i, w in enumerate(self.worlds): - agents = w.get_agents() - observation = agents[index].observe(validate(batch_actions[i])) - if observation is None: - raise ValueError('Agents should return what they observed.') - batch_observations.append(observation) - return batch_observations - def batch_act(self, index, batch_observation): # Given batch observation, do update for agents[index]. # Call update on agent @@ -512,6 +583,16 @@ def batch_act(self, index, batch_observation): batch_actions.append(acts[index]) return batch_actions + def batch_observe(self, index, batch_actions, index_acting): + batch_observations = [] + for i, w in enumerate(self.worlds): + agents = w.get_agents() + observation = agents[index].observe(validate(batch_actions[i])) + if observation is None: + raise ValueError('Agents should return what they observed.') + batch_observations.append(observation) + return batch_observations + def parley(self): # Collect batch together for each agent, and do update. # Assumes DialogPartnerWorld, MultiAgentWorld, or MultiWorlds of them. @@ -523,11 +604,17 @@ def parley(self): w.parley_init() for index in range(num_agents): + # The agent acts. batch_act = self.batch_act(index, batch_observations[index]) + # We possibly execute this action in the world. + for i, w in enumerate(self.worlds): + if hasattr(w, 'execute'): + w.execute(batch_actions[i]) + # All agents observe the results. for other_index in range(num_agents): - if index != other_index: - batch_observations[other_index] = ( - self.batch_observe(other_index, batch_act)) + batch_observations[other_index] = ( + self.batch_observe(other_index, batch_act, index)) + def display(self): s = ("[--batchsize " + str(len(self.worlds)) + "--]\n") From 4d0380d4bbe3c25967aac21a486fcaa865d64c8d Mon Sep 17 00:00:00 2001 From: Jason Weston Date: Mon, 26 Jun 2017 16:14:05 -0400 Subject: [PATCH 03/10] small --- parlai/core/worlds.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/parlai/core/worlds.py b/parlai/core/worlds.py index 717173bc347..af1d709e2b9 100644 --- a/parlai/core/worlds.py +++ b/parlai/core/worlds.py @@ -562,6 +562,16 @@ def __next__(self): if self.epoch_done(): raise StopIteration() + def batch_observe(self, index, batch_actions, index_acting): + batch_observations = [] + for i, w in enumerate(self.worlds): + agents = w.get_agents() + observation = agents[index].observe(validate(batch_actions[i])) + if observation is None: + raise ValueError('Agents should return what they observed.') + batch_observations.append(observation) + return batch_observations + def batch_act(self, index, batch_observation): # Given batch observation, do update for agents[index]. # Call update on agent @@ -583,16 +593,6 @@ def batch_act(self, index, batch_observation): batch_actions.append(acts[index]) return batch_actions - def batch_observe(self, index, batch_actions, index_acting): - batch_observations = [] - for i, w in enumerate(self.worlds): - agents = w.get_agents() - observation = agents[index].observe(validate(batch_actions[i])) - if observation is None: - raise ValueError('Agents should return what they observed.') - batch_observations.append(observation) - return batch_observations - def parley(self): # Collect batch together for each agent, and do update. # Assumes DialogPartnerWorld, MultiAgentWorld, or MultiWorlds of them. From 3666e573c487ba507e25807a80b3cefc835b789a Mon Sep 17 00:00:00 2001 From: Jason Weston Date: Mon, 26 Jun 2017 16:21:18 -0400 Subject: [PATCH 04/10] blah --- parlai/core/worlds.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/parlai/core/worlds.py b/parlai/core/worlds.py index af1d709e2b9..07bee82aca8 100644 --- a/parlai/core/worlds.py +++ b/parlai/core/worlds.py @@ -562,11 +562,14 @@ def __next__(self): if self.epoch_done(): raise StopIteration() - def batch_observe(self, index, batch_actions, index_acting): + def batch_observe(self, index, batch_actions) batch_observations = [] for i, w in enumerate(self.worlds): agents = w.get_agents() - observation = agents[index].observe(validate(batch_actions[i])) + if hasattr(w, 'observe'): + observation = w.observe(agents[index], validate(batch_actions[i])) + else: + observation = agents[index].observe(validate(batch_actions[i])) if observation is None: raise ValueError('Agents should return what they observed.') batch_observations.append(observation) @@ -610,10 +613,10 @@ def parley(self): for i, w in enumerate(self.worlds): if hasattr(w, 'execute'): w.execute(batch_actions[i]) - # All agents observe the results. + # All agents (might) observe the results. for other_index in range(num_agents): batch_observations[other_index] = ( - self.batch_observe(other_index, batch_act, index)) + self.batch_observe(other_index, batch_act)) def display(self): From 3b0dfaab3af9de25df00552d28c9c1a5e7b2615d Mon Sep 17 00:00:00 2001 From: Jason Weston Date: Mon, 26 Jun 2017 16:22:57 -0400 Subject: [PATCH 05/10] mm --- parlai/core/worlds.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/parlai/core/worlds.py b/parlai/core/worlds.py index 07bee82aca8..ef0029835ff 100644 --- a/parlai/core/worlds.py +++ b/parlai/core/worlds.py @@ -615,9 +615,9 @@ def parley(self): w.execute(batch_actions[i]) # All agents (might) observe the results. for other_index in range(num_agents): - batch_observations[other_index] = ( - self.batch_observe(other_index, batch_act)) - + obs = self.batch_observe(other_index, batch_act)) + if obs is not None: + batch_observations[other_index] = obs def display(self): s = ("[--batchsize " + str(len(self.worlds)) + "--]\n") From a26101f94a957f0d5c55d97283278bdad375db48 Mon Sep 17 00:00:00 2001 From: Jason Weston Date: Mon, 26 Jun 2017 16:25:46 -0400 Subject: [PATCH 06/10] index --- parlai/core/worlds.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/parlai/core/worlds.py b/parlai/core/worlds.py index ef0029835ff..d6a9d758365 100644 --- a/parlai/core/worlds.py +++ b/parlai/core/worlds.py @@ -562,16 +562,18 @@ def __next__(self): if self.epoch_done(): raise StopIteration() - def batch_observe(self, index, batch_actions) + def batch_observe(self, index, batch_actions, index_acting) batch_observations = [] for i, w in enumerate(self.worlds): agents = w.get_agents() if hasattr(w, 'observe'): observation = w.observe(agents[index], validate(batch_actions[i])) else: + if index == index_acting: + return None observation = agents[index].observe(validate(batch_actions[i])) - if observation is None: - raise ValueError('Agents should return what they observed.') + if observation is None: + raise ValueError('Agents should return what they observed.') batch_observations.append(observation) return batch_observations @@ -615,7 +617,7 @@ def parley(self): w.execute(batch_actions[i]) # All agents (might) observe the results. for other_index in range(num_agents): - obs = self.batch_observe(other_index, batch_act)) + obs = self.batch_observe(other_index, batch_act, index)) if obs is not None: batch_observations[other_index] = obs From 9844660b990bbb6ef05a405e3d830b4c865d861e Mon Sep 17 00:00:00 2001 From: Jason Weston Date: Mon, 26 Jun 2017 16:34:52 -0400 Subject: [PATCH 07/10] index --- parlai/core/worlds.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parlai/core/worlds.py b/parlai/core/worlds.py index d6a9d758365..b682f772907 100644 --- a/parlai/core/worlds.py +++ b/parlai/core/worlds.py @@ -562,7 +562,7 @@ def __next__(self): if self.epoch_done(): raise StopIteration() - def batch_observe(self, index, batch_actions, index_acting) + def batch_observe(self, index, batch_actions, index_acting): batch_observations = [] for i, w in enumerate(self.worlds): agents = w.get_agents() @@ -617,7 +617,7 @@ def parley(self): w.execute(batch_actions[i]) # All agents (might) observe the results. for other_index in range(num_agents): - obs = self.batch_observe(other_index, batch_act, index)) + obs = self.batch_observe(other_index, batch_act, index) if obs is not None: batch_observations[other_index] = obs From 46104957223b0e8052326fe9fd6efec46f2c1416 Mon Sep 17 00:00:00 2001 From: Jason Weston Date: Mon, 26 Jun 2017 17:53:16 -0400 Subject: [PATCH 08/10] index --- parlai/core/worlds.py | 102 +++++++++++++++--------------------------- 1 file changed, 37 insertions(+), 65 deletions(-) diff --git a/parlai/core/worlds.py b/parlai/core/worlds.py index b682f772907..a5300570fa9 100644 --- a/parlai/core/worlds.py +++ b/parlai/core/worlds.py @@ -315,85 +315,51 @@ def shutdown(self): a.shutdown() -class ExecutableWorld(World): +class ExecutableWorld(MultiAgentDialogWorld): """A world where messages from agents can be interpreted as _actions_ in the world which result in changes in the environment (are executed). Hence a grounded simulation can be implemented rather than just dialogue between agents. """ def __init__(self, opt, agents=None, shared=None): - super().__init__(opt) - if shared: - # Create agents based on shared data. - self.agents = create_agents_from_shared(shared['agents']) - else: - # Add passed in agents directly. - self.agents = agents - self.acts = [None] * len(agents) super().__init__(opt, agents, shared) self.init_world() def init_world(self): """An executable world class should implement this function, otherwise - the world still works, but actions do not do anything. + the actions do not do anything (and it is the same as MultiAgentDialogWorld). """ pass - def gen_observe(self, agent): - """ Generate an observation of the world for an agent - given the current state of the (executable) world. - This is differentiated from a message from another dialogue agent, - and hence has the id 'world' in the message. - """ - msg = {} - msg['text'] = '' # By default the world does nothing. - msg['id'] = 'world' - agent.observe(validate(msg)) - return msg - def execute(self, agent, act): - # Execute action from agent. We also send an update to all other agents - # that can observe the change. - if 'text' in act: - valid = self.g.parse_exec(agent.id, act['text']) - if not valid: - agent.observe({'id':'world', 'text':'invalid action'}) - for index, agent in enumerate(self.agents): - acts[index] = agent.act() - for other_agent in self.agents: - if other_agent != agent: - other_agent.observe(validate(acts[index])) - + """An executable world class should implement this function, otherwise + the actions do not do anything (and it is the same as MultiAgentDialogWorld). + """ + pass + def observe(self, agent, act): + """An executable world class should implement this function, otherwise + the observations for each agent are just the messages from other agents + and not confitioned on the world at all (and it is thus the same as + MultiAgentDialogWorld). """ + if agent.id == act['id']: + return None + else: + return act def parley(self): - """For each agent: observe, act, execute action in world + """For each agent: act, execute and observe actions in world """ - acts, agents = self.acts, self.agents - for index, agent in enumerate(agents): + acts = self.acts + for index, agent in enumerate(self.agents): + # The agent acts. acts[index] = agent.act() - # execute action in environment + # We execute this action in the world. self.execute(agent, acts[index]) - - def epoch_done(self): - done = False - for a in self.agents: - if a.epoch_done(): - done = True - return done - - def episode_done(self): - done = False - for a in self.agents: - if a.episode_done(): - done = True - return done - - def report(self): - return self.agents[0].report() - - def shutdown(self): - for a in self.agents: - a.shutdown() + # All agents (might) observe the results. + for other_agent in self.agents: + obs = self.observe(other_agent, acts[index]) + if obs is not None: + other_agent.observe(obs) class MultiWorld(World): @@ -538,7 +504,7 @@ class BatchWorld(World): """Creates a separate world for each item in the batch, sharing the parameters for each. The underlying world(s) it is batching can be either ``DialogPartnerWorld``, - ``MultiAgentWorld`` or ``MultiWorld``. + ``MultiAgentWorld``, ``ExecutableWorld`` or ``MultiWorld``. """ def __init__(self, opt, world): @@ -567,13 +533,19 @@ def batch_observe(self, index, batch_actions, index_acting): for i, w in enumerate(self.worlds): agents = w.get_agents() if hasattr(w, 'observe'): + # The world has its own observe function, which the action + # first goes through (agents do not directly receive messages + # from each other). observation = w.observe(agents[index], validate(batch_actions[i])) else: - if index == index_acting: + observation = validate(batch_actions[i]) + # An agent does not send a message to itself, but we do allow + # the world to send a message to it after the agent acts. + if index == index_acting and observation.id != 'world': return None - observation = agents[index].observe(validate(batch_actions[i])) - if observation is None: - raise ValueError('Agents should return what they observed.') + observation = agents[index].observe(observation) + if observation is None: + raise ValueError('Agents should return what they observed.') batch_observations.append(observation) return batch_observations @@ -614,7 +586,7 @@ def parley(self): # We possibly execute this action in the world. for i, w in enumerate(self.worlds): if hasattr(w, 'execute'): - w.execute(batch_actions[i]) + w.execute(w.agents[i], batch_act[i]) # All agents (might) observe the results. for other_index in range(num_agents): obs = self.batch_observe(other_index, batch_act, index) From d874e6106ba8d92e1b9109294d140e89e89756bd Mon Sep 17 00:00:00 2001 From: Jason Weston Date: Mon, 26 Jun 2017 18:03:46 -0400 Subject: [PATCH 09/10] small batch fixes --- parlai/core/worlds.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/parlai/core/worlds.py b/parlai/core/worlds.py index a5300570fa9..0d59b776ae2 100644 --- a/parlai/core/worlds.py +++ b/parlai/core/worlds.py @@ -532,20 +532,18 @@ def batch_observe(self, index, batch_actions, index_acting): batch_observations = [] for i, w in enumerate(self.worlds): agents = w.get_agents() + observation = None if hasattr(w, 'observe'): # The world has its own observe function, which the action - # first goes through (agents do not directly receive messages - # from each other). + # first goes through (agents receive messages via the world, + # not from each other). observation = w.observe(agents[index], validate(batch_actions[i])) else: + if index == index_acting: return None # don't observe yourself talking observation = validate(batch_actions[i]) - # An agent does not send a message to itself, but we do allow - # the world to send a message to it after the agent acts. - if index == index_acting and observation.id != 'world': - return None + if observation is None: + raise ValueError('Agents should return what they observed.') observation = agents[index].observe(observation) - if observation is None: - raise ValueError('Agents should return what they observed.') batch_observations.append(observation) return batch_observations From 2b2200196523c8bdba9b4a4ff28e467b38876be1 Mon Sep 17 00:00:00 2001 From: Jason Weston Date: Mon, 26 Jun 2017 18:05:39 -0400 Subject: [PATCH 10/10] small batch fixes --- parlai/core/worlds.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parlai/core/worlds.py b/parlai/core/worlds.py index 0d59b776ae2..a88d1bfa930 100644 --- a/parlai/core/worlds.py +++ b/parlai/core/worlds.py @@ -541,9 +541,9 @@ def batch_observe(self, index, batch_actions, index_acting): else: if index == index_acting: return None # don't observe yourself talking observation = validate(batch_actions[i]) - if observation is None: - raise ValueError('Agents should return what they observed.') observation = agents[index].observe(observation) + if observation is None: + raise ValueError('Agents should return what they observed.') batch_observations.append(observation) return batch_observations