diff --git a/examples/extract_image_feature.py b/examples/extract_image_feature.py new file mode 100644 index 00000000000..2198956ebfd --- /dev/null +++ b/examples/extract_image_feature.py @@ -0,0 +1,51 @@ +# All rights reserved. +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. An additional grant +# of patent rights can be found in the PATENTS file in the same directory. +"""Basic example which iterates through the tasks specified and load/extract the +image features. + +For example, to extract the image feature of COCO images: +`python examples/extract_image_feature.py -t vqa_v1 -im resnet152`. + +The CNN model and layer is specified at `--image-cnntype` and `--image-layernum` +in `parlai.core.image_featurizers`. + +For more options, check `parlai.core.image_featurizers` +""" + +from parlai.core.params import ParlaiParser +from parlai.agents.repeat_label.repeat_label import RepeatLabelAgent +from parlai.core.worlds import create_task +from parlai.core.image_featurizers import ImageLoader + +import random + +def main(): + random.seed(42) + + # Get command line arguments + parser = ParlaiParser() + parser.add_argument('-n', '--num-examples', default=10) + parser.set_defaults(datatype='train:ordered') + + ImageLoader.add_cmdline_args(parser) + opt = parser.parse_args() + + opt['no_cuda'] = False + opt['gpu'] = 0 + # create repeat label agent and assign it to the specified task + agent = RepeatLabelAgent(opt) + world = create_task(opt, agent) + + # Show some example dialogs. + with world: + for k in range(int(opt['num_examples'])): + world.parley() + print(world.display() + '\n~~') + if world.epoch_done(): + print('EPOCH DONE') + break + +if __name__ == '__main__': + main() diff --git a/parlai/core/dialog_teacher.py b/parlai/core/dialog_teacher.py index a7ef472b66b..eae0751af99 100644 --- a/parlai/core/dialog_teacher.py +++ b/parlai/core/dialog_teacher.py @@ -6,6 +6,7 @@ from .agents import Teacher +from .image_featurizers import ImageLoader from PIL import Image import random import os @@ -185,6 +186,7 @@ def __init__(self, opt, data_loader, cands=None): self._load(data_loader) self.cands = None if cands == None else set(sys.intern(c) for c in cands) self.addedCands = [] + self.image_loader = ImageLoader(opt) def __len__(self): """Returns total number of entries available. Each episode has at least @@ -272,7 +274,7 @@ def get(self, episode_idx, entry_idx=0): if entry[3] is not None: table['label_candidates'] = entry[3] if len(entry) > 4 and entry[4] is not None: - img = load_image(self.opt, entry[4]) + img = self.image_loader.load(entry[4]) if img is not None: table['image'] = img @@ -297,42 +299,3 @@ def get(self, episode_idx, entry_idx=0): table['episode_done'] = episode_done return table, end_of_data - -_greyscale = ' .,:;crsA23hHG#98&@' - - -def img_to_ascii(path): - im = Image.open(path) - im.thumbnail((60, 40), Image.BICUBIC) - im = im.convert('L') - asc = [] - for y in range(0, im.size[1]): - for x in range(0, im.size[0]): - lum = 255 - im.getpixel((x, y)) - asc.append(_greyscale[lum * len(_greyscale) // 256]) - asc.append('\n') - return ''.join(asc) - - -def load_image(opt, path): - mode = opt.get('image_mode', 'raw') - if mode is None or mode == 'none': - # don't need to load images - return None - elif mode == 'raw': - # raw just returns RGB values - return Image.open(path).convert('RGB') - elif mode == 'ascii': - # convert images to ascii ¯\_(ツ)_/¯ - return img_to_ascii(path) - else: - # otherwise, looks for preprocessed version under 'mode' directory - prepath, imagefn = os.path.split(path) - new_path = os.path.join(prepath, mode, imagefn) - if not os.path.isfile(new_path): - # currently only supports *downloaded* preprocessing - # TODO: generate preprocessed images if not available - raise NotImplementedError('image preprocessing mode' + - '{} not supported yet'.format(mode)) - else: - return Image.open(path) diff --git a/parlai/core/image_featurizers.py b/parlai/core/image_featurizers.py new file mode 100644 index 00000000000..ac8072e377a --- /dev/null +++ b/parlai/core/image_featurizers.py @@ -0,0 +1,158 @@ +# All rights reserved. +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. An additional grant +# of patent rights can be found in the PATENTS file in the same directory. + +import parlai.core.build_data as build_data + +import torch +from torch.autograd import Variable +import torchvision +import torchvision.transforms as transforms +import torch.nn as nn + +import os +import copy +import numpy as np +from PIL import Image + +_greyscale = ' .,:;crsA23hHG#98&@' + +class ImageLoader(): + """Extract image feature using pretrained CNN network. + """ + @staticmethod + def add_cmdline_args(argparser): + argparser.add_arg('--image-size', type=int, default=256, + help='') + argparser.add_arg('--image-cropsize', type=int, default=224, + help='') + + def __init__(self, opt): + self.opt = copy.deepcopy(opt) + self.netCNN = None + + def init_cnn(self): + """Lazy initialization of preprocessor model in case we don't need any image preprocessing.""" + opt = self.opt + self.image_size = opt['image_size'] + self.crop_size = opt['image_cropsize'] + self.datatype = opt['datatype'] + self.image_mode = opt['image_mode'] + + opt['cuda'] = not opt['no_cuda'] and torch.cuda.is_available() + self.use_cuda = opt['cuda'] + + if self.use_cuda: + print('[ Using CUDA ]') + torch.cuda.set_device(opt['gpu']) + + cnn_type, layer_num = self.image_mode_switcher() + + # initialize the pretrained CNN using pytorch. + CNN = getattr(torchvision.models, cnn_type) + + # cut off the additional layer. + self.netCNN = nn.Sequential(*list(CNN(pretrained=True).children())[:layer_num]) + + # initialize the transform function using torch vision. + self.transform = transforms.Compose([ + transforms.Scale(self.image_size), + transforms.CenterCrop(self.crop_size), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + # container for single image + self.xs = torch.FloatTensor(1, 3, self.crop_size, self.crop_size).fill_(0) + + if self.use_cuda: + self.cuda() + self.xs = self.xs.cuda() + + # make self.xs variable. + self.xs = Variable(self.xs) + + def cuda(self): + self.netCNN.cuda() + + def save(self, feature, path): + feature = feature.cpu().data.numpy() + np.save(path, feature) + + def image_mode_switcher(self): + switcher = { + 'resnet152': ['resnet152', -1], + 'resnet101': ['resnet101', -1], + 'resnet50': ['resnet50', -1], + 'resnet34': ['resnet34', -1], + 'resnet18': ['resnet18', -1], + 'resnet152_spatial': ['resnet152', -2], + 'resnet101_spatial': ['resnet101', -2], + 'resnet50_spatial': ['resnet50', -2], + 'resnet34_spatial': ['resnet34', -2], + 'resnet18_spatial': ['resnet18', -2], + } + + if self.image_mode not in switcher: + raise NotImplementedError('image preprocessing mode' + + '{} not supported yet'.format(self.image_mode)) + + return switcher.get(self.image_mode) + + def extract(self, image, path): + # check whether initlize CNN network. + if not self.netCNN: + self.init_cnn() + + self.xs.data.copy_(self.transform(image)) + # extract the image feature + feature = self.netCNN(self.xs) + # save the feature + self.save(feature, path) + return feature + + def img_to_ascii(self, path): + im = Image.open(path) + im.thumbnail((60, 40), Image.BICUBIC) + im = im.convert('L') + asc = [] + for y in range(0, im.size[1]): + for x in range(0, im.size[0]): + lum = 255 - im.getpixel((x, y)) + asc.append(_greyscale[lum * len(_greyscale) // 256]) + asc.append('\n') + return ''.join(asc) + + def load(self, path): + opt = self.opt + mode = opt.get('image_mode', 'raw') + if mode is None or mode == 'none': + # don't need to load images + return None + elif mode == 'raw': + # raw just returns RGB values + return Image.open(path).convert('RGB') + elif mode == 'ascii': + # convert images to ascii ¯\_(ツ)_/¯ + return self.img_to_ascii(path) + else: + # otherwise, looks for preprocessed version under 'mode' directory + prepath, imagefn = os.path.split(path) + + dpath = os.path.join(prepath, mode) + + if not os.path.exists(dpath): + build_data.make_dir(dpath) + + imagefn = imagefn + '.npy' + new_path = os.path.join(prepath, mode, imagefn) + + if not os.path.isfile(new_path): + return self.extract(Image.open(path).convert('RGB'), new_path) + else: + return np.load(new_path) + + + diff --git a/parlai/tasks/vqa_v1/agents.py b/parlai/tasks/vqa_v1/agents.py index 7de400cf223..2db211ce646 100644 --- a/parlai/tasks/vqa_v1/agents.py +++ b/parlai/tasks/vqa_v1/agents.py @@ -5,7 +5,7 @@ # of patent rights can be found in the PATENTS file in the same directory. from parlai.core.agents import Teacher -from parlai.core.dialog_teacher import load_image +from parlai.core.image_featurizers import ImageLoader from .build import build, buildImage import json @@ -29,7 +29,7 @@ def _path(opt): elif dt == 'test': ques_suffix = 'MultipleChoice_mscoco_test2015' annotation_suffix = 'None' - img_suffix = os.path.join('test2014', 'COCO_test2014_') + img_suffix = os.path.join('test2015', 'COCO_test2015_') else: raise RuntimeError('Not valid datatype.') @@ -66,7 +66,7 @@ def __init__(self, opt, shared=None): # size so they all process disparate sets of the data self.step_size = opt.get('batchsize', 1) self.data_offset = opt.get('batchindex', 0) - + self.image_loader = ImageLoader(opt) self.reset() def __len__(self): @@ -101,7 +101,7 @@ def act(self): img_path = self.image_path + '%012d.jpg' % (image_id) action = { - 'image': load_image(self.opt, img_path), + 'image': self.image_loader.load(img_path), 'text': question, 'episode_done': True } diff --git a/parlai/tasks/vqa_v1/build.py b/parlai/tasks/vqa_v1/build.py index ab274a06b92..dfe8a21a051 100644 --- a/parlai/tasks/vqa_v1/build.py +++ b/parlai/tasks/vqa_v1/build.py @@ -20,13 +20,14 @@ def buildImage(opt): # download the image data. fname1 = 'train2014.zip' fname2 = 'val2014.zip' - fname3 = 'test2014.zip' + fname3 = 'test2015.zip' - url = 'http://msvocds.blob.core.windows.net/coco2014/' + url1 = 'http://msvocds.blob.core.windows.net/coco2014/' + url2 = 'http://msvocds.blob.core.windows.net/coco2015/' - build_data.download(url + fname1, dpath, fname1) - build_data.download(url + fname2, dpath, fname2) - build_data.download(url + fname3, dpath, fname3) + build_data.download(url1 + fname1, dpath, fname1) + build_data.download(url1 + fname2, dpath, fname2) + build_data.download(url2 + fname3, dpath, fname3) build_data.untar(dpath, fname1) build_data.untar(dpath, fname2) diff --git a/parlai/tasks/vqa_v2/agents.py b/parlai/tasks/vqa_v2/agents.py index 51e19c3689f..9beebcdf67e 100644 --- a/parlai/tasks/vqa_v2/agents.py +++ b/parlai/tasks/vqa_v2/agents.py @@ -5,7 +5,7 @@ # of patent rights can be found in the PATENTS file in the same directory. from parlai.core.agents import Teacher -from parlai.core.dialog_teacher import load_image +from parlai.core.image_featurizers import ImageLoader from .build import build, buildImage import json @@ -29,7 +29,7 @@ def _path(opt): elif dt == 'test': ques_suffix = 'v2_OpenEnded_mscoco_test2015' annotation_suffix = 'None' - img_suffix = os.path.join('test2014', 'COCO_test2014_') + img_suffix = os.path.join('test2015', 'COCO_test2015_') else: raise RuntimeError('Not valid datatype.') @@ -67,6 +67,7 @@ def __init__(self, opt, shared=None): # size so they all process disparate sets of the data self.step_size = opt.get('batchsize', 1) self.data_offset = opt.get('batchindex', 0) + self.image_loader = ImageLoader(opt) self.reset() @@ -95,12 +96,9 @@ def act(self): qa = self.ques['questions'][self.episode_idx] question = qa['question'] - image_id = qa['image_id'] - - img_path = self.image_path + '%012d.jpg' % (image_id) action = { - 'image': load_image(self.opt, img_path), + 'image': self.image_loader.load(img_path), 'text': question, 'episode_done': True } diff --git a/parlai/tasks/vqa_v2/build.py b/parlai/tasks/vqa_v2/build.py index 76a0784ea0d..86e75fc7f21 100644 --- a/parlai/tasks/vqa_v2/build.py +++ b/parlai/tasks/vqa_v2/build.py @@ -20,13 +20,14 @@ def buildImage(opt): # download the image data. fname1 = 'train2014.zip' fname2 = 'val2014.zip' - fname3 = 'test2014.zip' + fname3 = 'test2015.zip' - url = 'http://msvocds.blob.core.windows.net/coco2014/' + url1 = 'http://msvocds.blob.core.windows.net/coco2014/' + url2 = 'http://msvocds.blob.core.windows.net/coco2015/' - build_data.download(url + fname1, dpath, fname1) - build_data.download(url + fname2, dpath, fname2) - build_data.download(url + fname3, dpath, fname3) + build_data.download(url1 + fname1, dpath, fname1) + build_data.download(url1 + fname2, dpath, fname2) + build_data.download(url2 + fname3, dpath, fname3) build_data.untar(dpath, fname1) build_data.untar(dpath, fname2)