diff --git a/src/common/survival_util.cc b/src/common/survival_util.cc index 82724e25155b..4ccf75420bd8 100644 --- a/src/common/survival_util.cc +++ b/src/common/survival_util.cc @@ -141,25 +141,28 @@ double AFTExtreme::HessPDF(double z) { } -double AFTLoss::Loss(double y_lower, double y_higher, double y_pred, double sigma) { +double AFTLoss::Loss(double y_lower, double y_upper, double y_pred, double sigma) { double pdf; double cdf_u, cdf_l, z_u, z_l; double cost; - if (y_lower == y_higher) { // uncensored - z_l = (y_lower - y_pred) / sigma; + + const double log_y_lower = std::log(y_lower); + const double log_y_upper = std::log(y_upper); + if (y_lower == y_upper) { // uncensored + z_l = (log_y_lower - y_pred) / sigma; pdf = dist_->PDF(z_l); cost = -std::log(pdf / (sigma * y_lower)); } else { // censored; now check what type of censorship we have - if (std::isinf(y_higher)) { // right-censored + if (std::isinf(y_upper)) { // right-censored cdf_u = 1; } else { // left-censored or interval-censored - z_u = (y_higher - y_pred) / sigma; + z_u = (log_y_upper - y_pred) / sigma; cdf_u = dist_->CDF(z_u); } if (std::isinf(y_lower)) { // left-censored cdf_l = 0; } else { // right-censored or interval-censored - z_l = (y_lower - y_pred) / sigma; + z_l = (log_y_lower - y_pred) / sigma; cdf_l = dist_->CDF(z_l); } cost = -std::log(cdf_u - cdf_l); @@ -167,7 +170,7 @@ double AFTLoss::Loss(double y_lower, double y_higher, double y_pred, double sigm return cost; } -double AFTLoss::Gradient(double y_lower, double y_higher, double y_pred, double sigma) { +double AFTLoss::Gradient(double y_lower, double y_upper, double y_pred, double sigma) { double pdf_l; double pdf_u; double pdf; @@ -180,17 +183,19 @@ double AFTLoss::Gradient(double y_lower, double y_higher, double y_pred, double double gradient; const double eps = 1e-12f; - if (y_lower == y_higher) { // uncensored - z = (y_lower - y_pred) / sigma; + const double log_y_lower = std::log(y_lower); + const double log_y_upper = std::log(y_upper); + if (y_lower == y_upper) { // uncensored + z = (log_y_lower - y_pred) / sigma; pdf = dist_->PDF(z); grad = dist_->GradPDF(z); gradient = grad / (sigma * pdf); } else { // censored; now check what type of censorship we have - if (std::isinf(y_higher)) { // right-censored + if (std::isinf(y_upper)) { // right-censored pdf_u = 0; cdf_u = 1; } else { // interval-censored or left-censored - z_u = (y_higher - y_pred) / sigma; + z_u = (log_y_upper - y_pred) / sigma; pdf_u = dist_->PDF(z_u); cdf_u = dist_->CDF(z_u); } @@ -198,7 +203,7 @@ double AFTLoss::Gradient(double y_lower, double y_higher, double y_pred, double pdf_l = 0; cdf_l = 0; } else { // interval-censored or right-censored - z_l = (y_lower - y_pred) / sigma; + z_l = (log_y_lower - y_pred) / sigma; pdf_l = dist_->PDF(z_l); cdf_l = dist_->CDF(z_l); } @@ -208,7 +213,7 @@ double AFTLoss::Gradient(double y_lower, double y_higher, double y_pred, double return gradient; } -double AFTLoss::Hessian(double y_lower, double y_higher, double y_pred, double sigma) { +double AFTLoss::Hessian(double y_lower, double y_upper, double y_pred, double sigma) { double z; double z_u; double z_l; @@ -232,19 +237,21 @@ double AFTLoss::Hessian(double y_lower, double y_higher, double y_pred, double s double hess_dist; const double eps = 1e-12f; - if (y_lower == y_higher) { // uncensored - z = (y_lower - y_pred) / sigma; + const double log_y_lower = std::log(y_lower); + const double log_y_upper = std::log(y_upper); + if (y_lower == y_upper) { // uncensored + z = (log_y_lower - y_pred) / sigma; pdf = dist_->PDF(z); grad = dist_->GradPDF(z); hess_dist = dist_->HessPDF(z); hessian = -(pdf * hess_dist - std::pow(grad, 2)) / (std::pow(sigma, 2) * std::pow(pdf, 2)); } else { // censored; now check what type of censorship we have - if (std::isinf(y_higher)) { // right-censored + if (std::isinf(y_upper)) { // right-censored pdf_u = 0; cdf_u = 1; grad_u = 0; } else { // interval-censored or left-censored - z_u = (y_higher - y_pred) / sigma; + z_u = (log_y_upper - y_pred) / sigma; pdf_u = dist_->PDF(z_u); cdf_u = dist_->CDF(z_u); grad_u = dist_->GradPDF(z_u); @@ -254,7 +261,7 @@ double AFTLoss::Hessian(double y_lower, double y_higher, double y_pred, double s cdf_l = 0; grad_l = 0; } else { // interval-censored or right-censored - z_l = (y_lower - y_pred) / sigma; + z_l = (log_y_lower - y_pred) / sigma; pdf_l = dist_->PDF(z_l); cdf_l = dist_->CDF(z_l); grad_l = dist_->GradPDF(z_l); diff --git a/src/common/survival_util.h b/src/common/survival_util.h index 46db58b1a263..ca879ac4bc38 100644 --- a/src/common/survival_util.h +++ b/src/common/survival_util.h @@ -92,9 +92,9 @@ class AFTLoss { } public: - double Loss(double y_lower, double y_higher, double y_pred, double sigma); - double Gradient(double y_lower, double y_higher, double y_pred, double sigma); - double Hessian(double y_lower, double y_higher, double y_pred, double sigma); + double Loss(double y_lower, double y_upper, double y_pred, double sigma); + double Gradient(double y_lower, double y_upper, double y_pred, double sigma); + double Hessian(double y_lower, double y_upper, double y_pred, double sigma); }; } // namespace common diff --git a/src/metric/survival_metric.cc b/src/metric/survival_metric.cc index 89eac3bbc7ae..6eff558eeb12 100644 --- a/src/metric/survival_metric.cc +++ b/src/metric/survival_metric.cc @@ -74,7 +74,7 @@ struct EvalAFT : public Metric { for (omp_ulong i = 0; i < nsize; ++i) { // If weights are empty, data is unweighted so we use 1.0 everywhere double w = is_null_weight ? 1.0 : weights[i]; - double loss = loss_->Loss(std::log(y_lower[i]), std::log(y_higher[i]), + double loss = loss_->Loss(y_lower[i], y_higher[i], yhat[i], param_.aft_loss_distribution_scale); nloglik_sum += loss; weight_sum += w; diff --git a/src/objective/aft_obj.cc b/src/objective/aft_obj.cc index fd4eadee5b35..a28ee4f4c54d 100644 --- a/src/objective/aft_obj.cc +++ b/src/objective/aft_obj.cc @@ -55,17 +55,16 @@ class AFTObj : public ObjFunction { << "yhat is too big"; const omp_ulong nsize = static_cast(yhat.size()); double first_order_grad; - double second_order_grad; #pragma omp parallel for schedule(static) for (omp_ulong i = 0; i < nsize; ++i) { // If weights are empty, data is unweighted so we use 1.0 everywhere - double w = is_null_weight ? 1.0 : weights[i]; - first_order_grad = loss_->Gradient(std::log(y_lower[i]), std::log(y_higher[i]), + const double w = is_null_weight ? 1.0 : weights[i]; + const double grad = loss_->Gradient(y_lower[i], y_higher[i], + yhat[i], param_.aft_loss_distribution_scale); + const double hess = loss_->Hessian(y_lower[i], y_higher[i], yhat[i], param_.aft_loss_distribution_scale); - second_order_grad = loss_->Hessian(std::log(y_lower[i]), std::log(y_higher[i]), - yhat[i], param_.aft_loss_distribution_scale); - gpair[i] = GradientPair(first_order_grad * w, second_order_grad * w); + gpair[i] = GradientPair(grad * w, hess * w); } }