Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

remote agent fixes, switch model param to default None #166

Merged
merged 4 commits into from
Jun 26, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/display_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def main():
parser = ParlaiParser(True, True)
parser.add_argument('-n', '--num-examples', default=10)
opt = parser.parse_args()

# Create model and assign it to the specified task
agent = create_agent(opt)
world = create_task(opt, agent)
Expand Down
34 changes: 16 additions & 18 deletions examples/memnn_luatorch_cpu/full_task_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,24 +56,22 @@ def main():
if not opt.get('dict_file'):
# build dictionary since we didn't load it
ordered_opt = copy.deepcopy(opt)
for datatype in ['train:ordered', 'valid']:
# we use train and valid sets to build dictionary
ordered_opt['datatype'] = datatype
ordered_opt['numthreads'] = 1
world_dict = create_task(ordered_opt, dictionary)

print('Dictionary building on {} data.'.format(datatype))
cnt = 0
# pass examples to dictionary
for _ in world_dict:
cnt += 1
if cnt > opt['dict_max_exs'] and opt['dict_max_exs'] > 0:
print('Processed {} exs, moving on.'.format(
opt['dict_max_exs']))
# don't wait too long...
break

world_dict.parley()
ordered_opt['datatype'] = 'train:ordered'
ordered_opt['numthreads'] = 1
world_dict = create_task(ordered_opt, dictionary)

print('Dictionary building on training data.')
cnt = 0
# pass examples to dictionary
for _ in world_dict:
cnt += 1
if cnt > opt['dict_max_exs'] and opt['dict_max_exs'] > 0:
print('Processed {} exs, moving on.'.format(
opt['dict_max_exs']))
# don't wait too long...
break

world_dict.parley()

# we need to save the dictionary to load it in memnn (sort it by freq)
dictionary.save('/tmp/dict.txt', sort=True)
Expand Down
52 changes: 52 additions & 0 deletions examples/remote.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2004-present Facebook. All Rights Reserved.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
"""Simple loop which sets up a remote connection. The paired agent can run this
same loop but with the '--remote-host' flag set. For example...

Agent 1:
python remote.py

Agent 2:
python remote.py --remote-host --remote-address '*'

Now humans connected to each agent can communicate over that thread.
If you would like to use a model instead, merely set the '--model' flag:

Either Agent (or both):
python remote.py -m seq2seq
"""

from parlai.agents.remote_agent.remote_agent import RemoteAgentAgent
from parlai.agents.local_human.local_human import LocalHumanAgent
from parlai.core.params import ParlaiParser
from parlai.core.agents import create_agent
from parlai.core.worlds import DialogPartnerWorld

import random

def main():
random.seed(42)

# Get command line arguments
parser = ParlaiParser(True, True)
RemoteAgentAgent.add_cmdline_args(parser)
opt = parser.parse_args()

if opt.get('model'):
local = create_agent(opt)
else:
local = LocalHumanAgent(opt)
remote = RemoteAgentAgent(opt)
agents = [local, remote] if not opt['remote_host'] else [remote, local]
world = DialogPartnerWorld(opt, agents)

# Talk to the remote agent
with world:
while True:
world.parley()

if __name__ == '__main__':
main()
69 changes: 44 additions & 25 deletions parlai/agents/remote_agent/remote_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,25 @@
import zmq


class RemoteAgent(Agent):
class RemoteAgentAgent(Agent):
"""Agent which connects over ZMQ to a paired agent. The other agent is
launched using the command line options set via `add_cmdline_args`."""

@staticmethod
def add_cmdline_args(argparser):
argparser.add_arg(
argparser.add_argument(
'--port', default=5555,
help='first port to connect to for remote agents')
argparser.add_arg(
'--remote-cmd', required=True,
help='command to launch paired agent')
argparser.add_arg(
argparser.add_argument(
'--remote-address', default='localhost',
help='address to connect to')
argparser.add_argument(
'--remote-host', action='store_true',
help='whether or not this connection is the host or the client')
argparser.add_argument(
'--remote-cmd',
help='command to launch paired agent, if applicable')
argparser.add_argument(
'--remote-args',
help='optional arguments to pass to paired agent')

Expand All @@ -35,42 +41,55 @@ def __init__(self, opt, shared=None):
the multithreading effectively in their environment. (We don't run
subprocess.Popen for each thread.)
"""

self.opt = copy.deepcopy(opt)
self.address = opt['remote_address']
self.socket_type = zmq.REP if opt['remote_host'] else zmq.REQ
if shared and 'port' in shared:
# for multithreading, use specified port
self.port = shared['port']
self.opt = copy.deepcopy(shared['opt'])
else:
if 'port' in opt:
self.port = opt['port']
else:
raise RuntimeError('You need to run RemoteAgent.' +
'add_cmdline_args(argparser) before ' +
'calling this class to set up options.')
self.process = subprocess.Popen(
'{cmd} {port} {numthreads} {args}'.format(
cmd=opt['remote_cmd'], port=opt['port'],
numthreads=opt['numthreads'],
args=opt.get('remote_args', '')
).split()
)
self.opt = copy.deepcopy(opt)
if opt.get('remote_cmd'):
# if available, command to launch partner instance, passing on
# some shared parameters from ParlAI
# useful especially if "remote" agent is running locally, e.g.
# in a different language than python
self.process = subprocess.Popen(
'{cmd} {port} {numthreads} {args}'.format(
cmd=opt['remote_cmd'], port=opt['port'],
numthreads=opt['numthreads'],
args=opt.get('remote_args', '')
).split()
)
self.connect()
super().__init__(opt, shared)

def connect(self):
"""Connect to ZMQ socket as client. Requires package zmq."""
context = zmq.Context()
self.socket = context.socket(zmq.REQ)
self.socket = context.socket(self.socket_type)
self.socket.setsockopt(zmq.LINGER, 1)
self.socket.connect('tcp://localhost:{0}'.format(self.port))
print('python thread connected to ' +
'tcp://localhost:{0}'.format(self.port))
host = 'tcp://{}:{}'.format(self.address, self.port)
if self.socket_type == zmq.REP:
self.socket.bind(host)
else:
self.socket.connect(host)
print('python thread connected to ' + host)

def act(self):
"""Send message to paired agent listening over zmq."""
if 'image' in self.observation:
# can't json serialize images
self.observation.pop('image', None)
text = json.dumps(self.observation)
self.socket.send_unicode(text)
if self.observation is not None:
if 'image' in self.observation:
# can't json serialize images
self.observation.pop('image', None)
text = json.dumps(self.observation)
self.socket.send_unicode(text)
reply = self.socket.recv_unicode()
return json.loads(reply)

Expand Down Expand Up @@ -106,7 +125,7 @@ def shutdown(self):
self.process.kill()


class ParsedRemoteAgent(RemoteAgent):
class ParsedRemoteAgent(RemoteAgentAgent):
"""Same as the regular remote agent, except that this agent converts all
text into vectors using its dictionary before sending them.
"""
Expand Down
2 changes: 2 additions & 0 deletions parlai/agents/seq2seq/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ def __init__(self, opt, shared=None):
print('[ Using CUDA ]')
torch.cuda.set_device(opt['gpu'])
if not shared:
# don't enter this loop for shared (ie batch) instantiations
self.dict = DictionaryAgent(opt)
self.id = 'Seq2Seq'
hsz = opt['hiddensize']
self.EOS = self.dict.eos_token
self.observation = {'text': self.EOS, episode_done = True}
self.EOS_TENSOR = torch.LongTensor(self.dict.parse(self.EOS))
self.hidden_size = hsz
self.num_layers = opt['numlayers']
Expand Down
7 changes: 5 additions & 2 deletions parlai/core/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,11 @@ def create_agent(opt):
(i.e. the path followed by the class name) or else just ``ir_baseline`` which
assumes the path above, and a class name suffixed with 'Agent'.
"""
model_class = get_agent_module(opt['model'])
return model_class(opt)
if opt.get('model'):
model_class = get_agent_module(opt['model'])
return model_class(opt)
else:
raise RuntimeError('Need to set `model` argument to use create_agent.')

# Helper functions to create agent/agents given shared parameters
# returned from agent.share(). Useful for parallelism, sharing params, etc.
Expand Down
2 changes: 1 addition & 1 deletion parlai/core/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def add_parlai_args(self):
def add_model_args(self, args=None):
model_args = self.add_argument_group('ParlAI Model Arguments')
model_args.add_argument(
'-m', '--model', default='repeat_label',
'-m', '--model', default=None,
help='the model class name, should match parlai/agents/<model>')
model_args.add_argument(
'-mf', '--model-file', default=None,
Expand Down
3 changes: 3 additions & 0 deletions parlai/core/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,9 @@ def create_task(opt, user_agents):
see ``parlai/tasks/tasks.py`` and see ``parlai/tasks/task_list.py``
for list of tasks.
"""
if not opt.get('task'):
raise RuntimeError('No task specified. Please select a task with ' +
'--task {task_name}.')
if type(user_agents) != list:
user_agents = [user_agents]

Expand Down