-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add image feature extraction modules and fix minor bugs. #169
Changes from 4 commits
c1612f7
9c6ea32
e08a125
dc9b633
cfbd016
a5f82dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# All rights reserved. | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. An additional grant | ||
# of patent rights can be found in the PATENTS file in the same directory. | ||
"""Basic example which iterates through the tasks specified and load/extract the | ||
image features. | ||
|
||
For example, to extract the image feature of COCO images: | ||
`python examples/extract_image_feature.py -t vqa_v1 -im resnet152`. | ||
|
||
The CNN model and layer is specified at `--image-cnntype` and `--image-layernum` | ||
in `parlai.core.image_featurizers`. | ||
|
||
For more options, check `parlai.core.image_featurizers` | ||
""" | ||
|
||
from parlai.core.params import ParlaiParser | ||
from parlai.agents.repeat_label.repeat_label import RepeatLabelAgent | ||
from parlai.core.worlds import create_task | ||
from parlai.core.image_featurizers import ImageLoader | ||
|
||
import random | ||
|
||
def main(): | ||
random.seed(42) | ||
|
||
# Get command line arguments | ||
parser = ParlaiParser() | ||
parser.add_argument('-n', '--num-examples', default=10) | ||
|
||
ImageLoader.add_cmdline_args(parser) | ||
opt = parser.parse_args() | ||
|
||
opt['no_cuda'] = False | ||
opt['gpu'] = 0 | ||
# create repeat label agent and assign it to the specified task | ||
agent = RepeatLabelAgent(opt) | ||
world = create_task(opt, agent) | ||
|
||
# Show some example dialogs. | ||
with world: | ||
for k in range(int(opt['num_examples'])): | ||
world.parley() | ||
print(world.display() + '\n~~') | ||
if world.epoch_done(): | ||
print('EPOCH DONE') | ||
break | ||
|
||
if __name__ == '__main__': | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
|
||
from .agents import Teacher | ||
|
||
from .image_featurizers import ImageLoader | ||
from PIL import Image | ||
import random | ||
import os | ||
|
@@ -185,6 +186,7 @@ def __init__(self, opt, data_loader, cands=None): | |
self._load(data_loader) | ||
self.cands = None if cands == None else set(sys.intern(c) for c in cands) | ||
self.addedCands = [] | ||
self.image_loader = None | ||
|
||
def __len__(self): | ||
"""Returns total number of entries available. Each episode has at least | ||
|
@@ -240,6 +242,7 @@ def _load(self, data_loader): | |
new_entry.append(None) | ||
if len(entry) > 4 and entry[4] is not None: | ||
new_entry.append(sys.intern(entry[4])) | ||
self.image_loader = ImageLoader(opt) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this will create a new image loader on every example with an image--not good. since the image initializer is lazy anyways, let's move this to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ahh, right. |
||
|
||
episode.append(tuple(new_entry)) | ||
|
||
|
@@ -272,7 +275,7 @@ def get(self, episode_idx, entry_idx=0): | |
if entry[3] is not None: | ||
table['label_candidates'] = entry[3] | ||
if len(entry) > 4 and entry[4] is not None: | ||
img = load_image(self.opt, entry[4]) | ||
img = self.image_loader.load(entry[4]) | ||
if img is not None: | ||
table['image'] = img | ||
|
||
|
@@ -297,42 +300,3 @@ def get(self, episode_idx, entry_idx=0): | |
table['episode_done'] = episode_done | ||
return table, end_of_data | ||
|
||
|
||
_greyscale = ' .,:;crsA23hHG#98&@' | ||
|
||
|
||
def img_to_ascii(path): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move this stuff to utils or something? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. those stuff are move to image_featurizers.py |
||
im = Image.open(path) | ||
im.thumbnail((60, 40), Image.BICUBIC) | ||
im = im.convert('L') | ||
asc = [] | ||
for y in range(0, im.size[1]): | ||
for x in range(0, im.size[0]): | ||
lum = 255 - im.getpixel((x, y)) | ||
asc.append(_greyscale[lum * len(_greyscale) // 256]) | ||
asc.append('\n') | ||
return ''.join(asc) | ||
|
||
|
||
def load_image(opt, path): | ||
mode = opt.get('image_mode', 'raw') | ||
if mode is None or mode == 'none': | ||
# don't need to load images | ||
return None | ||
elif mode == 'raw': | ||
# raw just returns RGB values | ||
return Image.open(path).convert('RGB') | ||
elif mode == 'ascii': | ||
# convert images to ascii ¯\_(ツ)_/¯ | ||
return img_to_ascii(path) | ||
else: | ||
# otherwise, looks for preprocessed version under 'mode' directory | ||
prepath, imagefn = os.path.split(path) | ||
new_path = os.path.join(prepath, mode, imagefn) | ||
if not os.path.isfile(new_path): | ||
# currently only supports *downloaded* preprocessing | ||
# TODO: generate preprocessed images if not available | ||
raise NotImplementedError('image preprocessing mode' + | ||
'{} not supported yet'.format(mode)) | ||
else: | ||
return Image.open(path) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# All rights reserved. | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. An additional grant | ||
# of patent rights can be found in the PATENTS file in the same directory. | ||
|
||
import parlai.core.build_data as build_data | ||
|
||
import torch | ||
from torch.autograd import Variable | ||
import torchvision | ||
import torchvision.transforms as transforms | ||
import torch.nn as nn | ||
|
||
import os | ||
import copy | ||
import numpy as np | ||
from PIL import Image | ||
|
||
_greyscale = ' .,:;crsA23hHG#98&@' | ||
|
||
class ImageLoader(): | ||
"""Extract image feature using pretrained CNN network. | ||
""" | ||
@staticmethod | ||
def add_cmdline_args(argparser): | ||
argparser.add_arg('--image-cnntype', type=str, default='resnet152', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there are two parameters used for image feature extraction. --image_cnntype (which specify the CNN type), and --image_layernum (which specify which layer of feature we need). Usually, we use output of fully connected layer as global image feature, or output of last convolutional layer as spatial image feature or others. Previously, I'm thinking we can let user to set the --image_cnntype or --image_layernum, and --image-mode just specify which folder to store the feature. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can limit this option --image-mode (Different ResNet architecture + Global or Spatial). Then we can delete the --image-layernum and --image--cnntype option. |
||
help='which CNN archtecture to use to extract the image feature'+ | ||
'current pretrained option can be found https://github.com/pytorch/vision') | ||
argparser.add_arg('--image-layernum', type=int, default=-1, | ||
help='which CNN layer of feature to extract.') | ||
argparser.add_arg('--image-size', type=int, default=256, | ||
help='') | ||
argparser.add_arg('--image-cropsize', type=int, default=224, | ||
help='') | ||
|
||
def __init__(self, opt): | ||
|
||
self.opt = copy.deepcopy(opt) | ||
self.netCNN = None | ||
self.transform = None | ||
self.xs = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we move these three into There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, let's do this. |
||
|
||
def init_cnn(self): | ||
"""Lazy initialization of preprocessor model in case we don't need any image preprocessing.""" | ||
opt = self.opt | ||
self.cnn_type = opt['image_cnntype'] | ||
self.layer_num = opt['image_layernum'] | ||
self.image_size = opt['image_size'] | ||
self.crop_size = opt['image_cropsize'] | ||
self.datatype = opt['datatype'] | ||
|
||
opt['cuda'] = not opt['no_cuda'] and torch.cuda.is_available() | ||
self.use_cuda = opt['cuda'] | ||
|
||
if self.use_cuda: | ||
print('[ Using CUDA ]') | ||
torch.cuda.set_device(opt['gpu']) | ||
|
||
# initialize the pretrained CNN using pytorch. | ||
CNN = getattr(torchvision.models, self.cnn_type) | ||
|
||
# cut off the additional layer. | ||
self.netCNN = nn.Sequential(*list(CNN(pretrained=True).children())[:self.layer_num]) | ||
|
||
# initialize the transform function using torch vision. | ||
self.transform = transforms.Compose([ | ||
transforms.Scale(self.image_size), | ||
transforms.CenterCrop(self.crop_size), | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=[0.485, 0.456, 0.406], | ||
std=[0.229, 0.224, 0.225]) | ||
]) | ||
|
||
# container for single image | ||
self.xs = torch.FloatTensor(1, 3, self.crop_size, self.crop_size).fill_(0) | ||
|
||
if self.use_cuda: | ||
self.cuda() | ||
self.xs = self.xs.cuda() | ||
|
||
# make self.xs variable. | ||
self.xs = Variable(self.xs) | ||
|
||
def cuda(self): | ||
self.netCNN.cuda() | ||
|
||
def save(self, feature, path): | ||
feature = feature.cpu().data.numpy() | ||
np.save(path, feature) | ||
|
||
def extract(self, image, path): | ||
# check whether initlize CNN network. | ||
if not self.netCNN: | ||
self.init_cnn() | ||
|
||
self.xs.data.copy_(self.transform(image)) | ||
# extract the image feature | ||
feature = self.netCNN(self.xs) | ||
# save the feature | ||
self.save(feature, path) | ||
return feature | ||
|
||
def img_to_ascii(self, path): | ||
im = Image.open(path) | ||
im.thumbnail((60, 40), Image.BICUBIC) | ||
im = im.convert('L') | ||
asc = [] | ||
for y in range(0, im.size[1]): | ||
for x in range(0, im.size[0]): | ||
lum = 255 - im.getpixel((x, y)) | ||
asc.append(_greyscale[lum * len(_greyscale) // 256]) | ||
asc.append('\n') | ||
return ''.join(asc) | ||
|
||
def load(self, path): | ||
opt = self.opt | ||
mode = opt.get('image_mode', 'raw') | ||
if mode is None or mode == 'none': | ||
# don't need to load images | ||
return None | ||
elif mode == 'raw': | ||
# raw just returns RGB values | ||
return Image.open(path).convert('RGB') | ||
elif mode == 'ascii': | ||
# convert images to ascii ¯\_(ツ)_/¯ | ||
return self.img_to_ascii(path) | ||
else: | ||
# otherwise, looks for preprocessed version under 'mode' directory | ||
prepath, imagefn = os.path.split(path) | ||
|
||
dpath = os.path.join(prepath, mode) | ||
|
||
if not os.path.exists(dpath): | ||
build_data.make_dir(dpath) | ||
|
||
imagefn = imagefn + '.npy' | ||
new_path = os.path.join(prepath, mode, imagefn) | ||
|
||
if not os.path.isfile(new_path): | ||
return self.extract(Image.open(path).convert('RGB'), new_path) | ||
else: | ||
return np.load(new_path) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(should help with hitting each image only once)
parser.set_defaults(datatype='train:ordered')