-
Notifications
You must be signed in to change notification settings - Fork 6
/
demo_AP.m
126 lines (112 loc) · 4.71 KB
/
demo_AP.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
function paths = demo_AP(dataset, nbits, modelType, varargin)
% This is the main function to run deep learning experiments for the mutual
% information based hashing method as described in the below papers.
%
% Please cite these papers if you use this code.
%
% 1. "Hashing with Mutual Information",
% Fatih Cakir*, Kun He*, Sarah A. Bargal, Stan Sclaroff
% arXiv:1803.00974 2018
%
% 2. "MIHash: Online Hashing with Mutual Information",
% Fatih Cakir*, Kun He*, Sarah A. Bargal, Stan Sclaroff
% International Conference on Computer Vision (ICCV) 2017
% (* equal contribution)
%
% INPUTS
% dataset - (string) in {'cifar', 'nuswide', 'labelme'}
% nbits - (int) length of binary code
% modelType- (string) in {'fc1', 'vggf', 'vggf_ft'} among others corresponding
% to the models as defined under '+models' folder.
% varargin - key-value argument pairs, see get_opts.m for details
%
% OUTPUTS
% paths (struct)
% .expfolder - (string) Path to the experiments folder
% .diary - (string) Path to the experimental log
%
% EXAMPLE COMMANDS
% refer to github page
%
% -----------------------------------------------------------------------------
% initialize opts
% -----------------------------------------------------------------------------
opts = get_opts(dataset, nbits, modelType, varargin{:});
finishup = onCleanup(@cleanup);
rng(opts.randseed, 'twister'); % set global random stream
% -----------------------------------------------------------------------------
% post-parsing
% -----------------------------------------------------------------------------
opts = process_opts(opts); % carry out all post-processing on opts
% -----------------------------------------------------------------------------
% print info
% -----------------------------------------------------------------------------
opts
myLogInfo(opts.methodID);
myLogInfo(opts.identifier);
% -----------------------------------------------------------------------------
% get neural net model
% -----------------------------------------------------------------------------
[net, opts] = get_model(opts);
% -----------------------------------------------------------------------------
% get dataset
% -----------------------------------------------------------------------------
global imdb
[imdb, opts, net] = get_imdb(imdb, opts, net);
% -----------------------------------------------------------------------------
% set batch sampling function
% -----------------------------------------------------------------------------
batchFunc = get_batchFunc(imdb, opts, net);
% -----------------------------------------------------------------------------
% set learning rate vector
% -----------------------------------------------------------------------------
lrvec = set_lr(opts);
% -----------------------------------------------------------------------------
% set model save checkpoints
% -----------------------------------------------------------------------------
saveps = set_saveps(opts);
% -----------------------------------------------------------------------------
% train
% -----------------------------------------------------------------------------
[net, info] = train_simplenn(net, imdb, batchFunc, ...
'continue', opts.continue, ...
'debug', opts.debug, ...
'plotStatistics', false, ...
'expDir', opts.expDir, ...
'batchSize', opts.batchSize, ...
'numEpochs', opts.epoch, ...
'saveEpochs', unique(saveps), ...
'learningRate', lrvec, ...
'weightDecay', opts.wdecay, ...
'backPropDepth', opts.bpdepth, ...
'val', find(imdb.images.set == 3), ...
'gpus', opts.gpus, ...
'errorFunction', 'none', ...
'epochCallback', @epoch_callback) ;
% -----------------------------------------------------------------------------
% return
% -----------------------------------------------------------------------------
paths.diary = opts.diary_path;
paths.expfolder = opts.expDir;
end
% -----------------------------------------------------------------------------
% postprocessing after each epoch
% -----------------------------------------------------------------------------
function net = epoch_callback(epoch, net, imdb, batchFunc, netopts)
opts = net.layers{end}.opts;
if numel(opts.gpus) >= 1
net = vl_simplenn_move(net, 'gpu');
end
% disp
myLogInfo('[%s]', opts.methodID);
myLogInfo('[%s]', opts.identifier);
% test?
if ~isfield(opts, 'testFunc'), opts.testFunc = @test_supervised; end
if ~isfield(opts, 'testInterval'), opts.testInterval = 10; end
if ~isfield(opts, 'metrics'), opts.metrics = {'AP'}; end
if ~mod(epoch, opts.testInterval) ...
|| (isfield(opts, 'ep1') & opts.ep1 & epoch==1) || (epoch == opts.epoch)
opts.testFunc(net, imdb, batchFunc, opts, opts.metrics);
end
diary off, diary on
end