From e05077753c9e523ab526003dde9ac4523d4b5ece Mon Sep 17 00:00:00 2001 From: Lht97 Date: Thu, 21 Nov 2024 15:58:46 +0800 Subject: [PATCH] Introduce alpha_threshold into the framework --- src/bds.m | 17 ++++++++++++++++- src/private/get_default_constant.m | 2 ++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/bds.m b/src/bds.m index 7d10a306..48cf0d25 100644 --- a/src/bds.m +++ b/src/bds.m @@ -32,6 +32,10 @@ % Default: 2. % shrink Shrinking factor of step size. A positive number less than 1. % Default: 0.5. +% alpha_threshold The threshold of the step size. If the step size is smaller than +% alpha_threshold, then the step size will be not allowed to shrink below +% alpha_threshold. It should be strictly less than StepTolerance. +% A positive number. Default: eps. % forcing_function The forcing function used for deciding whether the step achieves % a sufficient decrease. A function handle. % Default: @(alpha) alpha^2. See also reduction_factor. @@ -269,6 +273,14 @@ alpha_tol = get_default_constant("StepTolerance"); end +% Set the value of alpha_threshold. If the step size is smaller than alpha_threshold, then the step size +% will be not allowed to shrink below alpha_threshold. +if isfield(options, "alpha_threshold") + alpha_threshold = options.alpha_threshold; +else + alpha_threshold = get_default_constant("alpha_threshold"); +end + % Set the target of the objective function. if isfield(options, "ftarget") ftarget = options.ftarget; @@ -493,7 +505,10 @@ if sub_fopt + reduction_factor(3) * forcing_function(alpha_all(i_real)) < fbase alpha_all(i_real) = expand * alpha_all(i_real); elseif sub_fopt + reduction_factor(2) * forcing_function(alpha_all(i_real)) >= fbase - alpha_all(i_real) = shrink * alpha_all(i_real); + alpha_all(i_real) = max(shrink * alpha_all(i_real), alpha_threshold); + if shrink * alpha_all(i_real) < alpha_threshold + fprintf("The step size of the block %d is smaller than alpha_threshold.\n", i_real); + end end % Record the best function value and point encountered in the i_real-th block. diff --git a/src/private/get_default_constant.m b/src/private/get_default_constant.m index e4665adf..6abc31a9 100644 --- a/src/private/get_default_constant.m +++ b/src/private/get_default_constant.m @@ -17,6 +17,8 @@ constant_value = @(alpha) alpha^2; case {"alpha_init"} constant_value = 1; + case {"alpha_threshold"} + constant_value = eps; case {"StepTolerance"} constant_value = 1e-6; case {"permuting_period"}