Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Seq2Seq model #96

Merged
merged 12 commits into from
May 31, 2017
Merged

Seq2Seq model #96

merged 12 commits into from
May 31, 2017

Conversation

alexholdenmiller
Copy link
Member

First draft of simple seq2seq model

@@ -0,0 +1,97 @@
# Copyright (c) 2017-present, Facebook, Inc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm confused it is called rnn_baselines (plural) but then you just have train.py and agents.py, but i guess they could be moved if we add more rnn baselines? or could just give them more specific names now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll update these! haven't yet

@@ -0,0 +1,97 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to start making a general train program now. lets do it after you get this PR in



class Seq2SeqAgent(Agent):
"""Simple agent which uses an LSTM to process incoming text observations."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

explain a bit more the architecture please (GRUs, layers etc.)

self.freq = SharedTable(self.freq)
self.tok2ind = SharedTable(self.tok2ind)
self.ind2tok = SharedTable(self.ind2tok)
# self.freq = SharedTable(self.freq)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happened here?

@@ -294,8 +306,10 @@ def txt2vec(self, text, vec_type=np.ndarray):
(self[token] for token in self.tokenize(str(text))),
np.int
)
else:
elif vec_type == list:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to make sure you are not breaking anything else that uses the dict, e.g. IR baseline etc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually nothing was using this function except remote agent, looks like

argparser.add_arg('--gpu', type=int, default=-1,
help='which GPU device to use')

def __init__(self, opt, shared=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

be cool if this model also ranked candidates, which we need for many of the parlAI tasks, i guess it doesnt yet?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this doesn't do that yet! I'd rather check it in before adding that functionality

default_downloads_path = os.path.join(self.parlai_home, 'downloads')
self.parser.add_argument(
'-t', '--task',
help='ParlAI task(s), e.g. "babi:Task1" or "babi,cbt"')
self.parser.add_argument(
'--logpath', default=default_log_path,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still here?

@alexholdenmiller alexholdenmiller merged commit 6fe19ce into master May 31, 2017
@alexholdenmiller alexholdenmiller deleted the first_learner branch May 31, 2017 17:57
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants