Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#8 [3/4]: Implement model fitting with mgcv #20

Merged
merged 24 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
ef4a05b
Implement model fitting with `{mgcv}`
zsusswein Jun 30, 2024
7247691
Suppress public docs of internal function
zsusswein Jun 30, 2024
7151ed2
Add checks and warnings for unwise inputs
zsusswein Jul 2, 2024
72ff6c4
Refactor to S3 methods for fitting backends
zsusswein Jul 6, 2024
6c909e1
Add doc for missing param
zsusswein Jul 6, 2024
5e4de38
Explicitly namespace `modifyList()`
zsusswein Jul 6, 2024
31f4196
Clarify documentation
zsusswein Jul 6, 2024
becc4a3
Test warnings throw for suboptimal params
zsusswein Jul 6, 2024
652ff5b
Default args in S3 methods w/ user-supplied in ...
zsusswein Jul 10, 2024
17f48d4
Move backend check from input val to S3 dispatch
zsusswein Jul 10, 2024
0ca92eb
Whitespace
zsusswein Aug 26, 2024
60365e0
Move do.call() outside of fit_model()
zsusswein Aug 26, 2024
58deaf0
Dynamically find methods for `fit_model()`
zsusswein Aug 27, 2024
705f22e
Minimal working print method + RtGam() return
zsusswein Aug 27, 2024
ca70843
Add some basic diagnostic checks
zsusswein Aug 27, 2024
45625aa
Clean up existing tests
zsusswein Aug 28, 2024
c8f6157
Tests for print and diagnostics
zsusswein Aug 28, 2024
dbfaf52
Document `check_diagnostics()`
zsusswein Aug 28, 2024
ba251e1
Update R/RtGam.R
zsusswein Aug 29, 2024
037b2b6
pre-commit
github-actions[bot] Aug 29, 2024
fe87c71
Move `{withr}` to suggests
zsusswein Aug 29, 2024
6a6db82
DRY diagnostic functionality
zsusswein Aug 29, 2024
c0cc8d8
`pkg::func()` -> `func()` in `@examples`
zsusswein Aug 29, 2024
4f73459
Rename `format_for_return()` -> `new_RtGam()`
zsusswein Aug 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ repos:
- id: parsable-R
- id: no-browser-statement
- id: no-print-statement
exclude: '^tests/testthat/test-print\.R$'
- id: no-debug-statement
- id: deps-in-desc
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand All @@ -25,6 +26,7 @@ repos:
files: '^\.Rbuildignore$'
- id: end-of-file-fixer
exclude: '\.Rd'
exclude: 'tests/testthat/_snaps/'
- repo: https://github.com/pre-commit-ci/pre-commit-ci-config
rev: v1.6.1
hooks:
Expand Down
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ URL: https://github.com/cdcgov/cfa-gam-rt,
BugReports: https://github.com/cdcgov/cfa-gam-rt/issues
Suggests:
testthat (>= 3.0.0),
pkgdown
pkgdown,
withr
Config/testthat/edition: 3
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
Expand Down
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Generated by roxygen2: do not edit by hand

S3method(fit_model,RtGam_bam)
S3method(fit_model,RtGam_gam)
S3method(fit_model,default)
S3method(print,RtGam)
export(RtGam)
export(check_diagnostics)
export(penalty_dim_heuristic)
export(smooth_dim_heuristic)
importFrom(rlang,abort)
72 changes: 61 additions & 11 deletions R/RtGam.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
#' @param reference_date The associated date on which the count of incident
#' `cases` occurred. Missing dates are not allowed and dates can only occur
#' once.
#' @param group The grouping variable for the case/reference-date pair. Not
#' yet implemented and a value other than `NULL` will throw an error.
#' @param group The grouping variable for the case/reference-date pair. Not yet
#' implemented and a value other than `NULL` will throw an error.
#' @param k An integer, the _total_ dimension of all the smoothing basis
#' functions. Defaults to `smooth_dim_heuristic(length(cases))`, which picks a
#' reasonable estimate based on the number of provided data points. This total
Expand All @@ -32,37 +32,85 @@
#' time. An increase in this value above the default should be done carefully.
#' See [penalty_dim_heuristic()] for more information on `m` and when to
#' consider changing the default.
#' @param backend One of `gam` or `bam`; defaults to `gam`. In general, models
#' should be fit with [mgcv::gam()]. If [mgcv::gam()] is too slow,
#' [mgcv::bam()] converges more quickly but introduces some additional
#' numerical error. Note that the `bam` backend uses the `discrete = TRUE`
#' option for an additional speedup. See [mgcv::bam()] for more information.
#' @param warn_for_diagnostic_failure Should warnings be issued for
#' automatically identified diagnostic issues? Defaults to TRUE. A list of
#' quantitative model diagnostics can be inspected in the `diagnostics` slot
#' of the returned `RtGam` object.
#' @param ... Additional arguments passed to the specified modelling backend.
#' For example, the default negative binomial error structure could be changed
#' to poisson in the default [mgcv::gam] backend by passing `family =
#' "poisson"`.
#' @seealso [smooth_dim_heuristic()] more information on the smoothing basis
#' dimension and [mgcv::choose.k] for more general guidance on GAMs from
#' `mgcv`
#' dimension, [mgcv::choose.k] for more general guidance on GAMs from `mgcv`,
#' and [mgcv::gam]/[mgcv::bam] for documentation on arguments to the model
#' fitting functions.
#' @return Stub function: NULL
#' @export
#' @examples
#' cases <- c(1, 2, 3)
#' reference_date <- as.Date(c("2023-01-01", "2023-01-02", "2023-01-03"))
#' mod <- RtGam::RtGam(cases, reference_date)
#' withr::with_seed(12345, {
#' cases <- rpois(20, 10)
#' })
#' reference_date <- seq.Date(
#' from = as.Date("2023-01-01"),
#' length.out = 20,
#' by = "day"
#' )
#' fit <- RtGam(cases, reference_date)
#' fit
RtGam <- function(cases,
reference_date,
group = NULL,
k = smooth_dim_heuristic(length(cases)),
m = penalty_dim_heuristic(length(unique(reference_date)))) {
m = penalty_dim_heuristic(length(unique(reference_date))),
backend = "gam",
warn_for_diagnostic_failure = TRUE,
...) {
check_required_inputs_provided(
cases,
reference_date,
group,
k,
m
m,
backend
)
validate(cases, reference_date, group, k, m)

df <- prepare_inputs(cases, reference_date, group)
df <- dataset_creator(cases, reference_date, group, backend)
formula <- formula_creator(
k = k,
m = m,
is_grouped = !rlang::is_null(group)
)

invisible(NULL)
fit <- do.call(
fit_model,
list(
data = df,
formula = formula,
...
)
)
diagnostics <- calculate_diagnostics(fit)

RtGam_object <- new_RtGam(
fit = fit,
df = df,
group = group,
k = k,
m = m,
backend = backend,
formula = formula,
diagnostics = diagnostics
)

check_diagnostics(RtGam_object, warn_for_diagnostic_failure)

return(RtGam_object)
}

#' Propose total smoothing basis dimension from number of data points
Expand Down Expand Up @@ -201,6 +249,8 @@ smooth_dim_heuristic <- function(n) {
#' ## Very slow
#'
#' Decreasing the penalty basis dimension makes the model less demanding to fit.
#' `mgcv` describes an adaptive penalty with 10 basis dimensions and 200 data
#' points as roughly equivalent to fitting 10 GAMs each from 20 data points.
#' Using a single penalty throughout the model is much simpler than using an
#' adaptive smooth and should be preferred where possible. See
#' `[mgcv::smooth.construct.ad.smooth.spec]` for more information on how the
Expand Down
18 changes: 18 additions & 0 deletions R/checkers.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ check_required_inputs_provided <- function(cases,
group,
k,
m,
backend,
call = rlang::caller_env()) {
rlang::check_required(cases, "cases", call = call)
rlang::check_required(reference_date, "reference_date", call = call)
Expand All @@ -128,6 +129,23 @@ check_no_missingness <- function(x, arg = "x", call = rlang::caller_env()) {
}
}

check_elements_below_max <- function(x, arg, max, call = rlang::caller_env()) {
# Greater than or equal to 0 or is NA
is_below_max <- all((x <= max) | is.na(x))
if (!all(is_below_max)) {
cli::cli_abort(
c("{.arg {arg}} has elements larger than {.val {max}}",
"!" = "All elements must be {.val {max}} or less",
"i" = "Elements {.val {which(!is_below_max)}} are larger"
),
class = "RtGam_invalid_input",
call = call
)
}
invisible()
}


check_elements_above_min <- function(x, arg, min, call = rlang::caller_env()) {
# Greater than or equal to 0 or is NA
is_above_min <- (x >= min) | is.na(x)
Expand Down
29 changes: 26 additions & 3 deletions R/prepare_inputs.R → R/dataset_creator.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
#'
#' @inheritParams RtGam
#' @return A dataframe for mgcv
prepare_inputs <- function(cases, reference_date, group) {
dataset_creator <- function(cases, reference_date, group, backend) {
cases_int <- integerify_cases(cases)

timestep <- dates_to_timesteps(
reference_date,
min_supplied_date = min(reference_date),
Expand All @@ -13,12 +15,33 @@ prepare_inputs <- function(cases, reference_date, group) {
group <- rep(NA, length(cases))
}

data.frame(
cases = cases,
dat <- data.frame(
cases = cases_int,
timestep = timestep,
reference_date = reference_date,
group = group
)

class(dat) <- c(glue::glue("RtGam_{backend}"), class(dat))
dat
}

#' Convert dates to an integer if needed
#'
#' @param cases The user-supplied cases vector
#' @return cases_int Cases verified to be an int
#' @noRd
integerify_cases <- function(cases) {
if (!rlang::is_integer(cases)) {
cli::cli_warn(c(
"Coercing {.arg cases} to an integer vector",
"i" = "{.arg cases} is a {.obj_type_friendly {cases}}",
"x" = "RtGam uses a count model, requiring integer-valued cases"
))
as.integer(cases)
} else {
cases
}
}

#' Convert an arbitrary vector of dates to a vector of timesteps
Expand Down
108 changes: 108 additions & 0 deletions R/diagnostics.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#' Check quantitative diagnostics from a fitted RtGam model
#'
#' Evaluates for convergence, effective degrees of freedom, and residual
#' autocorrelation. If `warn_for_diagnostic_failure` is set to TRUE, will issue
#' warnings when potential diagnostic issues are detected. The diagnostics are
#' invisibly returned as a list and also stored within the `diagnostics` element
#' of the provided model object.
#'
#' @param fit A fitted `RtGam` model object. This should be the result of
#' calling `RtGam::RtGam()` with appropriate data.
#' @param warn_for_diagnostic_failure A logical value indicating whether to
#' issue warnings if diagnostic checks suggest potential issues with the model
#' fit. Defaults to TRUE, meaning that warnings will be issued by default.
#'
#' @return Invisibly returns a list containing diagnostic results:
#' - `model_converged`: Logical indicating if the model has converged.
#' - `k_prime`: The maximum available number of degrees of freedom that could
#' be used in the GAM fit.
#' - `k_edf`: Estimated degrees of freedom actually used by the smooth terms
#' in the model.
#' - `k_index`: The ratio of the residual variance of differenced
#' near-neighbor residuals to the overall residual variance. This should be
#' near 1 or above.
#' - `k_p_value`: P-value for testing if k' is adequate for modeling the data.
#' - `k_to_edf_ratio`: Ratio of k' to effective degrees of freedom of the
#' smooth terms. k' should be well below the available edf.
#' - `residual_autocorrelation`: Autocorrelation coefficients for residuals
#' up to lag 7 or one-tenth of series length, whichever is smaller.
#'
#' @export
#' @seealso [mgcv::k.check] for a description of the diagnostic tests,
#' [mgcv::choose.k] for a description of discussion of choosing the basis
#' dimension, and Wood, Simon N. Generalized additive models: an introduction
#' with R. chapman and hall/CRC, 2017. for a derivation of the metrics.
#' @examples
#' withr::with_seed(12345, {
#' cases <- rpois(20, 10)
#' })
#' reference_date <- seq.Date(
#' from = as.Date("2023-01-01"),
#' length.out = 20,
#' by = "day"
#' )
#' fit <- RtGam::RtGam(cases, reference_date)
#' check_diagnostics(fit)
check_diagnostics <- function(fit, warn_for_diagnostic_failure = TRUE) {
diagnostics <- fit[["diagnostics"]]
if (warn_for_diagnostic_failure) {
issue_diagnostic_warnings(diagnostics)
}
invisible(diagnostics)
}

calculate_diagnostics <- function(fit) {
converged <- fit$converged
k_check <- mgcv::k.check(fit)
max_lag <- min(7, round(nrow(fit$model) / 7))
rho <- stats::acf(fit$residuals, plot = FALSE, lag.max = max_lag)[[1]][, , 1]

list(
model_converged = converged,
k_prime = k_check[1],
k_edf = k_check[2],
k_index = k_check[3],
k_p_value = k_check[4],
k_to_edf_ratio = k_check[2] / k_check[1],
residual_autocorrelation = rho[2:length(rho)]
)
}

issue_diagnostic_warnings <- function(diagnostics) {
if (!diagnostics[["model_converged"]]) {
cli::cli_alert_danger(
c("Model failed to converge. Inference is not reliable.")
)
}
if (diagnostics[["k_to_edf_ratio"]] > 0.9) {
cli::cli_bullets(c(
"x" = "Effective degrees of freedom is near the supplied upper bound",
"!" = "Consider increasing {.arg k}",
"*" = "Actual: {.val {round(diagnostics[['k_edf']], 3)}}",
"*" = "Upper bound: {.val {diagnostics[['k_prime']]}}"
))
}
if (diagnostics[["k_p_value"]] < 0.05) {
cli::cli_bullets(
c(
"!" = "k-index for one or more smooths is below 1",
"*" = "k-index: {.val {round(diagnostics[['k_index']], 3)}}",
"*" = "Associated p-value: {.val {round(diagnostics[['k_p_value']],
2)}}",
"!" = "Suggests potential unmodeled residual trend.
Inspect model and/or consider increasing {.arg k}"
)
)
}
if (any(abs(diagnostics[["residual_autocorrelation"]]) > 0.5)) {
cli::cli_bullets(c(
"x" = "Residual autocorrelation present",
"*" = "Rho: {.val {round(diagnostics[['residual_autocorrelation']],
2)}}",
"*" = "Inspect manually with {.code acf(residuals(fit$model))}",
"!" = "Consider increasing {.arg k} and/or
specifying {.arg rho} with {.arg backend} bam"
))
}
invisible(NULL)
}
58 changes: 58 additions & 0 deletions R/fit_model.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
fit_model <- function(data, formula, ...) {
UseMethod("fit_model")
}

#' @export
fit_model.RtGam_gam <- function(
data,
formula,
family = "nb",
method = "REML",
...) {
# Override the defaults in formals with the user-supplied args in dots
mgcv::gam(
formula = formula,
family = family,
data = data,
method = method,
...
)
}

#' @export
fit_model.RtGam_bam <- function(
data,
formula,
family = "nb",
method = "fREML",
discrete = TRUE,
...) {
mgcv::bam(
formula = formula,
family = family,
data = data,
method = method,
discrete = discrete,
...
)
}

#' Used to throw informative error if non-supported backend supplied
#' @export
#' @noRd
fit_model.default <- function(
data,
formula,
...) {
requested_backend <- class(data)[1]
all_backends <- methods(fit_model)
# Drop fit_model.default
supported_backends <- all_backends[!(all_backends == "fit_model.default")]

cli::cli_abort(
c("Requested {.field backend} {.val {requested_backend}} not supported",
"!" = "Supported backends: {.val {supported_backends}}"
),
class = "RtGam_invalid_input"
)
}
Loading
Loading