Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

unit tests for display_data and eval_model #460

Merged
merged 2 commits into from
Dec 19, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions examples/display_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,7 @@

import random


def main():
random.seed(42)

# Get command line arguments
parser = ParlaiParser()
parser.add_argument('-n', '--num-examples', default=10, type=int)
opt = parser.parse_args()

def display_data(opt):
# create repeat label agent and assign it to the specified task
agent = RepeatLabelAgent(opt)
world = create_task(opt, agent)
Expand All @@ -40,5 +32,15 @@ def main():
break


def main():
random.seed(42)

# Get command line arguments
parser = ParlaiParser()
parser.add_argument('-n', '--num-examples', default=10, type=int)
opt = parser.parse_args()

display_data(opt)

if __name__ == '__main__':
main()
26 changes: 16 additions & 10 deletions examples/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,18 @@

import random

def main():
random.seed(42)

# Get command line arguments
parser = ParlaiParser(True, True)
parser.add_argument('-n', '--num-examples', default=100000000)
parser.add_argument('-d', '--display-examples', type='bool', default=False)
parser.set_defaults(datatype='valid')
opt = parser.parse_args(print_args=False)
def eval_model(opt, parser, printargs=True):
# Create model and assign it to the specified task
agent = create_agent(opt)
world = create_task(opt, agent)
# Show arguments after loading model
parser.opt = agent.opt
parser.print_args()
if (printargs):
parser.print_args()

# Show some example dialogs:
for k in range(int(opt['num_examples'])):
for _ in range(int(opt['num_examples'])):
world.parley()
print("---")
if opt['display_examples']:
Expand All @@ -45,5 +39,17 @@ def main():
break
world.shutdown()

def main():
random.seed(42)

# Get command line arguments
parser = ParlaiParser(True, True)
parser.add_argument('-n', '--num-examples', default=100000000)
parser.add_argument('-d', '--display-examples', type='bool', default=False)
parser.set_defaults(datatype='valid')
opt = parser.parse_args(print_args=False)

eval_model(opt, parser)

if __name__ == '__main__':
main()
5 changes: 2 additions & 3 deletions parlai/agents/repeat_label/repeat_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@ def act(self):
return {'text': 'Nothing to repeat yet.'}
reply = {}
reply['id'] = self.getID()
if ('labels' in obs and obs['labels'] is not None
and len(obs['labels']) > 0):
labels = obs['labels']
labels = obs.get('labels', obs.get('eval_labels', None))
if labels:
if random.random() >= self.cantAnswerPercent:
if self.returnOneRandomAnswer:
reply['text'] = labels[random.randrange(len(labels))]
Expand Down
2 changes: 2 additions & 0 deletions tests/run_tests_short.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ python3 test_dict.py
python3 test_tasklist.py
python3 test_threadutils.py
python3 test_utils.py
python3 test_display_data.py
python3 test_eval_model.py
52 changes: 52 additions & 0 deletions tests/test_display_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
from examples.display_data import display_data
from parlai.core.params import ParlaiParser

import sys
import unittest


class TestDisplayData(unittest.TestCase):
"""Basic tests on the display_data.py example."""

args = [
'--task', 'babi:task1k:1',
]
parser = ParlaiParser()
opt = parser.parse_args(args, print_args=False)
opt['num_examples'] = 1

def test_output(self):
"""Does display_data reach the end of the loop?"""

class display_output(object):
def __init__(self):
self.data = []

def write(self, s):
self.data.append(s)

def __str__(self):
return "".join(self.data)

old_out = sys.stdout
output = display_output()
try:
sys.stdout = output
display_data(self.opt)
finally:
# restore sys.stdout
sys.stdout = old_out

str_output = str(output)
self.assertTrue(len(str_output) > 0, "Output is empty")
self.assertTrue("[babi:task1k:1]:" in str_output,
"Babi task did not print")
self.assertTrue("~~" in str_output, "Example output did not complete")

if __name__ == '__main__':
unittest.main()
65 changes: 65 additions & 0 deletions tests/test_eval_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
from examples.eval_model import eval_model
from parlai.core.params import ParlaiParser

import ast
import unittest
import sys


class TestEvalModel(unittest.TestCase):
"""Basic tests on the eval_model.py example."""

args = [
'--task', '#moviedd-reddit',
'--datatype', 'valid',
]

parser = ParlaiParser()
parser.set_defaults(datatype='valid')
opt = parser.parse_args(args, print_args=False)
opt['model'] = 'repeat_label'
opt['num_examples'] = 5
opt['display_examples'] = False

def test_output(self):
"""Test output of running eval_model"""
class display_output(object):
def __init__(self):
self.data = []

def write(self, s):
self.data.append(s)

def __str__(self):
return "".join(self.data)

old_out = sys.stdout
output = display_output()
try:
sys.stdout = output
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sneaky

eval_model(self.opt, self.parser, printargs=False)
finally:
# restore sys.stdout
sys.stdout = old_out

str_output = str(output)
self.assertTrue(len(str_output) > 0, "Output is empty")

# decode the output
scores = str_output.split("\n---\n")
for i in range(1, len(scores)):
score = ast.literal_eval(scores[i])
# check totals
self.assertTrue(score['total'] == i,
"Total is incorrect")
# accuracy should be one
self.assertTrue(score['accuracy'] == 1,
"accuracy != 1")

if __name__ == '__main__':
unittest.main()