-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Conversation
@@ -0,0 +1,97 @@ | |||
# Copyright (c) 2017-present, Facebook, Inc. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm confused it is called rnn_baselines (plural) but then you just have train.py and agents.py, but i guess they could be moved if we add more rnn baselines? or could just give them more specific names now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll update these! haven't yet
@@ -0,0 +1,97 @@ | |||
# Copyright (c) 2017-present, Facebook, Inc. | |||
# All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to start making a general train program now. lets do it after you get this PR in
|
||
|
||
class Seq2SeqAgent(Agent): | ||
"""Simple agent which uses an LSTM to process incoming text observations.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
explain a bit more the architecture please (GRUs, layers etc.)
parlai/core/dict.py
Outdated
self.freq = SharedTable(self.freq) | ||
self.tok2ind = SharedTable(self.tok2ind) | ||
self.ind2tok = SharedTable(self.ind2tok) | ||
# self.freq = SharedTable(self.freq) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happened here?
parlai/core/dict.py
Outdated
@@ -294,8 +306,10 @@ def txt2vec(self, text, vec_type=np.ndarray): | |||
(self[token] for token in self.tokenize(str(text))), | |||
np.int | |||
) | |||
else: | |||
elif vec_type == list: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to make sure you are not breaking anything else that uses the dict, e.g. IR baseline etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually nothing was using this function except remote agent, looks like
argparser.add_arg('--gpu', type=int, default=-1, | ||
help='which GPU device to use') | ||
|
||
def __init__(self, opt, shared=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
be cool if this model also ranked candidates, which we need for many of the parlAI tasks, i guess it doesnt yet?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah this doesn't do that yet! I'd rather check it in before adding that functionality
parlai/core/params.py
Outdated
default_downloads_path = os.path.join(self.parlai_home, 'downloads') | ||
self.parser.add_argument( | ||
'-t', '--task', | ||
help='ParlAI task(s), e.g. "babi:Task1" or "babi,cbt"') | ||
self.parser.add_argument( | ||
'--logpath', default=default_log_path, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still here?
First draft of simple seq2seq model