diff --git a/NEWS.md b/NEWS.md index 64413d0d..87a0b6c9 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,10 +4,12 @@ * added references to used methods ([#5](https://github.com/ModelOriented/survex/issues/5)) * changed the package used to draw complex plots from `gridExtra` to `patchwork` ([#7](https://github.com/ModelOriented/survex/pull/7)) * fixed subtitles in plots ([#11](https://github.com/ModelOriented/survex/issues/11)) +* fixed calculating of roc curves for classification problems * added wrapper function for measures provided by `mlr3proba` ([#10](https://github.com/ModelOriented/survex/issues/10)) * created vignette showing how to use `mlr3proba` with `survex` * fixed incompatibility with new ggplot2 version 3.4 + # survex 0.1.1 * The `survex` package is now public * `model_parts`, `model_profile`, `predict_parts`, `predict_profile` explanations implemented diff --git a/R/model_performance.R b/R/model_performance.R index fa2677f2..011b289c 100644 --- a/R/model_performance.R +++ b/R/model_performance.R @@ -54,7 +54,7 @@ #' #' plot(rsf_ranger_model_performance, cph_model_performance, rsf_src_model_performance) #' -#' cph_model_performance_roc <- model_performance(cph_exp, type = "roc", times = c(100, 500, 1200)) +#' cph_model_performance_roc <- model_performance(cph_exp, type = "roc", times = c(100, 250, 500)) #' plot(cph_model_performance_roc) #' #' @rdname model_performance.surv_explainer diff --git a/R/surv_model_performance.R b/R/surv_model_performance.R index 134b2c23..15389c34 100644 --- a/R/surv_model_performance.R +++ b/R/surv_model_performance.R @@ -31,12 +31,19 @@ surv_model_performance <- function(explainer, ..., times = NULL, type = "metrics else { if (is.null(times)) stop("Times cannot be NULL for type `roc`") rocs <- lapply(times, function(time) { - labels <- 1 - explainer$y[, 2] - scores <- explainer$predict_survival_function(explainer$model, newdata, time) - labels <- labels[order(scores, decreasing = TRUE)] - cbind(time = time, data.frame(TPR = cumsum(labels) / sum(labels), FPR = cumsum(!labels) / sum(!labels), labels)) + censored_earlier_mask <- (explainer$y[, 1] < time & explainer$y[, 2] == 0) + event_later_mask <- explainer$y[, 1] > time + newdata_t <- newdata[!censored_earlier_mask, ] + labels <- explainer$y[,2] + labels[event_later_mask] <- 0 + labels <- labels[!censored_earlier_mask] + scores <- explainer$predict_survival_function(explainer$model, newdata_t, time) + labels <- labels[order(scores, decreasing = FALSE)] + cbind(time = time, data.frame(TPR = cumsum(labels) / sum(labels), + FPR = cumsum(!labels) / sum(!labels), labels)) }) + rocs_df <- do.call(rbind, rocs) class(rocs_df) <- c("surv_model_performance_rocs", class(rocs_df)) diff --git a/man/model_performance.surv_explainer.Rd b/man/model_performance.surv_explainer.Rd index ee9b9534..15c29b9f 100644 --- a/man/model_performance.surv_explainer.Rd +++ b/man/model_performance.surv_explainer.Rd @@ -84,7 +84,7 @@ plot(rsf_ranger_model_performance, cph_model_performance, plot(rsf_ranger_model_performance, cph_model_performance, rsf_src_model_performance) -cph_model_performance_roc <- model_performance(cph_exp, type = "roc", times = c(100, 500, 1200)) +cph_model_performance_roc <- model_performance(cph_exp, type = "roc", times = c(100, 250, 500)) plot(cph_model_performance_roc) }