Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matcaffe2 #501

Closed
wants to merge 63 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
c876a49
Extended matcaffe, splited init in init_net and load_modl, also to ge…
sguada Jun 10, 2014
d9725e8
Convert int to mxCreateDoubleScalar
sguada Jun 10, 2014
345fe4d
Fixed do_get_blobs_info
sguada Jun 10, 2014
c9eae73
Rename load_model to load_net
sguada Jun 10, 2014
3f9b2d9
Added count to blobs, changed name of blobs to weights for layers
sguada Jun 10, 2014
4a10c70
Changed name of blobs to weights for layers
sguada Jun 10, 2014
adedeba
Added get_layer_weights
sguada Jun 10, 2014
5a8f7d3
Added get_layer_weights to interface
sguada Jun 10, 2014
677668b
Added set_layer_weights
sguada Jun 10, 2014
7d41603
# This is a combination of 4 commits.
sguada Jun 10, 2014
30f23fc
Added save_net to be able to save a net to a proto binary file
sguada Aug 27, 2014
b3c3a40
Added check that layer_blobs[j]->count() == numel(elem)
sguada Jun 11, 2014
664947c
Added set_weights to be able to set multiple layers weights
sguada Jun 11, 2014
1eba9d8
Added get_blob_data
sguada Jun 11, 2014
dc05410
Added get_blob_diff to get the diff in blob
sguada Jun 11, 2014
f93c146
Added get_all_data and get_all_diff
sguada Jun 11, 2014
9d66670
Return [output,loss] in the forward pass
sguada Jun 11, 2014
28eb111
Re-organize mx_loss argument for do_foward
sguada Jun 11, 2014
7595530
Added forward_prefilled
sguada Jun 11, 2014
c9a6301
Added logs to track the loss
sguada Jun 11, 2014
d150a56
Added backward_prefilled
sguada Aug 27, 2014
26344fa
Fix check number of arguments for backwarda and backward_prefilled
sguada Jun 11, 2014
ac199b4
Remove set_mode_test from matcaffe_init.m
sguada Jun 11, 2014
83156ec
Added tools to display weights as filters
sguada Jun 11, 2014
86ee662
Now load_net return a new init_key
sguada Jun 12, 2014
7467d58
Added Class CaffeNet to wrap all the calls to caffe
sguada Jun 12, 2014
e996763
Added get_mode get_phase and get_device
sguada Jun 12, 2014
481016c
Added setters and getters to handle interactions with caffe
sguada Jun 13, 2014
34b7f48
Added GetDevice to commom Caffe and fixed get mode and phase in matcaffe
sguada Jun 13, 2014
db0d259
added set_input_blobs and set_output_blobs
sguada Aug 27, 2014
55120f7
Fixed device_id in MatCaffe
sguada Jun 13, 2014
18e0004
Fixed reference to weights_changed and initialize mode, phase and dev…
sguada Jun 13, 2014
9899661
Added set_input_blobs and set_output_blobs, avoided set when values d…
sguada Jun 13, 2014
5809489
Removed AbortSet when values don't change
sguada Jun 13, 2014
89c5294
Only load model and not init again
sguada Jun 13, 2014
69c60aa
Don't get_weights during initialization
sguada Jun 13, 2014
cf6468f
Make weights a dependent property
sguada Jun 13, 2014
343d085
Reorganize Constractor and Instance
sguada Jun 13, 2014
ec6cae8
Now initialize without parameters loads the default imagenet_deploy a…
sguada Jun 13, 2014
9507c91
Fixed if and conditions
sguada Jun 13, 2014
90e79f0
Added init(model_def_file,model_file) and delete to free up memory)
sguada Jun 14, 2014
42f1630
Check that the net is initialized before doing operations
sguada Jun 14, 2014
ce5d3e6
Make reset to remove self
sguada Jun 14, 2014
bdd2b8a
Restore reset to just call caffe('reset') and empyt info
sguada Jun 14, 2014
c6ca902
Clear weights_store when reset or delete
sguada Jun 14, 2014
a560523
Adapted matcaffe_demo to use CaffeNet
sguada Jun 14, 2014
139c42d
Added forgotten self
sguada Jun 14, 2014
d2562a5
Now matcaffe_demo([],1) runs in gpu mode with peppers.png
sguada Jun 14, 2014
7cd79a8
Added demo of how to extract weights from the default network
sguada Jun 14, 2014
907fbd4
Simplify matcaffe_demo_backward.m
sguada Jun 14, 2014
536df00
Adjusted demos to use CaffeNet
sguada Sep 20, 2014
87c19c5
Replace condition checking by assert(is_initialized)
sguada Jun 14, 2014
52f09ce
Add caffe_safe to catch caffe exceptions
sguada Jun 20, 2014
ab12475
Redefine CHECK_EQ to avoid core_dump in Matlab
sguada Jun 21, 2014
41f3f1c
Fixed varargin in caffe_safe.m
sguada Jun 20, 2014
d508d75
Reduce redundancy of code by creating auxiliary methods
sguada Aug 27, 2014
0ebb0be
Added mxarray_to_blob_data and mxarray_to_blob_diff
sguada Aug 27, 2014
8cf6dc1
Added comments to demo_backward
sguada Aug 27, 2014
a52e701
Create CHECK_EQ_MEX
sguada Aug 29, 2014
364f1a8
Move mxArray<->Blobs to the begining
sguada Aug 29, 2014
21b70bd
Added share_ptr
sguada Aug 29, 2014
f06fcac
Added share_ptr
sguada Aug 29, 2014
790abca
Make lint happy
sguada Sep 2, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/caffe/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ class Caffe {
// Sets the device. Since we have cublas and curand stuff, set device also
// requires us to reset those values.
static void SetDevice(const int device_id);
// Gets current device_id
static int GetDevice();
// Prints the current GPU status.
static void DeviceQuery();

Expand Down
252 changes: 252 additions & 0 deletions matlab/caffe/CaffeNet.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
classdef CaffeNet < handle
properties (SetAccess = private)
model_def_file = '../../examples/imagenet/imagenet_deploy.prototxt';
model_file = '../../examples/imagenet/caffe_reference_imagenet_model';
layers_info
blobs_info
end
properties
mode
phase
device_id
input_blobs
output_blobs
end
properties (Dependent)
weights
end
properties (Access = private)
init_key
weights_changed = true;
weights_store
end
methods (Access=private)
function self = CaffeNet(model_def_file, model_file)
if nargin == 0
init(self, self.model_def_file, self.model_file);
end
if nargin > 0
init_net(self, model_def_file);
end
if nargin > 1
load_net(self, model_file);
end
assert(is_initialized(self))
self.mode = caffe_safe('get_mode');
self.phase = caffe_safe('get_phase');
self.device_id = caffe_safe('get_device');
end
end
methods (Static)
function obj = instance(model_def_file, model_file)
persistent self
if isempty(self)
switch nargin
case 2
self = CaffeNet(model_def_file, model_file);
case 1
self = CaffeNet(model_def_file);
case 0
self = CaffeNet();
end
else
if nargin > 0 && ~isempty(model_def_file)
init_net(self,model_def_file);
end
if nargin > 1 && ~isempty(model_file)
load_net(self,model_file);
end
end
obj = self;
end
end
methods
function weights = get.weights(self)
assert(is_initialized(self))
if (self.weights_changed)
self.weights_store = caffe_safe('get_weights');
self.weights_changed = false;
end
weights = self.weights_store;
end
function set.weights(self,weights)
assert(is_initialized(self))
caffe_safe('set_weights', weights);
self.weights_store = weights;
self.weights_changed = false;
end
function set.mode(self,mode)
% mode = {'CPU' 'GPU'}
switch mode
case 'CPU'
caffe_safe('set_mode_cpu');
self.mode = mode;
case 'GPU'
caffe_safe('set_mode_gpu');
self.mode = mode;
otherwise
fprintf('Mode unknown choose between CPU and GPU\n');
error('Mode unknown');
end
end
function set.phase(self, phase)
% phase = {'TRAIN' 'TEST'}
switch phase
case 'TRAIN'
caffe_safe('set_phase_train');
self.phase = phase;
case {'test','TEST'}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why upper and lower here and just upper above? Maybe just switch on lower(phase) here and in other places.

caffe_safe('set_phase_test');
self.phase = phase;
otherwise
fprintf('Phase unknown choose between TRAIN and TEST')
error('Phase unknown');
end
end
function set.device_id(self, device_id)
caffe_safe('set_device', device_id);
self.device_id = device_id;
end
function set.input_blobs(self, input_blobs)
assert(is_initialized(self))
caffe_safe('set_input_blobs', input_blobs);
self.input_blobs = input_blobs;
end
function set.output_blobs(self, output_blobs)
assert(is_initialized(self))
caffe_safe('set_output_blobs', output_blobs);
self.output_blobs = output_blobs;
end
end
methods
function res = forward(self,input)
assert(is_initialized(self))
if nargin < 2
res = caffe_safe('forward');
else
res = caffe_safe('forward',input);
end
end
function res = backward(self,diff)
assert(is_initialized(self))
if nargin < 2
res = caffe_safe('backward');
else
res = caffe_safe('backward',diff);
end
end
function res = forward_prefilled(self)
assert(is_initialized(self))
res = caffe_safe('forward_prefilled');
end
function res = backward_prefilled(self)
assert(is_initialized(self))
res = caffe_safe('backward_prefilled');
end
function res = init(self, model_def_file, model_file)
self.init_key = caffe_safe('init',model_def_file, model_file);
assert(is_initialized(self))
self.model_def_file = model_def_file;
self.model_file = model_file;
self.layers_info = caffe_safe('get_layers_info');
self.blobs_info = caffe_safe('get_blobs_info');
self.weights_changed = true;
res = self.init_key;
end
function res = init_net(self, model_def_file)
self.init_key = caffe_safe('init_net',model_def_file);
assert(is_initialized(self))
self.model_def_file = model_def_file;
self.model_file = [];
self.layers_info = caffe_safe('get_layers_info');
self.blobs_info = caffe_safe('get_blobs_info');
self.weights_changed = true;
res = self.init_key;
end
function res = load_net(self, model_file)
assert(is_initialized(self))
self.init_key = caffe_safe('load_net',model_file);
self.model_file = model_file;
self.weights_changed = true;
res = self.init_key;
end
function res = save_net(self, model_file)
assert(is_initialized(self))
res = caffe_safe('save_net', model_file);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think caffe_safe('save_net',...) shouldn't return anything, since the corresponding matcaffe function doesn't return anything.

end
function res = is_initialized(~)
res = caffe_safe('is_initialized')>0;
end
function set_mode_cpu(self)
self.mode = 'CPU';
end
function set_mode_gpu(self)
self.mode = 'GPU';
end
function set_phase_train(self)
self.phase = 'TRAIN';
end
function set_phase_test(self)
self.phase = 'TEST';
end
function set_device(self, device_id)
self.device_id = device_id;
end
function res = get_weights(self)
assert(is_initialized(self))
res = self.weights;
end
function set_weights(self, weights)
self.weights = weights;
end
function res = get_layer_weights(self, layer_name)
assert(is_initialized(self))
res = caffe_safe('get_layer_weights', layer_name);
end
function res = set_layer_weights(self, layer_name, weights)
assert(is_initialized(self))
res = caffe_safe('set_layer_weights', layer_name, weights);
self.weights = caffe_safe('get_weights');
end
function res = get_layers_info(self)
assert(is_initialized(self))
res = self.layers_info;
end
function res = get_blobs_info(self)
assert(is_initialized(self))
res = self.blobs_info;
end
function res = get_blob_data(self, blob_name)
assert(is_initialized(self))
res = caffe_safe('get_blob_data', blob_name);
end
function res = get_blob_diff(self, blob_name)
assert(is_initialized(self))
res = caffe_safe('get_blob_diff', blob_name);
end
function res = get_all_data(self)
assert(is_initialized(self))
res = caffe_safe('get_all_data');
end
function res = get_all_diff(self)
assert(is_initialized(self))
res = caffe_safe('get_all_diff');
end
function res = get_init_key(self)
self.init_key = caffe_safe('get_init_key');
res = self.init_key;
end
function reset(self)
caffe_safe('reset');
self.init_key = caffe_safe('get_init_key');
self.layers_info = [];
self.blobs_info = [];
self.weights_store = [];
end
function delete(self)
self.weights_store = [];
caffe_safe('reset');
clear caffe;
end
end
end
15 changes: 15 additions & 0 deletions matlab/caffe/caffe_safe.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
function varargout = caffe_safe(varargin)
try
switch nargout
case 0
caffe(varargin{:});
varargout={};
case 1
varargout{1} = caffe(varargin{:});
case 2
[varargout{1} varargout{2}] = caffe(varargin{:});
end
catch
error('Exception in caffe');
end
end
61 changes: 61 additions & 0 deletions matlab/caffe/displayData.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
function [h, display_array] = displayData(X, example_width, display_cols)
%DISPLAYDATA Display 2D data in a nice grid
% [h, display_array] = DISPLAYDATA(X, example_width) displays 2D data
% stored in X in a nice grid. It returns the figure handle h and the
% displayed array if requested.

% Set example_width automatically if not passed in
if ~exist('example_width', 'var') || isempty(example_width)
example_width = round(sqrt(size(X, 2)));
end

% Gray Image
colormap(gray);

% Compute rows, cols
[m n] = size(X);
example_height = (n / example_width);

% Compute number of items to display
if ~exist('display_cols', 'var')
display_cols = floor(sqrt(m));
end
display_rows = ceil(m / display_cols);

% Between images padding
pad = 1;

% Setup blank display
display_array = - ones(pad + display_rows * (example_height + pad), ...
pad + display_cols * (example_width + pad));

% Copy each example into a patch on the display array
curr_ex = 1;
for j = 1:display_rows
for i = 1:display_cols
if curr_ex > m,
break;
end
% Copy the patch

% Get the max value of the patch
max_val = max(abs(X(curr_ex, :)));
display_array(pad + (j - 1) * (example_height + pad) + (1:example_height), ...
pad + (i - 1) * (example_width + pad) + (1:example_width)) = ...
reshape(X(curr_ex, :), example_height, example_width) / max_val;
curr_ex = curr_ex + 1;
end
if curr_ex > m,
break;
end
end

% Display Image
h = imagesc(display_array, [-1 1]);

% Do not show axis
axis image off

drawnow;

end
31 changes: 31 additions & 0 deletions matlab/caffe/displayFilters.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
function displayFilters(X, display_cols)
%displayFilters Display [height * width * num_channels * num_filters] filters
% in a nice grid using display_cols, at most it displays 32 filters
% [h, display_array] = displayFilters(X, display_cols)

% Gray Image
colormap(gray);

% Compute rows, cols
[example_height example_width num_channels num_filters] = size(X);


% Compute number of items to display
if ~exist('display_cols', 'var')
display_cols = floor(sqrt(num_filters));
end
display_rows = ceil(num_filters / display_cols);
colimage = [];
for n = 1:min(32,num_filters)
if mod(n,8)== 1
figure(ceil(n/8))
colimage = [];
end
filter = reshape(X(:,:,:,n),[],num_channels)';
[~,dsp] = displayData(filter,example_width,num_channels/4);
colimage = cat(1,colimage,dsp,ones(1,size(dsp,2)));
if mod(n,8)== 0
imagesc(colimage, [-1 1]), axis off; drawnow;
end
end

Loading