This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
ExecutableWorld, designed to also work with BatchWorld #170
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
25c24c9
small
jaseweston 38d4324
exec world
jaseweston 4d0380d
small
jaseweston 3666e57
blah
jaseweston 3b0dfaa
mm
jaseweston a26101f
index
jaseweston 9844660
index
jaseweston 4610495
index
jaseweston d874e61
small batch fixes
jaseweston 2b22001
small batch fixes
jaseweston File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -268,7 +268,7 @@ def shutdown(self): | |
|
||
class MultiAgentDialogWorld(World): | ||
"""Basic world where each agent gets a turn in a round-robin fashion, | ||
recieving as input the actions of all other agents since that agent last | ||
receiving as input the actions of all other agents since that agent last | ||
acted. | ||
""" | ||
def __init__(self, opt, agents=None, shared=None): | ||
|
@@ -315,6 +315,53 @@ def shutdown(self): | |
a.shutdown() | ||
|
||
|
||
class ExecutableWorld(MultiAgentDialogWorld): | ||
"""A world where messages from agents can be interpreted as _actions_ in the | ||
world which result in changes in the environment (are executed). Hence a grounded | ||
simulation can be implemented rather than just dialogue between agents. | ||
""" | ||
def __init__(self, opt, agents=None, shared=None): | ||
super().__init__(opt, agents, shared) | ||
self.init_world() | ||
|
||
def init_world(self): | ||
"""An executable world class should implement this function, otherwise | ||
the actions do not do anything (and it is the same as MultiAgentDialogWorld). | ||
""" | ||
pass | ||
|
||
def execute(self, agent, act): | ||
"""An executable world class should implement this function, otherwise | ||
the actions do not do anything (and it is the same as MultiAgentDialogWorld). | ||
""" | ||
pass | ||
|
||
def observe(self, agent, act): | ||
"""An executable world class should implement this function, otherwise | ||
the observations for each agent are just the messages from other agents | ||
and not confitioned on the world at all (and it is thus the same as | ||
MultiAgentDialogWorld). """ | ||
if agent.id == act['id']: | ||
return None | ||
else: | ||
return act | ||
|
||
def parley(self): | ||
"""For each agent: act, execute and observe actions in world | ||
""" | ||
acts = self.acts | ||
for index, agent in enumerate(self.agents): | ||
# The agent acts. | ||
acts[index] = agent.act() | ||
# We execute this action in the world. | ||
self.execute(agent, acts[index]) | ||
# All agents (might) observe the results. | ||
for other_agent in self.agents: | ||
obs = self.observe(other_agent, acts[index]) | ||
if obs is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
other_agent.observe(obs) | ||
|
||
|
||
class MultiWorld(World): | ||
"""Container for a set of worlds where each world gets a turn | ||
in a round-robin fashion. The same user_agents are placed in each, | ||
|
@@ -457,7 +504,7 @@ class BatchWorld(World): | |
"""Creates a separate world for each item in the batch, sharing | ||
the parameters for each. | ||
The underlying world(s) it is batching can be either ``DialogPartnerWorld``, | ||
``MultiAgentWorld`` or ``MultiWorld``. | ||
``MultiAgentWorld``, ``ExecutableWorld`` or ``MultiWorld``. | ||
""" | ||
|
||
def __init__(self, opt, world): | ||
|
@@ -481,11 +528,20 @@ def __next__(self): | |
if self.epoch_done(): | ||
raise StopIteration() | ||
|
||
def batch_observe(self, index, batch_actions): | ||
def batch_observe(self, index, batch_actions, index_acting): | ||
batch_observations = [] | ||
for i, w in enumerate(self.worlds): | ||
agents = w.get_agents() | ||
observation = agents[index].observe(validate(batch_actions[i])) | ||
observation = None | ||
if hasattr(w, 'observe'): | ||
# The world has its own observe function, which the action | ||
# first goes through (agents receive messages via the world, | ||
# not from each other). | ||
observation = w.observe(agents[index], validate(batch_actions[i])) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this have a chance of returning |
||
else: | ||
if index == index_acting: return None # don't observe yourself talking | ||
observation = validate(batch_actions[i]) | ||
observation = agents[index].observe(observation) | ||
if observation is None: | ||
raise ValueError('Agents should return what they observed.') | ||
batch_observations.append(observation) | ||
|
@@ -523,11 +579,17 @@ def parley(self): | |
w.parley_init() | ||
|
||
for index in range(num_agents): | ||
# The agent acts. | ||
batch_act = self.batch_act(index, batch_observations[index]) | ||
# We possibly execute this action in the world. | ||
for i, w in enumerate(self.worlds): | ||
if hasattr(w, 'execute'): | ||
w.execute(w.agents[i], batch_act[i]) | ||
# All agents (might) observe the results. | ||
for other_index in range(num_agents): | ||
if index != other_index: | ||
batch_observations[other_index] = ( | ||
self.batch_observe(other_index, batch_act)) | ||
obs = self.batch_observe(other_index, batch_act, index) | ||
if obs is not None: | ||
batch_observations[other_index] = obs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if you need to override this always to make sure that you don't accidentally view a stale message? maybe better to fill it with |
||
|
||
def display(self): | ||
s = ("[--batchsize " + str(len(self.worlds)) + "--]\n") | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None or {}?