Skip to content

Commit

Permalink
Fix model_performance() with type = 'roc' (for binary classificat…
Browse files Browse the repository at this point in the history
…ion problem) (#19)

* change `times` vector in `model_performance()` roc

* fix `model_performance()` roc calculation

* Update NEWS.md

Co-authored-by: Mikołaj Spytek <55801784+mikolajsp@users.noreply.github.com>
  • Loading branch information
krzyzinskim and mikolajsp authored Nov 15, 2022
1 parent facbf3f commit defda28
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 6 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
* added references to used methods ([#5](https://github.com/ModelOriented/survex/issues/5))
* changed the package used to draw complex plots from `gridExtra` to `patchwork` ([#7](https://github.com/ModelOriented/survex/pull/7))
* fixed subtitles in plots ([#11](https://github.com/ModelOriented/survex/issues/11))
* fixed calculating of roc curves for classification problems
* added wrapper function for measures provided by `mlr3proba` ([#10](https://github.com/ModelOriented/survex/issues/10))
* created vignette showing how to use `mlr3proba` with `survex`
* fixed incompatibility with new ggplot2 version 3.4


# survex 0.1.1
* The `survex` package is now public
* `model_parts`, `model_profile`, `predict_parts`, `predict_profile` explanations implemented
Expand Down
2 changes: 1 addition & 1 deletion R/model_performance.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
#'
#' plot(rsf_ranger_model_performance, cph_model_performance, rsf_src_model_performance)
#'
#' cph_model_performance_roc <- model_performance(cph_exp, type = "roc", times = c(100, 500, 1200))
#' cph_model_performance_roc <- model_performance(cph_exp, type = "roc", times = c(100, 250, 500))
#' plot(cph_model_performance_roc)
#'
#' @rdname model_performance.surv_explainer
Expand Down
15 changes: 11 additions & 4 deletions R/surv_model_performance.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,19 @@ surv_model_performance <- function(explainer, ..., times = NULL, type = "metrics
else {
if (is.null(times)) stop("Times cannot be NULL for type `roc`")
rocs <- lapply(times, function(time) {
labels <- 1 - explainer$y[, 2]
scores <- explainer$predict_survival_function(explainer$model, newdata, time)
labels <- labels[order(scores, decreasing = TRUE)]
cbind(time = time, data.frame(TPR = cumsum(labels) / sum(labels), FPR = cumsum(!labels) / sum(!labels), labels))
censored_earlier_mask <- (explainer$y[, 1] < time & explainer$y[, 2] == 0)
event_later_mask <- explainer$y[, 1] > time
newdata_t <- newdata[!censored_earlier_mask, ]
labels <- explainer$y[,2]
labels[event_later_mask] <- 0
labels <- labels[!censored_earlier_mask]
scores <- explainer$predict_survival_function(explainer$model, newdata_t, time)
labels <- labels[order(scores, decreasing = FALSE)]
cbind(time = time, data.frame(TPR = cumsum(labels) / sum(labels),
FPR = cumsum(!labels) / sum(!labels), labels))
})


rocs_df <- do.call(rbind, rocs)

class(rocs_df) <- c("surv_model_performance_rocs", class(rocs_df))
Expand Down
2 changes: 1 addition & 1 deletion man/model_performance.surv_explainer.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit defda28

Please sign in to comment.