diff --git a/examples/memnn_luatorch_cpu/full_task_train.py b/examples/memnn_luatorch_cpu/full_task_train.py index d4753a16775..8b01aa518d4 100644 --- a/examples/memnn_luatorch_cpu/full_task_train.py +++ b/examples/memnn_luatorch_cpu/full_task_train.py @@ -74,12 +74,13 @@ def main(): world_dict.parley() # we need to save the dictionary to load it in memnn (sort it by freq) + dictionary.sort() dictionary.save('/tmp/dict.txt', sort=True) print('Dictionary ready, moving on to training.') opt['datatype'] = 'train' - agent = ParsedRemoteAgent(opt, {'dictionary': dictionary}) + agent = ParsedRemoteAgent(opt, {'dictionary_shared': dictionary.share()}) world_train = create_task(opt, agent) opt['datatype'] = 'valid' world_valid = create_task(opt, agent) diff --git a/parlai/agents/remote_agent/remote_agent.py b/parlai/agents/remote_agent/remote_agent.py index 3483b36cd4b..e471bcddf93 100644 --- a/parlai/agents/remote_agent/remote_agent.py +++ b/parlai/agents/remote_agent/remote_agent.py @@ -5,6 +5,7 @@ # of patent rights can be found in the PATENTS file in the same directory. from parlai.core.agents import Agent, create_agent_from_shared from parlai.core.dict import DictionaryAgent +import argparse import copy import numpy as np import json @@ -27,20 +28,21 @@ class RemoteAgentAgent(Agent): @staticmethod def add_cmdline_args(argparser): - argparser.add_argument( + remote = argparser.add_argument_group('Remote Agent Args') + remote.add_argument( '--port', default=5555, help='first port to connect to for remote agents') - argparser.add_argument( + remote.add_argument( '--remote-address', default='localhost', help='address to connect to, defaults to localhost for ' + 'connections, overriden with `*` if remote-host is set') - argparser.add_argument( + remote.add_argument( '--remote-host', action='store_true', help='whether or not this connection is the host or the client') - argparser.add_argument( + remote.add_argument( '--remote-cmd', help='command to launch paired agent, if applicable') - argparser.add_argument( + remote.add_argument( '--remote-args', help='optional arguments to pass to paired agent') @@ -140,8 +142,12 @@ class ParsedRemoteAgent(RemoteAgentAgent): @staticmethod def add_cmdline_args(argparser): - super().add_cmdline_args(argparser) - ParsedRemoteAgent.dictionary_class().add_cmdline_args(argparser) + RemoteAgentAgent.add_cmdline_args(argparser) + try: + ParsedRemoteAgent.dictionary_class().add_cmdline_args(argparser) + except argparse.ArgumentError: + # don't freak out if the dictionary has already been added + pass @staticmethod def dictionary_class(): diff --git a/parlai/core/worlds.py b/parlai/core/worlds.py index 1f985ee0d81..2c04b310c97 100644 --- a/parlai/core/worlds.py +++ b/parlai/core/worlds.py @@ -357,7 +357,7 @@ def parley(self): for index, agent in enumerate(self.agents): # The agent acts. acts[index] = agent.act() - # We execute this action in the world. + # We execute this action in the world. self.execute(agent, acts[index]) # All agents (might) observe the results. for other_agent in self.agents: