Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

support cmd line args for pytorch datasets #1116

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions parlai/core/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from parlai.core.agents import get_agent_module, get_task_module
from parlai.tasks.tasks import ids_to_tasks
from parlai.core.build_data import modelzoo_path
from parlai.core.pytorch_data_teacher import get_dataset_classes


def get_model_name(opt):
Expand Down Expand Up @@ -395,6 +396,17 @@ def add_task_args(self, task):
# already added
pass

def add_pyt_dataset_args(self, opt):
"""Add arguments specific to specified pytorch dataset"""
dataset_classes = get_dataset_classes(opt)
for dataset, _, _ in dataset_classes:
try:
if hasattr(dataset, 'add_cmdline_args'):
dataset.add_cmdline_args(self)
except argparse.ArgumentError:
# already added
pass

def add_image_args(self, image_mode):
"""Add additional arguments for handling images."""
try:
Expand Down Expand Up @@ -424,6 +436,16 @@ def add_extra_args(self, args=None):
if evaltask is not None:
self.add_task_args(evaltask)

# find pytorch teacher task if specified, add its specific arguments
pytorch_teacher_task = parsed.get('pytorch_teacher_task', None)
if pytorch_teacher_task is not None:
self.add_task_args(pytorch_teacher_task)

# find pytorch dataset if specified, add its specific arguments
pytorch_teacher_dataset = parsed.get('pytorch_teacher_dataset', None)
if pytorch_teacher_dataset is not None:
self.add_pyt_dataset_args(parsed)

# find which model specified if any, and add its specific arguments
model = get_model_name(parsed)
if model is not None:
Expand Down
2 changes: 1 addition & 1 deletion parlai/scripts/build_pytorch_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
are used in a flattened episode.
"""
from parlai.core.agents import create_agent
from parlai.core.params import ParlaiParser
from parlai.core.worlds import create_task
from parlai.core.utils import ProgressLogger
import copy
Expand All @@ -26,6 +25,7 @@


def setup_args():
from parlai.core.params import ParlaiParser
return ParlaiParser(True, True)


Expand Down
9 changes: 8 additions & 1 deletion parlai/tasks/coco_caption/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ def __init__(self, opt, shared=None, version='2017'):
self.num_cands = opt.get('num_cands', -1)
self.include_rest_val = opt.get('include_rest_val', False)
test_info_path, annotation_path, self.image_path = _path(opt, version)
self.test_split = opt['test_split']

if shared:
# another instance was set up already, just reference its data
if 'annotation' in shared:
Expand All @@ -276,7 +278,6 @@ def __init__(self, opt, shared=None, version='2017'):
# need to set up data from scratch
self._setup_data(test_info_path, annotation_path, opt)
self.image_loader = ImageLoader(opt)

self.reset()

@staticmethod
Expand All @@ -294,6 +295,10 @@ def add_cmdline_args(argparser):
agent.add_argument('--include_rest_val', type='bool',
default=False,
help='Include unused validation images in training')
agent.add_argument('--test-split', type=int, default=-1,
choices=[-1, 0, 1, 2, 3, 4],
help='Which 1k image split of dataset to use for candidates'
'if -1, use all 5k test images')

def reset(self):
super().reset() # call parent reset so other fields can be set up
Expand Down Expand Up @@ -420,6 +425,8 @@ def _setup_data(self, test_info_path, annotation_path, opt):
self.cands = [l for d in self.annotation for l in [s['raw'] for s in d['sentences']]]
else:
self.annotation = [d for d in raw_data if d['split'] == 'test']
if self.test_split != -1:
self.annotation = self.annotation[self.test_split*1000:(self.test_split+1)*1000]
self.cands = [l for d in self.annotation for l in [s['raw'] for s in d['sentences']]]
else:
if not self.datatype.startswith('test'):
Expand Down
4 changes: 4 additions & 0 deletions parlai/tasks/flickr30k/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def __init__(self, opt, shared=None):
self._setup_data(data_path, opt.get('unittest', False))
self.dict_agent = DictionaryAgent(opt)

@staticmethod
def add_cmdline_args(argparser):
DefaultTeacher.add_cmdline_args(argparser)

def __getitem__(self, index):
cap = self.data[index]
image_id = int(cap['filename'].replace('.jpg', ''))
Expand Down