Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#8 [2/4]: Model fit helpers #18

Merged
merged 41 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
99272a3
Add `k` argument with documentation
zsusswein Jun 8, 2024
ec173d6
Default basis dimension selector and documentation
zsusswein Jun 10, 2024
80a451d
Implement and test setting k
zsusswein Jun 11, 2024
0b5cd8c
Model formula helper and documentation
zsusswein Jun 12, 2024
59fac3f
Grammar tweaks to `dimensionality_heuristic()`
zsusswein Jun 22, 2024
feda446
Remove `k` documentation from `RtGam()`
zsusswein Jun 22, 2024
dafc329
Better explanation of why use piecewise for `k`
zsusswein Jun 22, 2024
8957487
Wording tweaks and line-length formatting
zsusswein Jun 22, 2024
4a0e5db
dimensionality_heuristic -> smooth_dim_heuristic
zsusswein Jun 22, 2024
c6b76cb
Typo
zsusswein Jun 22, 2024
48da2d6
Refactor penalty basis selector into public func
zsusswein Jun 24, 2024
8783d78
Drop language linking groups to geography
zsusswein Jun 24, 2024
d4010f6
pre-commit
github-actions[bot] Jun 24, 2024
661483a
Small change to trigger GHA
zsusswein Jun 24, 2024
1c10e44
pre-commit
github-actions[bot] Jun 24, 2024
fda5165
Fix WARNING from usage
zsusswein Jun 24, 2024
0e60e5e
Drop obsolete comment
zsusswein Jul 2, 2024
cf29e5c
Implement model fitting with `{mgcv}`
zsusswein Jun 30, 2024
d87f361
Suppress public docs of internal function
zsusswein Jun 30, 2024
afa8102
Add checks and warnings for unwise inputs
zsusswein Jul 2, 2024
24d310c
Refactor to S3 methods for fitting backends
zsusswein Jul 6, 2024
6d841d1
Add doc for missing param
zsusswein Jul 6, 2024
046d055
Explicitly namespace `modifyList()`
zsusswein Jul 6, 2024
19ae8f0
Clarify documentation
zsusswein Jul 6, 2024
55ded03
Test warnings throw for suboptimal params
zsusswein Jul 6, 2024
bb0c13d
Default args in S3 methods w/ user-supplied in ...
zsusswein Jul 10, 2024
318fa9a
Move backend check from input val to S3 dispatch
zsusswein Jul 10, 2024
5b1085d
Whitespace
zsusswein Aug 26, 2024
3962347
Move do.call() outside of fit_model()
zsusswein Aug 26, 2024
8591b8a
Dynamically find methods for `fit_model()`
zsusswein Aug 27, 2024
93f4402
Minimal working print method + RtGam() return
zsusswein Aug 27, 2024
e9f1911
Add some basic diagnostic checks
zsusswein Aug 27, 2024
b19258f
Clean up existing tests
zsusswein Aug 28, 2024
6d2af44
Tests for print and diagnostics
zsusswein Aug 28, 2024
25d452e
Document `check_diagnostics()`
zsusswein Aug 28, 2024
b965903
Update R/RtGam.R
zsusswein Aug 29, 2024
84b8d92
pre-commit
github-actions[bot] Aug 29, 2024
5a12ee8
Move `{withr}` to suggests
zsusswein Aug 29, 2024
0c849f6
DRY diagnostic functionality
zsusswein Aug 29, 2024
6bfb09d
`pkg::func()` -> `func()` in `@examples`
zsusswein Aug 29, 2024
eed7b48
Rename `format_for_return()` -> `new_RtGam()`
zsusswein Aug 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,6 @@ Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
Imports:
cli,
glue,
mgcv,
rlang
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Generated by roxygen2: do not edit by hand

export(RtGam)
export(dimensionality_heuristic)
importFrom(rlang,abort)
164 changes: 160 additions & 4 deletions R/RtGam.R
Original file line number Diff line number Diff line change
@@ -1,23 +1,90 @@
#' Stub: Will become fitting/class constructor
#' Fit a generalized additive model to incident cases
#'
#' # Model specification
#' Incident cases are modeled as a smooth function of time with generalized
#' additive models (GAMs). [RtGam] always fits a GAM, predicting incident
#' cases as a smooth trend of time. However, the model adapts the penalty on
#' wiggliness over time, allowing for changing epidemic dynamics.
#'
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
#' If more than three weeks of data are available, [RtGam] will fit a GAM with
#' an adaptive spline basis. This basis is so named because it allows the
#' wiggliness penalization to vary over time. Some parts of the fit can be more
#' or less wiggly than other parts. If one part of the timeseries has a sudden
#' change in trend while another part shows a smooth increase, the model can
#' fit both components without smoothing away sharp changes or introducing
#' additional artificial wiggliness.
#'
#' The model introduces an additional penalty basis dimension for each
#' additional 21 days of observed data. A timeseries of 20 or fewer days would
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
#' have the same penalty the whole period, a timeseries of 21 to 42 days would
#' smoothly interpolate between two penalties, and so on for each additional
#' 21 day period. This adaptive penalty increases the computational cost of the
#' model, but allows for a single model to adapt to changing epidemic dynamics.
#'
#' In the special case of 20 or fewer oberved days, the model will use a single
#' penalty over the whole period and use a thin-plate spline as the smoothing
#' basis. The adaptive spline can only use a P-spline smoothing basis. The thin
#' plate spline generally has better performance and so [RtGam] uses it in this
#' special single-penalty case.
#'
#' # Setting k
#' The argument `k` governs the _total_ basis dimension for the penalized
#' regression spline model used by `RtGam`. The model is composed of one or
#' more smooth predictors, depending the specifics of the model specification.
#' Each smooth predictor is penalized and has its own basis
#' dimension. The basis dimension controls the maximum degrees of
#' freedom (and by proxy the "wiggliness") of the smooth. [RtGam]'s `k`
#' argument controls the total degrees of freedom available to the different
#' smooth predictors. In a simple model with only one smooth predictor, all the
#' degrees of freedom from `k` would be applied to that single smooth. In a
#' more complex model composed of multiple smooth predictors, the total degrees
#' degrees of freedom made available by `k` would be partitioned between the
#' different smooths.
#'
#' In practice, GAMs penalize the wiggliness of smooth terms, so the fitted
#' model will use fewer effective degrees of freedom than the total available.
#' Although usually harmless to the model fit, excess degrees of freedom can
#' make models slower to fit. However, models with `k` set too low may
#' produce biased estimates or fail to converge.
#'
#' `RtGam` attempts to strike a reasonable balance between these concerns. It
#' uses a rule-of-thumb heuristic to set `k` based on the number of data points
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
#' provided. GAMs fit through [`mgcv`] (as in [RtGam]) usually fit much quicker
#' than MCMC-based approaches, so this diagnostic leans toward providing a
#' higher `k` than likely needed. This approach is a reasonable first pass, but
#' is not a substitute for the user's expert judgement. If the data exhibit a
#' sharp change in epidemic trend or the initial [RtGam] fit fails to converge,
#' it would be reasonable to fit the model with `k` higher than the default.
#' `k` can be set up to the number of data points, but not higher.
#'
#' Including as a stub function to showcase control flow
#' @param cases A vector of non-negative incident case counts occurring on an
#' associated `reference_date`. Missing values (NAs) are not allowed.
#' @param reference_date The associated date on which the count of incident
#' `cases` occurred. Missing dates are not allowed and dates can only occur
#' once.
#' @param group The geographic grouping for the case/reference-date pair. Not
#' yet implemented and a value other than `NULL` will throw an error.
#' @param k An integer, the _total_ dimension of all the smoothing basis
#' functions. Defaults to `dimensionality_heuristic(length(cases))`, which
#' picks a reasonable estimate based on the number of provided data points.
#' This total dimension is partitioned between the different smooths in the
#' model. In the case of diagnostic issues in a fitted `RtGam` model,
#' increasing the value of `k` above this default and refitting the model is
#' a good first step. See the `Setting k` section
#' and [dimensionality_heuristic()] documentation for more information.
#'
#' @returns Stub function: NULL
#' @seealso [dimensionality_heuristic()] for the default basis dimension and
#' [mgcv::choose.k] for more general guidance on GAMs from `mgcv`
#' @return Stub function: NULL
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
#' @export
#' @examples
#' cases <- c(1, 2, 3)
#' reference_date <- as.Date(c("2023-01-01", "2023-01-02", "2023-01-03"))
#' mod <- RtGam::RtGam(cases, reference_date)
RtGam <- function(cases,
reference_date,
group = NULL) {
group = NULL,
k = dimensionality_heuristic(length(cases))) {
check_required_inputs_provided(
cases,
reference_date,
Expand All @@ -26,6 +93,95 @@ RtGam <- function(cases,
validate(cases, reference_date, group)

df <- prepare_inputs(cases, reference_date, group)
formula <- formula_creator(
n_timesteps = length(unique(df[["timesteps"]])),
k = k,
is_grouped = !rlang::is_null(group)
)

invisible(NULL)
}

#' Propose total basis dimensionality from number of data points
#'
#' Guess a reasonable value for the `k` argument of [RtGam] based on the number
#' of data points. This guess **may not work** and almost certainly is not the
#' optimal choice. Rather, it is a _reasonable_ first pass for many situations
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
#' and hopefully a good enough choice for most use cases. This guess leans
#' toward providing an excess number of degrees of freedom to the model. The
#' consequence is slower model fits, but a better chance of avoiding avoiding
#' non-convergence due to undersmoothing. See *When to use a different value*
#' for more guidance on use-cases where this heuristic is likely to fail and
#' alternative values may need to be chosen. Note that `k` may be a minimum of 2
#' or a maximum of the number of data points.
#'
#' # When to use a different value
#' ## Model non-convergence
#' When an [RtGam] model does not converge, a reasonable first debugging step
#' is to increase the value of `k` and refit the model. Commonly, GAMs exhibit
#' diagnostic issues when the model does not have enough flexibility to
#' represent the underlying data generating process. Increasing `k` above the
#' default heuristic guess provides more flexibility.
#'
#' However, insufficient flexibility is not the only source of non-convergence.
#' When increasing `k` does not improve the default model diagnostics, manual
#' model checking via [mgcv::gam.check()] may be needed. Also see
#' [mgcv::choose.k] for guidance.
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
#'
#' ## Slow model fits
#' [RtGam] models usually fit faster when the model has less flexibility (lower
#' values of `k`). The guess from [dimensionality_heuristic()] leans toward
#' providing excess degrees of freedom, so model fits may take a little longer
#' than needed. If models are taking a long time to converge, it would be
#' reasonable to set `k` to a small value, checking for convergence, and
#' increasing `k` if needed until the model convergences. This approach may or
#' may not be faster than simply waiting for a model with a higher `k` to fit.
#'
#' ## Very wiggly data
#' If running models in a setting where the data seem quite wiggly, exhibiting
#' sharp jumps or drops, a model with more flexibility than normal may be
#' needed. `k` should be increased to the maximum possible value. When running
#' pre-set models in production, it would also be reasonable to fix the value
#' of `k` above the default. Because GAMs penalize model wiggliness, the fit to
#' both wiggly and non-wiggly data is likely to be satisfactory, at the cost of
#' increased runtime.
#'
#' # Implementation details
#' The algorithm to pick `k` is a piecewise function. When \eqn{n \le 10}, then
#' the chosen value is \eqn{n}. When \eqn{n > 10}, then the selected value is
#' \eqn{ \lceil \sqrt{10n} \rceil }.
#' This approach is loosely inspired by Ward et al., 2021. As in Ward et al.,
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
#' the degrees of freedom of the spline is set to a reasonably high value to
#' avoid oversmoothing. The basis dimension increases with the length of the
#' timeseries. The scaled square root of the dimension of the data is used to
#' allow for the higher setup cost of the base model while still increasing the
#' available degrees of freedom when the length of the timeseries increases.
#'
#' @param n An integer, the dimension of the data.
#' @return An integer, the proposed _total_ basis dimensionality available to
#' the [RtGam] model.
#' @references Ward, Thomas, et al. "Growth, reproduction numbers and factors
#' affecting the spread of SARS-CoV-2 novel variants of concern in the UK from
#' October 2020 to July 2021: a modelling analysis." BMJ open 11.11 (2021):
#' e056636.
#' @seealso [RtGam()] for the use-case and additional documentation as well as
#' [mgcv::choose.k] for more general guidance from `mgcv`.
#' @export
#' @examples
#' cases <- 1:10
#' k <- dimensionality_heuristic(length(cases))
dimensionality_heuristic <- function(n) {
# Input checks
rlang::check_required(n, "n", call = rlang::caller_env())
check_vector(n)
check_integer(n)
check_no_missingness(n)
check_elements_above_min(n, "n", min = 1)
check_vector_length(length(n), "n", min = 1, max = 1)

if (n < 10) {
n
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
} else {
as.integer(ceiling(sqrt(10 * n)))
}
}
38 changes: 32 additions & 6 deletions R/checkers.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,29 @@
check_vector_length <- function(n, name, min, max, call = rlang::caller_env()) {
if (!rlang::is_na(min)) {
if (n < min) {
cli::cli_abort(
c("{.arg {name}} requires a minimum length of {.val {min}}",
"i" = "{.arg {name}} is of length {.val {n}}"
),
class = "RtGam_invalid_input",
call = call
)
}
}
if (!rlang::is_na(max)) {
if (n > max) {
cli::cli_abort(
c("{.arg {name}} requires a maximum length of {.val {max}}",
"i" = "{.arg {name}} is of length {.val {n}}"
),
class = "RtGam_invalid_input",
call = call
)
}
}
invisible()
}

check_vectors_equal_length <- function(cases,
reference_date,
group,
Expand Down Expand Up @@ -97,14 +123,14 @@ check_no_missingness <- function(x, arg = "x", call = rlang::caller_env()) {
}
}

check_elements_non_neg <- function(x, arg = "x", call = rlang::caller_env()) {
check_elements_above_min <- function(x, arg, min, call = rlang::caller_env()) {
# Greater than or equal to 0 or is NA
is_non_neg <- (x >= 0) | is.na(x)
if (!all(is_non_neg)) {
is_above_min <- (x >= min) | is.na(x)
if (!all(is_above_min)) {
cli::cli_abort(
c("{.arg {arg}} has negative elements",
"!" = "All elements must be 0 or greater",
"i" = "Elements {.val {which(!is_non_neg)}} are negative"
c("{.arg {arg}} has elements smaller than {.val {min}}",
"!" = "All elements must be {.val {min}} or greater",
"i" = "Elements {.val {which(!is_above_min)}} are smaller"
),
class = "RtGam_invalid_input",
call = call
Expand Down
70 changes: 70 additions & 0 deletions R/formula.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#' Build formula for `mgcv::gam()`
#'
#' Build up the formula as a string and and return a formula object meant for
#' use by [`mgcv::gam()`]. The formula components are built up based on the
#' parameters passed to [`RtGam`]. The model makes the assumptions that the
#' smoothness of epidemic dynamics does not change dramatically over a three
#' week period. For periods of three weeks or shorter, the model uses a simple
#' thin plate spline for the global epidemic trend. For periods of longer than
#' three days, the model uses an adaptive smoother with an additional penalty
#' basis for each additional three week period. The model smoothly interpolates
#' between the penalty bases and uses p-splines for the smoothing basis.
#'
#' Currently support for groups via hierarchical modeling is not supported,
#' but when implemented it will use Model GS from Pederson et al., 2019.
#'
#' @param n_timesteps Number of distinct timesteps in the dataframe
#' returned from [`prepare_inputs()`]
#' @param k Global basis dimension to be partitioned between the model smooths
#' @param is_grouped Whether to use a hierarchical model. Not yet supported.
#' @return A formula to be used by [`mgcv::gam()`]
#' @noRd
formula_creator <- function(n_timesteps, k, is_grouped) {
outcome <- "cases"
intercept <- "1"

penalty_basis_dim <- penalty_basis_creator(n_timesteps)
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
smooth_basis_dim <- smooth_basis_creator(k)

# Apply adaptive spline if 3 weeks or more of data are available
seabbs marked this conversation as resolved.
Show resolved Hide resolved
# nolint start
if (n_timesteps >= 21) {
# With adaptive basis, m refers to the order of the penalty matrix not the
# order of the smoothing penalty as it does in the other smoothing bases.
plus_global_trend <- glue::glue("+ s(timesteps,
k = {smooth_basis_dim[['global_trend']]},
m = {penalty_basis_dim},
bs = 'ad')") # nolint
} else {
# Adaptive penalty with `m = 1` is equivalent to a non-adaptive smooth but
# thin-plate performance is a bit better than p-spline, so preference to
# fall back to thin-plate.
plus_global_trend <- glue::glue("+ s(timesteps,
seabbs marked this conversation as resolved.
Show resolved Hide resolved
k = {smooth_basis_dim[['global_trend']]},
bs = 'tp')")
}
# nolint end

f <- glue::glue("{outcome} ~ {intercept} {plus_global_trend}")
stats::as.formula(f)
}

#' Create a penalty per three weeks of data for the global trend
#'
#' @inheritParams formula_creator
#' @return The penalty basis dimension for the global trend
#' @noRd
penalty_basis_creator <- function(n_timesteps) {
as.integer(floor(n_timesteps / 21) + 1)
}

#' Partition global basis dimension into components
#'
#' @param k The global basis dimension
#' @return A list with named components matching formula components
#' @noRd
smooth_basis_creator <- function(k) {
list(
"global_trend" = k
)
}
15 changes: 14 additions & 1 deletion R/validate.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ validate_cases <- function(cases, call) {
check_vector(cases, arg, call)
check_no_missingness(cases, arg, call)
check_integer(cases, arg, call)
check_elements_non_neg(cases, arg, call)
check_elements_above_min(cases, arg, min = 0, call = call)
invisible()
}

Expand All @@ -45,3 +45,16 @@ validate_group <- function(group, call) {
)
}
}

#' Used by both dimensionality_heuristic() and RtGam()
#' @noRd
validate_min_dimensionality <- function(k, call) {
arg <- "k"
check_vector(k, arg, call = call)
check_no_missingness(k, arg, call)
check_integer(k, arg, call)
check_elements_above_min(k, arg, min = 3, call = call)
check_vector_length(length(k), arg, min = 1, max = 1, call = call)

invisible()
}
Loading
Loading