diff --git a/docs/api.md b/docs/api.md index 77d07db91..2d5b590a2 100644 --- a/docs/api.md +++ b/docs/api.md @@ -690,6 +690,8 @@ iterations | int | yes | N/A | Max number of solver' snapshot | int | yes | N/A | Iterations between model snapshots snapshot_prefix | string | yes | empty | Prefix to snapshot file, supports repository solver_type | string | yes | SGD | from "SGD", "ADAGRAD", "NESTEROV", "RMSPROP", "ADADELTA", "ADAM", "AMSGRAD", "RANGER", "RANGER_PLUS", "ADAMW", "SGDW", "AMSGRADW" (*W version for decoupled weight decay, RANGER_PLUS is ranger + adabelief + centralized_gradient) +clip | bool | yes | false (true if RANGER* selected) | clip gradients, implemented only in ranger +clip_value | real | yes | 5.0 | value for clipping gradients (used only by RANGER) rectified | bool | yes | false | rectified momentum variance ie https://arxiv.org/abs/1908.03265 valid for ADAM[W] and AMSGRAD[W] adabelief | bool | yes | false | adabelief mod for ADAM https://arxiv.org/abs/2010.07468 gradient_centralization | bool | yes | false | centralized gradient mod for ADAM ie https://arxiv.org/abs/2004.01461v2 diff --git a/src/backends/torch/optim/ranger.cc b/src/backends/torch/optim/ranger.cc index 5aaaa60a8..bd43fc2d4 100644 --- a/src/backends/torch/optim/ranger.cc +++ b/src/backends/torch/optim/ranger.cc @@ -47,7 +47,10 @@ namespace dd && (std::get<1>(lhs.betas()) == std::get<1>(rhs.betas())) && (lhs.eps() == rhs.eps()) && (lhs.weight_decay() == rhs.weight_decay()) + && (lhs.clip() == rhs.clip()) + && (lhs.clip_value() == rhs.clip_value()) && (lhs.rectified() == rhs.rectified()) + && (lhs.decoupled_wd() == rhs.decoupled_wd()) && (lhs.lookahead() == rhs.lookahead()) && (lhs.adabelief() == rhs.adabelief()) && (lhs.gradient_centralization() == rhs.gradient_centralization()) @@ -60,6 +63,8 @@ namespace dd _TORCH_OPTIM_SERIALIZE_TORCH_ARG(betas); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(eps); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(weight_decay); + _TORCH_OPTIM_SERIALIZE_TORCH_ARG(clip); + _TORCH_OPTIM_SERIALIZE_TORCH_ARG(clip_value); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(decoupled_wd); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(rectified); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(lookahead); @@ -75,6 +80,8 @@ namespace dd _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(betas_t, betas); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, eps); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, weight_decay); + _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, clip); + _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, clip_value); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, decoupled_wd); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, rectified); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, lookahead); @@ -126,8 +133,9 @@ namespace dd continue; } auto grad = p.grad(); - TORCH_CHECK( - !grad.is_sparse(), "Ranger does not support sparse gradients" /*, please consider SparseRanger instead*/); + + TORCH_CHECK(!grad.is_sparse(), + "Ranger does not support sparse gradients"); auto param_state = state_.find(c10::guts::to_string(p.unsafeGetTensorImpl())); auto &options = static_cast(group.options()); @@ -174,6 +182,9 @@ namespace dd grad.add_(-grad.mean(torch::IntArrayRef(dim), true)); } + if (options.clip()) + grad.clamp_(-options.clip_value(), options.clip_value()); + exp_avg.mul_(beta1).add_(grad, 1 - beta1); // m_t if (options.adabelief()) diff --git a/src/backends/torch/optim/ranger.h b/src/backends/torch/optim/ranger.h index 40337caf9..08f005d8d 100644 --- a/src/backends/torch/optim/ranger.h +++ b/src/backends/torch/optim/ranger.h @@ -50,6 +50,8 @@ namespace dd TORCH_ARG(betas_t, betas) = std::make_tuple(0.9, 0.999); TORCH_ARG(double, eps) = 1e-8; TORCH_ARG(double, weight_decay) = 0.0; + TORCH_ARG(bool, clip) = false; + TORCH_ARG(double, clip_value) = 5.0; TORCH_ARG(bool, decoupled_wd) = false; TORCH_ARG(bool, rectified) = true; TORCH_ARG(bool, lookahead) = true; @@ -103,6 +105,8 @@ namespace dd "Invalid beta parameter at index 1: ", std::get<1>(betas)); TORCH_CHECK(defaults.weight_decay() >= 0, "Invalid weight_decay value: ", defaults.weight_decay()); + TORCH_CHECK(!defaults.clip() || defaults.clip_value() >= 0, + "Invalid clip value: ", defaults.clip_value()); TORCH_CHECK(defaults.lsteps() >= 0, "Invalid lookahead steps: ", defaults.lsteps()); TORCH_CHECK(defaults.lalpha() >= 0, diff --git a/src/backends/torch/torchsolver.cc b/src/backends/torch/torchsolver.cc index 1e60e49a7..43d2a290c 100644 --- a/src/backends/torch/torchsolver.cc +++ b/src/backends/torch/torchsolver.cc @@ -29,6 +29,9 @@ namespace dd if (ad_solver.has("solver_type")) _solver_type = ad_solver.get("solver_type").get(); + if (_solver_type == "RANGER" || _solver_type == "RANGER_PLUS") + _clip = true; + if (_solver_type == "RANGER_PLUS") { _adabelief = true; @@ -41,6 +44,10 @@ namespace dd _beta1 = ad_solver.get("beta1").get(); if (ad_solver.has("beta")) _beta2 = ad_solver.get("beta2").get(); + if (ad_solver.has("clip")) + _clip = ad_solver.get("clip").get(); + if (ad_solver.has("clip_value")) + _clip_value = ad_solver.get("clip_value").get(); if (ad_solver.has("rectified")) _rectified = ad_solver.get("rectified").get(); if (ad_solver.has("lookahead")) @@ -102,10 +109,15 @@ namespace dd .adabelief(_adabelief) .gradient_centralization(_gc) .lsteps(_lsteps) - .lalpha(_lalpha))); + .lalpha(_lalpha) + .clip(_clip) + .clip_value(_clip_value))); this->_logger->info("base_lr: {}", _base_lr); this->_logger->info("beta_1: {}", _beta1); this->_logger->info("beta_2: {}", _beta2); + this->_logger->info("clip: {}", _clip); + if (_clip) + this->_logger->info("clip_value: {}", _clip_value); this->_logger->info("weight_decay: {}", _weight_decay); this->_logger->info("rectified: {}", _rectified); this->_logger->info("lookahead: {}", _lookahead); diff --git a/src/backends/torch/torchsolver.h b/src/backends/torch/torchsolver.h index fcf357629..a0e9ea72a 100644 --- a/src/backends/torch/torchsolver.h +++ b/src/backends/torch/torchsolver.h @@ -105,6 +105,8 @@ namespace dd int _lsteps = 5; /**< for RANGER, if lookahead: number of lookahead steps */ double _lalpha = 0.5; /**< for RANGER, if lookahead: weight of lookahead */ + bool _clip = false; /**< for RANGER , clip gradients */ + double _clip_value = 5.0; /**< for RANGER, value to clip gradients to */ double _weight_decay = 0.0; /**< weight decay value*/ bool _decoupled_wd = false; /**< for RANGER : use decoupled weight decay, NOT YET IMPLEMENTED */