Skip to content

Commit

Permalink
Increase mlr3proba support (#18)
Browse files Browse the repository at this point in the history
* add mlr3proba metrics

* fix plotting functions - non time-dependent loss functions in permutational variable importance

* add plotting for mlr3 measures

* add time-dependent plotting back

* Add reverse parameter and documentation for the loss_adapt_mlr3proba

* Fix R CMD check note

* cleanup and fix tests

* vignette draft

* fix plotting model_parts (DALEX, ingredients conflict)

* add `mlr3proba` vignette

* Update NEWS.md

* add linewidth instead of size for ggplot2

* replace `aes_string()` with `aes()`

* Add mlr3 installation in vignette

* add dependencies for mlr3proba vignette

* Update example of plot.feature_importance_explainer

* Update NEWS.md

Co-authored-by: krzyzinskim <mateusz_krzyzinski@wp.pl>
  • Loading branch information
mikolajsp and krzyzinskim authored Nov 15, 2022
1 parent e107cd2 commit facbf3f
Show file tree
Hide file tree
Showing 25 changed files with 640 additions and 179 deletions.
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ survex.Rproj
LICENSE
^misc$
^cran-comments\.md$
^vignettes/articles$
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ S3method(model_performance,default)
S3method(model_performance,surv_explainer)
S3method(model_profile,default)
S3method(model_profile,surv_explainer)
S3method(plot,feature_importance_explainer)
S3method(plot,model_parts_survival)
S3method(plot,model_performance_survival)
S3method(plot,model_profile_survival)
Expand Down Expand Up @@ -48,6 +49,7 @@ export(explain)
export(explain_survival)
export(integrated_brier_score)
export(integrated_cd_auc)
export(loss_adapt_mlr3proba)
export(loss_brier_score)
export(loss_integrated_brier_score)
export(loss_one_minus_c_index)
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
* added references to used methods ([#5](https://github.com/ModelOriented/survex/issues/5))
* changed the package used to draw complex plots from `gridExtra` to `patchwork` ([#7](https://github.com/ModelOriented/survex/pull/7))
* fixed subtitles in plots ([#11](https://github.com/ModelOriented/survex/issues/11))
* added wrapper function for measures provided by `mlr3proba` ([#10](https://github.com/ModelOriented/survex/issues/10))
* created vignette showing how to use `mlr3proba` with `survex`
* fixed incompatibility with new ggplot2 version 3.4

# survex 0.1.1
* The `survex` package is now public
Expand Down
63 changes: 62 additions & 1 deletion R/metrics.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
utils::globalVariables(c("PredictionSurv"))
#' Compute the Harrell's Concordance index
#'
#' A function to compute the Harrell's concordance index of a survival model.
Expand Down Expand Up @@ -52,6 +53,8 @@ c_index <- function(y_true = NULL, risk = NULL, surv = NULL, times = NULL) {

}
attr(c_index, "loss_name") <- "C-index"
attr(c_index, "loss_type") <- "risk-based"


#' Calculate the Concordance index loss
#'
Expand Down Expand Up @@ -87,7 +90,7 @@ loss_one_minus_c_index <- function(y_true = NULL, risk = NULL, surv = NULL, time
1 - c_index(y_true = y_true, risk = risk, surv = surv, times = times)
}
attr(loss_one_minus_c_index, "loss_name") <- "One minus C-Index"

attr(loss_one_minus_c_index, "loss_type") <- "risk-based"

#' Calculate Brier score
#'
Expand Down Expand Up @@ -156,11 +159,13 @@ brier_score <- function(y_true = NULL, risk = NULL, surv = NULL, times = NULL) {

}
attr(brier_score, "loss_name") <- "Brier score"
attr(brier_score, "loss_type") <- "time-dependent"

#' @rdname brier_score
#' @export
loss_brier_score <- brier_score
attr(loss_brier_score, "loss_name") <- "Brier score"
attr(loss_brier_score, "loss_type") <- "time-dependent"

#' Calculate Cumulative/Dynamic AUC
#'
Expand Down Expand Up @@ -235,6 +240,7 @@ cd_auc <- function(y_true = NULL, risk = NULL, surv = NULL, times = NULL) {

}
attr(cd_auc, "loss_name") <- "C/D AUC"
attr(cd_auc, "loss_type") <- "time-dependent"


#' Calculate Cumulative/Dynamic AUC loss
Expand Down Expand Up @@ -270,6 +276,7 @@ loss_one_minus_cd_auc <- function(y_true = NULL, risk = NULL, surv = NULL, times
1 - cd_auc(y_true = y_true, risk = risk, surv = surv, times = times)
}
attr(loss_one_minus_cd_auc, "loss_name") <- "One minus C/D AUC"
attr(loss_one_minus_cd_auc, "loss_type") <- "time-dependent"

#' Calculate integrated C/D AUC
#'
Expand Down Expand Up @@ -329,6 +336,8 @@ integrated_cd_auc <- function(y_true = NULL, risk = NULL, surv = NULL, times = N
cumsum(c(0, iauc))[length(cumsum(c(0, iauc)))] / (max(times) - min(times))
}
attr(integrated_cd_auc, "loss_name") <- "integrated C/D AUC"
attr(integrated_cd_auc, "loss_type") <- "integrated"



#' Calculate integrated C/D AUC loss
Expand Down Expand Up @@ -371,6 +380,8 @@ loss_one_minus_integrated_cd_auc <- function(y_true = NULL, risk = NULL, surv =
1 - integrated_cd_auc(y_true = y_true, risk = risk, surv = surv, times = times, auc = auc)
}
attr(loss_one_minus_integrated_cd_auc, "loss_name") <- "One minus integrated C/D AUC"
attr(loss_one_minus_integrated_cd_auc, "loss_type") <- "integrated"



#' Calculate integrated Brier score
Expand Down Expand Up @@ -438,3 +449,53 @@ attr(integrated_brier_score, "loss_name") <- "integrated Brier score"
#' @export
loss_integrated_brier_score <- integrated_brier_score
attr(loss_integrated_brier_score, "loss_name") <- "integrated Brier score"
attr(loss_integrated_brier_score, "loss_type") <- "integrated"

#' Adapt mlr3proba measures for use with survex
#'
#' This function allows for usage of standardized measures from the mlr3proba package with `survex`.
#'
#' @param measure - a `MeasureSurv` object from the `mlr3proba` package, the object to adapt
#' @param reverse - boolean, FALSE by default, whether the metric should be reversed in order to be treated as loss (for permutational variable importance we need functions with lower values indicating better performance). If TRUE, the new metric value will be (1 - metric_value)
#' @param ... - other parameters, currently ignored
#'
#' @return a function with standardized parameters (`y_true`, `risk`, `surv`, `times`) that can be used to calculate loss
#'
#' if(FALSE){
#' measure <- msr("surv.calib_beta")
#' mlr_measure <- loss_adapt_mlr3proba(measure)
#' }
#'
#' @export
loss_adapt_mlr3proba <- function(measure, reverse = FALSE, ...){

loss_function <- function(y_true = NULL, risk = NULL, surv = NULL, times = NULL){

colnames(surv) <- times

surv_pred <- PredictionSurv$new(
row_ids = 1:length(y_true),
truth = y_true,
crank = risk,
distr = surv,
task = list(truth = y_true)
)

output <- surv_pred$score(measure)
names(output) <- NULL

if (reverse) output <- (1 - output)

return(output)
}

if (reverse) attr(loss_function, "loss_name") <- paste("one minus", measure$id)
else attr(loss_function, "loss_name") <- measure$id
attr(loss_function, "loss_type") <- "integrated"

return(loss_function)
}




64 changes: 32 additions & 32 deletions R/model_parts.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,39 +65,39 @@ model_parts.surv_explainer <- function(explainer,
if (type == "variable_importance") type <- "raw" # it's an alias

switch(output_type,
"risk" = DALEX::model_parts(
explainer = explainer,
loss_function = loss_function,
... = ...,
type = type,
N = N
),
"survival" = {
test_explainer(explainer, has_data = TRUE, has_y = TRUE, has_survival = TRUE, function_name = "model_parts")
"risk" = DALEX::model_parts(
explainer = explainer,
loss_function = loss_function,
... = ...,
type = type,
N = N
),
"survival" = {
test_explainer(explainer, has_data = TRUE, has_y = TRUE, has_survival = TRUE, function_name = "model_parts")

if (attr(loss_function, "loss_name") %in% c("integrated Brier score", "One minus integrated C/D AUC", "One minus C-Index")) {
res <- surv_integrated_feature_importance(
x = explainer,
loss_function = loss_function,
type = type,
N = N,
...
)
class(res) <- c("model_parts", class(res))
return(res)
} else {
res <- surv_feature_importance(
x = explainer,
loss_function = loss_function,
type = type,
N = N,
...
)
class(res) <- c("model_parts_survival", class(res))
res
}
},
stop("Type should be either `survival` or `risk`")
if (attr(loss_function, "loss_type") == "integrated") {
res <- surv_integrated_feature_importance(
x = explainer,
loss_function = loss_function,
type = type,
N = N,
...
)
class(res) <- c("model_parts_survival", class(res))
return(res)
} else {
res <- surv_feature_importance(
x = explainer,
loss_function = loss_function,
type = type,
N = N,
...
)
class(res) <- c("model_parts_survival", class(res))
res
}
},
stop("Type should be either `survival` or `risk`")
)
}

Expand Down
9 changes: 7 additions & 2 deletions R/model_performance.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#' @param explainer an explainer object - model preprocessed by the `explain()` function
#' @param ... other parameters, currently ignored
#' @param type character, either `"metrics"` or `"roc"`. If `"metrics"` then performance metrics are calculated, if `"roc"` ROC curves for selected time points are calculated.
#' @param metrics a named vector containing the metrics to be calculated. The values should be standardized loss functions. The functions can be supplied manually but has to have these named parameters (`y_true`, `risk`, `surv`, `times`), where `y_true` represents the `survival::Surv` object with observed times and statuses, `risk` is the risk score calculated by the model, and `surv` is the survival function for each observation evaluated at `times`.
#' @param times a numeric vector of times. If `type == "metrics"` then the survival function is evaluated at these times, if `type == "roc"` then the ROC curves are calculated at these times.
#'
#' @return An object of class `"model_performance_survival"`. It's a list of metric values calculated for the model. It contains:
Expand Down Expand Up @@ -62,10 +63,14 @@ model_performance <- function(explainer, ...) UseMethod("model_performance", exp

#' @rdname model_performance.surv_explainer
#' @export
model_performance.surv_explainer <- function(explainer, ..., type = "metrics", times = NULL) {
model_performance.surv_explainer <- function(explainer, ..., type = "metrics", metrics = c("C-index" = c_index,
"Integrated Brier score" = loss_integrated_brier_score,
"Integrated C/D AUC" = integrated_cd_auc,
"Brier score" = brier_score,
"C/D AUC" = cd_auc), times = NULL) {
test_explainer(explainer, "model_performance", has_data = TRUE, has_y = TRUE, has_survival = TRUE, has_predict = TRUE)

res <- surv_model_performance(explainer, ..., type = type, times = times)
res <- surv_model_performance(explainer, ..., type = type, metrics = metrics, times = times)

class(res) <- c("model_performance_survival", class(res))
res
Expand Down
134 changes: 134 additions & 0 deletions R/plot_feature_importance.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#' Plots Feature Importance
#'
#' This function plots variable importance calculated as changes in the loss function after variable drops.
#' It uses output from \code{feature_importance} function that corresponds to
#' permutation based measure of variable importance.
#' Variables are sorted in the same order in all panels.
#' The order depends on the average drop out loss.
#' In different panels variable contributions may not look like sorted if variable
#' importance is different in different in different models.
#'
#' Find more details in the \href{https://ema.drwhy.ai/featureImportance.html}{Feature Importance Chapter}.
#'
#' @param x a feature importance explainer produced with the \code{feature_importance()} function
#' @param ... other explainers that shall be plotted together
#' @param max_vars maximum number of variables that shall be presented for for each model.
#' By default \code{NULL} what means all variables
#' @param show_boxplots logical if \code{TRUE} (default) boxplot will be plotted to show permutation data.
#' @param bar_width width of bars. By default \code{10}
#' @param desc_sorting logical. Should the bars be sorted descending? By default TRUE
#' @param title the plot's title, by default \code{'Feature Importance'}
#' @param subtitle the plot's subtitle. By default - \code{NULL}, which means
#' the subtitle will be 'created for the XXX model', where XXX is the label of explainer(s)
#'
#' @importFrom stats model.frame reorder
#' @importFrom utils head tail
#'
#' @return a \code{ggplot2} object
#'
#' @references Explanatory Model Analysis. Explore, Explain, and Examine Predictive Models. \url{https://ema.drwhy.ai/}
#'
#' @examples
#' library(survex)
#' library(randomForestSRC)
#' library(survival)
#'
#' model <- rfsrc(Surv(time, status) ~., data = veteran)
#' explainer <- explain(model)
#'
#' mp <- model_parts(explainer, loss = loss_one_minus_c_index, output_type = "risk")
#' plot(mp)
#'
#' @export
plot.feature_importance_explainer <- function(x, ..., max_vars = NULL, show_boxplots = TRUE, bar_width = 10,
desc_sorting = TRUE, title = "Feature Importance", subtitle = NULL) {

if (!is.logical(desc_sorting)) {
stop("desc_sorting is not logical")
}

dfl <- c(list(x), list(...))

# add boxplot data
if (show_boxplots) {
dfl <- lapply(dfl, function(x) {
result <- data.frame(
min = tapply(x$dropout_loss, x$variable, min, na.rm = TRUE),
q1 = tapply(x$dropout_loss, x$variable, quantile, 0.25, na.rm = TRUE),
median = tapply(x$dropout_loss, x$variable, median, na.rm = TRUE),
q3 = tapply(x$dropout_loss, x$variable, quantile, 0.75, na.rm = TRUE),
max = tapply(x$dropout_loss, x$variable, max, na.rm = TRUE)
)

result$min <- as.numeric(result$min)
result$q1 <- as.numeric(result$q1)
result$median <- as.numeric(result$median)
result$q3 <- as.numeric(result$q3)
result$max <- as.numeric(result$max)

merge(x[x$permutation == 0,], cbind(rownames(result),result), by.x = "variable", by.y = "rownames(result)")
})
} else {
dfl <- lapply(dfl, function(x) {
x[x$permutation == 0,]
})
}

# combine all explainers in a single frame
expl_df <- do.call(rbind, dfl)

# add an additional column that serve as a baseline
bestFits <- expl_df[expl_df$variable == "_full_model_", ]
ext_expl_df <- merge(expl_df, bestFits[,c("label", "dropout_loss")], by = "label")

# set the order of variables depending on their contribution
ext_expl_df$variable <- reorder(ext_expl_df$variable,
(ext_expl_df$dropout_loss.x - ext_expl_df$dropout_loss.y) * ifelse(desc_sorting, 1, -1),
mean)

# remove rows that starts with _
ext_expl_df <- ext_expl_df[!(substr(ext_expl_df$variable,1,1) == "_"),]

# for each model leave only max_vars
if (!is.null(max_vars)) {
trimmed_parts <- lapply(unique(ext_expl_df$label), function(label) {
tmp <- ext_expl_df[ext_expl_df$label == label, ]
tmp[tail(order(tmp$dropout_loss.x), max_vars), ]
})
ext_expl_df <- do.call(rbind, trimmed_parts)
}

variable <- q1 <- q3 <- dropout_loss.x <- dropout_loss.y <- label <- dropout_loss <- NULL
nlabels <- length(unique(bestFits$label))

# extract labels for plot's subtitle
if (is.null(subtitle)) {
glm_labels <- paste0(unique(ext_expl_df$label), collapse = ", ")
subtitle <- paste0("created for the ", glm_labels, " model")
}

# plot it
pl <- ggplot(ext_expl_df, aes(variable, ymin = dropout_loss.y, ymax = dropout_loss.x, color = label)) +
geom_hline(data = bestFits, aes(yintercept = dropout_loss, color = label), lty= 3) +
geom_linerange(size = bar_width)

if (show_boxplots) {
pl <- pl +
geom_boxplot(aes(ymin = min, lower = q1, middle = median, upper = q3, ymax = max),
stat = "identity", fill = "#371ea3", color = "#371ea3", width = 0.25)
}

if (!is.null(attr(x, "loss_name"))) {
y_lab <- paste(attr(x, "loss_name")[1], "loss after permutations")
} else {
y_lab <- "Loss function after variable's permutations"
}
# facets have fixed space, can be resolved with ggforce https://github.com/tidyverse/ggplot2/issues/2933
pl + coord_flip() +
scale_color_manual(values = DALEX::colors_discrete_drwhy(nlabels)) +
facet_wrap(~label, ncol = 1, scales = "free_y") + DALEX::theme_drwhy_vertical() +
ylab(y_lab) + xlab("") +
labs(title = title, subtitle = subtitle) +
theme(legend.position = "none")

}
Loading

0 comments on commit facbf3f

Please sign in to comment.