Skip to content

Commit

Permalink
feat: code adaptions to treeshap computation using pre-difned surviva…
Browse files Browse the repository at this point in the history
…l times from explain-object

addresses ModelOriented#75
  • Loading branch information
kapsner committed Apr 8, 2023
1 parent ad886bb commit d2c16a6
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 9 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: survex
Title: Explainable Machine Learning in Survival Analysis
Version: 1.0.0.9001
Version: 1.0.0.9002
Authors@R:
c(
person("Mikołaj", "Spytek", email = "mikolajspytek@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-7111-2286")),
Expand Down
9 changes: 2 additions & 7 deletions R/surv_shap.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,7 @@ surv_shap <- function(explainer,
}

if (calculation_method == "treeshap") {
if (inherits(explainer$model, "ranger")) {
# hack to use rf-model's death times as explainer death times, as
# treeshap::ranger_surv_fun.unify extracts survival time-points directly
# from the ranger object for calculating the predictions
explainer$times <- explainer$model$unique.death.times
} else {
if (!inherits(explainer$model, "ranger")) {
stop("Calculation method `treeshap` is currently only implemented for `ranger`.")
}
}
Expand Down Expand Up @@ -262,7 +257,7 @@ use_treeshap <- function(explainer, new_observation, ...){
# UNIFY_FUN to prepare code for easy Integration of other ml algorithms
# that are supported by treeshap
UNIFY_FUN <- treeshap::ranger_surv.unify
unify_append_args <- list(type = "survival")
unify_append_args <- list(type = "survival", times = explainer$times)
} else {
stop("Support for `treeshap` is currently only implemented for `ranger`.")
}
Expand Down
1 change: 0 additions & 1 deletion tests/testthat/test-predict_parts.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ test_that("survshap explanations work", {
)
plot(parts_ranger_kernelshap)


parts_src <- predict_parts(rsf_src_exp, veteran[3, !colnames(veteran) %in% c("time", "status")])
plot(parts_src)

Expand Down

0 comments on commit d2c16a6

Please sign in to comment.