Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

remote agent fixes, switch model param to default None #166

Merged
merged 4 commits into from
Jun 26, 2017
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/display_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def main():
parser = ParlaiParser(True, True)
parser.add_argument('-n', '--num-examples', default=10)
opt = parser.parse_args()

# Create model and assign it to the specified task
agent = create_agent(opt)
world = create_task(opt, agent)
Expand Down
52 changes: 52 additions & 0 deletions examples/remote.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2004-present Facebook. All Rights Reserved.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
"""Simple loop which sets up a remote connection. The paired agent can run this
same loop but with the '--remote-host' flag set. For example...

Agent 1:
python remote.py

Agent 2:
python remote.py --remote-host --remote-address '*'

Now humans connected to each agent can communicate over that thread.
If you would like to use a model instead, merely set the '--model' flag:

Either Agent (or both):
python remote.py -m seq2seq
"""

from parlai.agents.remote_agent.remote_agent import RemoteAgentAgent
from parlai.agents.local_human.local_human import LocalHumanAgent
from parlai.core.params import ParlaiParser
from parlai.core.agents import create_agent
from parlai.core.worlds import DialogPartnerWorld

import random

def main():
random.seed(42)

# Get command line arguments
parser = ParlaiParser(True, True)
RemoteAgentAgent.add_cmdline_args(parser)
opt = parser.parse_args()

if opt.get('model'):
local = create_agent(opt)
else:
local = LocalHumanAgent(opt)
remote = RemoteAgentAgent(opt)
agents = [local, remote] if not opt['remote_host'] else [remote, local]
world = DialogPartnerWorld(opt, agents)

# Talk to the remote agent
with world:
while True:
world.parley()

if __name__ == '__main__':
main()
68 changes: 44 additions & 24 deletions parlai/agents/remote_agent/remote_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,25 @@
import zmq


class RemoteAgent(Agent):
class RemoteAgentAgent(Agent):
"""Agent which connects over ZMQ to a paired agent. The other agent is
launched using the command line options set via `add_cmdline_args`."""

@staticmethod
def add_cmdline_args(argparser):
argparser.add_arg(
argparser.add_argument(
'--port', default=5555,
help='first port to connect to for remote agents')
argparser.add_arg(
'--remote-cmd', required=True,
help='command to launch paired agent')
argparser.add_arg(
argparser.add_argument(
'--remote-address', default='localhost',
help='address to connect to')
argparser.add_argument(
'--remote-host', action='store_true',
help='whether or not this connection is the host or the client')
argparser.add_argument(
'--remote-cmd',
help='command to launch paired agent, if applicable')
argparser.add_argument(
'--remote-args',
help='optional arguments to pass to paired agent')

Expand All @@ -36,41 +42,55 @@ def __init__(self, opt, shared=None):
subprocess.Popen for each thread.)
"""
if shared and 'port' in shared:
# for multithreading
self.port = shared['port']
self.address = shared['address']
self.opt = copy.deepcopy(shared['opt'])
else:
self.opt = copy.deepcopy(opt)
if 'port' in opt:
self.port = opt['port']
self.address = opt['remote_address']
self.socket_type = zmq.REP if opt['remote_host'] else zmq.REQ
else:
raise RuntimeError('You need to run RemoteAgent.' +
'add_cmdline_args(argparser) before ' +
'calling this class to set up options.')
self.process = subprocess.Popen(
'{cmd} {port} {numthreads} {args}'.format(
cmd=opt['remote_cmd'], port=opt['port'],
numthreads=opt['numthreads'],
args=opt.get('remote_args', '')
).split()
)
self.opt = copy.deepcopy(opt)
if opt.get('remote_cmd'):
# if available, command to launch partner instance, passing on
# some shared parameters from ParlAI
# useful especially if "remote" agent is running locally, e.g.
# in a different language than python
self.process = subprocess.Popen(
'{cmd} {port} {numthreads} {args}'.format(
cmd=opt['remote_cmd'], port=opt['port'],
numthreads=opt['numthreads'],
args=opt.get('remote_args', '')
).split()
)
self.connect()
super().__init__(opt, shared)

def connect(self):
"""Connect to ZMQ socket as client. Requires package zmq."""
context = zmq.Context()
self.socket = context.socket(zmq.REQ)
self.socket = context.socket(self.socket_type)
self.socket.setsockopt(zmq.LINGER, 1)
self.socket.connect('tcp://localhost:{0}'.format(self.port))
print('python thread connected to ' +
'tcp://localhost:{0}'.format(self.port))
host = 'tcp://{}:{}'.format(self.address, self.port)
if self.socket_type == zmq.REP:
self.socket.bind(host)
else:
self.socket.connect(host)
print('python thread connected to ' + host)

def act(self):
"""Send message to paired agent listening over zmq."""
if 'image' in self.observation:
# can't json serialize images
self.observation.pop('image', None)
text = json.dumps(self.observation)
self.socket.send_unicode(text)
if self.observation is not None:
if 'image' in self.observation:
# can't json serialize images
self.observation.pop('image', None)
text = json.dumps(self.observation)
self.socket.send_unicode(text)
reply = self.socket.recv_unicode()
return json.loads(reply)

Expand Down Expand Up @@ -106,7 +126,7 @@ def shutdown(self):
self.process.kill()


class ParsedRemoteAgent(RemoteAgent):
class ParsedRemoteAgent(RemoteAgentAgent):
"""Same as the regular remote agent, except that this agent converts all
text into vectors using its dictionary before sending them.
"""
Expand Down
2 changes: 2 additions & 0 deletions parlai/agents/seq2seq/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ def __init__(self, opt, shared=None):
print('[ Using CUDA ]')
torch.cuda.set_device(opt['gpu'])
if not shared:
# don't enter this loop for shared (ie batch) instantiations
self.dict = DictionaryAgent(opt)
self.id = 'Seq2Seq'
hsz = opt['hiddensize']
self.EOS = self.dict.eos_token
self.observation = {'text': self.EOS, episode_done = True}
self.EOS_TENSOR = torch.LongTensor(self.dict.parse(self.EOS))
self.hidden_size = hsz
self.num_layers = opt['numlayers']
Expand Down
7 changes: 5 additions & 2 deletions parlai/core/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,11 @@ def create_agent(opt):
(i.e. the path followed by the class name) or else just ``ir_baseline`` which
assumes the path above, and a class name suffixed with 'Agent'.
"""
model_class = get_agent_module(opt['model'])
return model_class(opt)
if opt.get('model'):
model_class = get_agent_module(opt['model'])
return model_class(opt)
else:
raise RuntimeError('Need to set `model` argument to use create_agent.')

# Helper functions to create agent/agents given shared parameters
# returned from agent.share(). Useful for parallelism, sharing params, etc.
Expand Down
2 changes: 1 addition & 1 deletion parlai/core/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def add_parlai_args(self):
def add_model_args(self, args=None):
model_args = self.add_argument_group('ParlAI Model Arguments')
model_args.add_argument(
'-m', '--model', default='repeat_label',
'-m', '--model', default=None,
help='the model class name, should match parlai/agents/<model>')
model_args.add_argument(
'-mf', '--model-file', default=None,
Expand Down