-
Notifications
You must be signed in to change notification settings - Fork 2
/
minimize_stuff.m
221 lines (209 loc) · 11.3 KB
/
minimize_stuff.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
function [X, fX, i, flag] = minimize_stuff(X, f, length, varargin)
% Minimize a differentiable multivariate function using conjugate gradients.
%
% Usage: [X, fX, i] = minimize(X, f, length, P1, P2, P3, ... )
%
% X initial guess; may be of any type, including struct and cell array
% f the name or pointer to the function to be minimized. The function
% f must return two arguments, the value of the function, and it's
% partial derivatives wrt the elements of X. The partial derivative
% must have the same type as X.
% length length of the run; if it is positive, it gives the maximum number of
% line searches, if negative its absolute gives the maximum allowed
% number of function evaluations. Optionally, length can have a second
% component, which will indicate the reduction in function value to be
% expected in the first line-search (defaults to 1.0).
% P1, P2, ... parameters are passed to the function f.
%
% X the returned solution
% fX vector of function values indicating progress made
% i number of iterations (line searches or function evaluations,
% depending on the sign of "length") used at termination.
%
% The function returns when either its length is up, or if no further progress
% can be made (ie, we are at a (local) minimum, or so close that due to
% numerical problems, we cannot get any closer). NOTE: If the function
% terminates within a few iterations, it could be an indication that the
% function values and derivatives are not consistent (ie, there may be a bug in
% the implementation of your "f" function).
%
% The Polack-Ribiere flavour of conjugate gradients is used to compute search
% directions, and a line search using quadratic and cubic polynomial
% approximations and the Wolfe-Powell stopping criteria is used together with
% the slope ratio method for guessing initial step sizes. Additionally a bunch
% of checks are made to make sure that exploration is taking place and that
% extrapolation will not be unboundedly large.
%
% See also: checkgrad
%
% Copyright (C) 2001 - 2010 by Carl Edward Rasmussen, 2010-01-03
flag=0; %becomes 1 if algorithm halts due to hyperparams becoming too large
INT = 0.1; % don't reevaluate within 0.1 of the limit of the current bracket
EXT = 3.0; % extrapolate maximum 3 times the current step-size
MAX = 20; % max 20 function evaluations per line search
RATIO = 10; % maximum allowed slope ratio
SIG = 0.1; RHO = SIG/2; % SIG and RHO are the constants controlling the Wolfe-
% Powell conditions. SIG is the maximum allowed absolute ratio between
% previous and new slopes (derivatives in the search direction), thus setting
% SIG to low (positive) values forces higher precision in the line-searches.
% RHO is the minimum allowed fraction of the expected (from the slope at the
% initial point in the linesearch). Constants must satisfy 0 < RHO < SIG < 1.
% Tuning of SIG (depending on the nature of the function to be optimized) may
% speed up the minimization; it is probably not worth playing much with RHO.
% The code falls naturally into 3 parts, after the initial line search is
% started in the direction of steepest descent. 1) we first enter a while loop
% which uses point 1 (p1) and (p2) to compute an extrapolation (p3), until we
% have extrapolated far enough (Wolfe-Powell conditions). 2) if necessary, we
% enter the second loop which takes p2, p3 and p4 chooses the subinterval
% containing a (local) minimum, and interpolates it, unil an acceptable point
% is found (Wolfe-Powell conditions). Note, that points are always maintained
% in order p0 <= p1 <= p2 < p3 < p4. 3) compute a new search direction using
% conjugate gradients (Polack-Ribiere flavour), or revert to steepest if there
% was a problem in the previous line-search. Return the best value so far, if
% two consecutive line-searches fail, or whenever we run out of function
% evaluations or line-searches. During extrapolation, the "f" function may fail
% either with an error or returning Nan or Inf, and minimize should handle this
% gracefully.
if max(size(length)) == 2, red=length(2); length=length(1); else red=1; end
if length>0, S='Linesearch'; else S='Function evaluation'; end
i = 0; % zero the run length counter
ls_failed = 0; % no previous line search has failed
[f0 df0] = feval(f, X, varargin{:}); % get function value and gradient
Z = X; X = unwrap(X); df0 = unwrap(df0);
fprintf('%s %6i; Value %4.6e\r', S, i, f0);
if exist('fflush','builtin') fflush(stdout); end
fX = f0;
i = i + (length<0); % count epochs?!
s = -df0; d0 = -s'*s; % initial search direction (steepest) and slope
x3 = red/(1-d0); % initial step is red/(|s|+1)
while i < abs(length) % while not finished
i = i + (length>0); % count iterations?!
X0 = X; F0 = f0; dF0 = df0; % make a copy of current values
if length>0, M = MAX; else M = min(MAX, -length-i); end
while 1 % keep extrapolating as long as necessary
x2 = 0; f2 = f0; d2 = d0; f3 = f0; df3 = df0;
success = 0;
while ~success && M > 0
try
M = M - 1; i = i + (length<0); % count epochs?!
[f3 df3] = feval(f, rewrap(Z,X+x3*s), varargin{:});
df3 = unwrap(df3);
if isnan(f3) || isinf(f3) || any(isnan(df3)+isinf(df3)), error(' '),end
success = 1;
catch % catch any error which occured in f
x3 = (x2+x3)/2; % bisect and try again
end
end
if f3 < F0, X0 = X+x3*s; F0 = f3; dF0 = df3; end % keep best values
if isnan(df3)
flag=1;
break
end
d3 = df3'*s; % new slope
if (d3 > SIG*d0) || (f3 > f0+x3*RHO*d0) || (M == 0) % are we done extrapolating?
break
end
x1 = x2; f1 = f2; d1 = d2; % move point 2 to point 1
x2 = x3; f2 = f3; d2 = d3; % move point 3 to point 2
A = 6*(f1-f2)+3*(d2+d1)*(x2-x1); % make cubic extrapolation
B = 3*(f2-f1)-(2*d1+d2)*(x2-x1);
x3 = x1-d1*(x2-x1)^2/(B+sqrt(B*B-A*d1*(x2-x1))); % num. error possible, ok!
if ~isreal(x3) || isnan(x3) || isinf(x3) || x3 < 0 % num prob | wrong sign?
x3 = x2*EXT; % extrapolate maximum amount
elseif x3 > x2*EXT % new point beyond extrapolation limit?
x3 = x2*EXT; % extrapolate maximum amount
elseif x3 < x2+INT*(x2-x1) % new point too close to previous point?
x3 = x2+INT*(x2-x1);
end
end % end extrapolation
if flag==1
break
end
while (abs(d3) > -SIG*d0 || f3 > f0+x3*RHO*d0) && M > 0 % keep interpolating
if d3 > 0 || f3 > f0+x3*RHO*d0 % choose subinterval
x4 = x3; f4 = f3; d4 = d3; % move point 3 to point 4
else
x2 = x3; f2 = f3; d2 = d3; % move point 3 to point 2
end
if f4 > f0
x3 = x2-(0.5*d2*(x4-x2)^2)/(f4-f2-d2*(x4-x2)); % quadratic interpolation
else
A = 6*(f2-f4)/(x4-x2)+3*(d4+d2); % cubic interpolation
B = 3*(f4-f2)-(2*d2+d4)*(x4-x2);
x3 = x2+(sqrt(B*B-A*d2*(x4-x2)^2)-B)/A; % num. error possible, ok!
end
if isnan(x3) || isinf(x3)
x3 = (x2+x4)/2; % if we had a numerical problem then bisect
end
x3 = max(min(x3, x4-INT*(x4-x2)),x2+INT*(x4-x2)); % don't accept too close
[f3 df3] = feval(f, rewrap(Z,X+x3*s), varargin{:});
if isnan(f3)
flag=1;
break
end
df3 = unwrap(df3);
if f3 < F0, X0 = X+x3*s; F0 = f3; dF0 = df3; end % keep best values
M = M - 1; i = i + (length<0); % count epochs?!
d3 = df3'*s; % new slope
end % end interpolation
if flag==1
break
end
if abs(d3) < -SIG*d0 && f3 < f0+x3*RHO*d0 % if line search succeeded
X = X+x3*s; f0 = f3; fX = [fX' f0]'; % update variables
fprintf('%s %6i; Value %4.6e\r', S, i, f0);
if exist('fflush','builtin') fflush(stdout); end
s = (df3'*df3-df0'*df3)/(df0'*df0)*s - df3; % Polack-Ribiere CG direction
df0 = df3; % swap derivatives
d3 = d0; d0 = df0'*s;
if d0 > 0 % new slope must be negative
s = -df0; d0 = -s'*s; % otherwise use steepest direction
end
x3 = x3 * min(RATIO, d3/(d0-realmin)); % slope ratio but max RATIO
ls_failed = 0; % this line search did not fail
else
X = X0; f0 = F0; df0 = dF0; % restore best point so far
if ls_failed || i > abs(length) % line search failed twice in a row
break; % or we ran out of time, so we give up
end
s = -df0; d0 = -s'*s; % try steepest
x3 = 1/(1-d0);
ls_failed = 1; % this line search failed
end
end
if flag, warning('Optimisation halted due to hyperparams becoming extreme. Try different initialisations'); end
X = rewrap(Z,X);
fprintf('\n'); if exist('fflush','builtin') fflush(stdout); end
function v = unwrap(s)
% Extract the numerical values from "s" into the column vector "v". The
% variable "s" can be of any type, including struct and cell array.
% Non-numerical elements are ignored. See also the reverse rewrap.m.
v = [];
if isnumeric(s)
v = s(:); % numeric values are recast to column vector
elseif isstruct(s)
v = unwrap(struct2cell(orderfields(s))); % alphabetize, conv to cell, recurse
elseif iscell(s)
for i = 1:numel(s) % cell array elements are handled sequentially
v = [v; unwrap(s{i})];
end
end % other types are ignored
function [s v] = rewrap(s, v)
% Map the numerical elements in the vector "v" onto the variables "s" which can
% be of any type. The number of numerical elements must match; on exit "v"
% should be empty. Non-numerical entries are just copied. See also unwrap.m.
if isnumeric(s)
if numel(v) < numel(s)
error('The vector for conversion contains too few elements')
end
s = reshape(v(1:numel(s)), size(s)); % numeric values are reshaped
v = v(numel(s)+1:end); % remaining arguments passed on
elseif isstruct(s)
[s p] = orderfields(s); p(p) = 1:numel(p); % alphabetize, store ordering
[t v] = rewrap(struct2cell(s), v); % convert to cell, recurse
s = orderfields(cell2struct(t,fieldnames(s),1),p); % conv to struct, reorder
elseif iscell(s)
for i = 1:numel(s) % cell array elements are handled sequentially
[s{i} v] = rewrap(s{i}, v);
end
end % other types are not processed