From 0ed914ae0fb0b58a18907ab9c9cd1a92f4d8374f Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Wed, 3 Mar 2021 14:02:08 -0500 Subject: [PATCH] enable probability estimates for flexsurv for #10 --- NAMESPACE | 1 + R/surv_reg.R | 30 +++++++++----- R/surv_reg_data.R | 40 ++++++++++++++++++- ...eg_survival_probs.Rd => flexsurv_probs.Rd} | 11 +++-- 4 files changed, 67 insertions(+), 15 deletions(-) rename man/{survreg_survival_probs.Rd => flexsurv_probs.Rd} (60%) diff --git a/NAMESPACE b/NAMESPACE index 676cf40b..4faf80af 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -9,6 +9,7 @@ S3method(update,cox_reg) S3method(update,survival_reg) export(blackboost_train) export(cox_reg) +export(flexsurv_probs) export(glmnet_fit_wrapper) export(survival_reg) export(survreg_hazard_probs) diff --git a/R/surv_reg.R b/R/surv_reg.R index a3b48cb2..bfddf3d9 100644 --- a/R/surv_reg.R +++ b/R/surv_reg.R @@ -201,6 +201,23 @@ flexsurv_quant <- function(results, object) { results <- purrr::map(results, setNames, c(".quantile", ".pred", ".pred_lower", ".pred_upper")) } +#' Internal function helps for parametric survival models +#' @param object A `survreg` or `flexsurvreg` object. +#' @param new_data A data frame. +#' @param .time A vector of time points +#' @return A nested tibble with column name `.pred` +#' @keywords internal +#' @export +flexsurv_probs <- function(object, new_data, .time, type = "survival") { + type <- rlang::arg_match(type, c("survival", "hazard")) + res <- summary(object, newdata = new_data, type = type, t = .time) + res <- unname(res) + col_name <- rlang::sym(paste0(".pred_", type)) + res <- purrr::map(res, ~ dplyr::select(.x, time, est)) + res <- purrr::map(res, ~ setNames(.x, c(".time", col_name))) + tibble::tibble(.pred = res) +} + # ------------------------------------------------------------------------------ # helpers for survreg prediction @@ -208,17 +225,12 @@ survreg_survival <- function(location, object, time, scale = object$scale, .time distr <- object$dist tibble::tibble( .time = .time, - .prob_survival = 1 - survival::psurvreg(.time, location, distribution = distr, scale, ...) + .pred_survival = 1 - survival::psurvreg(.time, location, distribution = distr, scale, ...) ) } -#' Internal function helps for parameteric survival models -#' @param object A `survreg` object. -#' @param new_data A data frame. -#' @param .time A vector of time points -#' @return A nested tibble with column name `.pred` -#' @keywords internal #' @export +#' @rdname flexsurv_probs survreg_survival_probs <- function(object, new_data, .time) { lp_pred <- predict(object, new_data, type = "lp") res <- purrr::map(lp_pred, survreg_survival, object = object, .time = .time) @@ -232,12 +244,12 @@ survreg_hazard <- function(location, object, scale = object$scale, .time, ...) { (1 - survival::psurvreg(.time, location, distribution = distr, scale, ...)) tibble::tibble( .time = .time, - .prob_hazard = prob + .pred_hazard = prob ) } #' @export -#' @rdname survreg_survival_probs +#' @rdname flexsurv_probs survreg_hazard_probs <- function(object, new_data, .time) { lp_pred <- predict(object, new_data, type = "lp") res <- purrr::map(lp_pred, survreg_hazard, object = object, .time = .time) diff --git a/R/surv_reg_data.R b/R/surv_reg_data.R index 3b559473..9d0146aa 100644 --- a/R/surv_reg_data.R +++ b/R/surv_reg_data.R @@ -55,7 +55,7 @@ make_surv_reg_survival <- function() { model = "survival_reg", eng = "survival", mode = "censored regression", - type = "numeric", + type = "time", value = list( pre = NULL, post = NULL, @@ -167,7 +167,7 @@ make_surv_reg_flexsurv <- function() { model = "survival_reg", eng = "flexsurv", mode = "censored regression", - type = "numeric", + type = "time", value = list( pre = NULL, post = flexsurv_mean, @@ -199,6 +199,42 @@ make_surv_reg_flexsurv <- function() { ) ) ) + + parsnip::set_pred( + model = "survival_reg", + eng = "flexsurv", + mode = "censored regression", + type = "hazard", + value = list( + pre = NULL, + post = NULL, + func = c(pkg = "censored", fun = "flexsurv_probs"), + args = + list( + object = expr(object$fit), + new_data = expr(new_data), + type = "hazard" + ) + ) + ) + + parsnip::set_pred( + model = "survival_reg", + eng = "flexsurv", + mode = "censored regression", + type = "survival", + value = list( + pre = NULL, + post = NULL, + func = c(pkg = "censored", fun = "flexsurv_probs"), + args = + list( + object = expr(object$fit), + new_data = expr(new_data), + type = "survival" + ) + ) + ) } # nocov end diff --git a/man/survreg_survival_probs.Rd b/man/flexsurv_probs.Rd similarity index 60% rename from man/survreg_survival_probs.Rd rename to man/flexsurv_probs.Rd index 18d53376..ce37678a 100644 --- a/man/survreg_survival_probs.Rd +++ b/man/flexsurv_probs.Rd @@ -1,16 +1,19 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/surv_reg.R -\name{survreg_survival_probs} +\name{flexsurv_probs} +\alias{flexsurv_probs} \alias{survreg_survival_probs} \alias{survreg_hazard_probs} -\title{Internal function helps for parameteric survival models} +\title{Internal function helps for parametric survival models} \usage{ +flexsurv_probs(object, new_data, .time, type = "survival") + survreg_survival_probs(object, new_data, .time) survreg_hazard_probs(object, new_data, .time) } \arguments{ -\item{object}{A \code{survreg} object.} +\item{object}{A \code{survreg} or \code{flexsurvreg} object.} \item{new_data}{A data frame.} @@ -20,6 +23,6 @@ survreg_hazard_probs(object, new_data, .time) A nested tibble with column name \code{.pred} } \description{ -Internal function helps for parameteric survival models +Internal function helps for parametric survival models } \keyword{internal}