Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for flexsurvspline engine for survival_reg model spec #831

Merged
merged 13 commits into from
Nov 3, 2022
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ S3method(translate,survival_reg)
S3method(translate,svm_linear)
S3method(translate,svm_poly)
S3method(translate,svm_rbf)
S3method(tunable,survival_reg)
S3method(type_sum,model_fit)
S3method(type_sum,model_spec)
S3method(update,C5_rules)
Expand Down Expand Up @@ -316,6 +317,7 @@ importFrom(generics,fit_xy)
importFrom(generics,glance)
importFrom(generics,required_pkgs)
importFrom(generics,tidy)
importFrom(generics,tunable)
importFrom(generics,varying_args)
importFrom(ggplot2,autoplot)
importFrom(glue,glue_collapse)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# parsnip (development version)

* Adds documentation and tuning infrastructure for the new `flexsurvspline` engine for the `survival_reg()` model specification from the `censored` package (@mattwarkentin, #831).

* The matrix interface for fitting `fit_xy()` now works for the `"censored regression"` mode (#829).

* The `num_leaves` argument of `boost_tree()`s `lightgbm` engine (via the bonsai package) is now tunable.
Expand Down
2 changes: 1 addition & 1 deletion R/parsnip-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
## usethis namespace: start
#' @importFrom dplyr arrange bind_cols bind_rows collect full_join group_by
#' @importFrom dplyr mutate pull rename select starts_with summarise tally
#' @importFrom generics varying_args
#' @importFrom generics tunable varying_args
#' @importFrom glue glue_collapse
#' @importFrom pillar type_sum
#' @importFrom purrr as_vector imap imap_lgl map map_chr map_dbl map_df map_dfr
Expand Down
11 changes: 11 additions & 0 deletions R/survival_reg_flexsurvspline.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#' Flexible parametric survival regression
#'
#' [flexsurv::flexsurvspline()] fits a flexible parametric survival model.
#'
#' @includeRmd man/rmd/survival_reg_flexsurvspline.md details
#'
#' @name details_survival_reg_flexsurvspline
#' @keywords internal
NULL

# See inst/README-DOCS.md for a description of how these files are processed
20 changes: 20 additions & 0 deletions R/tunable.R
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,17 @@ brulee_multinomial_engine_args <-
brulee_mlp_engine_args %>%
dplyr::filter(name %in% c("momentum", "batch_size", "stop_iter", "class_weights"))

flexsurvspline_engine_args <-
tibble::tibble(
name = c("k"),
mattwarkentin marked this conversation as resolved.
Show resolved Hide resolved
call_info = list(
list(pkg = "dials", fun = "num_knots")
),
source = "model_spec",
component = "survival_reg",
component_id = "engine"
)

# ------------------------------------------------------------------------------

# Lazily registered in .onLoad()
Expand Down Expand Up @@ -324,5 +335,14 @@ tunable_mlp <- function(x, ...) {
res
}

#' @export
tunable.survival_reg <- function(x, ...) {
res <- NextMethod()
if (x$engine == "flexsurvspline") {
res <- add_engine_parameters(res, flexsurvspline_engine_args)
}
res
}

# nocov end

1 change: 1 addition & 0 deletions inst/models.tsv
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
"surv_reg" "regression" "flexsurv" NA
"surv_reg" "regression" "survival" NA
"survival_reg" "censored regression" "flexsurv" "censored"
"survival_reg" "censored regression" "flexsurvspline" "censored"
"survival_reg" "censored regression" "survival" "censored"
"svm_linear" "classification" "kernlab" NA
"svm_linear" "classification" "LiblineaR" NA
Expand Down
2 changes: 1 addition & 1 deletion man/details_auto_ml_h2o.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions man/details_survival_reg_flexsurv.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

80 changes: 80 additions & 0 deletions man/details_survival_reg_flexsurvspline.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions man/rmd/survival_reg_flexsurv.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ For this engine, stratification cannot be specified via [`strata()`], please see
```{r child = "template-survival-mean.Rmd"}
```

## Case weights

```{r child = "template-uses-case-weights.Rmd"}
```

## Saving fitted model objects

```{r child = "template-butcher.Rmd"}
Expand Down
7 changes: 7 additions & 0 deletions man/rmd/survival_reg_flexsurv.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ For this engine, stratification cannot be specified via [`strata()`], please see

Predictions of type `"time"` are predictions of the mean survival time.

## Case weights


This model can utilize case weights during model fitting. To use them, see the documentation in [case_weights] and the examples on `tidymodels.org`.

The `fit()` and `fit_xy()` arguments have arguments called `case_weights` that expect vectors of case weights.

## Saving fitted model objects


Expand Down
48 changes: 48 additions & 0 deletions man/rmd/survival_reg_flexsurvspline.Rmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
```{r, child = "aaa.Rmd", include = FALSE}
```

`r descr_models("survival_reg", "flexsurvspline")`

## Tuning Parameters

This model has one engine-specific tuning parameter:

* `k`: Number of knots in the spline. The default is `k = 0`.

## Translation from parsnip to the original package

`r uses_extension("survival_reg", "flexsurvspline", "censored regression")`

```{r flexsurvspline-creg}
library(censored)

survival_reg() %>%
set_engine("flexsurvspline") %>%
set_mode("censored regression") %>%
translate()
```

## Other details

The main interface for this model uses the formula method since the model specification typically involved the use of [survival::Surv()].

For this engine, stratification cannot be specified via [`strata()`], please see [flexsurv::flexsurvspline()] for alternative specifications.

```{r child = "template-survival-mean.Rmd"}
hfrick marked this conversation as resolved.
Show resolved Hide resolved
```

## Case weights

```{r child = "template-uses-case-weights.Rmd"}
```


## Saving fitted model objects

```{r child = "template-butcher.Rmd"}
```


## References

- Jackson, C. 2016. `flexsurv`: A Platform for Parametric Survival Modeling in R. _Journal of Statistical Software_, 70(8), 1 - 33.
62 changes: 62 additions & 0 deletions man/rmd/survival_reg_flexsurvspline.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@



For this engine, there is a single mode: censored regression

## Tuning Parameters

This model has one engine-specific tuning parameter:

* `k`: Number of knots in the spline. The default is `k = 0`.

## Translation from parsnip to the original package

The **censored** extension package is required to fit this model.


```r
library(censored)

survival_reg() %>%
set_engine("flexsurvspline") %>%
set_mode("censored regression") %>%
translate()
```

```
## Parametric Survival Regression Model Specification (censored regression)
##
## Computational engine: flexsurvspline
##
## Model fit template:
## flexsurv::flexsurvspline(formula = missing_arg(), data = missing_arg(),
## weights = missing_arg())
```

## Other details

The main interface for this model uses the formula method since the model specification typically involved the use of [survival::Surv()].

For this engine, stratification cannot be specified via [`strata()`], please see [flexsurv::flexsurvspline()] for alternative specifications.



Predictions of type `"time"` are predictions of the mean survival time.

## Case weights


This model can utilize case weights during model fitting. To use them, see the documentation in [case_weights] and the examples on `tidymodels.org`.

The `fit()` and `fit_xy()` arguments have arguments called `case_weights` that expect vectors of case weights.


## Saving fitted model objects


This model object contains data that are not required to make predictions. When saving the model for the purpose of prediction, the size of the saved object might be substantially reduced by using functions from the [butcher](https://butcher.tidymodels.org) package.


## References

- Jackson, C. 2016. `flexsurv`: A Platform for Parametric Survival Modeling in R. _Journal of Statistical Software_, 70(8), 1 - 33.