Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Added CLEVR task #233

Merged
merged 3 commits into from
Jul 24, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions parlai/tasks/clevr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
59 changes: 59 additions & 0 deletions parlai/tasks/clevr/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.

from parlai.core.dialog_teacher import DialogTeacher
from .build import build

import json
import os


def _path(opt):
build(opt)
dt = opt['datatype'].split(':')[0]

if dt == 'valid':
dt = 'val'
elif dt != 'train' and dt != 'test':
raise RuntimeError('Not valid datatype.')

prefix = os.path.join(opt['datapath'], 'CLEVR', 'CLEVR_v1.0')
questions_path = os.path.join(prefix, 'questions',
'CLEVR_' + dt + '_questions.json')
images_path = os.path.join(prefix, 'images', dt)

return questions_path, images_path


class DefaultTeacher(DialogTeacher):
def __init__(self, opt, shared=None):
self.datatype = opt['datatype']
data_path, self.images_path = _path(opt)
opt['datafile'] = data_path
self.id = 'clevr'

super().__init__(opt, shared)

def setup_data(self, path):
print('loading: ' + path)
with open(path) as data_file:
clevr = json.load(data_file)

image_file = None
for ques in clevr['questions']:
# episode done if first question or image changed
new_episode = ques['image_filename'] != image_file

# only show image at beginning of episode
image_file = ques['image_filename']
img_path = None
if new_episode:
img_path = os.path.join(self.images_path, image_file)

question = ques['question']
answer = [ques['answer']] if ques['split'] != 'test' else None
# TODO cands?
yield (question, answer, None, None, img_path), new_episode
33 changes: 33 additions & 0 deletions parlai/tasks/clevr/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
# Download and build the data if it does not exist.

import parlai.core.build_data as build_data
import os

from parlai.tasks.vqa_v1.build import buildImage


def build(opt):
dpath = os.path.join(opt['datapath'], 'CLEVR')
version = 'v1.0'

if not build_data.built(dpath, version_string=version):
print('[building data: ' + dpath + ']')
# An older version exists, so remove these outdated files.
if build_data.built(dpath):
build_data.remove_dir(dpath)
build_data.make_dir(dpath)

# Download the data.
fname = 'CLEVR_v1.0.zip'
url = 'https://s3-us-west-1.amazonaws.com/clevr/'

build_data.download(url + fname, dpath, fname)
build_data.untar(dpath, fname)

# Mark the data as built.
build_data.mark_done(dpath, version_string=version)