From ef4a05b1229d02546337bf705d516694f1975316 Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Sun, 30 Jun 2024 20:46:22 +0000 Subject: [PATCH 01/24] Implement model fitting with `{mgcv}` Allow optional switching between gam() and bam(), both for functionality now and to illustrate how one might extend to different backends in the future. --- R/RtGam.R | 17 +++++++++--- R/checkers.R | 7 +++++ R/fit_model.R | 41 ++++++++++++++++++++++++++++ R/formula.R | 4 +-- R/prepare_inputs.R | 22 ++++++++++++++- man/RtGam.Rd | 13 +++++++-- man/fit_model.Rd | 11 ++++++++ man/prepare_inputs.Rd | 4 +-- man/validate.Rd | 4 +-- tests/testthat/test-checkers.R | 14 ++++++++++ tests/testthat/test-fit_model.R | 40 +++++++++++++++++++++++++++ tests/testthat/test-formula.R | 4 +-- tests/testthat/test-prepare_inputs.R | 14 +++++++++- 13 files changed, 178 insertions(+), 17 deletions(-) create mode 100644 R/fit_model.R create mode 100644 man/fit_model.Rd create mode 100644 tests/testthat/test-fit_model.R diff --git a/R/RtGam.R b/R/RtGam.R index caf940c..8f237fa 100644 --- a/R/RtGam.R +++ b/R/RtGam.R @@ -18,8 +18,8 @@ #' @param reference_date The associated date on which the count of incident #' `cases` occurred. Missing dates are not allowed and dates can only occur #' once. -#' @param group The grouping variable for the case/reference-date pair. Not -#' yet implemented and a value other than `NULL` will throw an error. +#' @param group The grouping variable for the case/reference-date pair. Not yet +#' implemented and a value other than `NULL` will throw an error. #' @param k An integer, the _total_ dimension of all the smoothing basis #' functions. Defaults to `smooth_dim_heuristic(length(cases))`, which picks a #' reasonable estimate based on the number of provided data points. This total @@ -32,6 +32,11 @@ #' time. An increase in this value above the default should be done carefully. #' See [penalty_dim_heuristic()] for more information on `m` and when to #' consider changing the default. +#' @param backend One of `gam` or `bam`; defaults to `gam`. In general, models +#' should be fit with [mgcv::gam()]. If [mgcv::gam()] is too slow, +#' [mgcv::bam()] converges more quickly but introduces some additional +#' numerical error. Note that the `bam` backend uses the `discrete = TRUE` +#' option for an additional speedup. See [mgcv::bam()] for more information. #' @seealso [smooth_dim_heuristic()] more information on the smoothing basis #' dimension and [mgcv::choose.k] for more general guidance on GAMs from #' `mgcv` @@ -45,13 +50,15 @@ RtGam <- function(cases, reference_date, group = NULL, k = smooth_dim_heuristic(length(cases)), - m = penalty_dim_heuristic(length(unique(reference_date)))) { + m = penalty_dim_heuristic(length(unique(reference_date))), + backend = "gam") { check_required_inputs_provided( cases, reference_date, group, k, - m + m, + backend ) validate(cases, reference_date, group, k, m) @@ -62,6 +69,8 @@ RtGam <- function(cases, is_grouped = !rlang::is_null(group) ) + fit <- fit_model(df, formula, backend) + invisible(NULL) } diff --git a/R/checkers.R b/R/checkers.R index fe4fe38..5f883a3 100644 --- a/R/checkers.R +++ b/R/checkers.R @@ -103,12 +103,19 @@ check_required_inputs_provided <- function(cases, group, k, m, + backend, call = rlang::caller_env()) { rlang::check_required(cases, "cases", call = call) rlang::check_required(reference_date, "reference_date", call = call) rlang::check_required(group, "group", call = call) rlang::check_required(k, "k", call = call) rlang::check_required(m, "m", call = call) + rlang::arg_match(backend, + values = c("gam", "bam"), + error_arg = "backend", + error_call = call, + multiple = FALSE + ) invisible() } diff --git a/R/fit_model.R b/R/fit_model.R new file mode 100644 index 0000000..f56a884 --- /dev/null +++ b/R/fit_model.R @@ -0,0 +1,41 @@ +#' Fit the RtGam model with {mgcv} +#' +#' Use the pre-prepared model dataset and formula. Supply warnings as needed +fit_model <- function(data, formula, backend) { + args <- args_constructor(data, formula, backend) + call <- call_constructor(backend) + + do.call( + call, + args + ) +} + +args_constructor <- function(data, formula, backend) { + backend_agnostic_args <- list( + formula = formula, + data = data, + # Negative binomial family with overdispersion param estimated + family = "nb" + ) + if (backend == "gam") { + backend_specific_args <- list( + method = "REML" + ) + } else if (backend == "bam") { + backend_specific_args <- list( + method = "fREML", + discrete = TRUE + ) + } else { + cli::cli_abort("Other backends not yet implemented") + } + + c(backend_agnostic_args, backend_specific_args) +} + +call_constructor <- function(backend) { + # This is where we could implement {brms} or mgcv::ginla() at some point + func <- paste0("mgcv::", backend) + eval(parse(text = func)) +} diff --git a/R/formula.R b/R/formula.R index bd81712..76129d3 100644 --- a/R/formula.R +++ b/R/formula.R @@ -30,7 +30,7 @@ formula_creator <- function(k, m, is_grouped) { if (m > 1) { # With adaptive basis, m refers to the order of the penalty matrix not the # order of the smoothing penalty as it does in the other smoothing bases. - plus_global_trend <- glue::glue("+ s(timesteps, + plus_global_trend <- glue::glue("+ s(timestep, k = {smooth_basis_dim[['global_trend']]}, m = {m}, bs = 'ad')") # nolint @@ -38,7 +38,7 @@ formula_creator <- function(k, m, is_grouped) { # Adaptive penalty with `m = 1` is equivalent to a non-adaptive smooth but # thin-plate performance is a bit better than p-spline, so preference to # fall back to thin-plate. - plus_global_trend <- glue::glue("+ s(timesteps, + plus_global_trend <- glue::glue("+ s(timestep, k = {smooth_basis_dim[['global_trend']]}, bs = 'tp')") } diff --git a/R/prepare_inputs.R b/R/prepare_inputs.R index aae4ca0..3863fd5 100644 --- a/R/prepare_inputs.R +++ b/R/prepare_inputs.R @@ -3,6 +3,8 @@ #' @inheritParams RtGam #' @return A dataframe for mgcv prepare_inputs <- function(cases, reference_date, group) { + cases_int <- integerify_cases(cases) + timestep <- dates_to_timesteps( reference_date, min_supplied_date = min(reference_date), @@ -14,13 +16,31 @@ prepare_inputs <- function(cases, reference_date, group) { } data.frame( - cases = cases, + cases = cases_int, timestep = timestep, reference_date = reference_date, group = group ) } +#' Convert dates to an integer if needed +#' +#' @param cases The user-supplied cases vector +#' @return cases_int Cases verified to be an int +#' @noRd +integerify_cases <- function(cases) { + if (!rlang::is_integer(cases)) { + cli::cli_warn(c( + "Coercing {.arg cases} to an integer vector", + "i" = "{.arg cases} is a {.obj_type_friendly {cases}}", + "x" = "RtGam uses a count model, requiring integer-valued cases" + )) + as.integer(cases) + } else { + cases + } +} + #' Convert an arbitrary vector of dates to a vector of timesteps #' #' The `*_supplied_date` arguments are required rather than calculated diff --git a/man/RtGam.Rd b/man/RtGam.Rd index 99269ea..44beb90 100644 --- a/man/RtGam.Rd +++ b/man/RtGam.Rd @@ -9,7 +9,8 @@ RtGam( reference_date, group = NULL, k = smooth_dim_heuristic(length(cases)), - m = penalty_dim_heuristic(length(unique(reference_date))) + m = penalty_dim_heuristic(length(unique(reference_date))), + backend = "gam" ) } \arguments{ @@ -20,8 +21,8 @@ associated \code{reference_date}. Missing values (NAs) are not allowed.} \code{cases} occurred. Missing dates are not allowed and dates can only occur once.} -\item{group}{The grouping variable for the case/reference-date pair. Not -yet implemented and a value other than \code{NULL} will throw an error.} +\item{group}{The grouping variable for the case/reference-date pair. Not yet +implemented and a value other than \code{NULL} will throw an error.} \item{k}{An integer, the \emph{total} dimension of all the smoothing basis functions. Defaults to \code{smooth_dim_heuristic(length(cases))}, which picks a @@ -36,6 +37,12 @@ trend. If \code{m} is greater than 1, the smooth's wiggliness can change over time. An increase in this value above the default should be done carefully. See \code{\link[=penalty_dim_heuristic]{penalty_dim_heuristic()}} for more information on \code{m} and when to consider changing the default.} + +\item{backend}{One of \code{gam} or \code{bam}; defaults to \code{gam}. In general, models +should be fit with \code{\link[mgcv:gam]{mgcv::gam()}}. If \code{\link[mgcv:gam]{mgcv::gam()}} is too slow, +\code{\link[mgcv:bam]{mgcv::bam()}} converges more quickly but introduces some additional +numerical error. Note that the \code{bam} backend uses the \code{discrete = TRUE} +option for an additional speedup. See \code{\link[mgcv:bam]{mgcv::bam()}} for more information.} } \value{ Stub function: NULL diff --git a/man/fit_model.Rd b/man/fit_model.Rd new file mode 100644 index 0000000..1eb2d9b --- /dev/null +++ b/man/fit_model.Rd @@ -0,0 +1,11 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/fit_model.R +\name{fit_model} +\alias{fit_model} +\title{Fit the RtGam model with {mgcv}} +\usage{ +fit_model(data, formula, backend) +} +\description{ +Use the pre-prepared model dataset and formula. Supply warnings as needed +} diff --git a/man/prepare_inputs.Rd b/man/prepare_inputs.Rd index 5c9a342..ee6f34b 100644 --- a/man/prepare_inputs.Rd +++ b/man/prepare_inputs.Rd @@ -14,8 +14,8 @@ associated \code{reference_date}. Missing values (NAs) are not allowed.} \code{cases} occurred. Missing dates are not allowed and dates can only occur once.} -\item{group}{The grouping variable for the case/reference-date pair. Not -yet implemented and a value other than \code{NULL} will throw an error.} +\item{group}{The grouping variable for the case/reference-date pair. Not yet +implemented and a value other than \code{NULL} will throw an error.} } \value{ A dataframe for mgcv diff --git a/man/validate.Rd b/man/validate.Rd index e974784..bf066f8 100644 --- a/man/validate.Rd +++ b/man/validate.Rd @@ -14,8 +14,8 @@ associated \code{reference_date}. Missing values (NAs) are not allowed.} \code{cases} occurred. Missing dates are not allowed and dates can only occur once.} -\item{group}{The grouping variable for the case/reference-date pair. Not -yet implemented and a value other than \code{NULL} will throw an error.} +\item{group}{The grouping variable for the case/reference-date pair. Not yet +implemented and a value other than \code{NULL} will throw an error.} \item{k}{An integer, the \emph{total} dimension of all the smoothing basis functions. Defaults to \code{smooth_dim_heuristic(length(cases))}, which picks a diff --git a/tests/testthat/test-checkers.R b/tests/testthat/test-checkers.R index 253a08d..6e839f0 100644 --- a/tests/testthat/test-checkers.R +++ b/tests/testthat/test-checkers.R @@ -90,6 +90,7 @@ test_that("Required input check works", { k <- 2 m <- 1 + expect_error( check_required_inputs_provided( reference_date = reference_date, @@ -150,6 +151,18 @@ test_that("Required input check works", { ), class = "rlang_error" ) + expect_error( + check_required_inputs_provided( + cases = cases, + reference_date = reference_date, + group = group, + k = k, + m = m, + backend = "not_a_real_backend", + call = NULL + ), + class = "rlang_error" + ) expect_null( check_required_inputs_provided( cases = cases, @@ -157,6 +170,7 @@ test_that("Required input check works", { group = group, k = k, m = m, + backend = "gam", call = NULL ) ) diff --git a/tests/testthat/test-fit_model.R b/tests/testthat/test-fit_model.R new file mode 100644 index 0000000..f1c23c5 --- /dev/null +++ b/tests/testthat/test-fit_model.R @@ -0,0 +1,40 @@ +test_that("fit_model() fits a model", { + data <- data.frame(x = 1:20, y = rnbinom(20, mu = 1:20, size = 1)) + formula <- y ~ 1 + s(x) + + fit_gam <- fit_model(data, formula, backend = "gam") + fit_bam <- fit_model(data, formula, backend = "bam") + + expect_s3_class(fit_gam, "gam") + expect_s3_class(fit_bam, "bam") +}) + +test_that("arg_constructor returns fitting args", { + data <- data.frame(x = 1, y = 2) + formula <- y ~ x + backend <- "gam" + + args <- args_constructor(data, formula, backend) + + expect_type(args$formula, "language") + expect_s3_class(args$data, "data.frame") + expect_equal(args$family, "nb") + + # `mgcv::gam()` backend + expect_equal(args$method, "REML") + + # `mgcv::bam()` backend + args_bam <- args_constructor(data, formula, backend = "bam") + expect_equal(args_bam$method, "fREML") + + # Other backend + expect_error(args_constructor(data, formula, "not_a_backend")) +}) + +test_that("call_constructor returns a call", { + backend <- "gam" + + call <- call_constructor(backend) + expect_type(call, "closure") + expect_equal(call, mgcv::gam) +}) diff --git a/tests/testthat/test-formula.R b/tests/testthat/test-formula.R index ba511be..5cae9b0 100644 --- a/tests/testthat/test-formula.R +++ b/tests/testthat/test-formula.R @@ -2,7 +2,7 @@ test_that("Formula created more than 3 weeks", { k <- 10 m <- 2 is_grouped <- FALSE - expected <- "cases ~ 1 + s(timesteps, k = 10, m = 2, bs = \"ad\")" + expected <- "cases ~ 1 + s(timestep, k = 10, m = 2, bs = \"ad\")" f <- formula_creator(k, m, is_grouped) @@ -14,7 +14,7 @@ test_that("Formula created fewer than 3 weeks", { k <- 10 m <- 1 is_grouped <- FALSE - expected <- "cases ~ 1 + s(timesteps, k = 10, bs = \"tp\")" + expected <- "cases ~ 1 + s(timestep, k = 10, bs = \"tp\")" f <- formula_creator(k, m, is_grouped) diff --git a/tests/testthat/test-prepare_inputs.R b/tests/testthat/test-prepare_inputs.R index fbc6df9..65a436d 100644 --- a/tests/testthat/test-prepare_inputs.R +++ b/tests/testthat/test-prepare_inputs.R @@ -1,5 +1,5 @@ test_that("Dataframe constructed appropriately", { - cases <- c(1, 2, 3) + cases <- c(1L, 2L, 3L) reference_date <- as.Date(c("2023-01-01", "2023-01-02", "2023-01-03")) timestep <- c(0, 0.5, 1) @@ -38,3 +38,15 @@ test_that("Date conversion matches expected", { actual <- dates_to_timesteps(dates, min_date, max_date) expect_equal(actual, expected) }) + +test_that("Converts double vectors to integers with a warning", { + double_vec <- c(1, 2, 3) + integer_vec <- c(1L, 2L, 3L) + + expect_warning(actual <- integerify_cases(double_vec), + regexp = "Coercing" + ) + expect_equal(actual, integer_vec) + expect_no_message(integerify_cases(integer_vec)) + expect_equal(integerify_cases(integer_vec), integer_vec) +}) From 7247691bc6ad57dbf7b152ac37c0f6863a66c2ae Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Sun, 30 Jun 2024 20:53:43 +0000 Subject: [PATCH 02/24] Suppress public docs of internal function --- R/fit_model.R | 1 + man/fit_model.Rd | 11 ----------- 2 files changed, 1 insertion(+), 11 deletions(-) delete mode 100644 man/fit_model.Rd diff --git a/R/fit_model.R b/R/fit_model.R index f56a884..59e1d9e 100644 --- a/R/fit_model.R +++ b/R/fit_model.R @@ -1,6 +1,7 @@ #' Fit the RtGam model with {mgcv} #' #' Use the pre-prepared model dataset and formula. Supply warnings as needed +#' @noRd fit_model <- function(data, formula, backend) { args <- args_constructor(data, formula, backend) call <- call_constructor(backend) diff --git a/man/fit_model.Rd b/man/fit_model.Rd deleted file mode 100644 index 1eb2d9b..0000000 --- a/man/fit_model.Rd +++ /dev/null @@ -1,11 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/fit_model.R -\name{fit_model} -\alias{fit_model} -\title{Fit the RtGam model with {mgcv}} -\usage{ -fit_model(data, formula, backend) -} -\description{ -Use the pre-prepared model dataset and formula. Supply warnings as needed -} From 7151ed200237f28af399f325da77de96d256b14a Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Tue, 2 Jul 2024 14:08:44 -0400 Subject: [PATCH 03/24] Add checks and warnings for unwise inputs --- R/checkers.R | 17 +++++++++++++++++ R/formula.R | 19 +++++++++++++++++++ R/validate.R | 19 ++++++++++++++++--- tests/testthat/test-checkers.R | 14 ++++++++++++++ tests/testthat/test-validate.R | 20 +++++++++++++++++++- 5 files changed, 85 insertions(+), 4 deletions(-) diff --git a/R/checkers.R b/R/checkers.R index 5f883a3..5dbc3ad 100644 --- a/R/checkers.R +++ b/R/checkers.R @@ -135,6 +135,23 @@ check_no_missingness <- function(x, arg = "x", call = rlang::caller_env()) { } } +check_elements_below_max <- function(x, arg, max, call = rlang::caller_env()) { + # Greater than or equal to 0 or is NA + is_below_max <- all((x <= max) | is.na(x)) + if (!all(is_below_max)) { + cli::cli_abort( + c("{.arg {arg}} has elements larger than {.val {max}}", + "!" = "All elements must be {.val {max}} or less", + "i" = "Elements {.val {which(!is_below_max)}} are larger" + ), + class = "RtGam_invalid_input", + call = call + ) + } + invisible() +} + + check_elements_above_min <- function(x, arg, min, call = rlang::caller_env()) { # Greater than or equal to 0 or is NA is_above_min <- (x >= min) | is.na(x) diff --git a/R/formula.R b/R/formula.R index 76129d3..e29028c 100644 --- a/R/formula.R +++ b/R/formula.R @@ -58,3 +58,22 @@ smooth_basis_creator <- function(k) { "global_trend" = k ) } + +#' Issue warnings if parameterization allowed but suboptimal +#' @noRd +warn_for_suboptimal_params <- function(data, m, k) { + n_unique_date <- length(unique(data[["timepoint"]])) + total_dim <- nrow(data) + + if (m / n_unique_date > 0.2) { + cli::cli_warn( + c("Using {m} penalty bases with {n_unique_date} dates supplied", + "Consider decreasing penalty dimension {.arg m}", + "i" = "See {.func penalty_dim_heuristic()} for guidance" + ) + ) + } + + + invisible() +} diff --git a/R/validate.R b/R/validate.R index 5050ded..9dde9a1 100644 --- a/R/validate.R +++ b/R/validate.R @@ -14,8 +14,18 @@ validate <- function(cases, validate_cases(cases, call) validate_dates(reference_date, "reference_date", call) validate_group(group, call) - validate_min_dimensionality(k, "k", min_dim = 3, call) - validate_min_dimensionality(m, "m", min_dim = 1, call) + validate_min_dimensionality(k, + arg = "k", + min_dim = 3, + max_val = length(cases), + call + ) + validate_min_dimensionality(m, + arg = "m", + min_dim = 1, + max_val = length(unique(reference_date)), + call = call + ) # Per-group checks check_vectors_equal_length(cases, reference_date, group, call) @@ -52,12 +62,15 @@ validate_group <- function(group, call) { #' Used by both dimensionality_heuristic() and RtGam() #' @noRd -validate_min_dimensionality <- function(n, arg, min_dim, call) { +validate_min_dimensionality <- function(n, arg, min_dim, max_val = NA, call) { check_vector(n, arg, call = call) check_no_missingness(n, arg, call) check_integer(n, arg, call) check_elements_above_min(n, arg, min = min_dim, call = call) check_vector_length(length(n), arg, min = 1, max = 1, call = call) + if (!rlang::is_na(max_val)) { + check_elements_below_max(n, arg, max_val, call) + } invisible() } diff --git a/tests/testthat/test-checkers.R b/tests/testthat/test-checkers.R index 6e839f0..5542853 100644 --- a/tests/testthat/test-checkers.R +++ b/tests/testthat/test-checkers.R @@ -186,6 +186,20 @@ test_that("Missingness check works", { expect_null(check_no_missingness(no_missingness)) }) +test_that("Below max check works", { + max <- 5 + below_max <- 1:4 + above_max <- 2:6 + call <- NULL + arg <- "test" + + expect_error(check_elements_below_max(above_max, arg, max, call), + class = "RtGam_invalid_input" + ) + expect_null(check_elements_below_max(below_max, arg, max, call)) +}) + + test_that("Negative element check works", { min <- 0 non_neg <- c(0, 1, 2, NA) diff --git a/tests/testthat/test-validate.R b/tests/testthat/test-validate.R index 969caad..9e1140e 100644 --- a/tests/testthat/test-validate.R +++ b/tests/testthat/test-validate.R @@ -64,6 +64,7 @@ test_that("`validate_min_dimensionality()` is successful", { too_short <- c() good_input <- 3 min_dim <- 3 + max_val <- 0 expect_error(validate_min_dimensionality(not_a_vector, "test", min_dim), class = "RtGam_type_error" @@ -86,6 +87,23 @@ test_that("`validate_min_dimensionality()` is successful", { expect_error(validate_min_dimensionality(too_short, "test", min_dim), class = "RtGam_type_error" ) + expect_error( + validate_min_dimensionality(1:5, + "test", + min_dim, + max_val = 1 + ), + class = "RtGam_invalid_input" + ) - expect_null(validate_min_dimensionality(good_input, "test", min_dim)) + expect_null(validate_min_dimensionality( + good_input, + "test", + min_dim + )) + expect_null(validate_min_dimensionality(good_input, + "test", + min_dim, + max_val = 3 + )) }) From 72ff6c4c7b2521561b225df7d1dc2bfcdb796d66 Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Sat, 6 Jul 2024 02:40:49 +0000 Subject: [PATCH 04/24] Refactor to S3 methods for fitting backends --- NAMESPACE | 2 + R/RtGam.R | 7 +-- R/{prepare_inputs.R => dataset_creator.R} | 7 ++- R/fit_model.R | 54 ++++++++----------- man/RtGam.Rd | 3 +- man/{prepare_inputs.Rd => dataset_creator.Rd} | 14 +++-- man/dates_to_timesteps.Rd | 2 +- tests/testthat/test-fit_model.R | 39 ++++---------- tests/testthat/test-prepare_inputs.R | 8 ++- 9 files changed, 60 insertions(+), 76 deletions(-) rename R/{prepare_inputs.R => dataset_creator.R} (93%) rename man/{prepare_inputs.Rd => dataset_creator.Rd} (52%) diff --git a/NAMESPACE b/NAMESPACE index b90038c..78f24fd 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,5 +1,7 @@ # Generated by roxygen2: do not edit by hand +S3method(fit_model,RtGam_bam) +S3method(fit_model,RtGam_gam) export(RtGam) export(penalty_dim_heuristic) export(smooth_dim_heuristic) diff --git a/R/RtGam.R b/R/RtGam.R index 8f237fa..6ce0400 100644 --- a/R/RtGam.R +++ b/R/RtGam.R @@ -51,7 +51,8 @@ RtGam <- function(cases, group = NULL, k = smooth_dim_heuristic(length(cases)), m = penalty_dim_heuristic(length(unique(reference_date))), - backend = "gam") { + backend = "gam", + user_supplied_args = list()) { check_required_inputs_provided( cases, reference_date, @@ -62,14 +63,14 @@ RtGam <- function(cases, ) validate(cases, reference_date, group, k, m) - df <- prepare_inputs(cases, reference_date, group) + df <- dataset_creator(cases, reference_date, group, backend) formula <- formula_creator( k = k, m = m, is_grouped = !rlang::is_null(group) ) - fit <- fit_model(df, formula, backend) + fit <- fit_model(df, formula, user_supplied_args) invisible(NULL) } diff --git a/R/prepare_inputs.R b/R/dataset_creator.R similarity index 93% rename from R/prepare_inputs.R rename to R/dataset_creator.R index 3863fd5..5bf8bc8 100644 --- a/R/prepare_inputs.R +++ b/R/dataset_creator.R @@ -2,7 +2,7 @@ #' #' @inheritParams RtGam #' @return A dataframe for mgcv -prepare_inputs <- function(cases, reference_date, group) { +dataset_creator <- function(cases, reference_date, group, backend) { cases_int <- integerify_cases(cases) timestep <- dates_to_timesteps( @@ -15,12 +15,15 @@ prepare_inputs <- function(cases, reference_date, group) { group <- rep(NA, length(cases)) } - data.frame( + dat <- data.frame( cases = cases_int, timestep = timestep, reference_date = reference_date, group = group ) + + class(dat) <- c(glue::glue("RtGam_{backend}"), class(dat)) + dat } #' Convert dates to an integer if needed diff --git a/R/fit_model.R b/R/fit_model.R index 59e1d9e..4f69265 100644 --- a/R/fit_model.R +++ b/R/fit_model.R @@ -1,42 +1,30 @@ -#' Fit the RtGam model with {mgcv} -#' -#' Use the pre-prepared model dataset and formula. Supply warnings as needed -#' @noRd -fit_model <- function(data, formula, backend) { - args <- args_constructor(data, formula, backend) - call <- call_constructor(backend) - - do.call( - call, - args - ) +fit_model <- function(data, formula, user_supplied_args) { + UseMethod("fit_model") } -args_constructor <- function(data, formula, backend) { - backend_agnostic_args <- list( +#' @export +fit_model.RtGam_gam <- function(data, formula, user_supplied_args) { + default_args <- list( formula = formula, data = data, - # Negative binomial family with overdispersion param estimated - family = "nb" + family = "nb", + method = "REML" ) - if (backend == "gam") { - backend_specific_args <- list( - method = "REML" - ) - } else if (backend == "bam") { - backend_specific_args <- list( - method = "fREML", - discrete = TRUE - ) - } else { - cli::cli_abort("Other backends not yet implemented") - } + args <- modifyList(default_args, user_supplied_args) - c(backend_agnostic_args, backend_specific_args) + do.call(mgcv::gam, args) } -call_constructor <- function(backend) { - # This is where we could implement {brms} or mgcv::ginla() at some point - func <- paste0("mgcv::", backend) - eval(parse(text = func)) +#' @export +fit_model.RtGam_bam <- function(data, formula, user_supplied_args) { + default_args <- list( + formula = formula, + data = data, + family = "nb", + method = "fREML", + discrete = TRUE + ) + args <- modifyList(default_args, user_supplied_args) + + do.call(mgcv::bam, args) } diff --git a/man/RtGam.Rd b/man/RtGam.Rd index 44beb90..ac6cd43 100644 --- a/man/RtGam.Rd +++ b/man/RtGam.Rd @@ -10,7 +10,8 @@ RtGam( group = NULL, k = smooth_dim_heuristic(length(cases)), m = penalty_dim_heuristic(length(unique(reference_date))), - backend = "gam" + backend = "gam", + user_supplied_args = list() ) } \arguments{ diff --git a/man/prepare_inputs.Rd b/man/dataset_creator.Rd similarity index 52% rename from man/prepare_inputs.Rd rename to man/dataset_creator.Rd index ee6f34b..40400d1 100644 --- a/man/prepare_inputs.Rd +++ b/man/dataset_creator.Rd @@ -1,10 +1,10 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/prepare_inputs.R -\name{prepare_inputs} -\alias{prepare_inputs} +% Please edit documentation in R/dataset_creator.R +\name{dataset_creator} +\alias{dataset_creator} \title{Parse input vectors into a format for \code{{mgcv}}} \usage{ -prepare_inputs(cases, reference_date, group) +dataset_creator(cases, reference_date, group, backend) } \arguments{ \item{cases}{A vector of non-negative incident case counts occurring on an @@ -16,6 +16,12 @@ once.} \item{group}{The grouping variable for the case/reference-date pair. Not yet implemented and a value other than \code{NULL} will throw an error.} + +\item{backend}{One of \code{gam} or \code{bam}; defaults to \code{gam}. In general, models +should be fit with \code{\link[mgcv:gam]{mgcv::gam()}}. If \code{\link[mgcv:gam]{mgcv::gam()}} is too slow, +\code{\link[mgcv:bam]{mgcv::bam()}} converges more quickly but introduces some additional +numerical error. Note that the \code{bam} backend uses the \code{discrete = TRUE} +option for an additional speedup. See \code{\link[mgcv:bam]{mgcv::bam()}} for more information.} } \value{ A dataframe for mgcv diff --git a/man/dates_to_timesteps.Rd b/man/dates_to_timesteps.Rd index 5760497..54e343b 100644 --- a/man/dates_to_timesteps.Rd +++ b/man/dates_to_timesteps.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/prepare_inputs.R +% Please edit documentation in R/dataset_creator.R \name{dates_to_timesteps} \alias{dates_to_timesteps} \title{Convert an arbitrary vector of dates to a vector of timesteps} diff --git a/tests/testthat/test-fit_model.R b/tests/testthat/test-fit_model.R index f1c23c5..f033b97 100644 --- a/tests/testthat/test-fit_model.R +++ b/tests/testthat/test-fit_model.R @@ -1,40 +1,19 @@ -test_that("fit_model() fits a model", { +test_that("fit_model.RtGam_gam fits a model", { data <- data.frame(x = 1:20, y = rnbinom(20, mu = 1:20, size = 1)) + class(data) <- c("RtGam_gam", class(data)) formula <- y ~ 1 + s(x) - fit_gam <- fit_model(data, formula, backend = "gam") - fit_bam <- fit_model(data, formula, backend = "bam") + fit_gam <- fit_model(data, formula, list()) expect_s3_class(fit_gam, "gam") - expect_s3_class(fit_bam, "bam") }) -test_that("arg_constructor returns fitting args", { - data <- data.frame(x = 1, y = 2) - formula <- y ~ x - backend <- "gam" - - args <- args_constructor(data, formula, backend) - - expect_type(args$formula, "language") - expect_s3_class(args$data, "data.frame") - expect_equal(args$family, "nb") - - # `mgcv::gam()` backend - expect_equal(args$method, "REML") - - # `mgcv::bam()` backend - args_bam <- args_constructor(data, formula, backend = "bam") - expect_equal(args_bam$method, "fREML") - - # Other backend - expect_error(args_constructor(data, formula, "not_a_backend")) -}) +test_that("fit_model.RtGam_bam fits a model", { + data <- data.frame(x = 1:20, y = rnbinom(20, mu = 1:20, size = 1)) + class(data) <- c("RtGam_bam", class(data)) + formula <- y ~ 1 + s(x) -test_that("call_constructor returns a call", { - backend <- "gam" + fit_gam <- fit_model(data, formula, list()) - call <- call_constructor(backend) - expect_type(call, "closure") - expect_equal(call, mgcv::gam) + expect_s3_class(fit_gam, "bam") }) diff --git a/tests/testthat/test-prepare_inputs.R b/tests/testthat/test-prepare_inputs.R index 65a436d..f6158a3 100644 --- a/tests/testthat/test-prepare_inputs.R +++ b/tests/testthat/test-prepare_inputs.R @@ -10,7 +10,9 @@ test_that("Dataframe constructed appropriately", { reference_date = reference_date, group = rep(NA, 3) ) - actual <- prepare_inputs(cases, reference_date, NULL) + class(expected) <- c("RtGam_gam", class(expected)) + + actual <- dataset_creator(cases, reference_date, NULL, "gam") expect_equal(actual, expected) # With groups @@ -21,7 +23,9 @@ test_that("Dataframe constructed appropriately", { reference_date = reference_date, group = group ) - actual <- prepare_inputs(cases, reference_date, group) + class(expected) <- c("RtGam_gam", class(expected)) + + actual <- dataset_creator(cases, reference_date, group, backend = "gam") expect_equal(actual, expected) }) From 6c909e1aee2d491f010da046d7204348b5b859de Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Sat, 6 Jul 2024 03:00:38 +0000 Subject: [PATCH 05/24] Add doc for missing param --- R/RtGam.R | 7 +++++-- man/RtGam.Rd | 8 ++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/R/RtGam.R b/R/RtGam.R index 6ce0400..cf28481 100644 --- a/R/RtGam.R +++ b/R/RtGam.R @@ -37,9 +37,12 @@ #' [mgcv::bam()] converges more quickly but introduces some additional #' numerical error. Note that the `bam` backend uses the `discrete = TRUE` #' option for an additional speedup. See [mgcv::bam()] for more information. +#' @param user_supplied_args A list of custom arguments to pass to the model +#' fitting backend to override package defaults. #' @seealso [smooth_dim_heuristic()] more information on the smoothing basis -#' dimension and [mgcv::choose.k] for more general guidance on GAMs from -#' `mgcv` +#' dimension, [mgcv::choose.k] for more general guidance on GAMs from `mgcv`, +#' and [mgcv::gam]/[mgcv::bam] for documentation on arguments to the model +#' fitting functions. #' @return Stub function: NULL #' @export #' @examples diff --git a/man/RtGam.Rd b/man/RtGam.Rd index ac6cd43..d9a547a 100644 --- a/man/RtGam.Rd +++ b/man/RtGam.Rd @@ -44,6 +44,9 @@ should be fit with \code{\link[mgcv:gam]{mgcv::gam()}}. If \code{\link[mgcv:gam] \code{\link[mgcv:bam]{mgcv::bam()}} converges more quickly but introduces some additional numerical error. Note that the \code{bam} backend uses the \code{discrete = TRUE} option for an additional speedup. See \code{\link[mgcv:bam]{mgcv::bam()}} for more information.} + +\item{user_supplied_args}{A list of custom arguments to pass to the model +fitting backend to override package defaults.} } \value{ Stub function: NULL @@ -69,6 +72,7 @@ mod <- RtGam::RtGam(cases, reference_date) } \seealso{ \code{\link[=smooth_dim_heuristic]{smooth_dim_heuristic()}} more information on the smoothing basis -dimension and \link[mgcv:choose.k]{mgcv::choose.k} for more general guidance on GAMs from -\code{mgcv} +dimension, \link[mgcv:choose.k]{mgcv::choose.k} for more general guidance on GAMs from \code{mgcv}, +and \link[mgcv:gam]{mgcv::gam}/\link[mgcv:bam]{mgcv::bam} for documentation on arguments to the model +fitting functions. } From 5e4de380603bede09d7194aa4b0b23dea8df7526 Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Sat, 6 Jul 2024 03:02:06 +0000 Subject: [PATCH 06/24] Explicitly namespace `modifyList()` --- R/fit_model.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/fit_model.R b/R/fit_model.R index 4f69265..07cc49c 100644 --- a/R/fit_model.R +++ b/R/fit_model.R @@ -10,7 +10,7 @@ fit_model.RtGam_gam <- function(data, formula, user_supplied_args) { family = "nb", method = "REML" ) - args <- modifyList(default_args, user_supplied_args) + args <- utils::modifyList(default_args, user_supplied_args) do.call(mgcv::gam, args) } @@ -24,7 +24,7 @@ fit_model.RtGam_bam <- function(data, formula, user_supplied_args) { method = "fREML", discrete = TRUE ) - args <- modifyList(default_args, user_supplied_args) + args <- utils::modifyList(default_args, user_supplied_args) do.call(mgcv::bam, args) } From 31f4196dbf3c6dcfc799ea734ab1e64d646da0a7 Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Sat, 6 Jul 2024 03:16:09 +0000 Subject: [PATCH 07/24] Clarify documentation --- R/RtGam.R | 2 ++ R/formula.R | 12 +++++++++++- man/penalty_dim_heuristic.Rd | 2 ++ 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/R/RtGam.R b/R/RtGam.R index cf28481..20fe945 100644 --- a/R/RtGam.R +++ b/R/RtGam.R @@ -214,6 +214,8 @@ smooth_dim_heuristic <- function(n) { #' ## Very slow #' #' Decreasing the penalty basis dimension makes the model less demanding to fit. +#' `mgcv` describes an adaptive penalty with 10 basis dimensions and 200 data +#' points as roughly equivalent to fitting 10 GAMs each from 20 data points. #' Using a single penalty throughout the model is much simpler than using an #' adaptive smooth and should be preferred where possible. See #' `[mgcv::smooth.construct.ad.smooth.spec]` for more information on how the diff --git a/R/formula.R b/R/formula.R index e29028c..05e2ac6 100644 --- a/R/formula.R +++ b/R/formula.R @@ -1,4 +1,4 @@ -#' Build formula for `mgcv::gam()` +#' Build formula for model fitting backend #' #' Build up the formula as a string and and return a formula object meant for #' use by [`mgcv::gam()`]. The formula components are built up based on the @@ -60,11 +60,21 @@ smooth_basis_creator <- function(k) { } #' Issue warnings if parameterization allowed but suboptimal +#' #' @noRd warn_for_suboptimal_params <- function(data, m, k) { n_unique_date <- length(unique(data[["timepoint"]])) total_dim <- nrow(data) + # From mgcv: "Bear in mind that adaptive smoothing places quite severe demands + # on the data. For example, setting ‘m=10’ for a univariate smooth of 200 data + # is rather like estimating 10 smoothing parameters, each from a data series + # of length 20. The problem is particularly serious for smooths of 2 + # variables, where the number of smoothing parameters required to get + # reasonable flexibility in the penalty can grow rather fast, but it often + # requires a very large smoothing basis dimension to make good use of this + # flexibility. In short, adaptive smooths should be used sparingly and with + # care." if (m / n_unique_date > 0.2) { cli::cli_warn( c("Using {m} penalty bases with {n_unique_date} dates supplied", diff --git a/man/penalty_dim_heuristic.Rd b/man/penalty_dim_heuristic.Rd index 4739bb0..04903b2 100644 --- a/man/penalty_dim_heuristic.Rd +++ b/man/penalty_dim_heuristic.Rd @@ -40,6 +40,8 @@ without oversmoothing or introducing spurious wiggly trends. \subsection{Very slow}{ Decreasing the penalty basis dimension makes the model less demanding to fit. +\code{mgcv} describes an adaptive penalty with 10 basis dimensions and 200 data +points as roughly equivalent to fitting 10 GAMs each from 20 data points. Using a single penalty throughout the model is much simpler than using an adaptive smooth and should be preferred where possible. See \verb{[mgcv::smooth.construct.ad.smooth.spec]} for more information on how the From becc4a3d49bb89543015a57c7282febb0631f446 Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Sat, 6 Jul 2024 03:23:15 +0000 Subject: [PATCH 08/24] Test warnings throw for suboptimal params --- tests/testthat/test-formula.R | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/testthat/test-formula.R b/tests/testthat/test-formula.R index 5cae9b0..27dc3bf 100644 --- a/tests/testthat/test-formula.R +++ b/tests/testthat/test-formula.R @@ -32,3 +32,20 @@ test_that("Smooth basis dim created successfully", { expect_equal(c("global_trend"), names(smooth_basis_dim)) expect_equal(15, smooth_basis_dim[["global_trend"]]) }) + + +test_that("`warn_for_suboptimal_params()` throws warning", { + dat <- data.frame( + cases = 1:10, + reference_date = seq.Date( + from = as.Date("2023-12-1"), + length.out = 10, + by = "day" + ), + timepoint = 1:10 / 10 + ) + k <- 1 + + expect_warning(warn_for_suboptimal_params(dat, k, m = 10)) + expect_null(warn_for_suboptimal_params(dat, k, m = 1)) +}) From 652ff5b29c58aa441e9877ac03c3b0c28034861e Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Wed, 10 Jul 2024 17:48:55 +0000 Subject: [PATCH 09/24] Default args in S3 methods w/ user-supplied in ... It makes documentation slightly more straightforward --- R/RtGam.R | 10 +++++---- R/fit_model.R | 37 +++++++++++++++++++-------------- man/RtGam.Rd | 7 ++++--- tests/testthat/test-fit_model.R | 15 +++++++++++-- 4 files changed, 44 insertions(+), 25 deletions(-) diff --git a/R/RtGam.R b/R/RtGam.R index 20fe945..f811da4 100644 --- a/R/RtGam.R +++ b/R/RtGam.R @@ -37,8 +37,10 @@ #' [mgcv::bam()] converges more quickly but introduces some additional #' numerical error. Note that the `bam` backend uses the `discrete = TRUE` #' option for an additional speedup. See [mgcv::bam()] for more information. -#' @param user_supplied_args A list of custom arguments to pass to the model -#' fitting backend to override package defaults. +#' @param ... Additional arguments passed to the specified modelling backend. +#' For example, the default negative binomial error structure could be changed +#' to poisson in the default [mgcv::gam] backend by passing `family = +#' "poisson"`. #' @seealso [smooth_dim_heuristic()] more information on the smoothing basis #' dimension, [mgcv::choose.k] for more general guidance on GAMs from `mgcv`, #' and [mgcv::gam]/[mgcv::bam] for documentation on arguments to the model @@ -55,7 +57,7 @@ RtGam <- function(cases, k = smooth_dim_heuristic(length(cases)), m = penalty_dim_heuristic(length(unique(reference_date))), backend = "gam", - user_supplied_args = list()) { + ...) { check_required_inputs_provided( cases, reference_date, @@ -73,7 +75,7 @@ RtGam <- function(cases, is_grouped = !rlang::is_null(group) ) - fit <- fit_model(df, formula, user_supplied_args) + fit <- fit_model(df, formula, ...) invisible(NULL) } diff --git a/R/fit_model.R b/R/fit_model.R index 07cc49c..041ac0a 100644 --- a/R/fit_model.R +++ b/R/fit_model.R @@ -1,30 +1,35 @@ -fit_model <- function(data, formula, user_supplied_args) { +fit_model <- function(data, formula, ...) { UseMethod("fit_model") } #' @export -fit_model.RtGam_gam <- function(data, formula, user_supplied_args) { - default_args <- list( - formula = formula, - data = data, +fit_model.RtGam_gam <- function( + data, + formula, family = "nb", - method = "REML" - ) - args <- utils::modifyList(default_args, user_supplied_args) + method = "REML", + ...) { + # Override the defaults in formals with the user-supplied args in dots + formal_arg_names <- Filter(function(x) x != "...", names(formals())) + formals <- as.list(environment())[formal_arg_names] + dots <- rlang::list2(...) + args <- utils::modifyList(formals, dots) do.call(mgcv::gam, args) } #' @export -fit_model.RtGam_bam <- function(data, formula, user_supplied_args) { - default_args <- list( - formula = formula, - data = data, +fit_model.RtGam_bam <- function( + data, + formula, family = "nb", method = "fREML", - discrete = TRUE - ) - args <- utils::modifyList(default_args, user_supplied_args) - + discrete = TRUE, + ...) { + # Override the defaults in formals with the user-supplied args in dots + formal_arg_names <- Filter(function(x) x != "...", names(formals())) + formals <- as.list(environment())[formal_arg_names] + dots <- rlang::list2(...) + args <- utils::modifyList(formals, dots) do.call(mgcv::bam, args) } diff --git a/man/RtGam.Rd b/man/RtGam.Rd index d9a547a..6ad7ec6 100644 --- a/man/RtGam.Rd +++ b/man/RtGam.Rd @@ -11,7 +11,7 @@ RtGam( k = smooth_dim_heuristic(length(cases)), m = penalty_dim_heuristic(length(unique(reference_date))), backend = "gam", - user_supplied_args = list() + ... ) } \arguments{ @@ -45,8 +45,9 @@ should be fit with \code{\link[mgcv:gam]{mgcv::gam()}}. If \code{\link[mgcv:gam] numerical error. Note that the \code{bam} backend uses the \code{discrete = TRUE} option for an additional speedup. See \code{\link[mgcv:bam]{mgcv::bam()}} for more information.} -\item{user_supplied_args}{A list of custom arguments to pass to the model -fitting backend to override package defaults.} +\item{...}{Additional arguments passed to the specified modelling backend. +For example, the default negative binomial error structure could be changed +to poisson in the default \link[mgcv:gam]{mgcv::gam} backend by passing \code{family = "poisson"}.} } \value{ Stub function: NULL diff --git a/tests/testthat/test-fit_model.R b/tests/testthat/test-fit_model.R index f033b97..fca8962 100644 --- a/tests/testthat/test-fit_model.R +++ b/tests/testthat/test-fit_model.R @@ -3,7 +3,7 @@ test_that("fit_model.RtGam_gam fits a model", { class(data) <- c("RtGam_gam", class(data)) formula <- y ~ 1 + s(x) - fit_gam <- fit_model(data, formula, list()) + fit_gam <- fit_model(data, formula) expect_s3_class(fit_gam, "gam") }) @@ -13,7 +13,18 @@ test_that("fit_model.RtGam_bam fits a model", { class(data) <- c("RtGam_bam", class(data)) formula <- y ~ 1 + s(x) - fit_gam <- fit_model(data, formula, list()) + fit_gam <- fit_model(data, formula) expect_s3_class(fit_gam, "bam") }) + +test_that("... passes modified arg to `fit_model()`", { + data <- data.frame(x = 1:20, y = rnbinom(20, mu = 1:20, size = 1)) + class(data) <- c("RtGam_gam", class(data)) + formula <- y ~ 1 + s(x) + + fit_gam <- fit_model(data, formula, family = "poisson") + + expect_s3_class(fit_gam, "gam") + expect_true(grepl("poisson", fit_gam$family$family)) +}) From 17f48d47ba089e7728ede4031d80ec6de24385e9 Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Wed, 10 Jul 2024 18:22:56 +0000 Subject: [PATCH 10/24] Move backend check from input val to S3 dispatch Allowing future extensions. --- NAMESPACE | 1 + R/checkers.R | 6 ------ R/fit_model.R | 17 +++++++++++++++++ man/fit_model.default.Rd | 11 +++++++++++ tests/testthat/test-RtGam.R | 11 +++++++++++ tests/testthat/test-checkers.R | 12 ------------ tests/testthat/test-fit_model.R | 10 ++++++++++ 7 files changed, 50 insertions(+), 18 deletions(-) create mode 100644 man/fit_model.default.Rd diff --git a/NAMESPACE b/NAMESPACE index 78f24fd..78e05a6 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -2,6 +2,7 @@ S3method(fit_model,RtGam_bam) S3method(fit_model,RtGam_gam) +S3method(fit_model,default) export(RtGam) export(penalty_dim_heuristic) export(smooth_dim_heuristic) diff --git a/R/checkers.R b/R/checkers.R index 5dbc3ad..174c81b 100644 --- a/R/checkers.R +++ b/R/checkers.R @@ -110,12 +110,6 @@ check_required_inputs_provided <- function(cases, rlang::check_required(group, "group", call = call) rlang::check_required(k, "k", call = call) rlang::check_required(m, "m", call = call) - rlang::arg_match(backend, - values = c("gam", "bam"), - error_arg = "backend", - error_call = call, - multiple = FALSE - ) invisible() } diff --git a/R/fit_model.R b/R/fit_model.R index 041ac0a..99153e0 100644 --- a/R/fit_model.R +++ b/R/fit_model.R @@ -33,3 +33,20 @@ fit_model.RtGam_bam <- function( args <- utils::modifyList(formals, dots) do.call(mgcv::bam, args) } + +#' Used to throw informative error if non-supported backend supplied +#' @export +fit_model.default <- function( + data, + formula, + ...) { + requested_backend <- class(data)[1] + supported_backends <- c("gam", "bam") + + cli::cli_abort( + c("Requested {.field backend} {.val {requested_backend}} not supported", + "!" = "Allowed backends: {.val {supported_backends}}" + ), + class = "RtGam_invalid_input" + ) +} diff --git a/man/fit_model.default.Rd b/man/fit_model.default.Rd new file mode 100644 index 0000000..fbc3926 --- /dev/null +++ b/man/fit_model.default.Rd @@ -0,0 +1,11 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/fit_model.R +\name{fit_model.default} +\alias{fit_model.default} +\title{Used to throw informative error if non-supported backend supplied} +\usage{ +\method{fit_model}{default}(data, formula, ...) +} +\description{ +Used to throw informative error if non-supported backend supplied +} diff --git a/tests/testthat/test-RtGam.R b/tests/testthat/test-RtGam.R index 292ea08..e0191c8 100644 --- a/tests/testthat/test-RtGam.R +++ b/tests/testthat/test-RtGam.R @@ -6,3 +6,14 @@ test_that("RtGam parses inputs successfully", { expect_null(RtGam(cases, dates, group)) }) + +test_that("Non-supported backends throw error", { + timesteps <- 20 + cases <- 1:timesteps + dates <- as.Date("2023-01-01") + 1:timesteps + group <- NULL + + expect_error(RtGam(cases, dates, group, backend = "test"), + class = "RtGam_invalid_input" + ) +}) diff --git a/tests/testthat/test-checkers.R b/tests/testthat/test-checkers.R index 5542853..21bdc68 100644 --- a/tests/testthat/test-checkers.R +++ b/tests/testthat/test-checkers.R @@ -151,18 +151,6 @@ test_that("Required input check works", { ), class = "rlang_error" ) - expect_error( - check_required_inputs_provided( - cases = cases, - reference_date = reference_date, - group = group, - k = k, - m = m, - backend = "not_a_real_backend", - call = NULL - ), - class = "rlang_error" - ) expect_null( check_required_inputs_provided( cases = cases, diff --git a/tests/testthat/test-fit_model.R b/tests/testthat/test-fit_model.R index fca8962..9f704e2 100644 --- a/tests/testthat/test-fit_model.R +++ b/tests/testthat/test-fit_model.R @@ -28,3 +28,13 @@ test_that("... passes modified arg to `fit_model()`", { expect_s3_class(fit_gam, "gam") expect_true(grepl("poisson", fit_gam$family$family)) }) + +test_that("Unsupported backend errors", { + data <- data.frame(x = 1:20, y = rnbinom(20, mu = 1:20, size = 1)) + class(data) <- c("RtGam_test", class(data)) + formula <- y ~ 1 + s(x) + + expect_error(fit_model(data, formula), + class = "RtGam_invalid_input" + ) +}) From 0ca92eb9c4b9118b0bd69e933bd9dadb7f966acc Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Mon, 26 Aug 2024 22:57:15 +0000 Subject: [PATCH 11/24] Whitespace --- R/formula.R | 1 - 1 file changed, 1 deletion(-) diff --git a/R/formula.R b/R/formula.R index 05e2ac6..b46557b 100644 --- a/R/formula.R +++ b/R/formula.R @@ -84,6 +84,5 @@ warn_for_suboptimal_params <- function(data, m, k) { ) } - invisible() } From 60365e01ecd3d6dc4c01279994beda549c19570a Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Mon, 26 Aug 2024 23:39:23 +0000 Subject: [PATCH 12/24] Move do.call() outside of fit_model() As suggested by @seabbs. This allows ... to be evaluated and handles dispatch. --- R/RtGam.R | 11 +++++++++-- R/fit_model.R | 27 +++++++++++++++------------ man/RtGam.Rd | 2 +- 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/R/RtGam.R b/R/RtGam.R index f811da4..e7f7cdd 100644 --- a/R/RtGam.R +++ b/R/RtGam.R @@ -48,7 +48,7 @@ #' @return Stub function: NULL #' @export #' @examples -#' cases <- c(1, 2, 3) +#' cases <- c(1L, 2L, 3L) #' reference_date <- as.Date(c("2023-01-01", "2023-01-02", "2023-01-03")) #' mod <- RtGam::RtGam(cases, reference_date) RtGam <- function(cases, @@ -75,7 +75,14 @@ RtGam <- function(cases, is_grouped = !rlang::is_null(group) ) - fit <- fit_model(df, formula, ...) + fit <- do.call( + fit_model, + list( + data = df, + formula = formula, + ... + ) + ) invisible(NULL) } diff --git a/R/fit_model.R b/R/fit_model.R index 99153e0..cfc1381 100644 --- a/R/fit_model.R +++ b/R/fit_model.R @@ -10,12 +10,13 @@ fit_model.RtGam_gam <- function( method = "REML", ...) { # Override the defaults in formals with the user-supplied args in dots - formal_arg_names <- Filter(function(x) x != "...", names(formals())) - formals <- as.list(environment())[formal_arg_names] - dots <- rlang::list2(...) - args <- utils::modifyList(formals, dots) - - do.call(mgcv::gam, args) + mgcv::gam( + formula = formula, + family = family, + data = data, + method = method, + ... + ) } #' @export @@ -26,12 +27,14 @@ fit_model.RtGam_bam <- function( method = "fREML", discrete = TRUE, ...) { - # Override the defaults in formals with the user-supplied args in dots - formal_arg_names <- Filter(function(x) x != "...", names(formals())) - formals <- as.list(environment())[formal_arg_names] - dots <- rlang::list2(...) - args <- utils::modifyList(formals, dots) - do.call(mgcv::bam, args) + mgcv::bam( + formula = formula, + fmaily = family, + data = data, + method = method, + discrete = discrete, + ... + ) } #' Used to throw informative error if non-supported backend supplied diff --git a/man/RtGam.Rd b/man/RtGam.Rd index 6ad7ec6..a6b1f68 100644 --- a/man/RtGam.Rd +++ b/man/RtGam.Rd @@ -67,7 +67,7 @@ function of time. } \examples{ -cases <- c(1, 2, 3) +cases <- c(1L, 2L, 3L) reference_date <- as.Date(c("2023-01-01", "2023-01-02", "2023-01-03")) mod <- RtGam::RtGam(cases, reference_date) } From 58deaf059ece3320665a8c51252e3299defa26fb Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Tue, 27 Aug 2024 21:00:49 +0000 Subject: [PATCH 13/24] Dynamically find methods for `fit_model()` --- R/fit_model.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/R/fit_model.R b/R/fit_model.R index cfc1381..77eec2a 100644 --- a/R/fit_model.R +++ b/R/fit_model.R @@ -44,7 +44,9 @@ fit_model.default <- function( formula, ...) { requested_backend <- class(data)[1] - supported_backends <- c("gam", "bam") + all_backends <- methods(fit_model) + # Drop fit_model.default + supported_backends <- all_backends[!(all_backends == "fit_model.default")] cli::cli_abort( c("Requested {.field backend} {.val {requested_backend}} not supported", From 705f22e31a6c37f18dbdb53f2d43b6b0d89d7479 Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Tue, 27 Aug 2024 01:54:05 +0000 Subject: [PATCH 14/24] Minimal working print method + RtGam() return Has basic info but still needs model fit diagnostics --- NAMESPACE | 1 + R/RtGam.R | 10 ++++++- R/fit_model.R | 4 +-- R/print.R | 64 ++++++++++++++++++++++++++++++++++++++++ man/format_for_return.Rd | 19 ++++++++++++ man/print.RtGam.Rd | 17 +++++++++++ 6 files changed, 112 insertions(+), 3 deletions(-) create mode 100644 R/print.R create mode 100644 man/format_for_return.Rd create mode 100644 man/print.RtGam.Rd diff --git a/NAMESPACE b/NAMESPACE index 78e05a6..006d3c2 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -3,6 +3,7 @@ S3method(fit_model,RtGam_bam) S3method(fit_model,RtGam_gam) S3method(fit_model,default) +S3method(print,RtGam) export(RtGam) export(penalty_dim_heuristic) export(smooth_dim_heuristic) diff --git a/R/RtGam.R b/R/RtGam.R index e7f7cdd..7ee0b3d 100644 --- a/R/RtGam.R +++ b/R/RtGam.R @@ -84,7 +84,15 @@ RtGam <- function(cases, ) ) - invisible(NULL) + format_for_return( + fit = fit, + df = df, + group = group, + k = k, + m = m, + backend = backend, + formula = formula + ) } #' Propose total smoothing basis dimension from number of data points diff --git a/R/fit_model.R b/R/fit_model.R index 77eec2a..dfb8c22 100644 --- a/R/fit_model.R +++ b/R/fit_model.R @@ -29,7 +29,7 @@ fit_model.RtGam_bam <- function( ...) { mgcv::bam( formula = formula, - fmaily = family, + family = family, data = data, method = method, discrete = discrete, @@ -50,7 +50,7 @@ fit_model.default <- function( cli::cli_abort( c("Requested {.field backend} {.val {requested_backend}} not supported", - "!" = "Allowed backends: {.val {supported_backends}}" + "!" = "Supported backends: {.val {supported_backends}}" ), class = "RtGam_invalid_input" ) diff --git a/R/print.R b/R/print.R new file mode 100644 index 0000000..5740f87 --- /dev/null +++ b/R/print.R @@ -0,0 +1,64 @@ +#' Format the RtGam object for return from the main function/constructor +#' +#' @param fit The model fit created by [fit_model] +#' @param df The dataset created by [dataset_creator] +#' +#' @return An object of type RtGam +format_for_return <- function(fit, df, group, k, m, backend, formula) { + formatted <- list( + model = fit, + data = df, + min_date = min(df[["reference_date"]]), + max_data = max(df[["reference_date"]]), + k = k, + m = m, + backend = backend, + formula = formula + ) + + structure(formatted, class = "RtGam") +} + +#' Print an RtGam object +#' +#' @param x Fitted model object of class RtGam +#' +#' @return The RtGam object, invisibly +#' @export +print.RtGam <- function(x, ...) { + cat("===============================\n") + cat("Fitted RtGam model object (") + cat(x$backend) + cat(")\n") + cat("===============================\n\n") + + cat("Model type: ") + if (x$m > 1) { + cat("Adaptive (m = ") + } else { + cat("Non-adaptive (m = ") + } + cat(x$m) + cat(")\n") + + cat("Total smoothing basis dimension: ") + cat(x$k) + + # TODO estimated edf & diagnostics + + cat("\n===============================\n") + cat("\nObserved data points: ") + cat(nrow(x$data)) + + cat("\nDistinct reference dates: ") + cat(length(unique(x[["data"]][["reference_date"]]))) + + cat("\nDistinct groups: ") + if (rlang::is_null(x[["data"]][["group"]][[1]])) { + cat("1") + } else { + cat(length(unique(x[["data"]][["group"]]))) + } + cat("\n\n") + invisible(x) +} diff --git a/man/format_for_return.Rd b/man/format_for_return.Rd new file mode 100644 index 0000000..d7498b8 --- /dev/null +++ b/man/format_for_return.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/print.R +\name{format_for_return} +\alias{format_for_return} +\title{Format the RtGam object for return from the main function/constructor} +\usage{ +format_for_return(fit, df, group, k, m, backend, formula) +} +\arguments{ +\item{fit}{The model fit created by \link{fit_model}} + +\item{df}{The dataset created by \link{dataset_creator}} +} +\value{ +An object of type RtGam +} +\description{ +Format the RtGam object for return from the main function/constructor +} diff --git a/man/print.RtGam.Rd b/man/print.RtGam.Rd new file mode 100644 index 0000000..b2edbac --- /dev/null +++ b/man/print.RtGam.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/print.R +\name{print.RtGam} +\alias{print.RtGam} +\title{Print an RtGam object} +\usage{ +\method{print}{RtGam}(x, ...) +} +\arguments{ +\item{x}{Fitted model object of class RtGam} +} +\value{ +The RtGam object, invisibly +} +\description{ +Print an RtGam object +} From ca7084306c2dbefffda820c9c28502de75e3ff95 Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Tue, 27 Aug 2024 20:55:14 +0000 Subject: [PATCH 15/24] Add some basic diagnostic checks --- NAMESPACE | 1 + R/RtGam.R | 7 +++- R/diagnostics.R | 82 ++++++++++++++++++++++++++++++++++++++++ man/check_diagnostics.Rd | 32 ++++++++++++++++ 4 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 R/diagnostics.R create mode 100644 man/check_diagnostics.Rd diff --git a/NAMESPACE b/NAMESPACE index 006d3c2..a3f1825 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -5,6 +5,7 @@ S3method(fit_model,RtGam_gam) S3method(fit_model,default) S3method(print,RtGam) export(RtGam) +export(check_diagnostics) export(penalty_dim_heuristic) export(smooth_dim_heuristic) importFrom(rlang,abort) diff --git a/R/RtGam.R b/R/RtGam.R index 7ee0b3d..6f43d7f 100644 --- a/R/RtGam.R +++ b/R/RtGam.R @@ -83,6 +83,10 @@ RtGam <- function(cases, ... ) ) + diagnostics <- calculate_diagnostics(fit) + if (warn_for_diagnostic_failure) { + issue_diagnostic_warnings(diagnostics) + } format_for_return( fit = fit, @@ -91,7 +95,8 @@ RtGam <- function(cases, k = k, m = m, backend = backend, - formula = formula + formula = formula, + diagnostics = diagnostics ) } diff --git a/R/diagnostics.R b/R/diagnostics.R new file mode 100644 index 0000000..c2db387 --- /dev/null +++ b/R/diagnostics.R @@ -0,0 +1,82 @@ +#' Check diagnostics from a fitted model +#' +#' @param fit A fitted `RtGam` model object +#' @param warn_for_diagnostic_failure A bool, whether to issue warnings for +#' potential diagnostic issues. +#' +#' @return Invisibly, a list of diagnostics. This diagnostic list is also +#' present in the model object under `diagnostics`. +#' @export +#' +#' @examples +#' fit <- RtGam::RtGam( +#' cases = c(1L, 2L, 3L, 4L), +#' reference_date = as.Date(c( +#' "2023-01-01", +#' "2023-01-02", +#' "2023-01-03" +#' )) +#' ) +#' check_diagnostics(fit) +check_diagnostics <- function(fit, warn_for_diagnostic_failure = TRUE) { + diagnostics <- calculate_diagnostics(fit[["model"]]) + if (warn_for_diagnostic_failure) { + issue_diagnostic_warnings(diagnostics) + } + invisible(diagnostics) +} + +calculate_diagnostics <- function(fit) { + converged <- fit$converged + k_check <- mgcv::k.check(fit) + max_lag <- min(7, round(nrow(fit$model) / 7)) + rho <- acf(fit$residuals, plot = FALSE, lag.max = max_lag)[[1]][, , 1] + + list( + model_converged = converged, + k_edf = k_check[2], + k_index = k_check[3], + k_p_value = k_check[4], + k_to_edf_ratio = k_check[2] / k_check[1], + residual_autocorrelation = rho[2:length(rho)] + ) +} + +issue_diagnostic_warnings <- function(diagnostics) { + if (!diagnostics[["model_converged"]]) { + cli::cli_alert_danger( + c("Model failed to converge. Inference is not reliable.") + ) + } + if (diagnostics[["k_to_edf_ratio"]] > 0.9) { + cli::cli_bullets(c( + "x" = "Effective degrees of freedom is near the supplied upper bound", + "!" = "Consider increasing {.arg k}", + "*" = "Actual: {.val {round(diagnostics[['k_edf']], 3)}}", + "*" = "Upper bound: {.val {round(diagnostics[['k\\'']], 3)}}" + )) + } + if (diagnostics[["k_p_value"]] < 0.05) { + cli::cli_bullets( + c( + "!" = "k-index for one or more smooths is below 1", + "*" = "k-index: {.val {round(diagnostics[['k_index']], 3)}}", + "*" = "Associated p-value: {.val {round(diagnostics[['k_p_value']], + 2)}}", + "!" = "Suggests potential unmodeled residual trend. + Inspect model and/or consider increasing {.arg k}" + ) + ) + } + if (any(abs(diagnostics[["residual_autocorrelation"]]) > 0.5)) { + cli::cli_bullets(c( + "x" = "Residual autocorrelation present", + "*" = "Rho: {.val {round(diagnostics[['residual_autocorrelation']], + 2)}}", + "*" = "Inspect manually with {.code acf(residuals(fit$model))}", + "!" = "Consider increasing {.arg k} and/or + specifying {.arg rho} with {.arg backend} bam" + )) + } + invisible(NULL) +} diff --git a/man/check_diagnostics.Rd b/man/check_diagnostics.Rd new file mode 100644 index 0000000..34e67c3 --- /dev/null +++ b/man/check_diagnostics.Rd @@ -0,0 +1,32 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/diagnostics.R +\name{check_diagnostics} +\alias{check_diagnostics} +\title{Check diagnostics from a fitted model} +\usage{ +check_diagnostics(fit, warn_for_diagnostic_failure = TRUE) +} +\arguments{ +\item{fit}{A fitted \code{RtGam} model object} + +\item{warn_for_diagnostic_failure}{A bool, whether to issue warnings for +potential diagnostic issues.} +} +\value{ +Invisibly, a list of diagnostics. This diagnostic list is also +present in the model object under \code{diagnostics}. +} +\description{ +Check diagnostics from a fitted model +} +\examples{ +fit <- RtGam::RtGam( + cases = c(1L, 2L, 3L, 4L), + reference_date = as.Date(c( + "2023-01-01", + "2023-01-02", + "2023-01-03" + )) +) +check_diagnostics(fit) +} From 45625aa3c3b54f2e178d8f731db02ece44688007 Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Wed, 28 Aug 2024 13:41:57 +0000 Subject: [PATCH 16/24] Clean up existing tests --- DESCRIPTION | 3 ++- R/RtGam.R | 18 ++++++++++++++--- R/diagnostics.R | 17 ++++++++-------- R/fit_model.R | 1 + R/print.R | 39 ++++++++++++++++++++++++++----------- man/RtGam.Rd | 19 +++++++++++++++--- man/check_diagnostics.Rd | 15 +++++++------- man/fit_model.default.Rd | 11 ----------- man/format_for_return.Rd | 19 ------------------ man/print.RtGam.Rd | 15 ++++++++++++++ tests/testthat/test-RtGam.R | 8 ++++++-- 11 files changed, 100 insertions(+), 65 deletions(-) delete mode 100644 man/fit_model.default.Rd delete mode 100644 man/format_for_return.Rd diff --git a/DESCRIPTION b/DESCRIPTION index 62bb2d9..e4516ad 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -33,4 +33,5 @@ Imports: cli, glue, mgcv, - rlang + rlang, + withr diff --git a/R/RtGam.R b/R/RtGam.R index 6f43d7f..36fdfbd 100644 --- a/R/RtGam.R +++ b/R/RtGam.R @@ -37,6 +37,10 @@ #' [mgcv::bam()] converges more quickly but introduces some additional #' numerical error. Note that the `bam` backend uses the `discrete = TRUE` #' option for an additional speedup. See [mgcv::bam()] for more information. +#' @param warn_for_diagnostic_failure Should warnings be issued for +#' automatically identified diagnostic issues? Defaults to true. A list of +#' quantitative model diagnostics can be inspected in the `diagnostics` slot +#' of the returned `RtGam` object. #' @param ... Additional arguments passed to the specified modelling backend. #' For example, the default negative binomial error structure could be changed #' to poisson in the default [mgcv::gam] backend by passing `family = @@ -48,15 +52,23 @@ #' @return Stub function: NULL #' @export #' @examples -#' cases <- c(1L, 2L, 3L) -#' reference_date <- as.Date(c("2023-01-01", "2023-01-02", "2023-01-03")) -#' mod <- RtGam::RtGam(cases, reference_date) +#' withr::with_seed(12345, { +#' cases <- rpois(20, 10) +#' }) +#' reference_date <- seq.Date( +#' from = as.Date("2023-01-01"), +#' length.out = 20, +#' by = "day" +#' ) +#' fit <- RtGam::RtGam(cases, reference_date) +#' fit RtGam <- function(cases, reference_date, group = NULL, k = smooth_dim_heuristic(length(cases)), m = penalty_dim_heuristic(length(unique(reference_date))), backend = "gam", + warn_for_diagnostic_failure = TRUE, ...) { check_required_inputs_provided( cases, diff --git a/R/diagnostics.R b/R/diagnostics.R index c2db387..2b778a3 100644 --- a/R/diagnostics.R +++ b/R/diagnostics.R @@ -9,14 +9,15 @@ #' @export #' #' @examples -#' fit <- RtGam::RtGam( -#' cases = c(1L, 2L, 3L, 4L), -#' reference_date = as.Date(c( -#' "2023-01-01", -#' "2023-01-02", -#' "2023-01-03" -#' )) +#' withr::with_seed(12345, { +#' cases <- rpois(20, 10) +#' }) +#' reference_date <- seq.Date( +#' from = as.Date("2023-01-01"), +#' length.out = 20, +#' by = "day" #' ) +#' fit <- RtGam::RtGam(cases, reference_date) #' check_diagnostics(fit) check_diagnostics <- function(fit, warn_for_diagnostic_failure = TRUE) { diagnostics <- calculate_diagnostics(fit[["model"]]) @@ -30,7 +31,7 @@ calculate_diagnostics <- function(fit) { converged <- fit$converged k_check <- mgcv::k.check(fit) max_lag <- min(7, round(nrow(fit$model) / 7)) - rho <- acf(fit$residuals, plot = FALSE, lag.max = max_lag)[[1]][, , 1] + rho <- stats::acf(fit$residuals, plot = FALSE, lag.max = max_lag)[[1]][, , 1] list( model_converged = converged, diff --git a/R/fit_model.R b/R/fit_model.R index dfb8c22..c636577 100644 --- a/R/fit_model.R +++ b/R/fit_model.R @@ -39,6 +39,7 @@ fit_model.RtGam_bam <- function( #' Used to throw informative error if non-supported backend supplied #' @export +#' @noRd fit_model.default <- function( data, formula, diff --git a/R/print.R b/R/print.R index 5740f87..96af5ea 100644 --- a/R/print.R +++ b/R/print.R @@ -1,10 +1,13 @@ #' Format the RtGam object for return from the main function/constructor -#' -#' @param fit The model fit created by [fit_model] -#' @param df The dataset created by [dataset_creator] -#' -#' @return An object of type RtGam -format_for_return <- function(fit, df, group, k, m, backend, formula) { +#' @noRd +format_for_return <- function(fit, + df, + group, + k, + m, + backend, + formula, + diagnostics) { formatted <- list( model = fit, data = df, @@ -13,7 +16,8 @@ format_for_return <- function(fit, df, group, k, m, backend, formula) { k = k, m = m, backend = backend, - formula = formula + formula = formula, + diagnostics ) structure(formatted, class = "RtGam") @@ -22,9 +26,22 @@ format_for_return <- function(fit, df, group, k, m, backend, formula) { #' Print an RtGam object #' #' @param x Fitted model object of class RtGam +#' @param ... further arguments to be passed to or from other methods. They are +#' ignored in this function. #' #' @return The RtGam object, invisibly #' @export +#' @examples +#' withr::with_seed(12345, { +#' cases <- rpois(20, 10) +#' }) +#' reference_date <- seq.Date( +#' from = as.Date("2023-01-01"), +#' length.out = 20, +#' by = "day" +#' ) +#' fit <- RtGam::RtGam(cases, reference_date) +#' print(fit) print.RtGam <- function(x, ...) { cat("===============================\n") cat("Fitted RtGam model object (") @@ -38,11 +55,11 @@ print.RtGam <- function(x, ...) { } else { cat("Non-adaptive (m = ") } - cat(x$m) - cat(")\n") + cat(x$m, ")\n") + cat("Specified maximum smoothing basis dimension: ", x$k, "\n") + cat("Family:", x$model$family$family, "\n") + cat("Link function:", x$model$family$link) - cat("Total smoothing basis dimension: ") - cat(x$k) # TODO estimated edf & diagnostics diff --git a/man/RtGam.Rd b/man/RtGam.Rd index a6b1f68..b7a0a18 100644 --- a/man/RtGam.Rd +++ b/man/RtGam.Rd @@ -11,6 +11,7 @@ RtGam( k = smooth_dim_heuristic(length(cases)), m = penalty_dim_heuristic(length(unique(reference_date))), backend = "gam", + warn_for_diagnostic_failure = TRUE, ... ) } @@ -45,6 +46,11 @@ should be fit with \code{\link[mgcv:gam]{mgcv::gam()}}. If \code{\link[mgcv:gam] numerical error. Note that the \code{bam} backend uses the \code{discrete = TRUE} option for an additional speedup. See \code{\link[mgcv:bam]{mgcv::bam()}} for more information.} +\item{warn_for_diagnostic_failure}{Should warnings be issued for +automatically identified diagnostic issues? Defaults to true. A list of +quantitative model diagnostics can be inspected in the \code{diagnostics} slot +of the returned \code{RtGam} object.} + \item{...}{Additional arguments passed to the specified modelling backend. For example, the default negative binomial error structure could be changed to poisson in the default \link[mgcv:gam]{mgcv::gam} backend by passing \code{family = "poisson"}.} @@ -67,9 +73,16 @@ function of time. } \examples{ -cases <- c(1L, 2L, 3L) -reference_date <- as.Date(c("2023-01-01", "2023-01-02", "2023-01-03")) -mod <- RtGam::RtGam(cases, reference_date) +withr::with_seed(12345, { + cases <- rpois(20, 10) +}) +reference_date <- seq.Date( + from = as.Date("2023-01-01"), + length.out = 20, + by = "day" +) +fit <- RtGam::RtGam(cases, reference_date) +fit } \seealso{ \code{\link[=smooth_dim_heuristic]{smooth_dim_heuristic()}} more information on the smoothing basis diff --git a/man/check_diagnostics.Rd b/man/check_diagnostics.Rd index 34e67c3..6acc521 100644 --- a/man/check_diagnostics.Rd +++ b/man/check_diagnostics.Rd @@ -20,13 +20,14 @@ present in the model object under \code{diagnostics}. Check diagnostics from a fitted model } \examples{ -fit <- RtGam::RtGam( - cases = c(1L, 2L, 3L, 4L), - reference_date = as.Date(c( - "2023-01-01", - "2023-01-02", - "2023-01-03" - )) +withr::with_seed(12345, { + cases <- rpois(20, 10) +}) +reference_date <- seq.Date( + from = as.Date("2023-01-01"), + length.out = 20, + by = "day" ) +fit <- RtGam::RtGam(cases, reference_date) check_diagnostics(fit) } diff --git a/man/fit_model.default.Rd b/man/fit_model.default.Rd deleted file mode 100644 index fbc3926..0000000 --- a/man/fit_model.default.Rd +++ /dev/null @@ -1,11 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/fit_model.R -\name{fit_model.default} -\alias{fit_model.default} -\title{Used to throw informative error if non-supported backend supplied} -\usage{ -\method{fit_model}{default}(data, formula, ...) -} -\description{ -Used to throw informative error if non-supported backend supplied -} diff --git a/man/format_for_return.Rd b/man/format_for_return.Rd deleted file mode 100644 index d7498b8..0000000 --- a/man/format_for_return.Rd +++ /dev/null @@ -1,19 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/print.R -\name{format_for_return} -\alias{format_for_return} -\title{Format the RtGam object for return from the main function/constructor} -\usage{ -format_for_return(fit, df, group, k, m, backend, formula) -} -\arguments{ -\item{fit}{The model fit created by \link{fit_model}} - -\item{df}{The dataset created by \link{dataset_creator}} -} -\value{ -An object of type RtGam -} -\description{ -Format the RtGam object for return from the main function/constructor -} diff --git a/man/print.RtGam.Rd b/man/print.RtGam.Rd index b2edbac..d28fbad 100644 --- a/man/print.RtGam.Rd +++ b/man/print.RtGam.Rd @@ -8,6 +8,9 @@ } \arguments{ \item{x}{Fitted model object of class RtGam} + +\item{...}{further arguments to be passed to or from other methods. They are +ignored in this function.} } \value{ The RtGam object, invisibly @@ -15,3 +18,15 @@ The RtGam object, invisibly \description{ Print an RtGam object } +\examples{ +withr::with_seed(12345, { + cases <- rpois(20, 10) +}) +reference_date <- seq.Date( + from = as.Date("2023-01-01"), + length.out = 20, + by = "day" +) +fit <- RtGam::RtGam(cases, reference_date) +print(fit) +} diff --git a/tests/testthat/test-RtGam.R b/tests/testthat/test-RtGam.R index e0191c8..112d722 100644 --- a/tests/testthat/test-RtGam.R +++ b/tests/testthat/test-RtGam.R @@ -1,10 +1,14 @@ test_that("RtGam parses inputs successfully", { timesteps <- 20 - cases <- 1:timesteps + withr::with_seed(12345, { + cases <- rpois(20, 10) + }) dates <- as.Date("2023-01-01") + 1:timesteps group <- NULL - expect_null(RtGam(cases, dates, group)) + fit <- RtGam(cases, dates, group) + + expect_s3_class(fit, "RtGam") }) test_that("Non-supported backends throw error", { From c8f61579ed31ba13cce685ac2612f23a911e99ee Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Wed, 28 Aug 2024 12:09:58 -0400 Subject: [PATCH 17/24] Tests for print and diagnostics --- .pre-commit-config.yaml | 2 + R/diagnostics.R | 3 +- R/print.R | 18 ++-- tests/testthat/_snaps/print.md | 60 +++++++++++ tests/testthat/test-RtGam.R | 12 +++ tests/testthat/test-diagnostics.R | 168 ++++++++++++++++++++++++++++++ tests/testthat/test-print.R | 47 +++++++++ 7 files changed, 302 insertions(+), 8 deletions(-) create mode 100644 tests/testthat/_snaps/print.md create mode 100644 tests/testthat/test-diagnostics.R create mode 100644 tests/testthat/test-print.R diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 237d474..e0db795 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,6 +14,7 @@ repos: - id: parsable-R - id: no-browser-statement - id: no-print-statement + exclude: '^tests/testthat/test-print\.R$' - id: no-debug-statement - id: deps-in-desc - repo: https://github.com/pre-commit/pre-commit-hooks @@ -25,6 +26,7 @@ repos: files: '^\.Rbuildignore$' - id: end-of-file-fixer exclude: '\.Rd' + exclude: 'tests/testthat/_snaps/' - repo: https://github.com/pre-commit-ci/pre-commit-ci-config rev: v1.6.1 hooks: diff --git a/R/diagnostics.R b/R/diagnostics.R index 2b778a3..84d7f9c 100644 --- a/R/diagnostics.R +++ b/R/diagnostics.R @@ -35,6 +35,7 @@ calculate_diagnostics <- function(fit) { list( model_converged = converged, + k_prime = k_check[1], k_edf = k_check[2], k_index = k_check[3], k_p_value = k_check[4], @@ -54,7 +55,7 @@ issue_diagnostic_warnings <- function(diagnostics) { "x" = "Effective degrees of freedom is near the supplied upper bound", "!" = "Consider increasing {.arg k}", "*" = "Actual: {.val {round(diagnostics[['k_edf']], 3)}}", - "*" = "Upper bound: {.val {round(diagnostics[['k\\'']], 3)}}" + "*" = "Upper bound: {.val {diagnostics[['k_prime']]}}" )) } if (diagnostics[["k_p_value"]] < 0.05) { diff --git a/R/print.R b/R/print.R index 96af5ea..b3f0e18 100644 --- a/R/print.R +++ b/R/print.R @@ -17,7 +17,7 @@ format_for_return <- function(fit, m = m, backend = backend, formula = formula, - diagnostics + diagnostics = diagnostics ) structure(formatted, class = "RtGam") @@ -43,27 +43,31 @@ format_for_return <- function(fit, #' fit <- RtGam::RtGam(cases, reference_date) #' print(fit) print.RtGam <- function(x, ...) { + # Header cat("===============================\n") cat("Fitted RtGam model object (") cat(x$backend) cat(")\n") cat("===============================\n\n") + # Adaptive cat("Model type: ") if (x$m > 1) { cat("Adaptive (m = ") } else { cat("Non-adaptive (m = ") } - cat(x$m, ")\n") - cat("Specified maximum smoothing basis dimension: ", x$k, "\n") + cat(x$m) + cat(")\n") + + # Smoothing basis + cat("Specified maximum smoothing basis dimension: ") + cat(x$k, "\n") cat("Family:", x$model$family$family, "\n") cat("Link function:", x$model$family$link) - - # TODO estimated edf & diagnostics - cat("\n===============================\n") + # Data cat("\nObserved data points: ") cat(nrow(x$data)) @@ -71,7 +75,7 @@ print.RtGam <- function(x, ...) { cat(length(unique(x[["data"]][["reference_date"]]))) cat("\nDistinct groups: ") - if (rlang::is_null(x[["data"]][["group"]][[1]])) { + if (rlang::is_na(x[["data"]][["group"]][[1]])) { cat("1") } else { cat(length(unique(x[["data"]][["group"]]))) diff --git a/tests/testthat/_snaps/print.md b/tests/testthat/_snaps/print.md new file mode 100644 index 0000000..f9b3f86 --- /dev/null +++ b/tests/testthat/_snaps/print.md @@ -0,0 +1,60 @@ +# print method produces output for single group + + Code + print(mock_RtGam) + Output + =============================== + Fitted RtGam model object (MockBackend) + =============================== + + Model type: Adaptive (m = 2) + Specified maximum smoothing basis dimension: 5 + Family: poisson + Link function: log + =============================== + + Observed data points: 10 + Distinct reference dates: 10 + Distinct groups: 1 + + +# print method produces output for multiple groups + + Code + print(mock_RtGam) + Output + =============================== + Fitted RtGam model object (MockBackend) + =============================== + + Model type: Adaptive (m = 2) + Specified maximum smoothing basis dimension: 5 + Family: poisson + Link function: log + =============================== + + Observed data points: 10 + Distinct reference dates: 10 + Distinct groups: 2 + + +# print method produces output for non-adaptive + + Code + print(mock_RtGam) + Output + =============================== + Fitted RtGam model object (MockBackend) + =============================== + + Model type: Non-adaptive (m = 1) + Specified maximum smoothing basis dimension: 5 + Family: poisson + Link function: log + =============================== + + Observed data points: 10 + Distinct reference dates: 10 + Distinct groups: 2 + + diff --git a/tests/testthat/test-RtGam.R b/tests/testthat/test-RtGam.R index 112d722..9ca18e2 100644 --- a/tests/testthat/test-RtGam.R +++ b/tests/testthat/test-RtGam.R @@ -7,8 +7,20 @@ test_that("RtGam parses inputs successfully", { group <- NULL fit <- RtGam(cases, dates, group) + expected_slots <- c( + "model", + "data", + "min_date", + "max_data", + "k", + "m", + "backend", + "formula", + "diagnostics" + ) expect_s3_class(fit, "RtGam") + expect_equal(names(fit), expected_slots) }) test_that("Non-supported backends throw error", { diff --git a/tests/testthat/test-diagnostics.R b/tests/testthat/test-diagnostics.R new file mode 100644 index 0000000..0d209e7 --- /dev/null +++ b/tests/testthat/test-diagnostics.R @@ -0,0 +1,168 @@ +test_that("check_diagnostics() runs cleanly on happy path", { + withr::with_seed(12345, { + df <- data.frame( + x = 1:20, + y = rnorm(20, 1:20) + ) + model <- mgcv::gam( + formula = as.formula("y ~ 1 + s(x)"), + data = df + ) + }) + + fit <- list(model = model) + expected_diagnostics <- c( + "model_converged", + "k_prime", + "k_edf", + "k_index", + "k_p_value", + "k_to_edf_ratio", + "residual_autocorrelation" + ) + + expect_equal(names(check_diagnostics(fit)), expected_diagnostics) + expect_invisible(check_diagnostics(fit)) + expect_no_message(check_diagnostics(fit)) +}) + +test_that("check_diagnostics() runs throws warnings for a bad fit", { + fit <- RtGam::RtGam( + cases = c(1L, 2L, 3L), + reference_date = as.Date(c("2023-01-01", "2023-01-02", "2023-01-03")), + warn_for_diagnostic_failure = FALSE + ) + expected_diagnostics <- c( + "model_converged", + "k_prime", + "k_edf", + "k_index", + "k_p_value", + "k_to_edf_ratio", + "residual_autocorrelation" + ) + + suppressMessages(expect_condition( + check_diagnostics(fit, warn_for_diagnostic_failure = TRUE), + regexp = "Residual autocorrelation present" + )) +}) + +test_that("check_diagnostics() can be silenced", { + fit <- RtGam::RtGam( + cases = c(1L, 2L, 3L), + reference_date = as.Date(c("2023-01-01", "2023-01-02", "2023-01-03")), + warn_for_diagnostic_failure = FALSE + ) + expected_diagnostics <- c( + "model_converged", + "k_prime", + "k_edf", + "k_index", + "k_p_value", + "k_to_edf_ratio", + "residual_autocorrelation" + ) + + expect_no_condition( + check_diagnostics(fit, warn_for_diagnostic_failure = FALSE) + ) +}) + +test_that("calculate_diagnostics returns expected diagnostics ", { + withr::with_seed(12345, { + df <- data.frame( + x = 1:20, + y = rnorm(20, 1:20) + ) + fit <- mgcv::gam( + formula = as.formula("y ~ 1 + s(x)"), + data = df + ) + }) + + diagnostics <- calculate_diagnostics(fit) + expected_diagnostics <- c( + "model_converged", + "k_prime", + "k_edf", + "k_index", + "k_p_value", + "k_to_edf_ratio", + "residual_autocorrelation" + ) + + expect_true(inherits(diagnostics, "list")) + expect_equal(names(diagnostics), expected_diagnostics) + expect_true(diagnostics[["model_converged"]]) + expect_true(diagnostics[["k_index"]] > 1) + expect_true(diagnostics[["k_p_value"]] > 0.05) + expect_true(all(abs(diagnostics[["residual_autocorrelation"]]) < 0.5)) +}) + +test_that("Warnings are not issued for clean diagnostics", { + diagnostics <- list( + model_converged = TRUE, + k_edf = 2, + k_index = 1.2, + k_p_value = 0.81, + k_to_edf_ratio = 0.1, + residual_autocorrelation = c(-0.2, 0.0) + ) + + expect_no_condition(issue_diagnostic_warnings(diagnostics)) +}) + +test_that("Warnings are issued for diagnostic failures", { + no_convergence <- list( + model_converged = FALSE, + k_edf = 2, + k_index = 1.2, + k_p_value = 0.81, + k_to_edf_ratio = 0.1, + residual_autocorrelation = c(-0.2, 0.0) + ) + k_to_edf_ratio_high <- list( + model_converged = TRUE, + k_prime = 9, + k_edf = 2.0, + k_index = 1.2, + k_p_value = 0.81, + k_to_edf_ratio = 1.0, + residual_autocorrelation = c(-0.2, 0.0) + ) + k_p_value_low <- list( + model_converged = FALSE, + k_edf = 2, + k_index = 1.2, + k_p_value = 0.01, + k_to_edf_ratio = 0.1, + residual_autocorrelation = c(-0.2, 0.0) + ) + has_residual_autocorrelation <- list( + model_converged = FALSE, + k_edf = 2, + k_index = 1.2, + k_p_value = 0.81, + k_to_edf_ratio = 0.1, + residual_autocorrelation = c(0.8, 0.0) + ) + + expect_condition( + issue_diagnostic_warnings(no_convergence), + regexp = "Model failed to converge. Inference is not reliable." + ) + # Wrapping in suppressMessages to catch all of multi-line warnings + suppressMessages(expect_condition( + issue_diagnostic_warnings(k_to_edf_ratio_high), + regexp = "Effective degrees of freedom is near the supplied upper bound" + )) + suppressMessages(expect_condition( + issue_diagnostic_warnings(k_p_value_low), + regexp = "k-index for one or more smooths is below 1" + )) + suppressMessages(expect_condition( + issue_diagnostic_warnings(has_residual_autocorrelation), + regexp = "Residual autocorrelation present" + )) +}) diff --git a/tests/testthat/test-print.R b/tests/testthat/test-print.R new file mode 100644 index 0000000..c935718 --- /dev/null +++ b/tests/testthat/test-print.R @@ -0,0 +1,47 @@ +test_that("print method produces output for single group", { + mock_RtGam <- list( + backend = "MockBackend", + m = 2, + k = 5, + model = list(family = list(family = "poisson", link = "log")), + data = data.frame( + reference_date = as.Date("2020-01-01") + 1:10, + group = NA + ) + ) + class(mock_RtGam) <- "RtGam" + + expect_snapshot(print(mock_RtGam)) +}) + +test_that("print method produces output for multiple groups", { + mock_RtGam <- list( + backend = "MockBackend", + m = 2, + k = 5, + model = list(family = list(family = "poisson", link = "log")), + data = data.frame( + reference_date = as.Date("2020-01-01") + 1:10, + group = c(rep("a", 5), rep("b", 5)) + ) + ) + class(mock_RtGam) <- "RtGam" + + expect_snapshot(print(mock_RtGam)) +}) + +test_that("print method produces output for non-adaptive", { + mock_RtGam <- list( + backend = "MockBackend", + m = 1, + k = 5, + model = list(family = list(family = "poisson", link = "log")), + data = data.frame( + reference_date = as.Date("2020-01-01") + 1:10, + group = c(rep("a", 5), rep("b", 5)) + ) + ) + class(mock_RtGam) <- "RtGam" + + expect_snapshot(print(mock_RtGam)) +}) From dbfaf52ecd6438a5b2f044890127e77203b49a67 Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Wed, 28 Aug 2024 13:49:05 -0400 Subject: [PATCH 18/24] Document `check_diagnostics()` --- R/diagnostics.R | 38 +++++++++++++++++++++++++++++++------- man/check_diagnostics.Rd | 40 +++++++++++++++++++++++++++++++++------- 2 files changed, 64 insertions(+), 14 deletions(-) diff --git a/R/diagnostics.R b/R/diagnostics.R index 84d7f9c..da5cd57 100644 --- a/R/diagnostics.R +++ b/R/diagnostics.R @@ -1,13 +1,37 @@ -#' Check diagnostics from a fitted model +#' Check quantitative diagnostics from a fitted RtGam model #' -#' @param fit A fitted `RtGam` model object -#' @param warn_for_diagnostic_failure A bool, whether to issue warnings for -#' potential diagnostic issues. +#' Evaluates for convergence, effective degrees of freedom, and residual +#' autocorrelation. If `warn_for_diagnostic_failure` is set to TRUE, will issue +#' warnings when potential diagnostic issues are detected. The diagnostics are +#' invisibly returned as a list and also stored within the `diagnostics` element +#' of the provided model object. #' -#' @return Invisibly, a list of diagnostics. This diagnostic list is also -#' present in the model object under `diagnostics`. -#' @export +#' @param fit A fitted `RtGam` model object. This should be the result of +#' calling `RtGam::RtGam()` with appropriate data. +#' @param warn_for_diagnostic_failure A logical value indicating whether to +#' issue warnings if diagnostic checks suggest potential issues with the model +#' fit. Defaults to TRUE, meaning that warnings will be issued by default. +#' +#' @return Invisibly returns a list containing diagnostic results: +#' - `model_converged`: Logical indicating if the model has converged. +#' - `k_prime`: The maximum available number of degrees of freedom that could +#' be used in the GAM fit. +#' - `k_edf`: Estimated degrees of freedom actually used by the smooth terms +#' in the model. +#' - `k_index`: The ratio of the residual variance of differenced +#' near-neighbor residuals to the overall residual variance. This should be +#' near 1 or above. +#' - `k_p_value`: P-value for testing if k' is adequate for modeling the data. +#' - `k_to_edf_ratio`: Ratio of k' to effective degrees of freedom of the +#' smooth terms. k' should be well below the available edf. +#' - `residual_autocorrelation`: Autocorrelation coefficients for residuals +#' up to lag 7 or one-tenth of series length, whichever is smaller. #' +#' @export +#' @seealso [mgcv::k.check] for a description of the diagnostic tests, +#' [mgcv::choose.k] for a description of discussion of choosing the basis +#' dimension, and Wood, Simon N. Generalized additive models: an introduction +#' with R. chapman and hall/CRC, 2017. for a derivation of the metrics. #' @examples #' withr::with_seed(12345, { #' cases <- rpois(20, 10) diff --git a/man/check_diagnostics.Rd b/man/check_diagnostics.Rd index 6acc521..9b08fd1 100644 --- a/man/check_diagnostics.Rd +++ b/man/check_diagnostics.Rd @@ -2,22 +2,42 @@ % Please edit documentation in R/diagnostics.R \name{check_diagnostics} \alias{check_diagnostics} -\title{Check diagnostics from a fitted model} +\title{Check quantitative diagnostics from a fitted RtGam model} \usage{ check_diagnostics(fit, warn_for_diagnostic_failure = TRUE) } \arguments{ -\item{fit}{A fitted \code{RtGam} model object} +\item{fit}{A fitted \code{RtGam} model object. This should be the result of +calling \code{RtGam::RtGam()} with appropriate data.} -\item{warn_for_diagnostic_failure}{A bool, whether to issue warnings for -potential diagnostic issues.} +\item{warn_for_diagnostic_failure}{A logical value indicating whether to +issue warnings if diagnostic checks suggest potential issues with the model +fit. Defaults to TRUE, meaning that warnings will be issued by default.} } \value{ -Invisibly, a list of diagnostics. This diagnostic list is also -present in the model object under \code{diagnostics}. +Invisibly returns a list containing diagnostic results: +\itemize{ +\item \code{model_converged}: Logical indicating if the model has converged. +\item \code{k_prime}: The maximum available number of degrees of freedom that could +be used in the GAM fit. +\item \code{k_edf}: Estimated degrees of freedom actually used by the smooth terms +in the model. +\item \code{k_index}: The ratio of the residual variance of differenced +near-neighbor residuals to the overall residual variance. This should be +near 1 or above. +\item \code{k_p_value}: P-value for testing if k' is adequate for modeling the data. +\item \code{k_to_edf_ratio}: Ratio of k' to effective degrees of freedom of the +smooth terms. k' should be well below the available edf. +\item \code{residual_autocorrelation}: Autocorrelation coefficients for residuals +up to lag 7 or one-tenth of series length, whichever is smaller. +} } \description{ -Check diagnostics from a fitted model +Evaluates for convergence, effective degrees of freedom, and residual +autocorrelation. If \code{warn_for_diagnostic_failure} is set to TRUE, will issue +warnings when potential diagnostic issues are detected. The diagnostics are +invisibly returned as a list and also stored within the \code{diagnostics} element +of the provided model object. } \examples{ withr::with_seed(12345, { @@ -31,3 +51,9 @@ reference_date <- seq.Date( fit <- RtGam::RtGam(cases, reference_date) check_diagnostics(fit) } +\seealso{ +\link[mgcv:k.check]{mgcv::k.check} for a description of the diagnostic tests, +\link[mgcv:choose.k]{mgcv::choose.k} for a description of discussion of choosing the basis +dimension, and Wood, Simon N. Generalized additive models: an introduction +with R. chapman and hall/CRC, 2017. for a derivation of the metrics. +} From ba251e14573530d79468decc4a78c0d07e16416c Mon Sep 17 00:00:00 2001 From: Zachary Susswein <46581799+zsusswein@users.noreply.github.com> Date: Thu, 29 Aug 2024 09:19:17 -0400 Subject: [PATCH 19/24] Update R/RtGam.R Co-authored-by: Sam Abbott --- R/RtGam.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/RtGam.R b/R/RtGam.R index 36fdfbd..feaca3d 100644 --- a/R/RtGam.R +++ b/R/RtGam.R @@ -38,7 +38,7 @@ #' numerical error. Note that the `bam` backend uses the `discrete = TRUE` #' option for an additional speedup. See [mgcv::bam()] for more information. #' @param warn_for_diagnostic_failure Should warnings be issued for -#' automatically identified diagnostic issues? Defaults to true. A list of +#' automatically identified diagnostic issues? Defaults to TRUE. A list of #' quantitative model diagnostics can be inspected in the `diagnostics` slot #' of the returned `RtGam` object. #' @param ... Additional arguments passed to the specified modelling backend. From 037b2b615e0e4289f1d723feb759f287fc208e9b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 29 Aug 2024 13:20:20 +0000 Subject: [PATCH 20/24] pre-commit --- man/RtGam.Rd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/man/RtGam.Rd b/man/RtGam.Rd index b7a0a18..17c1ae8 100644 --- a/man/RtGam.Rd +++ b/man/RtGam.Rd @@ -47,7 +47,7 @@ numerical error. Note that the \code{bam} backend uses the \code{discrete = TRUE option for an additional speedup. See \code{\link[mgcv:bam]{mgcv::bam()}} for more information.} \item{warn_for_diagnostic_failure}{Should warnings be issued for -automatically identified diagnostic issues? Defaults to true. A list of +automatically identified diagnostic issues? Defaults to TRUE. A list of quantitative model diagnostics can be inspected in the \code{diagnostics} slot of the returned \code{RtGam} object.} From fe87c716908f889ac8fed8f9f9b9b088fdd1e1b1 Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Thu, 29 Aug 2024 09:25:43 -0400 Subject: [PATCH 21/24] Move `{withr}` to suggests --- DESCRIPTION | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index e4516ad..d340d3e 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -24,7 +24,8 @@ URL: https://github.com/cdcgov/cfa-gam-rt, BugReports: https://github.com/cdcgov/cfa-gam-rt/issues Suggests: testthat (>= 3.0.0), - pkgdown + pkgdown, + withr Config/testthat/edition: 3 Encoding: UTF-8 Roxygen: list(markdown = TRUE) @@ -33,5 +34,4 @@ Imports: cli, glue, mgcv, - rlang, - withr + rlang From 6a6db82bda9364b3b048123e4ad084acbfec2061 Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Thu, 29 Aug 2024 09:42:25 -0400 Subject: [PATCH 22/24] DRY diagnostic functionality By pointing the public function to the slot with the stored diagnostic list. Re-order `RtGam()` to use the public function for diagnostic reporting. --- R/RtGam.R | 9 +++++---- R/diagnostics.R | 2 +- tests/testthat/test-diagnostics.R | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/R/RtGam.R b/R/RtGam.R index feaca3d..e310ea3 100644 --- a/R/RtGam.R +++ b/R/RtGam.R @@ -96,11 +96,8 @@ RtGam <- function(cases, ) ) diagnostics <- calculate_diagnostics(fit) - if (warn_for_diagnostic_failure) { - issue_diagnostic_warnings(diagnostics) - } - format_for_return( + RtGam_object <- format_for_return( fit = fit, df = df, group = group, @@ -110,6 +107,10 @@ RtGam <- function(cases, formula = formula, diagnostics = diagnostics ) + + check_diagnostics(RtGam_object, warn_for_diagnostic_failure) + + return(RtGam_object) } #' Propose total smoothing basis dimension from number of data points diff --git a/R/diagnostics.R b/R/diagnostics.R index da5cd57..8a75f9e 100644 --- a/R/diagnostics.R +++ b/R/diagnostics.R @@ -44,7 +44,7 @@ #' fit <- RtGam::RtGam(cases, reference_date) #' check_diagnostics(fit) check_diagnostics <- function(fit, warn_for_diagnostic_failure = TRUE) { - diagnostics <- calculate_diagnostics(fit[["model"]]) + diagnostics <- fit[["diagnostics"]] if (warn_for_diagnostic_failure) { issue_diagnostic_warnings(diagnostics) } diff --git a/tests/testthat/test-diagnostics.R b/tests/testthat/test-diagnostics.R index 0d209e7..5d368c1 100644 --- a/tests/testthat/test-diagnostics.R +++ b/tests/testthat/test-diagnostics.R @@ -10,7 +10,7 @@ test_that("check_diagnostics() runs cleanly on happy path", { ) }) - fit <- list(model = model) + fit <- list(diagnostics = calculate_diagnostics(model)) expected_diagnostics <- c( "model_converged", "k_prime", From c0cc8d80aaf4d3186d26d86be8a3903d676d20e8 Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Thu, 29 Aug 2024 09:56:47 -0400 Subject: [PATCH 23/24] `pkg::func()` -> `func()` in `@examples` --- R/RtGam.R | 2 +- man/RtGam.Rd | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/RtGam.R b/R/RtGam.R index e310ea3..6f2ac40 100644 --- a/R/RtGam.R +++ b/R/RtGam.R @@ -60,7 +60,7 @@ #' length.out = 20, #' by = "day" #' ) -#' fit <- RtGam::RtGam(cases, reference_date) +#' fit <- RtGam(cases, reference_date) #' fit RtGam <- function(cases, reference_date, diff --git a/man/RtGam.Rd b/man/RtGam.Rd index 17c1ae8..f8f6ae8 100644 --- a/man/RtGam.Rd +++ b/man/RtGam.Rd @@ -81,7 +81,7 @@ reference_date <- seq.Date( length.out = 20, by = "day" ) -fit <- RtGam::RtGam(cases, reference_date) +fit <- RtGam(cases, reference_date) fit } \seealso{ From 4f7345953d2296a26251ad1c51d761dade0cb48c Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Thu, 29 Aug 2024 10:05:42 -0400 Subject: [PATCH 24/24] Rename `format_for_return()` -> `new_RtGam()` --- R/RtGam.R | 2 +- R/print.R | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/R/RtGam.R b/R/RtGam.R index 6f2ac40..33eec0c 100644 --- a/R/RtGam.R +++ b/R/RtGam.R @@ -97,7 +97,7 @@ RtGam <- function(cases, ) diagnostics <- calculate_diagnostics(fit) - RtGam_object <- format_for_return( + RtGam_object <- new_RtGam( fit = fit, df = df, group = group, diff --git a/R/print.R b/R/print.R index b3f0e18..a00e379 100644 --- a/R/print.R +++ b/R/print.R @@ -1,13 +1,13 @@ #' Format the RtGam object for return from the main function/constructor #' @noRd -format_for_return <- function(fit, - df, - group, - k, - m, - backend, - formula, - diagnostics) { +new_RtGam <- function(fit, + df, + group, + k, + m, + backend, + formula, + diagnostics) { formatted <- list( model = fit, data = df,