Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#8 [3/4]: Implement model fitting with mgcv #20

Merged
merged 24 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
ef4a05b
Implement model fitting with `{mgcv}`
zsusswein Jun 30, 2024
7247691
Suppress public docs of internal function
zsusswein Jun 30, 2024
7151ed2
Add checks and warnings for unwise inputs
zsusswein Jul 2, 2024
72ff6c4
Refactor to S3 methods for fitting backends
zsusswein Jul 6, 2024
6c909e1
Add doc for missing param
zsusswein Jul 6, 2024
5e4de38
Explicitly namespace `modifyList()`
zsusswein Jul 6, 2024
31f4196
Clarify documentation
zsusswein Jul 6, 2024
becc4a3
Test warnings throw for suboptimal params
zsusswein Jul 6, 2024
652ff5b
Default args in S3 methods w/ user-supplied in ...
zsusswein Jul 10, 2024
17f48d4
Move backend check from input val to S3 dispatch
zsusswein Jul 10, 2024
0ca92eb
Whitespace
zsusswein Aug 26, 2024
60365e0
Move do.call() outside of fit_model()
zsusswein Aug 26, 2024
58deaf0
Dynamically find methods for `fit_model()`
zsusswein Aug 27, 2024
705f22e
Minimal working print method + RtGam() return
zsusswein Aug 27, 2024
ca70843
Add some basic diagnostic checks
zsusswein Aug 27, 2024
45625aa
Clean up existing tests
zsusswein Aug 28, 2024
c8f6157
Tests for print and diagnostics
zsusswein Aug 28, 2024
dbfaf52
Document `check_diagnostics()`
zsusswein Aug 28, 2024
ba251e1
Update R/RtGam.R
zsusswein Aug 29, 2024
037b2b6
pre-commit
github-actions[bot] Aug 29, 2024
fe87c71
Move `{withr}` to suggests
zsusswein Aug 29, 2024
6a6db82
DRY diagnostic functionality
zsusswein Aug 29, 2024
c0cc8d8
`pkg::func()` -> `func()` in `@examples`
zsusswein Aug 29, 2024
4f73459
Rename `format_for_return()` -> `new_RtGam()`
zsusswein Aug 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Generated by roxygen2: do not edit by hand

S3method(fit_model,RtGam_bam)
S3method(fit_model,RtGam_gam)
S3method(fit_model,default)
export(RtGam)
export(penalty_dim_heuristic)
export(smooth_dim_heuristic)
Expand Down
40 changes: 32 additions & 8 deletions R/RtGam.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
#' @param reference_date The associated date on which the count of incident
#' `cases` occurred. Missing dates are not allowed and dates can only occur
#' once.
#' @param group The grouping variable for the case/reference-date pair. Not
#' yet implemented and a value other than `NULL` will throw an error.
#' @param group The grouping variable for the case/reference-date pair. Not yet
#' implemented and a value other than `NULL` will throw an error.
#' @param k An integer, the _total_ dimension of all the smoothing basis
#' functions. Defaults to `smooth_dim_heuristic(length(cases))`, which picks a
#' reasonable estimate based on the number of provided data points. This total
Expand All @@ -32,36 +32,58 @@
#' time. An increase in this value above the default should be done carefully.
#' See [penalty_dim_heuristic()] for more information on `m` and when to
#' consider changing the default.
#' @param backend One of `gam` or `bam`; defaults to `gam`. In general, models
#' should be fit with [mgcv::gam()]. If [mgcv::gam()] is too slow,
#' [mgcv::bam()] converges more quickly but introduces some additional
#' numerical error. Note that the `bam` backend uses the `discrete = TRUE`
#' option for an additional speedup. See [mgcv::bam()] for more information.
#' @param ... Additional arguments passed to the specified modelling backend.
#' For example, the default negative binomial error structure could be changed
#' to poisson in the default [mgcv::gam] backend by passing `family =
#' "poisson"`.
#' @seealso [smooth_dim_heuristic()] more information on the smoothing basis
#' dimension and [mgcv::choose.k] for more general guidance on GAMs from
#' `mgcv`
#' dimension, [mgcv::choose.k] for more general guidance on GAMs from `mgcv`,
#' and [mgcv::gam]/[mgcv::bam] for documentation on arguments to the model
#' fitting functions.
#' @return Stub function: NULL
#' @export
#' @examples
#' cases <- c(1, 2, 3)
#' cases <- c(1L, 2L, 3L)
#' reference_date <- as.Date(c("2023-01-01", "2023-01-02", "2023-01-03"))
#' mod <- RtGam::RtGam(cases, reference_date)
RtGam <- function(cases,
reference_date,
group = NULL,
k = smooth_dim_heuristic(length(cases)),
m = penalty_dim_heuristic(length(unique(reference_date)))) {
m = penalty_dim_heuristic(length(unique(reference_date))),
backend = "gam",
...) {
check_required_inputs_provided(
cases,
reference_date,
group,
k,
m
m,
backend
)
validate(cases, reference_date, group, k, m)

df <- prepare_inputs(cases, reference_date, group)
df <- dataset_creator(cases, reference_date, group, backend)
formula <- formula_creator(
k = k,
m = m,
is_grouped = !rlang::is_null(group)
)

fit <- do.call(
fit_model,
list(
data = df,
formula = formula,
...
)
)

invisible(NULL)
}

Expand Down Expand Up @@ -201,6 +223,8 @@ smooth_dim_heuristic <- function(n) {
#' ## Very slow
#'
#' Decreasing the penalty basis dimension makes the model less demanding to fit.
#' `mgcv` describes an adaptive penalty with 10 basis dimensions and 200 data
#' points as roughly equivalent to fitting 10 GAMs each from 20 data points.
#' Using a single penalty throughout the model is much simpler than using an
#' adaptive smooth and should be preferred where possible. See
#' `[mgcv::smooth.construct.ad.smooth.spec]` for more information on how the
Expand Down
18 changes: 18 additions & 0 deletions R/checkers.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ check_required_inputs_provided <- function(cases,
group,
k,
m,
backend,
call = rlang::caller_env()) {
rlang::check_required(cases, "cases", call = call)
rlang::check_required(reference_date, "reference_date", call = call)
Expand All @@ -128,6 +129,23 @@ check_no_missingness <- function(x, arg = "x", call = rlang::caller_env()) {
}
}

check_elements_below_max <- function(x, arg, max, call = rlang::caller_env()) {
# Greater than or equal to 0 or is NA
is_below_max <- all((x <= max) | is.na(x))
if (!all(is_below_max)) {
cli::cli_abort(
c("{.arg {arg}} has elements larger than {.val {max}}",
"!" = "All elements must be {.val {max}} or less",
"i" = "Elements {.val {which(!is_below_max)}} are larger"
),
class = "RtGam_invalid_input",
call = call
)
}
invisible()
}


check_elements_above_min <- function(x, arg, min, call = rlang::caller_env()) {
# Greater than or equal to 0 or is NA
is_above_min <- (x >= min) | is.na(x)
Expand Down
29 changes: 26 additions & 3 deletions R/prepare_inputs.R → R/dataset_creator.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
#'
#' @inheritParams RtGam
#' @return A dataframe for mgcv
prepare_inputs <- function(cases, reference_date, group) {
dataset_creator <- function(cases, reference_date, group, backend) {
cases_int <- integerify_cases(cases)

timestep <- dates_to_timesteps(
reference_date,
min_supplied_date = min(reference_date),
Expand All @@ -13,12 +15,33 @@ prepare_inputs <- function(cases, reference_date, group) {
group <- rep(NA, length(cases))
}

data.frame(
cases = cases,
dat <- data.frame(
cases = cases_int,
timestep = timestep,
reference_date = reference_date,
group = group
)

class(dat) <- c(glue::glue("RtGam_{backend}"), class(dat))
dat
}

#' Convert dates to an integer if needed
#'
#' @param cases The user-supplied cases vector
#' @return cases_int Cases verified to be an int
#' @noRd
integerify_cases <- function(cases) {
if (!rlang::is_integer(cases)) {
cli::cli_warn(c(
"Coercing {.arg cases} to an integer vector",
"i" = "{.arg cases} is a {.obj_type_friendly {cases}}",
"x" = "RtGam uses a count model, requiring integer-valued cases"
))
as.integer(cases)
} else {
cases
}
}

#' Convert an arbitrary vector of dates to a vector of timesteps
Expand Down
57 changes: 57 additions & 0 deletions R/fit_model.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
fit_model <- function(data, formula, ...) {
UseMethod("fit_model")
}

#' @export
fit_model.RtGam_gam <- function(
data,
formula,
family = "nb",
method = "REML",
...) {
# Override the defaults in formals with the user-supplied args in dots
mgcv::gam(
formula = formula,
family = family,
data = data,
method = method,
...
)
}

#' @export
fit_model.RtGam_bam <- function(
data,
formula,
family = "nb",
method = "fREML",
discrete = TRUE,
...) {
mgcv::bam(
formula = formula,
fmaily = family,
data = data,
method = method,
discrete = discrete,
...
)
}

#' Used to throw informative error if non-supported backend supplied
#' @export
fit_model.default <- function(
data,
formula,
...) {
requested_backend <- class(data)[1]
all_backends <- methods(fit_model)
# Drop fit_model.default
supported_backends <- all_backends[!(all_backends == "fit_model.default")]

cli::cli_abort(
c("Requested {.field backend} {.val {requested_backend}} not supported",
"!" = "Allowed backends: {.val {supported_backends}}"
),
class = "RtGam_invalid_input"
)
}
34 changes: 31 additions & 3 deletions R/formula.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#' Build formula for `mgcv::gam()`
#' Build formula for model fitting backend
#'
#' Build up the formula as a string and and return a formula object meant for
#' use by [`mgcv::gam()`]. The formula components are built up based on the
Expand Down Expand Up @@ -30,15 +30,15 @@ formula_creator <- function(k, m, is_grouped) {
if (m > 1) {
# With adaptive basis, m refers to the order of the penalty matrix not the
# order of the smoothing penalty as it does in the other smoothing bases.
plus_global_trend <- glue::glue("+ s(timesteps,
plus_global_trend <- glue::glue("+ s(timestep,
k = {smooth_basis_dim[['global_trend']]},
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
m = {m},
bs = 'ad')") # nolint
} else {
# Adaptive penalty with `m = 1` is equivalent to a non-adaptive smooth but
# thin-plate performance is a bit better than p-spline, so preference to
# fall back to thin-plate.
plus_global_trend <- glue::glue("+ s(timesteps,
plus_global_trend <- glue::glue("+ s(timestep,
k = {smooth_basis_dim[['global_trend']]},
bs = 'tp')")
}
Expand All @@ -58,3 +58,31 @@ smooth_basis_creator <- function(k) {
"global_trend" = k
)
}

#' Issue warnings if parameterization allowed but suboptimal
#'
#' @noRd
warn_for_suboptimal_params <- function(data, m, k) {
n_unique_date <- length(unique(data[["timepoint"]]))
total_dim <- nrow(data)

# From mgcv: "Bear in mind that adaptive smoothing places quite severe demands
# on the data. For example, setting ‘m=10’ for a univariate smooth of 200 data
# is rather like estimating 10 smoothing parameters, each from a data series
# of length 20. The problem is particularly serious for smooths of 2
# variables, where the number of smoothing parameters required to get
# reasonable flexibility in the penalty can grow rather fast, but it often
# requires a very large smoothing basis dimension to make good use of this
# flexibility. In short, adaptive smooths should be used sparingly and with
# care."
if (m / n_unique_date > 0.2) {
cli::cli_warn(
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
c("Using {m} penalty bases with {n_unique_date} dates supplied",
"Consider decreasing penalty dimension {.arg m}",
"i" = "See {.func penalty_dim_heuristic()} for guidance"
)
)
}

invisible()
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
}
19 changes: 16 additions & 3 deletions R/validate.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,18 @@ validate <- function(cases,
validate_cases(cases, call)
validate_dates(reference_date, "reference_date", call)
validate_group(group, call)
validate_min_dimensionality(k, "k", min_dim = 3, call)
validate_min_dimensionality(m, "m", min_dim = 1, call)
validate_min_dimensionality(k,
arg = "k",
min_dim = 3,
max_val = length(cases),
call
)
validate_min_dimensionality(m,
arg = "m",
min_dim = 1,
max_val = length(unique(reference_date)),
call = call
)

# Per-group checks
check_vectors_equal_length(cases, reference_date, group, call)
Expand Down Expand Up @@ -52,12 +62,15 @@ validate_group <- function(group, call) {

#' Used by both dimensionality_heuristic() and RtGam()
#' @noRd
validate_min_dimensionality <- function(n, arg, min_dim, call) {
validate_min_dimensionality <- function(n, arg, min_dim, max_val = NA, call) {
check_vector(n, arg, call = call)
check_no_missingness(n, arg, call)
check_integer(n, arg, call)
check_elements_above_min(n, arg, min = min_dim, call = call)
check_vector_length(length(n), arg, min = 1, max = 1, call = call)
if (!rlang::is_na(max_val)) {
check_elements_below_max(n, arg, max_val, call)
}

invisible()
}
25 changes: 19 additions & 6 deletions man/RtGam.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading