diff --git a/parlai/core/params.py b/parlai/core/params.py index 2ff5d7606a0..6229b4500a4 100644 --- a/parlai/core/params.py +++ b/parlai/core/params.py @@ -16,6 +16,7 @@ from parlai.core.agents import get_agent_module, get_task_module from parlai.tasks.tasks import ids_to_tasks from parlai.core.build_data import modelzoo_path +from parlai.core.pytorch_data_teacher import get_dataset_classes def get_model_name(opt): @@ -395,6 +396,17 @@ def add_task_args(self, task): # already added pass + def add_pyt_dataset_args(self, opt): + """Add arguments specific to specified pytorch dataset""" + dataset_classes = get_dataset_classes(opt) + for dataset, _, _ in dataset_classes: + try: + if hasattr(dataset, 'add_cmdline_args'): + dataset.add_cmdline_args(self) + except argparse.ArgumentError: + # already added + pass + def add_image_args(self, image_mode): """Add additional arguments for handling images.""" try: @@ -424,6 +436,16 @@ def add_extra_args(self, args=None): if evaltask is not None: self.add_task_args(evaltask) + # find pytorch teacher task if specified, add its specific arguments + pytorch_teacher_task = parsed.get('pytorch_teacher_task', None) + if pytorch_teacher_task is not None: + self.add_task_args(pytorch_teacher_task) + + # find pytorch dataset if specified, add its specific arguments + pytorch_teacher_dataset = parsed.get('pytorch_teacher_dataset', None) + if pytorch_teacher_dataset is not None: + self.add_pyt_dataset_args(parsed) + # find which model specified if any, and add its specific arguments model = get_model_name(parsed) if model is not None: diff --git a/parlai/scripts/build_pytorch_data.py b/parlai/scripts/build_pytorch_data.py index d56ace9d9fe..4ba72d9e22d 100644 --- a/parlai/scripts/build_pytorch_data.py +++ b/parlai/scripts/build_pytorch_data.py @@ -13,7 +13,6 @@ are used in a flattened episode. """ from parlai.core.agents import create_agent -from parlai.core.params import ParlaiParser from parlai.core.worlds import create_task from parlai.core.utils import ProgressLogger import copy @@ -26,6 +25,7 @@ def setup_args(): + from parlai.core.params import ParlaiParser return ParlaiParser(True, True) diff --git a/parlai/tasks/coco_caption/agents.py b/parlai/tasks/coco_caption/agents.py index 08341ba1e0b..e53cdcdc794 100644 --- a/parlai/tasks/coco_caption/agents.py +++ b/parlai/tasks/coco_caption/agents.py @@ -265,6 +265,8 @@ def __init__(self, opt, shared=None, version='2017'): self.num_cands = opt.get('num_cands', -1) self.include_rest_val = opt.get('include_rest_val', False) test_info_path, annotation_path, self.image_path = _path(opt, version) + self.test_split = opt['test_split'] + if shared: # another instance was set up already, just reference its data if 'annotation' in shared: @@ -276,7 +278,6 @@ def __init__(self, opt, shared=None, version='2017'): # need to set up data from scratch self._setup_data(test_info_path, annotation_path, opt) self.image_loader = ImageLoader(opt) - self.reset() @staticmethod @@ -294,6 +295,10 @@ def add_cmdline_args(argparser): agent.add_argument('--include_rest_val', type='bool', default=False, help='Include unused validation images in training') + agent.add_argument('--test-split', type=int, default=-1, + choices=[-1, 0, 1, 2, 3, 4], + help='Which 1k image split of dataset to use for candidates' + 'if -1, use all 5k test images') def reset(self): super().reset() # call parent reset so other fields can be set up @@ -420,6 +425,8 @@ def _setup_data(self, test_info_path, annotation_path, opt): self.cands = [l for d in self.annotation for l in [s['raw'] for s in d['sentences']]] else: self.annotation = [d for d in raw_data if d['split'] == 'test'] + if self.test_split != -1: + self.annotation = self.annotation[self.test_split*1000:(self.test_split+1)*1000] self.cands = [l for d in self.annotation for l in [s['raw'] for s in d['sentences']]] else: if not self.datatype.startswith('test'): diff --git a/parlai/tasks/flickr30k/agents.py b/parlai/tasks/flickr30k/agents.py index fad375e7568..410afbdf339 100644 --- a/parlai/tasks/flickr30k/agents.py +++ b/parlai/tasks/flickr30k/agents.py @@ -45,6 +45,10 @@ def __init__(self, opt, shared=None): self._setup_data(data_path, opt.get('unittest', False)) self.dict_agent = DictionaryAgent(opt) + @staticmethod + def add_cmdline_args(argparser): + DefaultTeacher.add_cmdline_args(argparser) + def __getitem__(self, index): cap = self.data[index] image_id = int(cap['filename'].replace('.jpg', ''))