Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

ExecutableWorld, designed to also work with BatchWorld #170

Merged
merged 10 commits into from
Jun 27, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions parlai/agents/repeat_label/repeat_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def __init__(self, opt, shared=None):

def act(self):
obs = self.observation
if obs is None:
return { 'text': "Nothing to repeat yet." }
reply = {}
reply['id'] = self.getID()
if ('labels' in obs and obs['labels'] is not None
Expand Down
76 changes: 69 additions & 7 deletions parlai/core/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def shutdown(self):

class MultiAgentDialogWorld(World):
"""Basic world where each agent gets a turn in a round-robin fashion,
recieving as input the actions of all other agents since that agent last
receiving as input the actions of all other agents since that agent last
acted.
"""
def __init__(self, opt, agents=None, shared=None):
Expand Down Expand Up @@ -315,6 +315,53 @@ def shutdown(self):
a.shutdown()


class ExecutableWorld(MultiAgentDialogWorld):
"""A world where messages from agents can be interpreted as _actions_ in the
world which result in changes in the environment (are executed). Hence a grounded
simulation can be implemented rather than just dialogue between agents.
"""
def __init__(self, opt, agents=None, shared=None):
super().__init__(opt, agents, shared)
self.init_world()

def init_world(self):
"""An executable world class should implement this function, otherwise
the actions do not do anything (and it is the same as MultiAgentDialogWorld).
"""
pass

def execute(self, agent, act):
"""An executable world class should implement this function, otherwise
the actions do not do anything (and it is the same as MultiAgentDialogWorld).
"""
pass

def observe(self, agent, act):
"""An executable world class should implement this function, otherwise
the observations for each agent are just the messages from other agents
and not confitioned on the world at all (and it is thus the same as
MultiAgentDialogWorld). """
if agent.id == act['id']:
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None or {}?

else:
return act

def parley(self):
"""For each agent: act, execute and observe actions in world
"""
acts = self.acts
for index, agent in enumerate(self.agents):
# The agent acts.
acts[index] = agent.act()
# We execute this action in the world.
self.execute(agent, acts[index])
# All agents (might) observe the results.
for other_agent in self.agents:
obs = self.observe(other_agent, acts[index])
if obs is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if obs instead of you change None to {}, but either way is fine with me

other_agent.observe(obs)


class MultiWorld(World):
"""Container for a set of worlds where each world gets a turn
in a round-robin fashion. The same user_agents are placed in each,
Expand Down Expand Up @@ -457,7 +504,7 @@ class BatchWorld(World):
"""Creates a separate world for each item in the batch, sharing
the parameters for each.
The underlying world(s) it is batching can be either ``DialogPartnerWorld``,
``MultiAgentWorld`` or ``MultiWorld``.
``MultiAgentWorld``, ``ExecutableWorld`` or ``MultiWorld``.
"""

def __init__(self, opt, world):
Expand All @@ -481,11 +528,20 @@ def __next__(self):
if self.epoch_done():
raise StopIteration()

def batch_observe(self, index, batch_actions):
def batch_observe(self, index, batch_actions, index_acting):
batch_observations = []
for i, w in enumerate(self.worlds):
agents = w.get_agents()
observation = agents[index].observe(validate(batch_actions[i]))
observation = None
if hasattr(w, 'observe'):
# The world has its own observe function, which the action
# first goes through (agents receive messages via the world,
# not from each other).
observation = w.observe(agents[index], validate(batch_actions[i]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this have a chance of returning None?

else:
if index == index_acting: return None # don't observe yourself talking
observation = validate(batch_actions[i])
observation = agents[index].observe(observation)
if observation is None:
raise ValueError('Agents should return what they observed.')
batch_observations.append(observation)
Expand Down Expand Up @@ -523,11 +579,17 @@ def parley(self):
w.parley_init()

for index in range(num_agents):
# The agent acts.
batch_act = self.batch_act(index, batch_observations[index])
# We possibly execute this action in the world.
for i, w in enumerate(self.worlds):
if hasattr(w, 'execute'):
w.execute(w.agents[i], batch_act[i])
# All agents (might) observe the results.
for other_index in range(num_agents):
if index != other_index:
batch_observations[other_index] = (
self.batch_observe(other_index, batch_act))
obs = self.batch_observe(other_index, batch_act, index)
if obs is not None:
batch_observations[other_index] = obs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if you need to override this always to make sure that you don't accidentally view a stale message? maybe better to fill it with {}


def display(self):
s = ("[--batchsize " + str(len(self.worlds)) + "--]\n")
Expand Down