Skip to content

Commit

Permalink
enable probability estimates for flexsurv for #10
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Mar 3, 2021
1 parent 0167f89 commit 0ed914a
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 15 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ S3method(update,cox_reg)
S3method(update,survival_reg)
export(blackboost_train)
export(cox_reg)
export(flexsurv_probs)
export(glmnet_fit_wrapper)
export(survival_reg)
export(survreg_hazard_probs)
Expand Down
30 changes: 21 additions & 9 deletions R/surv_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -201,24 +201,36 @@ flexsurv_quant <- function(results, object) {
results <- purrr::map(results, setNames, c(".quantile", ".pred", ".pred_lower", ".pred_upper"))
}

#' Internal function helps for parametric survival models
#' @param object A `survreg` or `flexsurvreg` object.
#' @param new_data A data frame.
#' @param .time A vector of time points
#' @return A nested tibble with column name `.pred`
#' @keywords internal
#' @export
flexsurv_probs <- function(object, new_data, .time, type = "survival") {
type <- rlang::arg_match(type, c("survival", "hazard"))
res <- summary(object, newdata = new_data, type = type, t = .time)
res <- unname(res)
col_name <- rlang::sym(paste0(".pred_", type))
res <- purrr::map(res, ~ dplyr::select(.x, time, est))
res <- purrr::map(res, ~ setNames(.x, c(".time", col_name)))
tibble::tibble(.pred = res)
}

# ------------------------------------------------------------------------------
# helpers for survreg prediction

survreg_survival <- function(location, object, time, scale = object$scale, .time, ...) {
distr <- object$dist
tibble::tibble(
.time = .time,
.prob_survival = 1 - survival::psurvreg(.time, location, distribution = distr, scale, ...)
.pred_survival = 1 - survival::psurvreg(.time, location, distribution = distr, scale, ...)
)
}

#' Internal function helps for parameteric survival models
#' @param object A `survreg` object.
#' @param new_data A data frame.
#' @param .time A vector of time points
#' @return A nested tibble with column name `.pred`
#' @keywords internal
#' @export
#' @rdname flexsurv_probs
survreg_survival_probs <- function(object, new_data, .time) {
lp_pred <- predict(object, new_data, type = "lp")
res <- purrr::map(lp_pred, survreg_survival, object = object, .time = .time)
Expand All @@ -232,12 +244,12 @@ survreg_hazard <- function(location, object, scale = object$scale, .time, ...) {
(1 - survival::psurvreg(.time, location, distribution = distr, scale, ...))
tibble::tibble(
.time = .time,
.prob_hazard = prob
.pred_hazard = prob
)
}

#' @export
#' @rdname survreg_survival_probs
#' @rdname flexsurv_probs
survreg_hazard_probs <- function(object, new_data, .time) {
lp_pred <- predict(object, new_data, type = "lp")
res <- purrr::map(lp_pred, survreg_hazard, object = object, .time = .time)
Expand Down
40 changes: 38 additions & 2 deletions R/surv_reg_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ make_surv_reg_survival <- function() {
model = "survival_reg",
eng = "survival",
mode = "censored regression",
type = "numeric",
type = "time",
value = list(
pre = NULL,
post = NULL,
Expand Down Expand Up @@ -167,7 +167,7 @@ make_surv_reg_flexsurv <- function() {
model = "survival_reg",
eng = "flexsurv",
mode = "censored regression",
type = "numeric",
type = "time",
value = list(
pre = NULL,
post = flexsurv_mean,
Expand Down Expand Up @@ -199,6 +199,42 @@ make_surv_reg_flexsurv <- function() {
)
)
)

parsnip::set_pred(
model = "survival_reg",
eng = "flexsurv",
mode = "censored regression",
type = "hazard",
value = list(
pre = NULL,
post = NULL,
func = c(pkg = "censored", fun = "flexsurv_probs"),
args =
list(
object = expr(object$fit),
new_data = expr(new_data),
type = "hazard"
)
)
)

parsnip::set_pred(
model = "survival_reg",
eng = "flexsurv",
mode = "censored regression",
type = "survival",
value = list(
pre = NULL,
post = NULL,
func = c(pkg = "censored", fun = "flexsurv_probs"),
args =
list(
object = expr(object$fit),
new_data = expr(new_data),
type = "survival"
)
)
)
}

# nocov end
11 changes: 7 additions & 4 deletions man/survreg_survival_probs.Rd → man/flexsurv_probs.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 0ed914a

Please sign in to comment.