From df034486e46e3b8347d75bca035f4089ad7f74d2 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 2 Apr 2020 03:27:47 -0700 Subject: [PATCH 1/8] Robust regularization of AFT gradient and hessian --- src/common/survival_util.cc | 150 +++++++++++++++++++++---- src/common/survival_util.h | 9 +- tests/cpp/common/test_survival_util.cc | 38 +++++++ tests/cpp/objective/test_aft_obj.cc | 15 ++- 4 files changed, 181 insertions(+), 31 deletions(-) create mode 100644 tests/cpp/common/test_survival_util.cc diff --git a/src/common/survival_util.cc b/src/common/survival_util.cc index 58c5a7946af7..be4ca98c5164 100644 --- a/src/common/survival_util.cc +++ b/src/common/survival_util.cc @@ -18,6 +18,109 @@ https://github.com/avinashbarnwal/GSOC-2019/blob/master/doc/Accelerated_Failure_Time.pdf */ +namespace { + +// Allowable range for gradient and hessian. Used for regularization +constexpr double kMinGradient = -15.0; +constexpr double kMaxGradient = 15.0; +constexpr double kMinHessian = 1e-16; // Ensure that no data point gets zero hessian +constexpr double kMaxHessian = 15.0; + +constexpr double kEps = 1e-12; // A denomitor in a fraction should not be too small + +// Clip (limit) x to fit range [x_min, x_max]. +// If x < x_min, return x_min; if x > x_max, return x_max; if x_min <= x <= x_max, return x. +// This function assumes x_min < x_max; behavior is undefined if this assumption does not hold. +inline double Clip(double x, double x_min, double x_max) { + if (x < x_min) { + return x_min; + } + if (x > x_max) { + return x_max; + } + return x; +} + +using xgboost::common::ProbabilityDistributionType; + +enum class CensoringType : uint8_t { + kUncensored, kRightCensored, kLeftCensored, kIntervalCensored +}; + +struct GradHessPair { + double gradient; + double hessian; +}; + +inline GradHessPair GetLimitAtInfPred(ProbabilityDistributionType dist_type, + CensoringType censor_type, + double sign, double sigma) { + switch (censor_type) { + case CensoringType::kUncensored: + switch (dist_type) { + case ProbabilityDistributionType::kNormal: + return sign ? GradHessPair{ kMinGradient, 1.0 / (sigma * sigma) } + : GradHessPair{ kMaxGradient, 1.0 / (sigma * sigma) }; + case ProbabilityDistributionType::kLogistic: + return sign ? GradHessPair{ -1.0 / sigma, kMinHessian } + : GradHessPair{ 1.0 / sigma, kMinHessian }; + case ProbabilityDistributionType::kExtreme: + return sign ? GradHessPair{ kMinGradient, kMaxHessian } + : GradHessPair{ 1.0 / sigma, kMinHessian }; + default: + LOG(FATAL) << "Unknown distribution type"; + } + case CensoringType::kRightCensored: + switch (dist_type) { + case ProbabilityDistributionType::kNormal: + return sign ? GradHessPair{ kMinGradient, 1.0 / (sigma * sigma) } + : GradHessPair{ 0.0, kMinHessian }; + case ProbabilityDistributionType::kLogistic: + return sign ? GradHessPair{ -1.0 / sigma, kMinHessian } + : GradHessPair{ 0.0, kMinHessian }; + case ProbabilityDistributionType::kExtreme: + return sign ? GradHessPair{ kMinGradient, kMaxHessian } + : GradHessPair{ 0.0, kMinHessian }; + default: + LOG(FATAL) << "Unknown distribution type"; + } + case CensoringType::kLeftCensored: + switch (dist_type) { + case ProbabilityDistributionType::kNormal: + return sign ? GradHessPair{ 0.0, kMinHessian } + : GradHessPair{ kMaxGradient, 1.0 / (sigma * sigma) }; + case ProbabilityDistributionType::kLogistic: + return sign ? GradHessPair{ 0.0, kMinHessian } + : GradHessPair{ 1.0 / sigma, kMinHessian }; + case ProbabilityDistributionType::kExtreme: + return sign ? GradHessPair{ 0.0, kMinHessian } + : GradHessPair{ 1.0 / sigma, kMinHessian }; + default: + LOG(FATAL) << "Unknown distribution type"; + } + case CensoringType::kIntervalCensored: + switch (dist_type) { + case ProbabilityDistributionType::kNormal: + return sign ? GradHessPair{ kMinGradient, 1.0 / (sigma * sigma) } + : GradHessPair{ kMaxGradient, 1.0 / (sigma * sigma) }; + case ProbabilityDistributionType::kLogistic: + return sign ? GradHessPair{ -1.0 / sigma, kMinHessian } + : GradHessPair{ 1.0 / sigma, kMinHessian }; + case ProbabilityDistributionType::kExtreme: + return sign ? GradHessPair{ kMinGradient, kMaxHessian } + : GradHessPair{ 1.0 / sigma, kMinHessian }; + default: + LOG(FATAL) << "Unknown distribution type"; + } + default: + LOG(FATAL) << "Unknown censoring type"; + } + + return { 0.0, 0.0 }; +} + +} // anonymous namespace + namespace xgboost { namespace common { @@ -26,14 +129,14 @@ DMLC_REGISTER_PARAMETER(AFTParam); double AFTLoss::Loss(double y_lower, double y_upper, double y_pred, double sigma) { const double log_y_lower = std::log(y_lower); const double log_y_upper = std::log(y_upper); - const double eps = 1e-12; + double cost; if (y_lower == y_upper) { // uncensored const double z = (log_y_lower - y_pred) / sigma; const double pdf = dist_->PDF(z); // Regularize the denominator with eps, to avoid INF or NAN - cost = -std::log(std::max(pdf / (sigma * y_lower), eps)); + cost = -std::log(std::max(pdf / (sigma * y_lower), kEps)); } else { // censored; now check what type of censorship we have double z_u, z_l, cdf_u, cdf_l; if (std::isinf(y_upper)) { // right-censored @@ -49,7 +152,7 @@ double AFTLoss::Loss(double y_lower, double y_upper, double y_pred, double sigma cdf_l = dist_->CDF(z_l); } // Regularize the denominator with eps, to avoid INF or NAN - cost = -std::log(std::max(cdf_u - cdf_l, eps)); + cost = -std::log(std::max(cdf_u - cdf_l, kEps)); } return cost; @@ -59,19 +162,19 @@ double AFTLoss::Gradient(double y_lower, double y_upper, double y_pred, double s const double log_y_lower = std::log(y_lower); const double log_y_upper = std::log(y_upper); double gradient; - const double eps = 1e-12; if (y_lower == y_upper) { // uncensored const double z = (log_y_lower - y_pred) / sigma; const double pdf = dist_->PDF(z); const double grad_pdf = dist_->GradPDF(z); - // Regularize the denominator with eps, so that gradient doesn't get too big - gradient = grad_pdf / (sigma * std::max(pdf, eps)); + gradient = grad_pdf / (sigma * pdf); } else { // censored; now check what type of censorship we have - double z_u, z_l, pdf_u, pdf_l, cdf_u, cdf_l; + double z_u = 0.0, z_l = 0.0, pdf_u, pdf_l, cdf_u, cdf_l; + CensoringType censor_type = CensoringType::kIntervalCensored; if (std::isinf(y_upper)) { // right-censored pdf_u = 0; cdf_u = 1; + censor_type = CensoringType::kRightCensored; } else { // interval-censored or left-censored z_u = (log_y_upper - y_pred) / sigma; pdf_u = dist_->PDF(z_u); @@ -80,22 +183,27 @@ double AFTLoss::Gradient(double y_lower, double y_upper, double y_pred, double s if (std::isinf(y_lower)) { // left-censored pdf_l = 0; cdf_l = 0; + censor_type = CensoringType::kLeftCensored; } else { // interval-censored or right-censored z_l = (log_y_lower - y_pred) / sigma; pdf_l = dist_->PDF(z_l); cdf_l = dist_->CDF(z_l); } - // Regularize the denominator with eps, so that gradient doesn't get too big - gradient = (pdf_u - pdf_l) / (sigma * std::max(cdf_u - cdf_l, eps)); + + const double numerator = pdf_u - pdf_l; + const double denominator = sigma * (cdf_u - cdf_l); + gradient = numerator / denominator; + if (denominator < kEps && (std::isnan(gradient) || std::isinf(gradient))) { + gradient = GetLimitAtInfPred(dist_type_, censor_type, (z_u > 0 || z_l > 0), sigma).gradient; + } } - return gradient; + return Clip(gradient, kMinGradient, kMaxGradient); } double AFTLoss::Hessian(double y_lower, double y_upper, double y_pred, double sigma) { const double log_y_lower = std::log(y_lower); const double log_y_upper = std::log(y_upper); - const double eps = 1e-12; double hessian; if (y_lower == y_upper) { // uncensored @@ -103,15 +211,16 @@ double AFTLoss::Hessian(double y_lower, double y_upper, double y_pred, double si const double pdf = dist_->PDF(z); const double grad_pdf = dist_->GradPDF(z); const double hess_pdf = dist_->HessPDF(z); - // Regularize the denominator with eps, so that gradient doesn't get too big - hessian = -(pdf * hess_pdf - std::pow(grad_pdf, 2)) - / (std::pow(sigma, 2) * std::pow(std::max(pdf, eps), 2)); + hessian = -(pdf * hess_pdf - grad_pdf * grad_pdf) + / (sigma * sigma * pdf * pdf); } else { // censored; now check what type of censorship we have - double z_u, z_l, grad_pdf_u, grad_pdf_l, pdf_u, pdf_l, cdf_u, cdf_l; + double z_u = 0.0, z_l = 0.0, grad_pdf_u, grad_pdf_l, pdf_u, pdf_l, cdf_u, cdf_l; + CensoringType censor_type = CensoringType::kIntervalCensored; if (std::isinf(y_upper)) { // right-censored pdf_u = 0; cdf_u = 1; grad_pdf_u = 0; + censor_type = CensoringType::kRightCensored; } else { // interval-censored or left-censored z_u = (log_y_upper - y_pred) / sigma; pdf_u = dist_->PDF(z_u); @@ -122,6 +231,7 @@ double AFTLoss::Hessian(double y_lower, double y_upper, double y_pred, double si pdf_l = 0; cdf_l = 0; grad_pdf_l = 0; + censor_type = CensoringType::kLeftCensored; } else { // interval-censored or right-censored z_l = (log_y_lower - y_pred) / sigma; pdf_l = dist_->PDF(z_l); @@ -131,15 +241,17 @@ double AFTLoss::Hessian(double y_lower, double y_upper, double y_pred, double si const double cdf_diff = cdf_u - cdf_l; const double pdf_diff = pdf_u - pdf_l; const double grad_diff = grad_pdf_u - grad_pdf_l; - // Regularize the denominator with eps, so that gradient doesn't get too big - const double cdf_diff_thresh = std::max(cdf_diff, eps); const double numerator = -(cdf_diff * grad_diff - pdf_diff * pdf_diff); - const double sqrt_denominator = sigma * cdf_diff_thresh; + const double sqrt_denominator = sigma * cdf_diff; const double denominator = sqrt_denominator * sqrt_denominator; + hessian = numerator / denominator; + if (denominator < kEps && (std::isnan(hessian) || std::isinf(hessian))) { + hessian = GetLimitAtInfPred(dist_type_, censor_type, (z_u > 0 || z_l > 0), sigma).hessian; + } } - return hessian; + return Clip(hessian, kMinHessian, kMaxHessian); } } // namespace common diff --git a/src/common/survival_util.h b/src/common/survival_util.h index baae99b34e00..feb582db5e2d 100644 --- a/src/common/survival_util.h +++ b/src/common/survival_util.h @@ -42,15 +42,16 @@ struct AFTParam : public XGBoostParameter { class AFTLoss { private: std::unique_ptr dist_; + ProbabilityDistributionType dist_type_; public: /*! * \brief Constructor for AFT loss function - * \param dist Choice of probability distribution for the noise term in AFT + * \param dist_type Choice of probability distribution for the noise term in AFT */ - explicit AFTLoss(ProbabilityDistributionType dist) { - dist_.reset(ProbabilityDistribution::Create(dist)); - } + explicit AFTLoss(ProbabilityDistributionType dist_type) + : dist_(ProbabilityDistribution::Create(dist_type)), + dist_type_(dist_type) {} public: /*! diff --git a/tests/cpp/common/test_survival_util.cc b/tests/cpp/common/test_survival_util.cc new file mode 100644 index 000000000000..d19d93c9a5f8 --- /dev/null +++ b/tests/cpp/common/test_survival_util.cc @@ -0,0 +1,38 @@ +/*! + * Copyright (c) by Contributors 2020 + */ +#include + +#include "../../../src/common/survival_util.h" + +namespace xgboost { +namespace common { + +TEST(AFTLoss, RobustGradientPair) { // Ensure that INF and NAN don't show up + const double y_lower = 16.0; + const double y_upper = 200.0; + const double sigma = 2.0; + + for (auto dist_type : { ProbabilityDistributionType::kNormal, + ProbabilityDistributionType::kLogistic, + ProbabilityDistributionType::kExtreme }) { + AFTLoss loss(dist_type); + for (int i = 50; i >= -50; --i) { + const double y_pred = std::pow(10.0, static_cast(i)); + const double z = (std::log(y_lower) - std::log(y_pred)) / sigma; + const double gradient = loss.Gradient(y_lower, y_upper, std::log(y_pred), sigma); + const double hessian = loss.Hessian(y_lower, y_upper, std::log(y_pred), sigma); + ASSERT_FALSE(std::isnan(gradient)) << "z = " << z << ", y_pred = " << y_pred + << ", dist = " << static_cast(dist_type); + ASSERT_FALSE(std::isinf(gradient)) << "z = " << z << ", y_pred = " << y_pred + << ", dist = " << static_cast(dist_type); + ASSERT_FALSE(std::isnan(hessian)) << "z = " << z << ", y_pred = " << y_pred + << ", dist = " << static_cast(dist_type); + ASSERT_FALSE(std::isinf(hessian)) << "z = " << z << ", y_pred = " << y_pred + << ", dist = " << static_cast(dist_type); + } + } +} + +} // namespace common +} // namespace xgboost diff --git a/tests/cpp/objective/test_aft_obj.cc b/tests/cpp/objective/test_aft_obj.cc index 01e965df8d15..36f80756a2c1 100644 --- a/tests/cpp/objective/test_aft_obj.cc +++ b/tests/cpp/objective/test_aft_obj.cc @@ -93,10 +93,10 @@ TEST(Objective, AFTObjGPairUncensoredLabels) { { 0.0384f, 0.0624f, 0.0997f, 0.1551f, 0.2316f, 0.3254f, 0.4200f, 0.4861f, 0.4962f, 0.4457f, 0.3567f, 0.2601f, 0.1772f, 0.1152f, 0.0726f, 0.0449f, 0.0275f, 0.0167f, 0.0101f, 0.0061f }); CheckGPairOverGridPoints(obj.get(), 100.0f, 100.0f, "extreme", - { -0.0000f, -29.0026f, -17.0031f, -9.8028f, -5.4822f, -2.8897f, -1.3340f, -0.4005f, 0.1596f, + { -15.0000f, -15.0000f, -15.0000f, -9.8028f, -5.4822f, -2.8897f, -1.3340f, -0.4005f, 0.1596f, 0.4957f, 0.6974f, 0.8184f, 0.8910f, 0.9346f, 0.9608f, 0.9765f, 0.9859f, 0.9915f, 0.9949f, 0.9969f }, - { 0.0000f, 30.0026f, 18.0031f, 10.8028f, 6.4822f, 3.8897f, 2.3340f, 1.4005f, 0.8404f, 0.5043f, + { 15.0000f, 15.0000f, 15.0000f, 10.8028f, 6.4822f, 3.8897f, 2.3340f, 1.4005f, 0.8404f, 0.5043f, 0.3026f, 0.1816f, 0.1090f, 0.0654f, 0.0392f, 0.0235f, 0.0141f, 0.0085f, 0.0051f, 0.0031f }); } @@ -106,10 +106,9 @@ TEST(Objective, AFTObjGPairLeftCensoredLabels) { CheckGPairOverGridPoints(obj.get(), -std::numeric_limits::infinity(), 20.0f, "normal", { 0.0285f, 0.0832f, 0.1951f, 0.3804f, 0.6403f, 0.9643f, 1.3379f, 1.7475f, 2.1828f, 2.6361f, - 3.1023f, 3.5779f, 4.0603f, 4.5479f, 5.0394f, 5.5340f, 6.0309f, 6.5298f, 7.0303f, 0.5072f }, + 3.1023f, 3.5779f, 4.0603f, 4.5479f, 5.0394f, 5.5340f, 6.0309f, 6.5298f, 7.0303f, 7.5326f }, { 0.0663f, 0.1559f, 0.2881f, 0.4378f, 0.5762f, 0.6878f, 0.7707f, 0.8300f, 0.8719f, 0.9016f, - 0.9229f, 0.9385f, 0.9501f, 0.9588f, 0.9656f, 0.9709f, 0.9751f, 0.9785f, 0.9812f, 0.0045f }, - 2e-4); + 0.9229f, 0.9385f, 0.9501f, 0.9588f, 0.9656f, 0.9709f, 0.9751f, 0.9785f, 0.9813f, 0.9877f }); CheckGPairOverGridPoints(obj.get(), -std::numeric_limits::infinity(), 20.0f, "logistic", { 0.0909f, 0.1428f, 0.2174f, 0.3164f, 0.4355f, 0.5625f, 0.6818f, 0.7812f, 0.8561f, 0.9084f, 0.9429f, 0.9650f, 0.9787f, 0.9871f, 0.9922f, 0.9953f, 0.9972f, 0.9983f, 0.9990f, 0.9994f }, @@ -139,10 +138,10 @@ TEST(Objective, AFTObjGPairRightCensoredLabels) { { 0.0312f, 0.0499f, 0.0776f, 0.1158f, 0.1627f, 0.2100f, 0.2430f, 0.2481f, 0.2228f, 0.1783f, 0.1300f, 0.0886f, 0.0576f, 0.0363f, 0.0225f, 0.0137f, 0.0083f, 0.0050f, 0.0030f, 0.0018f }); CheckGPairOverGridPoints(obj.get(), 60.0f, std::numeric_limits::infinity(), "extreme", - { -2.8073f, -18.0015f, -10.8018f, -6.4817f, -3.8893f, -2.3338f, -1.4004f, -0.8403f, -0.5042f, + { -15.0000f, -15.0000f, -10.8018f, -6.4817f, -3.8893f, -2.3338f, -1.4004f, -0.8403f, -0.5042f, -0.3026f, -0.1816f, -0.1089f, -0.0654f, -0.0392f, -0.0235f, -0.0141f, -0.0085f, -0.0051f, - -0.0031f, -0.0018f }, - { 0.2614f, 18.0015f, 10.8018f, 6.4817f, 3.8893f, 2.3338f, 1.4004f, 0.8403f, 0.5042f, 0.3026f, + -0.0031f, -0.0018f }, + { 15.0000f, 15.0000f, 10.8018f, 6.4817f, 3.8893f, 2.3338f, 1.4004f, 0.8403f, 0.5042f, 0.3026f, 0.1816f, 0.1089f, 0.0654f, 0.0392f, 0.0235f, 0.0141f, 0.0085f, 0.0051f, 0.0031f, 0.0018f }); } From 110498a6e0f0de4d0882860a4e35c005a17549ae Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 2 Apr 2020 03:37:46 -0700 Subject: [PATCH 2/8] Fix AFT doc; expose it to tutorial TOC --- demo/aft_survival/aft_survival_demo.py | 2 +- demo/aft_survival/aft_survival_demo_with_optuna.py | 2 +- doc/tutorials/aft_survival_analysis.rst | 2 +- doc/tutorials/index.rst | 1 + 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/demo/aft_survival/aft_survival_demo.py b/demo/aft_survival/aft_survival_demo.py index 6b8181cf1060..3cdccc1c21af 100644 --- a/demo/aft_survival/aft_survival_demo.py +++ b/demo/aft_survival/aft_survival_demo.py @@ -51,4 +51,4 @@ print(df[np.isinf(df['Label (upper bound)'])]) # Save trained model -bst.save_model('aft_model.json') \ No newline at end of file +bst.save_model('aft_model.json') diff --git a/demo/aft_survival/aft_survival_demo_with_optuna.py b/demo/aft_survival/aft_survival_demo_with_optuna.py index 998afc4816b6..117be8ba1be2 100644 --- a/demo/aft_survival/aft_survival_demo_with_optuna.py +++ b/demo/aft_survival/aft_survival_demo_with_optuna.py @@ -75,4 +75,4 @@ def objective(trial): print(df[np.isinf(df['Label (upper bound)'])]) # Save trained model -bst.save_model('aft_best_model.json') \ No newline at end of file +bst.save_model('aft_best_model.json') diff --git a/doc/tutorials/aft_survival_analysis.rst b/doc/tutorials/aft_survival_analysis.rst index 4f06ce54c331..237a5392f4ac 100644 --- a/doc/tutorials/aft_survival_analysis.rst +++ b/doc/tutorials/aft_survival_analysis.rst @@ -68,7 +68,7 @@ Note that this model is a generalized form of a linear regression model :math:`Y \ln{Y} = \mathcal{T}(\mathbf{x}) + \sigma Z -where :math:`\mathcal{T}(\mathbf{x})` represents the output from a decision tree ensemble, given input :math:`\mathbf{x}`. Since :math:`Z` is a random variable, we have a likelihood defined for the expression :math:`\ln{Y} = \mathcal{T}(\mathbf{x}) + \sigma Z`. So the goal for XGBoost is to maximize the (log) likelihood by fitting a good tree ensemble :math:`\mathbf{x}`. +where :math:`\mathcal{T}(\mathbf{x})` represents the output from a decision tree ensemble, given input :math:`\mathbf{x}`. Since :math:`Z` is a random variable, we have a likelihood defined for the expression :math:`\ln{Y} = \mathcal{T}(\mathbf{x}) + \sigma Z`. So the goal for XGBoost is to maximize the (log) likelihood by fitting a good tree ensemble :math:`\mathcal{T}(\mathbf{x})`. ********** How to use diff --git a/doc/tutorials/index.rst b/doc/tutorials/index.rst index 0334385a605c..bcc34284a4d9 100644 --- a/doc/tutorials/index.rst +++ b/doc/tutorials/index.rst @@ -18,6 +18,7 @@ See `Awesome XGBoost `_ for mo monotonic rf feature_interaction_constraint + aft_survival_analysis input_format param_tuning external_memory From 7178f7233f7f2bf08851998a867e0808e3997fd7 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 3 Apr 2020 02:05:55 -0700 Subject: [PATCH 3/8] Apply robust regularization to uncensored case too --- src/common/survival_util.cc | 51 +++++++++++++++----------- tests/cpp/common/test_survival_util.cc | 50 +++++++++++++++---------- 2 files changed, 60 insertions(+), 41 deletions(-) diff --git a/src/common/survival_util.cc b/src/common/survival_util.cc index be4ca98c5164..278f2b5af98e 100644 --- a/src/common/survival_util.cc +++ b/src/common/survival_util.cc @@ -161,16 +161,21 @@ double AFTLoss::Loss(double y_lower, double y_upper, double y_pred, double sigma double AFTLoss::Gradient(double y_lower, double y_upper, double y_pred, double sigma) { const double log_y_lower = std::log(y_lower); const double log_y_upper = std::log(y_upper); - double gradient; + double numerator, denominator, gradient; // numerator and denominator of gradient + CensoringType censor_type; + bool z_sign; // sign of z-score if (y_lower == y_upper) { // uncensored const double z = (log_y_lower - y_pred) / sigma; const double pdf = dist_->PDF(z); const double grad_pdf = dist_->GradPDF(z); - gradient = grad_pdf / (sigma * pdf); + censor_type = CensoringType::kUncensored; + numerator = grad_pdf; + denominator = sigma * pdf; + z_sign = (z > 0); } else { // censored; now check what type of censorship we have double z_u = 0.0, z_l = 0.0, pdf_u, pdf_l, cdf_u, cdf_l; - CensoringType censor_type = CensoringType::kIntervalCensored; + censor_type = CensoringType::kIntervalCensored; if (std::isinf(y_upper)) { // right-censored pdf_u = 0; cdf_u = 1; @@ -189,13 +194,13 @@ double AFTLoss::Gradient(double y_lower, double y_upper, double y_pred, double s pdf_l = dist_->PDF(z_l); cdf_l = dist_->CDF(z_l); } - - const double numerator = pdf_u - pdf_l; - const double denominator = sigma * (cdf_u - cdf_l); - gradient = numerator / denominator; - if (denominator < kEps && (std::isnan(gradient) || std::isinf(gradient))) { - gradient = GetLimitAtInfPred(dist_type_, censor_type, (z_u > 0 || z_l > 0), sigma).gradient; - } + z_sign = (z_u > 0 || z_l > 0); + numerator = pdf_u - pdf_l; + denominator = sigma * (cdf_u - cdf_l); + } + gradient = numerator / denominator; + if (denominator < kEps && (std::isnan(gradient) || std::isinf(gradient))) { + gradient = GetLimitAtInfPred(dist_type_, censor_type, z_sign, sigma).gradient; } return Clip(gradient, kMinGradient, kMaxGradient); @@ -204,18 +209,22 @@ double AFTLoss::Gradient(double y_lower, double y_upper, double y_pred, double s double AFTLoss::Hessian(double y_lower, double y_upper, double y_pred, double sigma) { const double log_y_lower = std::log(y_lower); const double log_y_upper = std::log(y_upper); - double hessian; + double numerator, denominator, hessian; // numerator and denominator of hessian + CensoringType censor_type; + bool z_sign; // sign of z-score if (y_lower == y_upper) { // uncensored const double z = (log_y_lower - y_pred) / sigma; const double pdf = dist_->PDF(z); const double grad_pdf = dist_->GradPDF(z); const double hess_pdf = dist_->HessPDF(z); - hessian = -(pdf * hess_pdf - grad_pdf * grad_pdf) - / (sigma * sigma * pdf * pdf); + censor_type = CensoringType::kUncensored; + numerator = -(pdf * hess_pdf - grad_pdf * grad_pdf); + denominator = sigma * sigma * pdf * pdf; + z_sign = (z > 0); } else { // censored; now check what type of censorship we have double z_u = 0.0, z_l = 0.0, grad_pdf_u, grad_pdf_l, pdf_u, pdf_l, cdf_u, cdf_l; - CensoringType censor_type = CensoringType::kIntervalCensored; + censor_type = CensoringType::kIntervalCensored; if (std::isinf(y_upper)) { // right-censored pdf_u = 0; cdf_u = 1; @@ -241,14 +250,14 @@ double AFTLoss::Hessian(double y_lower, double y_upper, double y_pred, double si const double cdf_diff = cdf_u - cdf_l; const double pdf_diff = pdf_u - pdf_l; const double grad_diff = grad_pdf_u - grad_pdf_l; - const double numerator = -(cdf_diff * grad_diff - pdf_diff * pdf_diff); const double sqrt_denominator = sigma * cdf_diff; - const double denominator = sqrt_denominator * sqrt_denominator; - - hessian = numerator / denominator; - if (denominator < kEps && (std::isnan(hessian) || std::isinf(hessian))) { - hessian = GetLimitAtInfPred(dist_type_, censor_type, (z_u > 0 || z_l > 0), sigma).hessian; - } + z_sign = (z_u > 0 || z_l > 0); + numerator = -(cdf_diff * grad_diff - pdf_diff * pdf_diff); + denominator = sqrt_denominator * sqrt_denominator; + } + hessian = numerator / denominator; + if (denominator < kEps && (std::isnan(hessian) || std::isinf(hessian))) { + hessian = GetLimitAtInfPred(dist_type_, censor_type, z_sign, sigma).hessian; } return Clip(hessian, kMinHessian, kMaxHessian); diff --git a/tests/cpp/common/test_survival_util.cc b/tests/cpp/common/test_survival_util.cc index d19d93c9a5f8..29318ce8a429 100644 --- a/tests/cpp/common/test_survival_util.cc +++ b/tests/cpp/common/test_survival_util.cc @@ -8,30 +8,40 @@ namespace xgboost { namespace common { -TEST(AFTLoss, RobustGradientPair) { // Ensure that INF and NAN don't show up +inline static void RobustTestSuite(ProbabilityDistributionType dist_type, + double y_lower, double y_upper, double sigma) { + AFTLoss loss(dist_type); + for (int i = 50; i >= -50; --i) { + const double y_pred = std::pow(10.0, static_cast(i)); + const double z = (std::log(y_lower) - std::log(y_pred)) / sigma; + const double gradient = loss.Gradient(y_lower, y_upper, std::log(y_pred), sigma); + const double hessian = loss.Hessian(y_lower, y_upper, std::log(y_pred), sigma); + ASSERT_FALSE(std::isnan(gradient)) << "z = " << z << ", y \\in [" + << y_lower << ", " << y_upper << "], y_pred = " << y_pred + << ", dist = " << static_cast(dist_type); + ASSERT_FALSE(std::isinf(gradient)) << "z = " << z << ", y \\in [" + << y_lower << ", " << y_upper << "], y_pred = " << y_pred + << ", dist = " << static_cast(dist_type); + ASSERT_FALSE(std::isnan(hessian)) << "z = " << z << ", y \\in [" + << y_lower << ", " << y_upper << "], y_pred = " << y_pred + << ", dist = " << static_cast(dist_type); + ASSERT_FALSE(std::isinf(hessian)) << "z = " << z << ", y \\in [" + << y_lower << ", " << y_upper << "], y_pred = " << y_pred + << ", dist = " << static_cast(dist_type); + } +} + +TEST(AFTLoss, RobustGradientPair) { // Ensure that INF and NAN don't show up in gradient pair const double y_lower = 16.0; const double y_upper = 200.0; const double sigma = 2.0; - for (auto dist_type : { ProbabilityDistributionType::kNormal, - ProbabilityDistributionType::kLogistic, - ProbabilityDistributionType::kExtreme }) { - AFTLoss loss(dist_type); - for (int i = 50; i >= -50; --i) { - const double y_pred = std::pow(10.0, static_cast(i)); - const double z = (std::log(y_lower) - std::log(y_pred)) / sigma; - const double gradient = loss.Gradient(y_lower, y_upper, std::log(y_pred), sigma); - const double hessian = loss.Hessian(y_lower, y_upper, std::log(y_pred), sigma); - ASSERT_FALSE(std::isnan(gradient)) << "z = " << z << ", y_pred = " << y_pred - << ", dist = " << static_cast(dist_type); - ASSERT_FALSE(std::isinf(gradient)) << "z = " << z << ", y_pred = " << y_pred - << ", dist = " << static_cast(dist_type); - ASSERT_FALSE(std::isnan(hessian)) << "z = " << z << ", y_pred = " << y_pred - << ", dist = " << static_cast(dist_type); - ASSERT_FALSE(std::isinf(hessian)) << "z = " << z << ", y_pred = " << y_pred - << ", dist = " << static_cast(dist_type); - } - } + RobustTestSuite(ProbabilityDistributionType::kNormal, 16.0, 200.0, 2.0); + RobustTestSuite(ProbabilityDistributionType::kLogistic, 16.0, 200.0, 2.0); + RobustTestSuite(ProbabilityDistributionType::kExtreme, 16.0, 200.0, 2.0); + RobustTestSuite(ProbabilityDistributionType::kNormal, 100.0, 100.0, 2.0); + RobustTestSuite(ProbabilityDistributionType::kLogistic, 100.0, 100.0, 2.0); + RobustTestSuite(ProbabilityDistributionType::kExtreme, 100.0, 100.0, 2.0); } } // namespace common From 12da24f9b0b8d946ae7af68fdf559d8ecd3ee894 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 3 Apr 2020 02:06:16 -0700 Subject: [PATCH 4/8] Revise unit test slightly --- tests/python/test_survival.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/python/test_survival.py b/tests/python/test_survival.py index 12c79ed4bb33..b50cfe50f15c 100644 --- a/tests/python/test_survival.py +++ b/tests/python/test_survival.py @@ -85,6 +85,7 @@ def test_aft_survival_demo_data(): # AFT metric (negative log likelihood) improve monotonically assert all(p >= q for p, q in zip(nloglik_rec[dist], nloglik_rec[dist][:1])) # For this data, normal distribution works the best - assert nloglik_rec['normal'][-1] < 5.0 - assert nloglik_rec['logistic'][-1] > 5.0 - assert nloglik_rec['extreme'][-1] > 5.0 + print (nloglik_rec['normal'][-1], nloglik_rec['logistic'][-1], nloglik_rec['extreme'][-1]) + assert nloglik_rec['normal'][-1] < 4.9 + assert nloglik_rec['logistic'][-1] > 4.9 + assert nloglik_rec['extreme'][-1] > 4.9 From e356dde8d6e5af6ce53e91366e747e5b80c3eeb7 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 3 Apr 2020 02:21:48 -0700 Subject: [PATCH 5/8] Fix lint --- src/common/survival_util.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/common/survival_util.h b/src/common/survival_util.h index feb582db5e2d..50c6aab5109b 100644 --- a/src/common/survival_util.h +++ b/src/common/survival_util.h @@ -49,9 +49,9 @@ class AFTLoss { * \brief Constructor for AFT loss function * \param dist_type Choice of probability distribution for the noise term in AFT */ - explicit AFTLoss(ProbabilityDistributionType dist_type) - : dist_(ProbabilityDistribution::Create(dist_type)), - dist_type_(dist_type) {} + explicit AFTLoss(ProbabilityDistributionType dist_type) + : dist_(ProbabilityDistribution::Create(dist_type)), + dist_type_(dist_type) {} public: /*! From 14ab9928f63521ae9802ba71832cd1d34c30c225 Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Fri, 3 Apr 2020 02:28:05 -0700 Subject: [PATCH 6/8] Update test_survival.py --- tests/python/test_survival.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/test_survival.py b/tests/python/test_survival.py index b50cfe50f15c..67e1a4da1c7f 100644 --- a/tests/python/test_survival.py +++ b/tests/python/test_survival.py @@ -85,7 +85,6 @@ def test_aft_survival_demo_data(): # AFT metric (negative log likelihood) improve monotonically assert all(p >= q for p, q in zip(nloglik_rec[dist], nloglik_rec[dist][:1])) # For this data, normal distribution works the best - print (nloglik_rec['normal'][-1], nloglik_rec['logistic'][-1], nloglik_rec['extreme'][-1]) assert nloglik_rec['normal'][-1] < 4.9 assert nloglik_rec['logistic'][-1] > 4.9 assert nloglik_rec['extreme'][-1] > 4.9 From f907878609d73e5e578e669b803f8658e2a67636 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Sat, 4 Apr 2020 10:45:00 -0700 Subject: [PATCH 7/8] Use GradientPairPrecise --- src/common/survival_util.cc | 63 ++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 33 deletions(-) diff --git a/src/common/survival_util.cc b/src/common/survival_util.cc index 278f2b5af98e..5630da35593b 100644 --- a/src/common/survival_util.cc +++ b/src/common/survival_util.cc @@ -47,68 +47,65 @@ enum class CensoringType : uint8_t { kUncensored, kRightCensored, kLeftCensored, kIntervalCensored }; -struct GradHessPair { - double gradient; - double hessian; -}; +using xgboost::GradientPairPrecise; -inline GradHessPair GetLimitAtInfPred(ProbabilityDistributionType dist_type, - CensoringType censor_type, - double sign, double sigma) { +inline GradientPairPrecise GetLimitAtInfPred(ProbabilityDistributionType dist_type, + CensoringType censor_type, + double sign, double sigma) { switch (censor_type) { case CensoringType::kUncensored: switch (dist_type) { case ProbabilityDistributionType::kNormal: - return sign ? GradHessPair{ kMinGradient, 1.0 / (sigma * sigma) } - : GradHessPair{ kMaxGradient, 1.0 / (sigma * sigma) }; + return sign ? GradientPairPrecise{ kMinGradient, 1.0 / (sigma * sigma) } + : GradientPairPrecise{ kMaxGradient, 1.0 / (sigma * sigma) }; case ProbabilityDistributionType::kLogistic: - return sign ? GradHessPair{ -1.0 / sigma, kMinHessian } - : GradHessPair{ 1.0 / sigma, kMinHessian }; + return sign ? GradientPairPrecise{ -1.0 / sigma, kMinHessian } + : GradientPairPrecise{ 1.0 / sigma, kMinHessian }; case ProbabilityDistributionType::kExtreme: - return sign ? GradHessPair{ kMinGradient, kMaxHessian } - : GradHessPair{ 1.0 / sigma, kMinHessian }; + return sign ? GradientPairPrecise{ kMinGradient, kMaxHessian } + : GradientPairPrecise{ 1.0 / sigma, kMinHessian }; default: LOG(FATAL) << "Unknown distribution type"; } case CensoringType::kRightCensored: switch (dist_type) { case ProbabilityDistributionType::kNormal: - return sign ? GradHessPair{ kMinGradient, 1.0 / (sigma * sigma) } - : GradHessPair{ 0.0, kMinHessian }; + return sign ? GradientPairPrecise{ kMinGradient, 1.0 / (sigma * sigma) } + : GradientPairPrecise{ 0.0, kMinHessian }; case ProbabilityDistributionType::kLogistic: - return sign ? GradHessPair{ -1.0 / sigma, kMinHessian } - : GradHessPair{ 0.0, kMinHessian }; + return sign ? GradientPairPrecise{ -1.0 / sigma, kMinHessian } + : GradientPairPrecise{ 0.0, kMinHessian }; case ProbabilityDistributionType::kExtreme: - return sign ? GradHessPair{ kMinGradient, kMaxHessian } - : GradHessPair{ 0.0, kMinHessian }; + return sign ? GradientPairPrecise{ kMinGradient, kMaxHessian } + : GradientPairPrecise{ 0.0, kMinHessian }; default: LOG(FATAL) << "Unknown distribution type"; } case CensoringType::kLeftCensored: switch (dist_type) { case ProbabilityDistributionType::kNormal: - return sign ? GradHessPair{ 0.0, kMinHessian } - : GradHessPair{ kMaxGradient, 1.0 / (sigma * sigma) }; + return sign ? GradientPairPrecise{ 0.0, kMinHessian } + : GradientPairPrecise{ kMaxGradient, 1.0 / (sigma * sigma) }; case ProbabilityDistributionType::kLogistic: - return sign ? GradHessPair{ 0.0, kMinHessian } - : GradHessPair{ 1.0 / sigma, kMinHessian }; + return sign ? GradientPairPrecise{ 0.0, kMinHessian } + : GradientPairPrecise{ 1.0 / sigma, kMinHessian }; case ProbabilityDistributionType::kExtreme: - return sign ? GradHessPair{ 0.0, kMinHessian } - : GradHessPair{ 1.0 / sigma, kMinHessian }; + return sign ? GradientPairPrecise{ 0.0, kMinHessian } + : GradientPairPrecise{ 1.0 / sigma, kMinHessian }; default: LOG(FATAL) << "Unknown distribution type"; } case CensoringType::kIntervalCensored: switch (dist_type) { case ProbabilityDistributionType::kNormal: - return sign ? GradHessPair{ kMinGradient, 1.0 / (sigma * sigma) } - : GradHessPair{ kMaxGradient, 1.0 / (sigma * sigma) }; + return sign ? GradientPairPrecise{ kMinGradient, 1.0 / (sigma * sigma) } + : GradientPairPrecise{ kMaxGradient, 1.0 / (sigma * sigma) }; case ProbabilityDistributionType::kLogistic: - return sign ? GradHessPair{ -1.0 / sigma, kMinHessian } - : GradHessPair{ 1.0 / sigma, kMinHessian }; + return sign ? GradientPairPrecise{ -1.0 / sigma, kMinHessian } + : GradientPairPrecise{ 1.0 / sigma, kMinHessian }; case ProbabilityDistributionType::kExtreme: - return sign ? GradHessPair{ kMinGradient, kMaxHessian } - : GradHessPair{ 1.0 / sigma, kMinHessian }; + return sign ? GradientPairPrecise{ kMinGradient, kMaxHessian } + : GradientPairPrecise{ 1.0 / sigma, kMinHessian }; default: LOG(FATAL) << "Unknown distribution type"; } @@ -200,7 +197,7 @@ double AFTLoss::Gradient(double y_lower, double y_upper, double y_pred, double s } gradient = numerator / denominator; if (denominator < kEps && (std::isnan(gradient) || std::isinf(gradient))) { - gradient = GetLimitAtInfPred(dist_type_, censor_type, z_sign, sigma).gradient; + gradient = GetLimitAtInfPred(dist_type_, censor_type, z_sign, sigma).GetGrad(); } return Clip(gradient, kMinGradient, kMaxGradient); @@ -257,7 +254,7 @@ double AFTLoss::Hessian(double y_lower, double y_upper, double y_pred, double si } hessian = numerator / denominator; if (denominator < kEps && (std::isnan(hessian) || std::isinf(hessian))) { - hessian = GetLimitAtInfPred(dist_type_, censor_type, z_sign, sigma).hessian; + hessian = GetLimitAtInfPred(dist_type_, censor_type, z_sign, sigma).GetHess(); } return Clip(hessian, kMinHessian, kMaxHessian); From 3b2d96cfb95f32d0d5540dfc1e530038f0c629e3 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Sat, 4 Apr 2020 10:49:05 -0700 Subject: [PATCH 8/8] Remove unused variables --- tests/cpp/common/test_survival_util.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/cpp/common/test_survival_util.cc b/tests/cpp/common/test_survival_util.cc index 29318ce8a429..53faaac6cd62 100644 --- a/tests/cpp/common/test_survival_util.cc +++ b/tests/cpp/common/test_survival_util.cc @@ -32,10 +32,6 @@ inline static void RobustTestSuite(ProbabilityDistributionType dist_type, } TEST(AFTLoss, RobustGradientPair) { // Ensure that INF and NAN don't show up in gradient pair - const double y_lower = 16.0; - const double y_upper = 200.0; - const double sigma = 2.0; - RobustTestSuite(ProbabilityDistributionType::kNormal, 16.0, 200.0, 2.0); RobustTestSuite(ProbabilityDistributionType::kLogistic, 16.0, 200.0, 2.0); RobustTestSuite(ProbabilityDistributionType::kExtreme, 16.0, 200.0, 2.0);