Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Add image feature extraction modules and fix minor bugs. #169

Merged
merged 6 commits into from
Jun 29, 2017
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions examples/extract_image_feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
"""Basic example which iterates through the tasks specified and load/extract the
image features.

For example, to extract the image feature of COCO images:
`python examples/extract_image_feature.py -t vqa_v1 -im resnet152`.

The CNN model and layer is specified at `--image-cnntype` and `--image-layernum`
in `parlai.core.image_featurizers`.

For more options, check `parlai.core.image_featurizers`
"""

from parlai.core.params import ParlaiParser
from parlai.agents.repeat_label.repeat_label import RepeatLabelAgent
from parlai.core.worlds import create_task
from parlai.core.image_featurizers import ImageLoader

import random

def main():
random.seed(42)

# Get command line arguments
parser = ParlaiParser()
parser.add_argument('-n', '--num-examples', default=10)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(should help with hitting each image only once)
parser.set_defaults(datatype='train:ordered')

ImageLoader.add_cmdline_args(parser)
opt = parser.parse_args()

opt['no_cuda'] = False
opt['gpu'] = 0
# create repeat label agent and assign it to the specified task
agent = RepeatLabelAgent(opt)
world = create_task(opt, agent)

# Show some example dialogs.
with world:
for k in range(int(opt['num_examples'])):
world.parley()
print(world.display() + '\n~~')
if world.epoch_done():
print('EPOCH DONE')
break

if __name__ == '__main__':
main()
44 changes: 4 additions & 40 deletions parlai/core/dialog_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .agents import Teacher

from .image_featurizers import ImageLoader
from PIL import Image
import random
import os
Expand Down Expand Up @@ -185,6 +186,7 @@ def __init__(self, opt, data_loader, cands=None):
self._load(data_loader)
self.cands = None if cands == None else set(sys.intern(c) for c in cands)
self.addedCands = []
self.image_loader = None

def __len__(self):
"""Returns total number of entries available. Each episode has at least
Expand Down Expand Up @@ -240,6 +242,7 @@ def _load(self, data_loader):
new_entry.append(None)
if len(entry) > 4 and entry[4] is not None:
new_entry.append(sys.intern(entry[4]))
self.image_loader = ImageLoader(opt)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will create a new image loader on every example with an image--not good. since the image initializer is lazy anyways, let's move this to __init__

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, right.


episode.append(tuple(new_entry))

Expand Down Expand Up @@ -272,7 +275,7 @@ def get(self, episode_idx, entry_idx=0):
if entry[3] is not None:
table['label_candidates'] = entry[3]
if len(entry) > 4 and entry[4] is not None:
img = load_image(self.opt, entry[4])
img = self.image_loader.load(entry[4])
if img is not None:
table['image'] = img

Expand All @@ -297,42 +300,3 @@ def get(self, episode_idx, entry_idx=0):
table['episode_done'] = episode_done
return table, end_of_data


_greyscale = ' .,:;crsA23hHG#98&@'


def img_to_ascii(path):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this stuff to utils or something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those stuff are move to image_featurizers.py

im = Image.open(path)
im.thumbnail((60, 40), Image.BICUBIC)
im = im.convert('L')
asc = []
for y in range(0, im.size[1]):
for x in range(0, im.size[0]):
lum = 255 - im.getpixel((x, y))
asc.append(_greyscale[lum * len(_greyscale) // 256])
asc.append('\n')
return ''.join(asc)


def load_image(opt, path):
mode = opt.get('image_mode', 'raw')
if mode is None or mode == 'none':
# don't need to load images
return None
elif mode == 'raw':
# raw just returns RGB values
return Image.open(path).convert('RGB')
elif mode == 'ascii':
# convert images to ascii ¯\_(ツ)_/¯
return img_to_ascii(path)
else:
# otherwise, looks for preprocessed version under 'mode' directory
prepath, imagefn = os.path.split(path)
new_path = os.path.join(prepath, mode, imagefn)
if not os.path.isfile(new_path):
# currently only supports *downloaded* preprocessing
# TODO: generate preprocessed images if not available
raise NotImplementedError('image preprocessing mode' +
'{} not supported yet'.format(mode))
else:
return Image.open(path)
145 changes: 145 additions & 0 deletions parlai/core/image_featurizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.

import parlai.core.build_data as build_data

import torch
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn

import os
import copy
import numpy as np
from PIL import Image

_greyscale = ' .,:;crsA23hHG#98&@'

class ImageLoader():
"""Extract image feature using pretrained CNN network.
"""
@staticmethod
def add_cmdline_args(argparser):
argparser.add_arg('--image-cnntype', type=str, default='resnet152',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually: --image-mode can cover this, unless you can think of a case where image-mode might be set to something that would clash? otherwise let's delete this param

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are two parameters used for image feature extraction. --image_cnntype (which specify the CNN type), and --image_layernum (which specify which layer of feature we need). Usually, we use output of fully connected layer as global image feature, or output of last convolutional layer as spatial image feature or others. Previously, I'm thinking we can let user to set the --image_cnntype or --image_layernum, and --image-mode just specify which folder to store the feature.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can limit this option --image-mode (Different ResNet architecture + Global or Spatial). Then we can delete the --image-layernum and --image--cnntype option.

help='which CNN archtecture to use to extract the image feature'+
'current pretrained option can be found https://github.com/pytorch/vision')
argparser.add_arg('--image-layernum', type=int, default=-1,
help='which CNN layer of feature to extract.')
argparser.add_arg('--image-size', type=int, default=256,
help='')
argparser.add_arg('--image-cropsize', type=int, default=224,
help='')

def __init__(self, opt):

self.opt = copy.deepcopy(opt)
self.netCNN = None
self.transform = None
self.xs = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we move these three into init_cnn?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, let's do this.


def init_cnn(self):
"""Lazy initialization of preprocessor model in case we don't need any image preprocessing."""
opt = self.opt
self.cnn_type = opt['image_cnntype']
self.layer_num = opt['image_layernum']
self.image_size = opt['image_size']
self.crop_size = opt['image_cropsize']
self.datatype = opt['datatype']

opt['cuda'] = not opt['no_cuda'] and torch.cuda.is_available()
self.use_cuda = opt['cuda']

if self.use_cuda:
print('[ Using CUDA ]')
torch.cuda.set_device(opt['gpu'])

# initialize the pretrained CNN using pytorch.
CNN = getattr(torchvision.models, self.cnn_type)

# cut off the additional layer.
self.netCNN = nn.Sequential(*list(CNN(pretrained=True).children())[:self.layer_num])

# initialize the transform function using torch vision.
self.transform = transforms.Compose([
transforms.Scale(self.image_size),
transforms.CenterCrop(self.crop_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])

# container for single image
self.xs = torch.FloatTensor(1, 3, self.crop_size, self.crop_size).fill_(0)

if self.use_cuda:
self.cuda()
self.xs = self.xs.cuda()

# make self.xs variable.
self.xs = Variable(self.xs)

def cuda(self):
self.netCNN.cuda()

def save(self, feature, path):
feature = feature.cpu().data.numpy()
np.save(path, feature)

def extract(self, image, path):
# check whether initlize CNN network.
if not self.netCNN:
self.init_cnn()

self.xs.data.copy_(self.transform(image))
# extract the image feature
feature = self.netCNN(self.xs)
# save the feature
self.save(feature, path)
return feature

def img_to_ascii(self, path):
im = Image.open(path)
im.thumbnail((60, 40), Image.BICUBIC)
im = im.convert('L')
asc = []
for y in range(0, im.size[1]):
for x in range(0, im.size[0]):
lum = 255 - im.getpixel((x, y))
asc.append(_greyscale[lum * len(_greyscale) // 256])
asc.append('\n')
return ''.join(asc)

def load(self, path):
opt = self.opt
mode = opt.get('image_mode', 'raw')
if mode is None or mode == 'none':
# don't need to load images
return None
elif mode == 'raw':
# raw just returns RGB values
return Image.open(path).convert('RGB')
elif mode == 'ascii':
# convert images to ascii ¯\_(ツ)_/¯
return self.img_to_ascii(path)
else:
# otherwise, looks for preprocessed version under 'mode' directory
prepath, imagefn = os.path.split(path)

dpath = os.path.join(prepath, mode)

if not os.path.exists(dpath):
build_data.make_dir(dpath)

imagefn = imagefn + '.npy'
new_path = os.path.join(prepath, mode, imagefn)

if not os.path.isfile(new_path):
return self.extract(Image.open(path).convert('RGB'), new_path)
else:
return np.load(new_path)



8 changes: 4 additions & 4 deletions parlai/tasks/vqa_v1/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# of patent rights can be found in the PATENTS file in the same directory.

from parlai.core.agents import Teacher
from parlai.core.dialog_teacher import load_image
from parlai.core.image_featurizers import ImageLoader
from .build import build, buildImage

import json
Expand All @@ -29,7 +29,7 @@ def _path(opt):
elif dt == 'test':
ques_suffix = 'MultipleChoice_mscoco_test2015'
annotation_suffix = 'None'
img_suffix = os.path.join('test2014', 'COCO_test2014_')
img_suffix = os.path.join('test2015', 'COCO_test2015_')
else:
raise RuntimeError('Not valid datatype.')

Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(self, opt, shared=None):
# size so they all process disparate sets of the data
self.step_size = opt.get('batchsize', 1)
self.data_offset = opt.get('batchindex', 0)

self.image_loader = ImageLoader(opt)
self.reset()

def __len__(self):
Expand Down Expand Up @@ -101,7 +101,7 @@ def act(self):
img_path = self.image_path + '%012d.jpg' % (image_id)

action = {
'image': load_image(self.opt, img_path),
'image': self.image_loader.load(img_path),
'text': question,
'episode_done': True
}
Expand Down
11 changes: 6 additions & 5 deletions parlai/tasks/vqa_v1/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ def buildImage(opt):
# download the image data.
fname1 = 'train2014.zip'
fname2 = 'val2014.zip'
fname3 = 'test2014.zip'
fname3 = 'test2015.zip'

url = 'http://msvocds.blob.core.windows.net/coco2014/'
url1 = 'http://msvocds.blob.core.windows.net/coco2014/'
url2 = 'http://msvocds.blob.core.windows.net/coco2015/'

build_data.download(url + fname1, dpath, fname1)
build_data.download(url + fname2, dpath, fname2)
build_data.download(url + fname3, dpath, fname3)
build_data.download(url1 + fname1, dpath, fname1)
build_data.download(url1 + fname2, dpath, fname2)
build_data.download(url2 + fname3, dpath, fname3)

build_data.untar(dpath, fname1)
build_data.untar(dpath, fname2)
Expand Down
10 changes: 4 additions & 6 deletions parlai/tasks/vqa_v2/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# of patent rights can be found in the PATENTS file in the same directory.

from parlai.core.agents import Teacher
from parlai.core.dialog_teacher import load_image
from parlai.core.image_featurizers import ImageLoader
from .build import build, buildImage

import json
Expand All @@ -29,7 +29,7 @@ def _path(opt):
elif dt == 'test':
ques_suffix = 'v2_OpenEnded_mscoco_test2015'
annotation_suffix = 'None'
img_suffix = os.path.join('test2014', 'COCO_test2014_')
img_suffix = os.path.join('test2015', 'COCO_test2015_')
else:
raise RuntimeError('Not valid datatype.')

Expand Down Expand Up @@ -67,6 +67,7 @@ def __init__(self, opt, shared=None):
# size so they all process disparate sets of the data
self.step_size = opt.get('batchsize', 1)
self.data_offset = opt.get('batchindex', 0)
self.image_loader = ImageLoader(opt)

self.reset()

Expand Down Expand Up @@ -95,12 +96,9 @@ def act(self):

qa = self.ques['questions'][self.episode_idx]
question = qa['question']
image_id = qa['image_id']

img_path = self.image_path + '%012d.jpg' % (image_id)

action = {
'image': load_image(self.opt, img_path),
'image': self.image_loader.load(img_path),
'text': question,
'episode_done': True
}
Expand Down
11 changes: 6 additions & 5 deletions parlai/tasks/vqa_v2/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ def buildImage(opt):
# download the image data.
fname1 = 'train2014.zip'
fname2 = 'val2014.zip'
fname3 = 'test2014.zip'
fname3 = 'test2015.zip'

url = 'http://msvocds.blob.core.windows.net/coco2014/'
url1 = 'http://msvocds.blob.core.windows.net/coco2014/'
url2 = 'http://msvocds.blob.core.windows.net/coco2015/'

build_data.download(url + fname1, dpath, fname1)
build_data.download(url + fname2, dpath, fname2)
build_data.download(url + fname3, dpath, fname3)
build_data.download(url1 + fname1, dpath, fname1)
build_data.download(url1 + fname2, dpath, fname2)
build_data.download(url2 + fname3, dpath, fname3)

build_data.untar(dpath, fname1)
build_data.untar(dpath, fname2)
Expand Down