Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#8 [2/4]: Model fit helpers #18

Merged
merged 41 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
99272a3
Add `k` argument with documentation
zsusswein Jun 8, 2024
ec173d6
Default basis dimension selector and documentation
zsusswein Jun 10, 2024
80a451d
Implement and test setting k
zsusswein Jun 11, 2024
0b5cd8c
Model formula helper and documentation
zsusswein Jun 12, 2024
59fac3f
Grammar tweaks to `dimensionality_heuristic()`
zsusswein Jun 22, 2024
feda446
Remove `k` documentation from `RtGam()`
zsusswein Jun 22, 2024
dafc329
Better explanation of why use piecewise for `k`
zsusswein Jun 22, 2024
8957487
Wording tweaks and line-length formatting
zsusswein Jun 22, 2024
4a0e5db
dimensionality_heuristic -> smooth_dim_heuristic
zsusswein Jun 22, 2024
c6b76cb
Typo
zsusswein Jun 22, 2024
48da2d6
Refactor penalty basis selector into public func
zsusswein Jun 24, 2024
8783d78
Drop language linking groups to geography
zsusswein Jun 24, 2024
d4010f6
pre-commit
github-actions[bot] Jun 24, 2024
661483a
Small change to trigger GHA
zsusswein Jun 24, 2024
1c10e44
pre-commit
github-actions[bot] Jun 24, 2024
fda5165
Fix WARNING from usage
zsusswein Jun 24, 2024
0e60e5e
Drop obsolete comment
zsusswein Jul 2, 2024
cf29e5c
Implement model fitting with `{mgcv}`
zsusswein Jun 30, 2024
d87f361
Suppress public docs of internal function
zsusswein Jun 30, 2024
afa8102
Add checks and warnings for unwise inputs
zsusswein Jul 2, 2024
24d310c
Refactor to S3 methods for fitting backends
zsusswein Jul 6, 2024
6d841d1
Add doc for missing param
zsusswein Jul 6, 2024
046d055
Explicitly namespace `modifyList()`
zsusswein Jul 6, 2024
19ae8f0
Clarify documentation
zsusswein Jul 6, 2024
55ded03
Test warnings throw for suboptimal params
zsusswein Jul 6, 2024
bb0c13d
Default args in S3 methods w/ user-supplied in ...
zsusswein Jul 10, 2024
318fa9a
Move backend check from input val to S3 dispatch
zsusswein Jul 10, 2024
5b1085d
Whitespace
zsusswein Aug 26, 2024
3962347
Move do.call() outside of fit_model()
zsusswein Aug 26, 2024
8591b8a
Dynamically find methods for `fit_model()`
zsusswein Aug 27, 2024
93f4402
Minimal working print method + RtGam() return
zsusswein Aug 27, 2024
e9f1911
Add some basic diagnostic checks
zsusswein Aug 27, 2024
b19258f
Clean up existing tests
zsusswein Aug 28, 2024
6d2af44
Tests for print and diagnostics
zsusswein Aug 28, 2024
25d452e
Document `check_diagnostics()`
zsusswein Aug 28, 2024
b965903
Update R/RtGam.R
zsusswein Aug 29, 2024
84b8d92
pre-commit
github-actions[bot] Aug 29, 2024
5a12ee8
Move `{withr}` to suggests
zsusswein Aug 29, 2024
0c849f6
DRY diagnostic functionality
zsusswein Aug 29, 2024
6bfb09d
`pkg::func()` -> `func()` in `@examples`
zsusswein Aug 29, 2024
eed7b48
Rename `format_for_return()` -> `new_RtGam()`
zsusswein Aug 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,6 @@ Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
Imports:
cli,
glue,
mgcv,
rlang
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Generated by roxygen2: do not edit by hand

export(RtGam)
export(penalty_dim_heuristic)
export(smooth_dim_heuristic)
importFrom(rlang,abort)
234 changes: 225 additions & 9 deletions R/RtGam.R
Original file line number Diff line number Diff line change
@@ -1,31 +1,247 @@
#' Stub: Will become fitting/class constructor
#' Fit a generalized additive model to incident cases
#'
#' Incident cases are modeled as a smooth function of time with a generalized
#' additive model (GAM). The model is fit with [mgcv::gam()] and some
#' familiarity with `mgcv` may be helpful.
#'
#' # Model specification
#'
#' Incident cases (\eqn{y}) are modeled as smoothly changing over time:
#'
#' \deqn{\text{log}\{E(y)\} = \alpha + f_{\text{global}(t)}}
#'
#' where incidence is negative-binomially distributed and \eqn{f(t)} is a smooth
#' function of time.
#'
#' Including as a stub function to showcase control flow
#' @param cases A vector of non-negative incident case counts occurring on an
#' associated `reference_date`. Missing values (NAs) are not allowed.
#' @param reference_date The associated date on which the count of incident
#' `cases` occurred. Missing dates are not allowed and dates can only occur
#' once.
#' @param group The geographic grouping for the case/reference-date pair. Not
#' yet implemented and a value other than `NULL` will throw an error.
#'
#' @returns Stub function: NULL
#' @param group The grouping variable for the case/reference-date pair. Not
#' yet implemented and a value other than `NULL` will throw an error.
#' @param k An integer, the _total_ dimension of all the smoothing basis
#' functions. Defaults to `smooth_dim_heuristic(length(cases))`, which picks a
#' reasonable estimate based on the number of provided data points. This total
#' dimension is partitioned between the different smooths in the model. In the
#' case of diagnostic issues in a fitted `RtGam` model, increasing the value
#' of `k` above this default and refitting the model is a good first step. See
#' the [smooth_dim_heuristic()] documentation for more information.
#' @param m An integer, the dimension of the penalty basis for the global smooth
#' trend. If `m` is greater than 1, the smooth's wiggliness can change over
#' time. An increase in this value above the default should be done carefully.
#' See [penalty_dim_heuristic()] for more information on `m` and when to
#' consider changing the default.
#' @seealso [smooth_dim_heuristic()] more information on the smoothing basis
#' dimension and [mgcv::choose.k] for more general guidance on GAMs from
#' `mgcv`
#' @return Stub function: NULL
#' @export
#' @examples
#' cases <- c(1, 2, 3)
#' reference_date <- as.Date(c("2023-01-01", "2023-01-02", "2023-01-03"))
#' mod <- RtGam::RtGam(cases, reference_date)
RtGam <- function(cases,
reference_date,
group = NULL) {
group = NULL,
k = smooth_dim_heuristic(length(cases)),
m = penalty_dim_heuristic(length(unique(reference_date)))) {
check_required_inputs_provided(
cases,
reference_date,
group
group,
k,
m
)
validate(cases, reference_date, group)
validate(cases, reference_date, group, k, m)

df <- prepare_inputs(cases, reference_date, group)
formula <- formula_creator(
k = k,
m = m,
is_grouped = !rlang::is_null(group)
)

invisible(NULL)
}

#' Propose total smoothing basis dimension from number of data points
#'
#' Return a reasonable value for the `k` argument of [RtGam] (the _total_ smooth
#' basis dimension of the model's one or more smooth predictors) based on the
#' number of data points. The smooth basis dimension controls the maximum
#' degrees of freedom (and by proxy the "wiggliness") of the smooth predictors.
#' The estimation procedure leans toward providing an excess number of degrees
#' of freedom to the model. The consequence is slower model fits, but a better
#' chance of avoiding avoiding non-convergence due to undersmoothing. If
#' manually supplying a value to `k` rather than relying on the default
#' estimate, see *When to use a different value* for [RtGam]-specific
#' implementation guidance and [mgcv::choose.k] for more general debugging
#' guidance from the underlying model fitting package. Note that `k` may be a
#' minimum of 2 or a maximum of the number of data points.
#'
#' # How `k` is used
#'
#' The model is composed of one or more smooth predictors, depending the
#' specifics of the model specification. In a simple model with only one smooth
#' predictor, all the degrees of freedom from `k` would be applied to that
#' single smooth. In a more complex model composed of multiple smooth
#' predictors, the total degrees degrees of freedom made available by `k` would
#' be partitioned between the different smooths.
#'
#' # When to use a different value
#'
#' ## Model non-convergence
#'
#' When an [RtGam] model does not converge, a reasonable first debugging step is
#' to increase the value of `k` and refit the model. Commonly, GAMs exhibit
#' diagnostic issues when the model does not have enough flexibility to
#' represent the underlying data generating process. Increasing `k` above the
#' default estimate provides more flexibility.
#'
#' However, insufficient flexibility is not the only source of non-convergence.
#' When increasing `k` does not improve the default model diagnostics, manual
#' model checking via [mgcv::gam.check()] may be needed. Also see
#' [mgcv::choose.k] for guidance.
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
#'
#' ## Slow model fits
#'
#' [RtGam] models usually fit faster when the model has less flexibility (lower
#' values of `k`). The guess from [smooth_dim_heuristic()] leans toward
#' providing excess degrees of freedom, so model fits may take a little longer
#' than needed. If models are taking a long time to converge, it would be
#' reasonable to set `k` to a small value, checking for convergence, and
#' increasing `k` if needed until the model convergences. This approach may or
#' may not be faster than simply waiting for a model with a higher `k` to fit.
#'
#' ## Very wiggly data
#'
#' If running models in a setting where the data seem quite wiggly, exhibiting
#' sharp jumps or drops, a model with more flexibility than normal may be
#' needed. `k` should be increased to the maximum possible value. When running
#' pre-set models in production, it would also be reasonable to fix the value of
#' `k` above the default. Because GAMs penalize model wiggliness, the fit to
#' both wiggly and non-wiggly data is likely to be satisfactory, at the cost of
#' increased runtime.
#'
#' # Implementation details
#'
#' The algorithm to pick `k` is a piecewise function. When \eqn{n \le 10}, then
#' the returned value is \eqn{n}. When \eqn{n > 10}, then the returned value is
#' \eqn{ \lceil \sqrt{10n} \rceil }. This approach is loosely inspired by Ward
#' et al., 2021. As in Ward et al. the degrees of freedom of the spline (1) is
#' set to a reasonably high value to avoid oversmoothing and (2) scales with the
#' dimension of the data to accommodate changing trends over time.
#'
#' [smooth_dim_heuristic()] uses a piecewise function because each smooth
#' parameter needs its own degrees of freedom, which adds a fixed initial setup
#' cost. When the dimension of the data is small, the default value of `k`
#' increases linearly with the data to accommodate this fixed setup cost. When
#' the dimension of the data is larger, the default value of `k` increases with
#' the square root of the data to balance having sufficient basis dimension to
#' fit to changing trends over time without having so many dimensions that model
#' fits are very slow.
#'
#' @param n An integer, the dimension of the data.
#' @return An integer, the proposed _total_ smooth basis dimensionality
#' available to the [RtGam] model.
#' @references Ward, Thomas, et al. "Growth, reproduction numbers and factors
#' affecting the spread of SARS-CoV-2 novel variants of concern in the UK from
#' October 2020 to July 2021: a modelling analysis." BMJ open 11.11 (2021):
#' e056636.
#' @seealso [RtGam()] for the use-case and additional documentation as well as
#' [mgcv::choose.k] and [mgcv::gam.check] for more general guidance from
#' `mgcv`.
#' @export
#' @examples
#' cases <- 1:10
#' k <- smooth_dim_heuristic(length(cases))
smooth_dim_heuristic <- function(n) {
# Input checks
rlang::check_required(n, "n", call = rlang::caller_env())
check_vector(n)
check_integer(n)
check_no_missingness(n)
check_elements_above_min(n, "n", min = 1)
check_vector_length(length(n), "n", min = 1, max = 1)

if (n < 10) {
n
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
} else {
as.integer(ceiling(sqrt(10 * n)))
}
}

#' Propose a penalty basis dimension based on the number of observed dates
#'
#' Return a reasonable value for the `m` argument of [RtGam()] based on the
#' number of dates that cases are observed. The `m` argument controls the
#' dimension of the smoothing penalty basis for the model's global smooth trend
#' (see the *Model specification* section of the [RtGam()] documentation for
#' more information about the global trend). The penalty basis dimension
#' controls how much the wiggliness of the global smooth trend can vary over
#' time. Higher values of `m` help the model to adapt quickly to different
#' epidemic regimes, but are computationally costly.
#'
#' # How `m` is used
#'
#' The parameter `m` controls the penalty basis dimension of the model's global
#' smooth trend. If `m` is 1, there will be single constant penalty on
#' wiggliness over the entire smooth and [RtGam] will use a thin-plate spline
#' basis for its superior performance in single-penalty settings. If `m` is 2 or
#' more, the model will use `m` distinct penalties on the smooth trend's
#' wiggliness and use an adaptive spline basis. The realized penalty at each
#' timepoint smoothly interpolates between the `m` estimated wiggliness
#' penalties. This adaptive penalty increases the computational cost of the
#' model, but allows for a single model to adapt to changing epidemic dynamics
#' without oversmoothing or introducing spurious wiggly trends.
#'
#' # When to use a different value
#'
#' ## Very slow
#'
#' Decreasing the penalty basis dimension makes the model less demanding to fit.
#' Using a single penalty throughout the model is much simpler than using an
#' adaptive smooth and should be preferred where possible. See
#' `[mgcv::smooth.construct.ad.smooth.spec]` for more information on how the
#' adaptive smooth basis uses the penalty dimension.
#'
#' ## Observed over-smoothing of non-stationary data
#'
#' If a fitted model is observably over-smoothing, it may be reasonable to refit
#' with a higher penalty basis dimension. Moments with a sudden change in
#' epidemic dynamics, such as a sharp epidemic peak, can be challenging to fit
#' with smooth functions. This option should be used with care due to the
#' increased computational cost.
#'
#' # Implementation details
#'
#' The algorithm to pick `m` is \eqn{\lfloor \frac{n}{21} \rfloor + 1} where
#' \eqn{n \in \mathbb{W}} is the number of observed dates. This algorithm
#' assumes that over a 21-day period, epidemic dynamics remain roughly similarly
#' wiggly. Sharp jumps or drops requiring a very wiggly trend would remain
#' similarly plausible over much of the 21-day band.
#'
#' @param n An integer, the number of dates with an associated case observation.
#' @return An integer, the proposed penalty basis dimension to be used by the
#' global trend.
#' @seealso [RtGam()] for the use-case and additional documentation as well as
#' [mgcv::smooth.construct.ad.smooth.spec] for an explanation of the
#' underlying adaptive-smooth machinery.
#' @export
#' @examples
#' # Default use invokes `unique()` in case of repeated dates from groups
#' reference_date <- as.Date(c("2023-01-01", "2023-01-02", "2023-01-03"))
#' m <- penalty_dim_heuristic(length(reference_date))
#'
penalty_dim_heuristic <- function(n) {
# Input checks
rlang::check_required(n, "n", call = rlang::caller_env())
check_vector(n)
check_integer(n)
check_no_missingness(n)
check_elements_above_min(n, "n", min = 1)
check_vector_length(length(n), "n", min = 1, max = 1)

as.integer(floor(n / 21) + 1)
}
43 changes: 37 additions & 6 deletions R/checkers.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,29 @@
check_vector_length <- function(n, name, min, max, call = rlang::caller_env()) {
if (!rlang::is_na(min)) {
if (n < min) {
cli::cli_abort(
c("{.arg {name}} requires a minimum length of {.val {min}}",
"i" = "{.arg {name}} is of length {.val {n}}"
),
class = "RtGam_invalid_input",
call = call
)
}
}
if (!rlang::is_na(max)) {
if (n > max) {
cli::cli_abort(
c("{.arg {name}} requires a maximum length of {.val {max}}",
"i" = "{.arg {name}} is of length {.val {n}}"
),
class = "RtGam_invalid_input",
call = call
)
}
}
invisible()
}

check_vectors_equal_length <- function(cases,
reference_date,
group,
Expand Down Expand Up @@ -75,10 +101,15 @@ check_dates_unique <- function(reference_date,
check_required_inputs_provided <- function(cases,
reference_date,
group,
k,
m,
call = rlang::caller_env()) {
rlang::check_required(cases, "cases", call = call)
rlang::check_required(reference_date, "reference_date", call = call)
rlang::check_required(group, "group", call = call)
rlang::check_required(k, "k", call = call)
rlang::check_required(m, "m", call = call)

invisible()
}

Expand All @@ -97,14 +128,14 @@ check_no_missingness <- function(x, arg = "x", call = rlang::caller_env()) {
}
}

check_elements_non_neg <- function(x, arg = "x", call = rlang::caller_env()) {
check_elements_above_min <- function(x, arg, min, call = rlang::caller_env()) {
# Greater than or equal to 0 or is NA
is_non_neg <- (x >= 0) | is.na(x)
if (!all(is_non_neg)) {
is_above_min <- (x >= min) | is.na(x)
if (!all(is_above_min)) {
cli::cli_abort(
c("{.arg {arg}} has negative elements",
"!" = "All elements must be 0 or greater",
"i" = "Elements {.val {which(!is_non_neg)}} are negative"
c("{.arg {arg}} has elements smaller than {.val {min}}",
"!" = "All elements must be {.val {min}} or greater",
"i" = "Elements {.val {which(!is_above_min)}} are smaller"
),
class = "RtGam_invalid_input",
call = call
Expand Down
61 changes: 61 additions & 0 deletions R/formula.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#' Build formula for `mgcv::gam()`
#'
#' Build up the formula as a string and and return a formula object meant for
#' use by [`mgcv::gam()`]. The formula components are built up based on the
#' parameters passed to [`RtGam`]. The model makes the assumptions that the
#' smoothness of epidemic dynamics does not change dramatically over a three
#' week period. For periods of three weeks or shorter, the model uses a simple
#' thin plate spline for the global epidemic trend. For periods of longer than
#' three days, the model uses an adaptive smoother with an additional penalty
#' basis for each additional three week period. The model smoothly interpolates
#' between the penalty bases and uses p-splines for the smoothing basis.
#'
#' Currently support for groups via hierarchical modeling is not supported,
#' but when implemented it will use Model GS from Pederson et al., 2019.
#'
#' @param n_timesteps Number of distinct timesteps in the dataframe
#' returned from [`prepare_inputs()`]
#' @param k Global basis dimension to be partitioned between the model smooths
#' @param m Penalty basis dimension on the global smooth
#' @param is_grouped Whether to use a hierarchical model. Not yet supported.
#' @return A formula to be used by [`mgcv::gam()`]
#' @noRd
formula_creator <- function(k, m, is_grouped) {
outcome <- "cases"
intercept <- "1"

smooth_basis_dim <- smooth_basis_creator(k)

# Apply adaptive spline if 3 weeks or more of data are available
seabbs marked this conversation as resolved.
Show resolved Hide resolved
# nolint start
if (m > 1) {
# With adaptive basis, m refers to the order of the penalty matrix not the
# order of the smoothing penalty as it does in the other smoothing bases.
plus_global_trend <- glue::glue("+ s(timesteps,
k = {smooth_basis_dim[['global_trend']]},
m = {m},
bs = 'ad')") # nolint
} else {
# Adaptive penalty with `m = 1` is equivalent to a non-adaptive smooth but
# thin-plate performance is a bit better than p-spline, so preference to
# fall back to thin-plate.
plus_global_trend <- glue::glue("+ s(timesteps,
seabbs marked this conversation as resolved.
Show resolved Hide resolved
k = {smooth_basis_dim[['global_trend']]},
bs = 'tp')")
}
# nolint end

f <- glue::glue("{outcome} ~ {intercept} {plus_global_trend}")
stats::as.formula(f)
}

#' Partition global basis dimension into components
#'
#' @param k The global basis dimension
#' @return A list with named components matching formula components
#' @noRd
smooth_basis_creator <- function(k) {
list(
"global_trend" = k
)
}
Loading
Loading