-
Notifications
You must be signed in to change notification settings - Fork 10
/
surv_model_performance.R
58 lines (48 loc) · 2.2 KB
/
surv_model_performance.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#' Helper functions for `model_performance.R`
#'
#' @param explainer an explainer object - model preprocessed by the `explain()` function
#' @param ... other parameters, currently ignored
#' @param times a numeric vector, time points at which ROC curves are calculated if `type == "roc"` or at which metrics are calculated if `type == "metrics"`. Note: if `type=="roc"` this parameter is obligatory
#' @param type character, either `"metrics"` which calculates performance metrics or `"roc"` which calculates ROC curves at specific time points
#'
#' @return Either a list when `type == "metrics"` or a data.frame if `type == "roc"`
#'
#' @keywords internal
surv_model_performance <- function(explainer, ..., times = NULL, type = "metrics") {
newdata <- explainer$data
if (type == "metrics") {
if (is.null(times)) times <- explainer$times
sf <- explainer$predict_survival_function(explainer$model, newdata, times)
risk <- explainer$predict_function(explainer$model, newdata)
y <- explainer$y
brier_score <- brier_score(y, risk, sf, times)
auc <- cd_auc(y, risk, sf, times)
cindex <- c_index(y, risk)
iauc <- integrated_cd_auc(auc = auc, times = times)
ibs <- integrated_brier_score(times = times, brier = brier_score)
ret_list <- list(
eval_times = times,
brier_score = brier_score,
auc = auc,
cindex = cindex,
iauc = iauc,
integrated_brier_score = ibs
)
class(ret_list) <- c("surv_model_performance", class(ret_list))
attr(ret_list, "label") <- explainer$label
ret_list
}
else {
if (is.null(times)) stop("Times cannot be NULL for type `roc`")
rocs <- lapply(times, function(time) {
labels <- 1 - explainer$y[, 2]
scores <- explainer$predict_survival_function(explainer$model, newdata, time)
labels <- labels[order(scores, decreasing = TRUE)]
cbind(time = time, data.frame(TPR = cumsum(labels) / sum(labels), FPR = cumsum(!labels) / sum(!labels), labels))
})
rocs_df <- do.call(rbind, rocs)
class(rocs_df) <- c("surv_model_performance_rocs", class(rocs_df))
attr(rocs_df, "label") <- explainer$label
rocs_df
}
}