Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Survival censoring weights #897

Merged
merged 20 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: parsnip
Title: A Common API to Modeling and Analysis Functions
Version: 1.0.4.9003
Version: 1.0.4.9004
Authors@R: c(
person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")),
person("Davis", "Vaughan", , "davis@posit.co", role = "aut"),
Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Generated by roxygen2: do not edit by hand

S3method(.censoring_weights_graf,default)
S3method(.censoring_weights_graf,model_fit)
S3method(.censoring_weights_graf,workflow)
S3method(augment,model_fit)
S3method(autoplot,glmnet)
S3method(autoplot,model_fit)
Expand Down Expand Up @@ -144,6 +147,7 @@ S3method(varying_args,model_spec)
S3method(varying_args,recipe)
S3method(varying_args,step)
export("%>%")
export(.censoring_weights_graf)
export(.check_glmnet_penalty_fit)
export(.check_glmnet_penalty_predict)
export(.cols)
Expand Down
199 changes: 199 additions & 0 deletions R/ipcw.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ trunc_probs <- function(probs, trunc = 0.01) {
}

.filter_eval_time <- function(eval_time, fail = TRUE) {
if (!is.null(eval_time)) {
eval_time <- as.numeric(eval_time)
}
# will still propagate nulls:
eval_time <- eval_time[!is.na(eval_time)]
eval_time <- unique(eval_time)
Expand All @@ -32,3 +35,199 @@ trunc_probs <- function(probs, trunc = 0.01) {
}
eval_time
}

add_dot_row_to_weights <- function(dat, rows = NULL) {
if (is.null(rows)) {
dat <- add_rowindex(dat)
} else {
m <- length(rows)
n <- nrow(dat)
if (m != n) {
rlang::abort(
glue::glue(
"The length of 'rows' ({m}) should be equal to the number of rows in 'data' ({n})"
)
)
}
dat$.row <- rows
}
dat
}

.check_censor_model <- function(x) {
nms <- names(x)
if (!any(nms == "censor_probs")) {
rlang::abort("Please refit the model with parsnip version 1.0.4 or greater.")
}
invisible(NULL)
}

# nocov start
# these are tested in extratests
# ------------------------------------------------------------------------------
# Brier score helpers. Most of this is based off of Graf, E., Schmoor, C.,
# Sauerbrei, W. and Schumacher, M. (1999), Assessment and comparison of
# prognostic classification schemes for survival data. _Statist. Med._, 18:
# 2529-2545.

# We need to use the time of analysis to determine what time to use to evaluate
# the IPCWs.

graf_weight_time <- function(surv_obj, eval_time, rows = NULL, eps = 10^-10) {
event_time <- .extract_surv_time(surv_obj)
status <- .extract_surv_status(surv_obj)
is_event_before_t <- event_time <= eval_time & status == 1
is_censored <- event_time > eval_time

# Three possible contributions to the statistic from Graf 1999

# Censoring time before eval_time, no contribution (Graf category 3)
weight_time <- rep(NA_real_, length(event_time))

# A real event prior to eval_time (Graf category 1)
weight_time[is_event_before_t] <- event_time[is_event_before_t] - eps

# Observed time greater than eval_time (Graf category 2)
weight_time[is_censored] <- eval_time - eps

weight_time <- ifelse(weight_time < 0, 0, weight_time)

res <- tibble::tibble(surv = surv_obj, weight_time = weight_time, eval_time)
add_dot_row_to_weights(res, rows)
}

# ------------------------------------------------------------------------------
#' Calculations for inverse probability of censoring weights (IPCW)
#'
#' The method of Graf _et al_ (1999) is used to compute weights at specific
#' evaluation times that can be used to help measure a model's time-dependent
#' performance (e.g. the time-dependent Brier score or the area under the ROC
#' curve).
#' @param data A data frame with a column containing a [survival::Surv()] object.
#' @param predictors Not currently used. A potential future slot for models with
#' informative censoring based on columns in `data`.
#' @param rows An optional integer vector with length equal to the number of
#' rows in `data` that is used to index the original data. The default is to
#' use a fresh index on data (i.e. `1:nrow(data)`).
#' @param eval_time A vector of finite, non-negative times at which to
#' compute the probability of censoring and the corresponding weights.
#' @param object A fitted parsnip model object or fitted workflow with a mode
#' of "censored regression".
#' @param trunc A potential lower bound for the probability of censoring to avoid
#' very large weight values.
#' @param eps A small value that is subtracted from the evaluation time when
#' computing the censoring probabilities. See Details below.
#' @return A tibble with columns `.row`, `eval_time`, `.prob_cens` (the
#' probability of being censored just prior to the evaluation time), and
#' `.weight_cens` (the inverse probability of censoring weight).
#' @details
#'
#' A probability that the data are censored immediately prior to a specific
#' time is computed. To do this, we must determine what time to
#' make the prediction. There are two time values for each row of the data set:
#' the observed time (either censored or not) and the time that the model is
#' being evaluated at (e.g. the survival function prediction at some time point),
#' which is constant across rows. .
#'
#' From Graf _et al_ (1999) there are three cases:
#'
#' - If the observed time is a censoring time and that is before the
#' evaluation time, the data point should make no contribution to the
#' performance metric (their "category 3"). These values have a missing
#' value for their probability estimate (and also for their weight column).
#'
#' - If the observed time corresponds to an actual event, and that time is
#' prior to the evaluation time (category 1), the probability of being
#' censored is predicted at the observed time (minus an epsilon).
#'
#' - If the observed time is _after_ the evaluation time (category 2), regardless of
#' the status, the probability of being censored is predicted at the evaluation
#' time (minus an epsilon).
#'
#' The epsilon is used since, we would not have actual information at time `t`
#' for a data point being predicted at time `t` (only data prior to time `t`
#' should be available).
#'
#' After the censoring probability is computed, the `trunc` option is used to
#' avoid using numbers pathologically close to zero. After this, the weight is
#' computed by inverting the censoring probability.
#'
#' The `eps` argument is used to avoid information leakage when computing the
#' censoring probability. Subtracting a small number avoids using data that
#' would not be known at the time of prediction. For example, if we are making
#' survival probability predictions at `eval_time = 3.0`, we would not know the
#' about the probability of being censored at that exact time (since it has not
#' occurred yet).
#'
#' Note that if there are `n` rows in `data` and `t` time points, the resulting
#' data has `n * t` rows. Computations will not easily scale well as `t` becomes
#' large.
#' @references Graf, E., Schmoor, C., Sauerbrei, W. and Schumacher, M. (1999),
#' Assessment and comparison of prognostic classification schemes for survival
#' data. _Statist. Med._, 18: 2529-2545.
#' @export
#' @name censoring_weights
#' @keywords internal
.censoring_weights_graf <- function(object, ...) {
UseMethod(".censoring_weights_graf")
}

#' @export
#' @rdname censoring_weights
.censoring_weights_graf.default <- function(object, ...) {
cls <- paste0("'", class(object), "'", collapse = ", ")
msg <- paste("There is no `.censoring_weights_graf()` method for objects with class(es):",
cls)
rlang::abort(msg)
}


#' @export
#' @rdname censoring_weights
.censoring_weights_graf.workflow <- function(object,
data,
eval_time,
rows = NULL,
predictors = NULL,
trunc = 0.05, eps = 10^-10, ...) {
if (is.null(object$fit$fit)) {
rlang::abort("The workflow does not have a model fit object.", call = FALSE)
}
.censoring_weights_graf(object$fit$fit, data, eval_time, rows, predictors, trunc, eps)
}

#' @export
#' @rdname censoring_weights
.censoring_weights_graf.model_fit <- function(object,
data,
eval_time,
rows = NULL,
predictors = NULL,
trunc = 0.05, eps = 10^-10, ...) {
rlang::check_dots_empty()
.check_censor_model(object)
if (!is.null(predictors)) {
rlang::warn("The 'predictors' argument to the survival weighting function is not currently used.", call = FALSE)
}
eval_time <- .filter_eval_time(eval_time)

truth <- object$preproc$y_var
if (length(truth) != 1) {
# check_outcome() tests that the outcome column is a Surv object
rlang::abort("The event time data should be in a single column with class 'Surv'", call = FALSE)
}
surv_data <- dplyr::select(data, dplyr::all_of(!!truth)) %>% setNames("surv")
.check_censored_right(surv_data$surv)

purrr::map(eval_time,
~ graf_weight_time(surv_data$surv, .x, eps = eps, rows = rows)) %>%
purrr::list_rbind() %>%
dplyr::mutate(
.prob_cens = predict(object$censor_probs, time = weight_time, as_vector = TRUE),
.prob_cens = trunc_probs(.prob_cens, trunc),
.weight_cens = 1 / .prob_cens
) %>%
dplyr::select(.row, eval_time, .prob_cens, .weight_cens)
}

# nocov end
2 changes: 1 addition & 1 deletion R/parsnip-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ utils::globalVariables(
"compute_intercept", "remove_intercept", "estimate", "term",
"call_info", "component", "component_id", "func", "tunable", "label",
"pkg", ".order", "item", "tunable", "has_ext", "id", "weights", "has_wts",
"protect", "s"
"protect", "weight_time", ".prob_cens", ".weight_cens", "s"
)
)

Expand Down
116 changes: 116 additions & 0 deletions man/censoring_weights.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions tests/testthat/helper-objects.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,16 @@ caught_ctrl <- control_parsnip(verbosity = 1, catch = TRUE)
quiet_ctrl <- control_parsnip(verbosity = 0, catch = TRUE)

run_glmnet <- utils::compareVersion('3.6.0', as.character(getRversion())) > 0

# ------------------------------------------------------------------------------
# for skips

is_tf_ok <- function() {
tf_ver <- try(tensorflow::tf_version(), silent = TRUE)
if (inherits(tf_ver, "try-error")) {
res <- FALSE
} else {
res <- !is.null(tf_ver)
}
res
}
Loading