Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Added CLEVR task #233

Merged
merged 3 commits into from
Jul 24, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions parlai/tasks/clevr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
66 changes: 66 additions & 0 deletions parlai/tasks/clevr/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.

from parlai.core.dialog_teacher import DialogTeacher
from .build import build

import json
import os


def _path(opt):
build(opt)
dt = opt['datatype'].split(':')[0]

if dt == 'valid':
dt = 'val'
elif dt != 'train' and dt != 'test':
raise RuntimeError('Not valid datatype.')

prefix = os.path.join(opt['datapath'], 'CLEVR', 'CLEVR_v1.0')
questions_path = os.path.join(prefix, 'questions',
'CLEVR_' + dt + '_questions.json')
images_path = os.path.join(prefix, 'images', dt)

return questions_path, images_path


class DefaultTeacher(DialogTeacher):
"""
This version of VisDial inherits from the core Dialog Teacher, which just
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

visdial?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, removed that

requires it to define an iterator over its data `setup_data` in order to
inherit basic metrics, a `act` function, and enables
Hogwild training with shared memory with no extra work.
"""
def __init__(self, opt, shared=None):

self.datatype = opt['datatype']
data_path, self.images_path = _path(opt)
opt['datafile'] = data_path
self.id = 'clevr'

super().__init__(opt, shared)

def setup_data(self, path):
print('loading: ' + path)
with open(path) as data_file:
clevr = json.load(data_file)

image_file = None
for ques in clevr['questions']:
# episode done if first question or image changed
new_episode = ques['image_filename'] != image_file

# only show image at beginning of episode
image_file = ques['image_filename']
img_path = None
if new_episode:
img_path = os.path.join(self.images_path, image_file)

question = ques['question']
answer = [ques['answer']] if ques['split'] != 'test' else None
# TODO cands?
yield (question, answer, None, None, img_path), new_episode
33 changes: 33 additions & 0 deletions parlai/tasks/clevr/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
# Download and build the data if it does not exist.

import parlai.core.build_data as build_data
import os

from parlai.tasks.vqa_v1.build import buildImage


def build(opt):
dpath = os.path.join(opt['datapath'], 'CLEVR')
version = 'v1.0'

if not build_data.built(dpath, version_string=version):
print('[building data: ' + dpath + ']')
# An older version exists, so remove these outdated files.
if build_data.built(dpath):
build_data.remove_dir(dpath)
build_data.make_dir(dpath)

# Download the data.
fname = 'CLEVR_v1.0.zip'
url = 'https://s3-us-west-1.amazonaws.com/clevr/'

build_data.download(url + fname, dpath, fname)
build_data.untar(dpath, fname)

# Mark the data as built.
build_data.mark_done(dpath, version_string=version)