Skip to content

Commit

Permalink
set input preprocessing per blob in python
Browse files Browse the repository at this point in the history
  • Loading branch information
shelhamer committed May 14, 2014
1 parent 56ca978 commit 96cd02d
Showing 1 changed file with 64 additions and 22 deletions.
86 changes: 64 additions & 22 deletions python/caffe/pycaffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@
# inheritance) so that nets created by caffe (e.g., by SGDSolver) will
# automatically have the improved interface

Net.input = property(lambda self: self.blobs.values()[0])
Net.input_scale = None # for a model that expects data = input * input_scale

Net.output = property(lambda self: self.blobs.values()[-1])

Net.mean = None # image mean (ndarray, input dimensional or broadcastable)
# Input preprocessing
Net.mean = {} # image mean (ndarray, input dimensional or broadcastable)
Net.input_scale = {} # for a model that expects data = input * input_scale
Net.channel_swap = {} # for RGB -> BGR and the like


@property
Expand All @@ -44,33 +42,69 @@ def _Net_params(self):
Net.params = _Net_params


def _Net_set_mean(self, mean_f, mode='image'):
def _Net_set_mean(self, input_, mean_f, mode='image'):
"""
Set the mean to subtract for data centering.
Take
input_: which input to assign this mean.
mean_f: path to mean .npy
mode: image = use the whole-image mean (and check dimensions)
channel = channel constant (i.e. mean pixel instead of mean image)
"""
if input_ not in self.inputs:
raise Exception('Input not in {}'.format(self.inputs))
mean = np.load(mean_f)
if mode == 'image':
if mean.shape != self.input.data.shape[1:]:
raise Exception('The mean shape does not match the input shape.')
self.mean = mean
self.mean[input_] = mean
elif mode == 'channel':
self.mean = mean.mean(1).mean(1)
self.mean[input_] = mean.mean(1).mean(1)
else:
raise Exception('Mode not in {}'.format(['image', 'channel']))

Net.set_mean = _Net_set_mean


def _Net_format_image(self, image):
def _Net_set_input_scale(self, input_, scale):
"""
Set the input feature scaling factor s.t. input blob = input * scale.
Take
input_: which input to assign this scale factor
scale: scale coefficient
"""
if input_ not in self.inputs:
raise Exception('Input not in {}'.format(self.inputs))
self.input_scale[input_] = scale

Net.set_input_scale = _Net_set_input_scale


def _Net_set_channel_swap(self, input_, order):
"""
Set the input channel order for e.g. RGB to BGR conversion
as needed for the reference ImageNet model.
Take
input_: which input to assign this channel order
order: the order to take the channels. (2,1,0) maps RGB to BGR for example.
"""
if input_ not in self.inputs:
raise Exception('Input not in {}'.format(self.inputs))
self.channel_swap[input_] = order

Net.set_channel_swap = _Net_set_channel_swap


def _Net_format_image(self, input_, image):
"""
Format image for input to Caffe:
- convert to single
- reorder color to BGR
- scale feature
- reorder channels (for instance color to BGR)
- subtract mean
- reshape to 1 x K x H x W
Take
Expand All @@ -80,29 +114,37 @@ def _Net_format_image(self, image):
image: (K x H x W) ndarray
"""
caf_image = image.astype(np.float32)
if self.input_scale:
caf_image *= self.input_scale
caf_image = caf_image[:, :, ::-1]
if self.mean is not None:
caf_image -= self.mean
input_scale = self.input_scale.get(input_)
channel_order = self.channel_swap.get(input_)
mean = self.mean.get(input_)
if input_scale:
caf_image *= input_scale
if channel_order:
caf_image = caf_image[:, :, channel_order]
if mean:
caf_image -= mean
caf_image = caf_image.transpose((2, 0, 1))
caf_image = caf_image[np.newaxis, :, :, :]
return caf_image

Net.format_image = _Net_format_image


def _Net_decaffeinate_image(self, image):
def _Net_decaffeinate_image(self, input_, image):
"""
Invert Caffe formatting; see _Net_format_image().
"""
decaf_image = image.squeeze()
decaf_image = decaf_image.transpose((1,2,0))
if self.mean is not None:
decaf_image += self.mean
decaf_image = decaf_image[:, :, ::-1]
if self.input_scale:
decaf_image /= self.input_scale
input_scale = self.input_scale.get(input_)
channel_order = self.channel_swap.get(input_)
mean = self.mean.get(input_)
if mean:
decaf_image += mean
if channel_order:
decaf_image = decaf_image[:, :, channel_order[::-1]]
if input_scale:
decaf_image /= input_scale
return decaf_image

Net.decaffeinate_image = _Net_decaffeinate_image
Expand Down

0 comments on commit 96cd02d

Please sign in to comment.