diff --git a/examples/display_data.py b/examples/display_data.py index aab08dda37f..677359ca169 100644 --- a/examples/display_data.py +++ b/examples/display_data.py @@ -17,15 +17,7 @@ import random - -def main(): - random.seed(42) - - # Get command line arguments - parser = ParlaiParser() - parser.add_argument('-n', '--num-examples', default=10, type=int) - opt = parser.parse_args() - +def display_data(opt): # create repeat label agent and assign it to the specified task agent = RepeatLabelAgent(opt) world = create_task(opt, agent) @@ -40,5 +32,15 @@ def main(): break +def main(): + random.seed(42) + + # Get command line arguments + parser = ParlaiParser() + parser.add_argument('-n', '--num-examples', default=10, type=int) + opt = parser.parse_args() + + display_data(opt) + if __name__ == '__main__': main() diff --git a/examples/eval_model.py b/examples/eval_model.py index 95a02c5ebd6..f79ca2f3b1c 100644 --- a/examples/eval_model.py +++ b/examples/eval_model.py @@ -17,24 +17,18 @@ import random -def main(): - random.seed(42) - # Get command line arguments - parser = ParlaiParser(True, True) - parser.add_argument('-n', '--num-examples', default=100000000) - parser.add_argument('-d', '--display-examples', type='bool', default=False) - parser.set_defaults(datatype='valid') - opt = parser.parse_args(print_args=False) +def eval_model(opt, parser, printargs=True): # Create model and assign it to the specified task agent = create_agent(opt) world = create_task(opt, agent) # Show arguments after loading model parser.opt = agent.opt - parser.print_args() + if (printargs): + parser.print_args() # Show some example dialogs: - for k in range(int(opt['num_examples'])): + for _ in range(int(opt['num_examples'])): world.parley() print("---") if opt['display_examples']: @@ -45,5 +39,17 @@ def main(): break world.shutdown() +def main(): + random.seed(42) + + # Get command line arguments + parser = ParlaiParser(True, True) + parser.add_argument('-n', '--num-examples', default=100000000) + parser.add_argument('-d', '--display-examples', type='bool', default=False) + parser.set_defaults(datatype='valid') + opt = parser.parse_args(print_args=False) + + eval_model(opt, parser) + if __name__ == '__main__': main() diff --git a/parlai/agents/repeat_label/repeat_label.py b/parlai/agents/repeat_label/repeat_label.py index bce9c23b226..f83e3e15f6d 100644 --- a/parlai/agents/repeat_label/repeat_label.py +++ b/parlai/agents/repeat_label/repeat_label.py @@ -36,9 +36,8 @@ def act(self): return {'text': 'Nothing to repeat yet.'} reply = {} reply['id'] = self.getID() - if ('labels' in obs and obs['labels'] is not None - and len(obs['labels']) > 0): - labels = obs['labels'] + labels = obs.get('labels', obs.get('eval_labels', None)) + if labels: if random.random() >= self.cantAnswerPercent: if self.returnOneRandomAnswer: reply['text'] = labels[random.randrange(len(labels))] diff --git a/tests/run_tests_short.sh b/tests/run_tests_short.sh index ecbad8af794..64b5ef2e4cb 100755 --- a/tests/run_tests_short.sh +++ b/tests/run_tests_short.sh @@ -13,3 +13,5 @@ python3 test_dict.py python3 test_tasklist.py python3 test_threadutils.py python3 test_utils.py +python3 test_display_data.py +python3 test_eval_model.py diff --git a/tests/test_display_data.py b/tests/test_display_data.py new file mode 100644 index 00000000000..3703366bdcc --- /dev/null +++ b/tests/test_display_data.py @@ -0,0 +1,52 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. An additional grant +# of patent rights can be found in the PATENTS file in the same directory. +from examples.display_data import display_data +from parlai.core.params import ParlaiParser + +import sys +import unittest + + +class TestDisplayData(unittest.TestCase): + """Basic tests on the display_data.py example.""" + + args = [ + '--task', 'babi:task1k:1', + ] + parser = ParlaiParser() + opt = parser.parse_args(args, print_args=False) + opt['num_examples'] = 1 + + def test_output(self): + """Does display_data reach the end of the loop?""" + + class display_output(object): + def __init__(self): + self.data = [] + + def write(self, s): + self.data.append(s) + + def __str__(self): + return "".join(self.data) + + old_out = sys.stdout + output = display_output() + try: + sys.stdout = output + display_data(self.opt) + finally: + # restore sys.stdout + sys.stdout = old_out + + str_output = str(output) + self.assertTrue(len(str_output) > 0, "Output is empty") + self.assertTrue("[babi:task1k:1]:" in str_output, + "Babi task did not print") + self.assertTrue("~~" in str_output, "Example output did not complete") + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_eval_model.py b/tests/test_eval_model.py new file mode 100644 index 00000000000..94c32bb69fc --- /dev/null +++ b/tests/test_eval_model.py @@ -0,0 +1,65 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. An additional grant +# of patent rights can be found in the PATENTS file in the same directory. +from examples.eval_model import eval_model +from parlai.core.params import ParlaiParser + +import ast +import unittest +import sys + + +class TestEvalModel(unittest.TestCase): + """Basic tests on the eval_model.py example.""" + + args = [ + '--task', '#moviedd-reddit', + '--datatype', 'valid', + ] + + parser = ParlaiParser() + parser.set_defaults(datatype='valid') + opt = parser.parse_args(args, print_args=False) + opt['model'] = 'repeat_label' + opt['num_examples'] = 5 + opt['display_examples'] = False + + def test_output(self): + """Test output of running eval_model""" + class display_output(object): + def __init__(self): + self.data = [] + + def write(self, s): + self.data.append(s) + + def __str__(self): + return "".join(self.data) + + old_out = sys.stdout + output = display_output() + try: + sys.stdout = output + eval_model(self.opt, self.parser, printargs=False) + finally: + # restore sys.stdout + sys.stdout = old_out + + str_output = str(output) + self.assertTrue(len(str_output) > 0, "Output is empty") + + # decode the output + scores = str_output.split("\n---\n") + for i in range(1, len(scores)): + score = ast.literal_eval(scores[i]) + # check totals + self.assertTrue(score['total'] == i, + "Total is incorrect") + # accuracy should be one + self.assertTrue(score['accuracy'] == 1, + "accuracy != 1") + +if __name__ == '__main__': + unittest.main()