Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep IPCW results in the list column format predicted by the predict() methods #937

Merged
merged 19 commits into from
Apr 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: parsnip
Title: A Common API to Modeling and Analysis Functions
Version: 1.0.4.9005
Version: 1.0.4.9006
Authors@R: c(
person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")),
person("Davis", "Vaughan", , "davis@posit.co", role = "aut"),
Expand Down Expand Up @@ -36,7 +36,7 @@ Imports:
tibble (>= 2.1.1),
tidyr (>= 1.3.0),
utils,
vctrs (>= 0.4.1),
vctrs (>= 0.6.0),
withr
Suggests:
C50,
Expand Down
169 changes: 103 additions & 66 deletions R/ipcw.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,36 +22,41 @@ trunc_probs <- function(probs, trunc = 0.01) {
if (!is.null(eval_time)) {
eval_time <- as.numeric(eval_time)
}
eval_time_0 <- eval_time
# will still propagate nulls:
eval_time <- eval_time[!is.na(eval_time)]
eval_time <- unique(eval_time)
topepo marked this conversation as resolved.
Show resolved Hide resolved
eval_time <- sort(eval_time)
eval_time <- eval_time[eval_time >= 0 & is.finite(eval_time)]
eval_time <- unique(eval_time)
if (fail && identical(eval_time, numeric(0))) {
rlang::abort(
"There were no usable evaluation times (finite, non-missing, and >= 0).",
call = NULL
)
}
if (!identical(eval_time, eval_time_0)) {
diffs <- setdiff(eval_time_0, eval_time)
msg <-
cli::pluralize(
"There {?was/were} {length(diffs)} inappropriate evaluation time point{?s} that {?was/were} removed.")
rlang::warn(msg)
}
eval_time
}

add_dot_row_to_weights <- function(dat, rows = NULL) {
if (is.null(rows)) {
dat <- add_rowindex(dat)
} else {
m <- length(rows)
n <- nrow(dat)
if (m != n) {
rlang::abort(
glue::glue(
"The length of 'rows' ({m}) should be equal to the number of rows in 'data' ({n})"
)
)
}
dat$.row <- rows
.check_pred_col <- function(x, call = rlang::env_parent()) {
if (!any(names(x) == ".pred")) {
rlang::abort("The input should have a list column called `.pred`.", call = call)
}
if (!is.list(x$.pred)) {
rlang::abort("The input should have a list column called `.pred`.", call = call)
}
dat
req_cols <- c(".eval_time", ".pred_survival")
if (!all(req_cols %in% names(x$.pred[[1]]))) {
msg <- paste0("The `.pred` tibbles should have columns: ",
paste0("'", req_cols, "'", collapse = ", "))
rlang::abort(msg, call = call)
}
invisible(NULL)
}

.check_censor_model <- function(x) {
Expand All @@ -73,7 +78,7 @@ add_dot_row_to_weights <- function(dat, rows = NULL) {
# We need to use the time of analysis to determine what time to use to evaluate
# the IPCWs.

graf_weight_time <- function(surv_obj, eval_time, rows = NULL, eps = 10^-10) {
graf_weight_time_vec <- function(surv_obj, eval_time, eps = 10^-10) {
event_time <- .extract_surv_time(surv_obj)
status <- .extract_surv_status(surv_obj)
is_event_before_t <- event_time <= eval_time & status == 1
Expand All @@ -85,15 +90,14 @@ graf_weight_time <- function(surv_obj, eval_time, rows = NULL, eps = 10^-10) {
weight_time <- rep(NA_real_, length(event_time))

# A real event prior to eval_time (Graf category 1)
weight_time[is_event_before_t] <- event_time[is_event_before_t] - eps
weight_time <- ifelse(is_event_before_t, event_time - eps, weight_time)

# Observed time greater than eval_time (Graf category 2)
weight_time[is_censored] <- eval_time - eps
weight_time <- ifelse(is_censored, eval_time - eps, weight_time)

weight_time <- ifelse(weight_time < 0, 0, weight_time)

res <- tibble::tibble(surv = surv_obj, weight_time = weight_time, eval_time)
add_dot_row_to_weights(res, rows)
weight_time
}

# ------------------------------------------------------------------------------
Expand All @@ -102,24 +106,28 @@ graf_weight_time <- function(surv_obj, eval_time, rows = NULL, eps = 10^-10) {
#' The method of Graf _et al_ (1999) is used to compute weights at specific
#' evaluation times that can be used to help measure a model's time-dependent
#' performance (e.g. the time-dependent Brier score or the area under the ROC
#' curve).
#' @param data A data frame with a column containing a [survival::Surv()] object.
#' @param predictors Not currently used. A potential future slot for models with
#' informative censoring based on columns in `data`.
#' @param rows An optional integer vector with length equal to the number of
#' rows in `data` that is used to index the original data. The default is to
#' use a fresh index on data (i.e. `1:nrow(data)`).
#' @param eval_time A vector of finite, non-negative times at which to
#' compute the probability of censoring and the corresponding weights.
#' curve). This is an internal function.
#'
#' @param predictions A data frame with a column containing a [survival::Surv()]
topepo marked this conversation as resolved.
Show resolved Hide resolved
#' object as well as a list column called `.pred` that contains the data
#' structure produced by [predict.model_fit()].
#' @param cens_predictors Not currently used. A potential future slot for models with
#' informative censoring based on columns in `predictions`.
#' @param object A fitted parsnip model object or fitted workflow with a mode
#' of "censored regression".
#' @param trunc A potential lower bound for the probability of censoring to avoid
#' very large weight values.
#' @param eps A small value that is subtracted from the evaluation time when
#' computing the censoring probabilities. See Details below.
#' @return A tibble with columns `.row`, `eval_time`, `.prob_cens` (the
#' probability of being censored just prior to the evaluation time), and
#' `.weight_cens` (the inverse probability of censoring weight).
#' @return The same data are returned with the `pred` tibbles containing
#' several new columns:
#'
#' - `.weight_time`: the time at which the inverse censoring probability weights
topepo marked this conversation as resolved.
Show resolved Hide resolved
#' are computed. This is a function of the observed time and the time of
#' analysis (i.e., `eval_time`). See Details for more information.
#' - `.pred_censored`: the probability of being censored at `.weight_time`.
#' - `.weight_censored`: The inverse of the censoring probability.
#'
#' @details
#'
#' A probability that the data are censored immediately prior to a specific
Expand Down Expand Up @@ -155,13 +163,21 @@ graf_weight_time <- function(surv_obj, eval_time, rows = NULL, eps = 10^-10) {
#' The `eps` argument is used to avoid information leakage when computing the
#' censoring probability. Subtracting a small number avoids using data that
#' would not be known at the time of prediction. For example, if we are making
#' survival probability predictions at `eval_time = 3.0`, we would not know the
#' survival probability predictions at `eval_time = 3.0`, we would _not_ know the
#' about the probability of being censored at that exact time (since it has not
#' occurred yet).
#'
#' When creating weights by inverting probabilities, there is the risk that a few
#' cases will have severe outliers due to probabilities close to zero. To
#' mitigate this, the `trunc` argument can be used to put a cap on the weights.
#' If the smallest probability is greater than `trunc`, the probabilities with
#' values less than `trunc` are given that value. Otherwise, `trunc` is
#' adjusted to be half of the smallest probability and that value is used as the
#' lower bound..
#'
#' Note that if there are `n` rows in `data` and `t` time points, the resulting
#' data has `n * t` rows. Computations will not easily scale well as `t` becomes
#' large.
#' data, once unnested, has `n * t` rows. Computations will not easily scale
#' well as `t` becomes very large.
#' @references Graf, E., Schmoor, C., Sauerbrei, W. and Schumacher, M. (1999),
#' Assessment and comparison of prognostic classification schemes for survival
#' data. _Statist. Med._, 18: 2529-2545.
Expand All @@ -185,49 +201,70 @@ graf_weight_time <- function(surv_obj, eval_time, rows = NULL, eps = 10^-10) {
#' @export
#' @rdname censoring_weights
.censoring_weights_graf.workflow <- function(object,
data,
eval_time,
rows = NULL,
predictors = NULL,
predictions,
cens_predictors = NULL,
trunc = 0.05, eps = 10^-10, ...) {
if (is.null(object$fit$fit)) {
rlang::abort("The workflow does not have a model fit object.", call = FALSE)
rlang::abort("The workflow does not have a model fit object.")
}
.censoring_weights_graf(object$fit$fit, data, eval_time, rows, predictors, trunc, eps)
.censoring_weights_graf(object$fit$fit, predictions, cens_predictors, trunc, eps)
}

#' @export
#' @rdname censoring_weights
.censoring_weights_graf.model_fit <- function(object,
data,
eval_time,
rows = NULL,
predictors = NULL,
predictions,
cens_predictors = NULL,
trunc = 0.05, eps = 10^-10, ...) {
rlang::check_dots_empty()
.check_censor_model(object)
if (!is.null(predictors)) {
rlang::warn("The 'predictors' argument to the survival weighting function is not currently used.", call = FALSE)
truth <- .find_surv_col(predictions)
.check_censored_right(predictions[[truth]])
.check_pred_col(predictions)

if (!is.null(cens_predictors)) {
msg <- "The 'cens_predictors' argument to the survival weighting function is not currently used."
rlang::warn(msg)
}
eval_time <- .filter_eval_time(eval_time)
topepo marked this conversation as resolved.
Show resolved Hide resolved
predictions$.pred <-
add_graf_weights_vec(object,
predictions$.pred,
predictions[[truth]],
trunc = trunc,
eps = eps)
predictions
}

# ------------------------------------------------------------------------------
# Helpers

add_graf_weights_vec <- function(object, .pred, surv_obj, trunc = 0.05, eps = 10^-10) {
# Expand the list column to one data frame
n <- length(.pred)
num_times <- vctrs::list_sizes(.pred)
y <- vctrs::list_unchop(.pred)
y$surv_obj <- vctrs::vec_rep_each(surv_obj, times = num_times)
names(y)[names(y) == ".time"] <- ".eval_time" # Temporary
# Compute the actual time of evaluation
y$.weight_time <- graf_weight_time_vec(y$surv_obj, y$.eval_time, eps = eps)
# Compute the corresponding probability of being censored
y$.pred_censored <- predict(object$censor_probs, time = y$.weight_time, as_vector = TRUE)
y$.pred_censored <- trunc_probs(y$.pred_censored, trunc = trunc)
# Invert the probabilities to create weights
y$.weight_censored = 1 / y$.pred_censored
# Convert back the list column format
y$surv_obj <- NULL
vctrs::vec_chop(y, sizes = num_times)
}

truth <- object$preproc$y_var
if (length(truth) != 1) {
# check_outcome() tests that the outcome column is a Surv object
rlang::abort("The event time data should be in a single column with class 'Surv'", call = FALSE)
.find_surv_col <- function(x, call = rlang::env_parent()) {
is_lst_col <- purrr::map_lgl(x, purrr::is_list)
is_surv <- purrr::map_lgl(x[!is_lst_col], .is_surv, fail = FALSE)
num_surv <- sum(is_surv)
if (num_surv != 1) {
rlang::abort("There should be a single column of class `Surv`", call = call)
}
surv_data <- dplyr::select(data, dplyr::all_of(!!truth)) %>% setNames("surv")
.check_censored_right(surv_data$surv)

purrr::map(eval_time,
~ graf_weight_time(surv_data$surv, .x, eps = eps, rows = rows)) %>%
purrr::list_rbind() %>%
dplyr::mutate(
.prob_cens = predict(object$censor_probs, time = weight_time, as_vector = TRUE),
.prob_cens = trunc_probs(.prob_cens, trunc),
.weight_cens = 1 / .prob_cens
) %>%
dplyr::select(.row, eval_time, .prob_cens, .weight_cens)
names(is_surv)[is_surv]
}

# nocov end
10 changes: 7 additions & 3 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,13 @@
#' ## Censored regression predictions
#'
#' For censored regression, a numeric vector for `eval_time` is required when
#' survival or hazard probabilities are requested. Also, when
#' `type = "linear_pred"`, censored regression models will by default be
#' formatted such that the linear predictor _increases_ with time. This may
#' survival or hazard probabilities are requested. The time values are required
#' to be unique, finite, non-missing, and non-negative. The `predict()`
#' functions will adjust the values to fit this specification by removing
#' offending points (with a warning).
#'
#' Also, when `type = "linear_pred"`, censored regression models will by default
#' be formatted such that the linear predictor _increases_ with time. This may
#' have the opposite sign as what the underlying model's `predict()` method
#' produces. Set `increasing = FALSE` to suppress this behavior.
#'
Expand Down
1 change: 1 addition & 0 deletions R/predict_hazard.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ predict_hazard.model_fit <- function(object,
)
eval_time <- time
}
eval_time <- .filter_eval_time(eval_time)

check_spec_pred_type(object, "hazard")

Expand Down
1 change: 1 addition & 0 deletions R/predict_survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ predict_survival.model_fit <- function(object,
)
eval_time <- time
}
eval_time <- .filter_eval_time(eval_time)

check_spec_pred_type(object, "survival")

Expand Down
Loading