From 7f1acc682308651655b585b451435e00d9d89117 Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Mon, 23 Sep 2024 18:53:38 -0400 Subject: [PATCH 1/5] WIP: summary --- DESCRIPTION | 1 + R/summary.R | 60 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 R/summary.R diff --git a/DESCRIPTION b/DESCRIPTION index 99238e1..e8d26b5 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -35,6 +35,7 @@ RoxygenNote: 7.3.2 Imports: cli, glue, + gratia, mgcv, rlang Depends: diff --git a/R/summary.R b/R/summary.R new file mode 100644 index 0000000..54ac636 --- /dev/null +++ b/R/summary.R @@ -0,0 +1,60 @@ +summary.RtGam <- function( + object, + type, + horizon = 14, + min_date = NULL, + max_date = NULL, + ...) { + rlang::arg_match(type, + values = c( + "obs_cases", + "latent_cases", + "r", + "Rt" + ) + ) + if (rlang::is_null(min_date)) { + min_date <- object[["min_date"]] + } + if (rlang::is_null(max_date)) { + max_date <- object[["max_data"]] + } + + if (min_date >= max_date) { + rlang::abort(c( + "{.arg min_date} must be greater than {.arg max_date}", + "{.arg {min_date}}: {.val {min_date}}", + "{.arg max_date}: {.val {max_date}}" + )) + } + desired_dates <- seq.Date( + from = min_date, + to = max_date + horizon, + by = "day" + ) + timesteps <- dates_to_timesteps(desired_dates, + min_supplied_date = object[["min_date"]], + max_supplied_date = object[["max_data"]] + ) + + newdata <- data.frame( + timestep = timesteps, + .row = seq_along(timesteps), + reference_date = desired_dates + ) + + fitted <- gratia::posterior_samples(object[["model"]], + data = newdata, ... + ) + + merged <- merge(fitted, + newdata, + by = ".row" + ) + + data.frame( + reference_date = merged[["reference_date"]], + .response = merged[[".response"]], + .draw = merged[[".draw"]] + ) +} From cf5d77c46a83089cdd5f3dd83234925ad8752431 Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Wed, 2 Oct 2024 07:40:04 -0400 Subject: [PATCH 2/5] Broken --- R/summary.R | 159 ++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 137 insertions(+), 22 deletions(-) diff --git a/R/summary.R b/R/summary.R index 54ac636..0354c8d 100644 --- a/R/summary.R +++ b/R/summary.R @@ -1,23 +1,35 @@ -summary.RtGam <- function( +predict.RtGam <- function( object, type, - horizon = 14, + horizon = NULL, min_date = NULL, max_date = NULL, + n = 10, + mean_delay = NULL, + seed = 12345, ...) { rlang::arg_match(type, values = c( "obs_cases", - "latent_cases", + "incidence", "r", "Rt" ) ) + # If horizon, estimate from the last date to the forecast horizon + if (!rlang::is_null(horizon)) { + rlang::check_exclusive(max_date, horizon) + if (rlang::is_null(min_date)) { + min_date <- object[["max_date"]] + 1 + } + max_date <- object[["max_date"]] + horizon + } + # Else if nothing is specified, estimate period w/ obs data if (rlang::is_null(min_date)) { min_date <- object[["min_date"]] } if (rlang::is_null(max_date)) { - max_date <- object[["max_data"]] + max_date <- object[["max_date"]] } if (min_date >= max_date) { @@ -27,29 +39,132 @@ summary.RtGam <- function( "{.arg max_date}: {.val {max_date}}" )) } - desired_dates <- seq.Date( - from = min_date, - to = max_date + horizon, - by = "day" - ) - timesteps <- dates_to_timesteps(desired_dates, - min_supplied_date = object[["min_date"]], - max_supplied_date = object[["max_data"]] - ) - newdata <- data.frame( - timestep = timesteps, - .row = seq_along(timesteps), - reference_date = desired_dates - ) + if (type == "obs_cases") { + desired_dates <- seq.Date( + from = min_date, + to = max_date, + by = "day" + # TODO: does this handle doubles properly? How should they be handled? + ) - fitted <- gratia::posterior_samples(object[["model"]], - data = newdata, ... - ) + timesteps <- dates_to_timesteps(desired_dates, + min_supplied_date = object[["min_date"]], + max_supplied_date = object[["max_date"]] + ) + newdata <- data.frame( + timestep = timesteps, + .row = seq_along(timesteps), + reference_date = desired_dates + ) + + fitted <- gratia::posterior_samples( + object[["model"]], + data = newdata, + unconditional = TRUE, + ... + ) + print(head(fitted)) + } else if (type == "incidence") { + # I'm breaking it out because we're going to want to also remove day of week effect I think. + if (!rlang::is_bare_numeric(mean_delay)) { + cli::cli_abort("{.arg mean_delay} is required when {.arg type} is {.val obs_incidence}") + } + + # Extract incidence at mean-shifted dates + desired_dates <- seq.Date( + from = min_date, + to = max_date, + by = "day" + # TODO: does this handle doubles properly? How should they be handled? + ) + + # Shift desired dates _forward_ to get the corresponding date in fitted cases + mean_shifted_dates <- desired_dates + mean_delay + + + timesteps <- dates_to_timesteps(mean_shifted_dates, + min_supplied_date = object[["min_date"]], + max_supplied_date = object[["max_date"]] + ) + + # And associate the timestep with the incidence date, not the case date to shift back from cases to incidence + newdata <- data.frame( + timestep = timesteps, + .row = seq_along(timesteps), + reference_date = desired_dates + ) + + # Use `posterior_samples()` over `fitted_samples()` to get response w/ obs uncertainty + fitted <- gratia::posterior_samples( + object[["model"]], + data = newdata, + unconditional = TRUE, + ... + ) + } else if (type == "r") { + if (!rlang::is_bare_numeric(mean_delay)) { + cli::cli_abort("{.arg mean_delay} is required when {.arg type} is {.val obs_incidence}") + } + + # Extract incidence at mean-shifted dates + desired_dates <- seq.Date( + from = min_date, + to = max_date, + iy = "day" + ) + + # Shift desired dates _forward_ to get the corresponding date in fitted cases + mean_shifted_dates <- desired_dates + mean_delay + + + timesteps <- dates_to_timesteps(mean_shifted_dates, + min_supplied_date = object[["min_date"]], + max_supplied_date = object[["max_date"]] + ) + + # Construct timesteps as with incidence and shift by delta for discrete derivative + newdata <- data.frame( + timestep = timesteps, + reference_date = desired_dates + ) + delta <- 1e-9 + timesteps_shifted <- timesteps + delta + print(head(timesteps)) + print(head(timesteps_shifted)) + all_timesteps <- c(rbind(timesteps, timesteps_shifted)) + print(head(all_timesteps)) + # `rbind()` interleaves vectors, so we can difference them + ds <- gratia::data_slice(fit[["model"]], timestep = all_timesteps) + + all_timesteps <- gratia::fitted_samples( + object[["model"]], + data = ds, + n = n, + seed = seed, + unconditional = TRUE, + scale = "response", + ... + ) + is_orig_timestep <- all_timesteps[[".row"]] %% 2 == 1 + differenced_draws <- all_timesteps[[".fitted"]][is_orig_timestep] - all_timesteps[[".fitted"]][!is_orig_timestep] + print(head(differenced_draws)) + growth_rate <- differenced_draws[is_orig_timestep] / delta + print(head(growth_rate)) + fitted <- data.frame( + .response = growth_rate, + .row = 1:length(growth_rate), + .draw = all_timesteps[[".draw"]][is_orig_timestep], + timestep = timesteps + ) + print(head(fitted)) + } else { + cli::cli_abort("Not implemented") + } merged <- merge(fitted, newdata, - by = ".row" + by = "timestep" ) data.frame( From f4d85ce63fd5a534c92de332a0f498e7295b9f21 Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Sun, 13 Oct 2024 16:28:06 -0400 Subject: [PATCH 3/5] Save current status before deleting a bunch of stuff --- ' | 251 ++++++++++++++++++++++++++++++++++ R/checkers.R | 2 +- R/summary.R | 208 ++++++++++++++++++---------- tests/testthat/test-predict.R | 109 +++++++++++++++ 4 files changed, 499 insertions(+), 71 deletions(-) create mode 100644 ' create mode 100644 tests/testthat/test-predict.R diff --git a/' b/' new file mode 100644 index 0000000..d202fe8 --- /dev/null +++ b/' @@ -0,0 +1,251 @@ +#' Stub docs: Predict method +#' +#' Supports specifying the the dates for prediction four ways: +#' 1. Passing fit and not specifying any date arguments. In this +#' predictions are returned for the dates in between the minimum and maximum +#' dates passed to [RtGam()]. +#' 2. Specifying `horizon`. A forecast of `horizon` days is returned, starting +#' from last day passed to [RtGam()]. To extract a 7 day-ahead forecast, pass +#' `predict(RtGamFit, horizon = 7)`. +#' 3. Specifying `min_date` and `horizon`. Predictions are returned for the +#' period from `min_date` through `horizon` days after last day passed +#' to [RtGam()]. +#' 4. Specifying `min_date` and `max_date`. Predictions are returned for each day +#' between these two dates (inclusive). +#' +#' @export +predict.RtGam <- function( + object, + type, + horizon = NULL, + min_date = NULL, + max_date = NULL, + n = 10, + mean_delay = NULL, + gi_pmf, + seed = 12345, + ...) { + rlang::arg_match(type, + values = c( + "obs_cases", + "incidence", + "r", + "Rt" + ), + call = rlang::caller_env() + ) + desired_dates <- parse_predict_dates( + object = object, + min_date = min_date, + max_date = max_date, + horizon = horizon + ) + timesteps <- prep_timesteps_for_pred( + type = type, + min_date = object[["min_date"]], + max_date = object[["max_date"]] +) + + if (type == "obs_cases") { + desired_dates <- seq.Date( + from = min_date, + to = max_date, + by = "day" + # TODO: does this handle doubles properly? How should they be handled? + ) + + timesteps <- dates_to_timesteps(desired_dates, + min_supplied_date = object[["min_date"]], + max_supplied_date = object[["max_date"]] + ) + newdata <- data.frame( + timestep = timesteps, + .row = seq_along(timesteps), + reference_date = desired_dates + ) + + fitted <- gratia::posterior_samples( + object[["model"]], + data = newdata, + unconditional = TRUE, + ... + ) + print(head(fitted)) + } else if (type == "incidence") { + # I'm breaking it out because we're going to want to also remove day of week effect I think. + if (!rlang::is_bare_numeric(mean_delay)) { + cli::cli_abort("{.arg mean_delay} is required when {.arg type} is {.val obs_incidence}") + } + + # Extract incidence at mean-shifted dates + desired_dates <- seq.Date( + from = min_date, + to = max_date, + by = "day" + # TODO: does this handle doubles properly? How should they be handled? + ) + + # Shift desired dates _forward_ to get the corresponding date in fitted cases + mean_shifted_dates <- desired_dates + mean_delay + + + timesteps <- dates_to_timesteps(mean_shifted_dates, + min_supplied_date = object[["min_date"]], + max_supplied_date = object[["max_date"]] + ) + + # And associate the timestep with the incidence date, not the case date to shift back from cases to incidence + newdata <- data.frame( + timestep = timesteps, + .row = seq_along(timesteps), + reference_date = desired_dates + ) + + # Use `posterior_samples()` over `fitted_samples()` to get response w/ obs uncertainty + fitted <- gratia::posterior_samples( + object[["model"]], + data = newdata, + unconditional = TRUE, + ... + ) + } else if (type == "r") { + if (!rlang::is_bare_numeric(mean_delay)) { + cli::cli_abort("{.arg mean_delay} is required when {.arg type} is {.val obs_incidence}") + } + + # Extract incidence at mean-shifted dates + desired_dates <- seq.Date( + from = min_date, + to = max_date, + iy = "day" + ) + + # Shift desired dates _forward_ to get the corresponding date in fitted cases + mean_shifted_dates <- desired_dates + mean_delay + + + timesteps <- dates_to_timesteps(mean_shifted_dates, + min_supplied_date = object[["min_date"]], + max_supplied_date = object[["max_date"]] + ) + + # Construct timesteps as with incidence and shift by delta for discrete derivative + newdata <- data.frame( + timestep = timesteps, + reference_date = desired_dates + ) + delta <- 1e-9 + timesteps_shifted <- timesteps + delta + print(head(timesteps)) + print(head(timesteps_shifted)) + all_timesteps <- c(rbind(timesteps, timesteps_shifted)) + print(head(all_timesteps)) + # `rbind()` interleaves vectors, so we can difference them + ds <- gratia::data_slice(fit[["model"]], timestep = all_timesteps) + + all_timesteps <- gratia::fitted_samples( + object[["model"]], + data = ds, + n = n, + seed = seed, + unconditional = TRUE, + scale = "response", + ... + ) + is_orig_timestep <- all_timesteps[[".row"]] %% 2 == 1 + differenced_draws <- all_timesteps[[".fitted"]][is_orig_timestep] - all_timesteps[[".fitted"]][!is_orig_timestep] + print(head(differenced_draws)) + growth_rate <- differenced_draws[is_orig_timestep] / delta + print(head(growth_rate)) + fitted <- data.frame( + .response = growth_rate, + .row = 1:length(growth_rate), + .draw = all_timesteps[[".draw"]][is_orig_timestep], + timestep = timesteps + ) + print(head(fitted)) + } else { + cli::cli_abort("Not implemented") + } + + merged <- merge(fitted, + newdata, + by = "timestep" + ) + + data.frame( + reference_date = merged[["reference_date"]], + .response = merged[[".response"]], + .draw = merged[[".draw"]] + ) +} + +#' Convert from user specification to necessary date range +#' +#' @inheritParams predict.RtGam +#' @param call The calling environment to be reflected in the error message +#' +#' @return List with two elements: min_date and max_date +parse_predict_dates <- function( + object, + min_date = NULL, + max_date = NULL, + horizon = NULL, + call = rlang::caller_env() +) { + if (!rlang::is_null(min_date)) check_date(min_date, call = call) + if (!rlang::is_null(max_date)) check_date(max_date, call = call) + if (!rlang::is_null(horizon)) check_integer(horizon, call = call) + + # Handle horizon to estimate dates if provided + if (!is.null(horizon)) { + rlang::check_exclusive(max_date, horizon, call = call) + min_date <- min_date %||% (object[["max_date"]] + 1) + max_date <- object[["max_date"]] + horizon + 1 + } else { + # Default to object's date range if not specified + min_date <- min_date %||% object[["min_date"]] + max_date <- max_date %||% object[["max_date"]] + } + + # Ensure min_date is before max_date + if (min_date >= max_date) { + cli::cli_warn(c( + "Swapping {.arg min_date} and {.arg max_date}", + "{.arg min_date} {.val {min_date}} is after {.arg max_date} {.val {max_date}}" + ), + class = "RtGam_predict_dates_backward", + call = call + ) + temp_var <- max_date + max_date <- min_date + min_date <- temp_var + } + + list(min_date = as.Date(min_date), max_date = as.Date(max_date)) +} + +#' Convert from user-specified dates to internal timesteps +#' +#' @inheritParams predict.RtGam +#' @return Double vector, the timesteps to predict +prep_timesteps_for_pred <- function( +type, +min_date, +max_date, +mean_delay +call = rlang::caller_env() +) { + + if (type == "incidence" || type == "growth_rate") { + # Shift cases up by mean delay to get projected incidence on day + min_date = min_date + mean_delay + max_date = max_date + mean_delay + } else if (type == "Rt") + # Shift up by mean delay to move to incidence scale and also pad by the + # GI on either side to prevent missing dates in the convolution + min_date = min_date + mean_delay - length(gi_pmf) + max_date = max_date + mean_delay + length(gi_pmf) + } + +} diff --git a/R/checkers.R b/R/checkers.R index e084b4c..8ec332e 100644 --- a/R/checkers.R +++ b/R/checkers.R @@ -131,7 +131,7 @@ check_no_missingness <- function(x, arg = "x", call = rlang::caller_env()) { check_elements_below_max <- function(x, arg, max, call = rlang::caller_env()) { # Greater than or equal to 0 or is NA - is_below_max <- all((x <= max) | is.na(x)) + is_below_max <- (x <= max) | is.na(x) if (!all(is_below_max)) { cli::cli_abort( c("{.arg {arg}} has elements larger than {.val {max}}", diff --git a/R/summary.R b/R/summary.R index 0354c8d..226daab 100644 --- a/R/summary.R +++ b/R/summary.R @@ -1,3 +1,19 @@ +#' Stub docs: Predict method +#' +#' Supports specifying the the dates for prediction four ways: +#' 1. Passing fit and not specifying any date arguments. In this +#' predictions are returned for the dates in between the minimum and maximum +#' dates passed to [RtGam()]. +#' 2. Specifying `horizon`. A forecast of `horizon` days is returned, starting +#' from last day passed to [RtGam()]. To extract a 7 day-ahead forecast, pass +#' `predict(RtGamFit, horizon = 7)`. +#' 3. Specifying `min_date` and `horizon`. Predictions are returned for the +#' period from `min_date` through `horizon` days after last day passed +#' to [RtGam()]. +#' 4. Specifying `min_date` and `max_date`. Predictions are returned for each day +#' between these two dates (inclusive). +#' +#' @export predict.RtGam <- function( object, type, @@ -6,6 +22,7 @@ predict.RtGam <- function( max_date = NULL, n = 10, mean_delay = NULL, + gi_pmf = NULL, seed = 12345, ...) { rlang::arg_match(type, @@ -14,44 +31,55 @@ predict.RtGam <- function( "incidence", "r", "Rt" - ) + ), + call = rlang::caller_env() ) - # If horizon, estimate from the last date to the forecast horizon - if (!rlang::is_null(horizon)) { - rlang::check_exclusive(max_date, horizon) - if (rlang::is_null(min_date)) { - min_date <- object[["max_date"]] + 1 + if (type != "obs_cases") { + if (rlang::is_null(mean_delay)) { + cli::cli_abort("{.arg mean_delay} is required when type is {.val {type}}") } - max_date <- object[["max_date"]] + horizon - } - # Else if nothing is specified, estimate period w/ obs data - if (rlang::is_null(min_date)) { - min_date <- object[["min_date"]] + check_integer(mean_delay, "gi_pmf") } - if (rlang::is_null(max_date)) { - max_date <- object[["max_date"]] + if (type == "Rt") { + if (rlang::is_null(gi_pmf)) { + cli::cli_abort("{.arg gi_pmf} is required when type is {.val Rt}") + } + check_vector(gi_pmf, "gi_pmf") + check_no_missingness(gi_pmf, "gi_pmf") + check_elements_above_min(gi_pmf, "gi_pmf", 0) + check_elements_below_max(gi_pmf, "gi_pmf", 1) + check_sums_to_one(gi_pmf, "gi_pmf") } - if (min_date >= max_date) { - rlang::abort(c( - "{.arg min_date} must be greater than {.arg max_date}", - "{.arg {min_date}}: {.val {min_date}}", - "{.arg max_date}: {.val {max_date}}" - )) - } + desired_dates <- parse_predict_dates( + object = object, + min_date = min_date, + max_date = max_date, + horizon = horizon + ) + timesteps <- prep_timesteps_for_pred( + type = type, + desired_min_date = desired_dates[["min_date"]], + desired_max_date = desired_dates[["max_date"]], + fit_min_date = object[["min_date"]], + fit_max_date = object[["max_date"]], + mean_delay = mean_delay + ) if (type == "obs_cases") { - desired_dates <- seq.Date( - from = min_date, - to = max_date, - by = "day" - # TODO: does this handle doubles properly? How should they be handled? + predict <- predict_obs_cases( + object, + desired_dates, + timesteps, + n = n, + seed = seed, + ... ) - timesteps <- dates_to_timesteps(desired_dates, - min_supplied_date = object[["min_date"]], - max_supplied_date = object[["max_date"]] - ) + } else { + cli::cli_abort("{.val {type}} not yet implemented}") + } + newdata <- data.frame( timestep = timesteps, .row = seq_along(timesteps), @@ -67,26 +95,6 @@ predict.RtGam <- function( print(head(fitted)) } else if (type == "incidence") { # I'm breaking it out because we're going to want to also remove day of week effect I think. - if (!rlang::is_bare_numeric(mean_delay)) { - cli::cli_abort("{.arg mean_delay} is required when {.arg type} is {.val obs_incidence}") - } - - # Extract incidence at mean-shifted dates - desired_dates <- seq.Date( - from = min_date, - to = max_date, - by = "day" - # TODO: does this handle doubles properly? How should they be handled? - ) - - # Shift desired dates _forward_ to get the corresponding date in fitted cases - mean_shifted_dates <- desired_dates + mean_delay - - - timesteps <- dates_to_timesteps(mean_shifted_dates, - min_supplied_date = object[["min_date"]], - max_supplied_date = object[["max_date"]] - ) # And associate the timestep with the incidence date, not the case date to shift back from cases to incidence newdata <- data.frame( @@ -103,27 +111,7 @@ predict.RtGam <- function( ... ) } else if (type == "r") { - if (!rlang::is_bare_numeric(mean_delay)) { - cli::cli_abort("{.arg mean_delay} is required when {.arg type} is {.val obs_incidence}") - } - - # Extract incidence at mean-shifted dates - desired_dates <- seq.Date( - from = min_date, - to = max_date, - iy = "day" - ) - - # Shift desired dates _forward_ to get the corresponding date in fitted cases - mean_shifted_dates <- desired_dates + mean_delay - - - timesteps <- dates_to_timesteps(mean_shifted_dates, - min_supplied_date = object[["min_date"]], - max_supplied_date = object[["max_date"]] - ) - - # Construct timesteps as with incidence and shift by delta for discrete derivative + # Construct timesteps as with incidence and shift by delta for discrete derivative newdata <- data.frame( timestep = timesteps, reference_date = desired_dates @@ -173,3 +161,83 @@ predict.RtGam <- function( .draw = merged[[".draw"]] ) } + +#' Convert from user specification to necessary date range +#' +#' @inheritParams predict.RtGam +#' @param call The calling environment to be reflected in the error message +#' +#' @return List with two elements: min_date and max_date +parse_predict_dates <- function( + object, + min_date = NULL, + max_date = NULL, + horizon = NULL, + call = rlang::caller_env()) { + if (!rlang::is_null(min_date)) check_date(min_date, call = call) + if (!rlang::is_null(max_date)) check_date(max_date, call = call) + if (!rlang::is_null(horizon)) check_integer(horizon, call = call) + + # Handle horizon to estimate dates if provided + if (!is.null(horizon)) { + rlang::check_exclusive(max_date, horizon, call = call) + min_date <- min_date %||% (object[["max_date"]] + 1) + max_date <- object[["max_date"]] + horizon + 1 + } else { + # Default to object's date range if not specified + min_date <- min_date %||% object[["min_date"]] + max_date <- max_date %||% object[["max_date"]] + } + + # Ensure min_date is before max_date + if (min_date >= max_date) { + cli::cli_warn( + c( + "Swapping {.arg min_date} and {.arg max_date}", + "{.arg min_date} {.val {min_date}} is after {.arg max_date} {.val {max_date}}" + ), + class = "RtGam_predict_dates_backward", + call = call + ) + temp_var <- max_date + max_date <- min_date + min_date <- temp_var + } + + list(min_date = as.Date(min_date), max_date = as.Date(max_date)) +} + +#' Convert from user-specified dates to internal timesteps +#' +#' @inheritParams predict.RtGam +#' @return Double vector, the timesteps to predict +prep_timesteps_for_pred <- function( + type, + fit_min_date, + fit_max_date, + desired_min_date, + desired_max_date, + mean_delay, + call = rlang::caller_env()) { + if (type == "incidence" || type == "growth_rate") { + # Shift cases up by mean delay to get projected incidence on day + desired_min_date <- desired_min_date + mean_delay + desired_max_date <- desired_max_date + mean_delay + } else if (type == "Rt") { + # Shift up by mean delay to move to incidence scale and also pad by the + # GI on either side to prevent missing dates in the convolution + desired_min_date <- desired_min_date + desired_mean_delay - length(gi_pmf) + desired_max_date <- desired_max_date + mean_delay + length(gi_pmf) + } + + dates <- seq.Date( + from = desired_min_date, + to = desired_max_date, + by = "day" + ) + dates_to_timesteps( + dates, + min_supplied_date = fit_min_date, + max_supplied_date = fit_max_date + ) +} diff --git a/tests/testthat/test-predict.R b/tests/testthat/test-predict.R new file mode 100644 index 0000000..1140022 --- /dev/null +++ b/tests/testthat/test-predict.R @@ -0,0 +1,109 @@ +test_that("Dates are parsed correctly", { + # Passing just `object` gives the same min and max dates + expected_object_dates <- list( + min_date = as.Date("2023-01-01"), + max_date = as.Date("2023-01-15") + ) + actual_object_dates <- parse_predict_dates( + object = expected_object_dates + ) + expect_equal( + actual_object_dates, + expected_object_dates + ) + + # Passing horizon gives you n day ahead forcast + object_dates <- list( + min_date = as.Date("2023-01-01"), + max_date = as.Date("2023-01-15") + ) + horizon <- 5 + expected_horizon_dates <- list( + min_date = object_dates$max_date + 1, + max_date = object_dates$max_date + 1 + horizon + ) + + actual_horizon_dates <- parse_predict_dates( + object = object_dates, + horizon = horizon + ) + expect_equal(actual_horizon_dates, expected_horizon_dates) + + # Horizon + min_date gives period from min_date to horizon + object_dates <- list( + min_date = as.Date("2023-01-01"), + max_date = as.Date("2023-01-15") + ) + min_date <- as.Date("2023-01-10") + horizon <- 3 + expected_min_horiz_dates <- list( + min_date = min_date, + max_date = object_dates$max_date + 1 + horizon + ) + actual_min_horiz_dates <- parse_predict_dates( + object = object_dates, + horizon = horizon, + min_date = min_date + ) + expect_equal(actual_min_horiz_dates, expected_min_horiz_dates) + + # Specifying min and max dates returns from min to max + object_dates <- list( + min_date = as.Date("2023-01-01"), + max_date = as.Date("2023-01-15") + ) + min_date <- as.Date("2023-01-02") + max_date <- as.Date("2023-01-20") + expected_min_max_dates <- list( + min_date = min_date, + max_date = max_date + ) + actual_min_max_dates <- parse_predict_dates( + object = object, + min_date = min_date, + max_date = max_date + ) + expect_equal(actual_min_max_dates, expected_min_max_dates) +}) + +test_that("Bad dates throw appropriate status messages", { + object <- list( + min_date = as.Date("2023-01-01"), + max_date = as.Date("2023-01-15") + ) + # Type errors + expect_error( + parse_predict_dates(object, max_date = "potato"), + class = "RtGam_type_error" + ) + expect_error( + parse_predict_dates(object, min_date = "potato"), + class = "RtGam_type_error" + ) + expect_error( + parse_predict_dates(object, horizon = "potato"), + class = "RtGam_type_error" + ) + + # Specifying both horizon and max date + expect_error( + parse_predict_dates(list(), horizon = 15, max_date = as.Date("2023-01-01")) + ) + + # min_date after max_date + min_date <- as.Date("2023-01-01") + max_date <- min_date + 1 + expect_warning( + actual <- parse_predict_dates( + list(), + # Swap min and max to invoke warning + min_date = max_date, + max_date = min_date + ), + class = "RtGam_predict_dates_backward" + ) + expect_equal( + actual, + list(min_date = min_date, max_date = max_date) + ) +}) From ca63ffa7b88850d5443fa4fbb93179310416b602 Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Mon, 14 Oct 2024 10:46:56 -0400 Subject: [PATCH 4/5] Pausing here for feedback --- ' | 251 ------------------------------------- R/{summary.R => predict.R} | 177 ++++++++++++-------------- 2 files changed, 81 insertions(+), 347 deletions(-) delete mode 100644 ' rename R/{summary.R => predict.R} (52%) diff --git a/' b/' deleted file mode 100644 index d202fe8..0000000 --- a/' +++ /dev/null @@ -1,251 +0,0 @@ -#' Stub docs: Predict method -#' -#' Supports specifying the the dates for prediction four ways: -#' 1. Passing fit and not specifying any date arguments. In this -#' predictions are returned for the dates in between the minimum and maximum -#' dates passed to [RtGam()]. -#' 2. Specifying `horizon`. A forecast of `horizon` days is returned, starting -#' from last day passed to [RtGam()]. To extract a 7 day-ahead forecast, pass -#' `predict(RtGamFit, horizon = 7)`. -#' 3. Specifying `min_date` and `horizon`. Predictions are returned for the -#' period from `min_date` through `horizon` days after last day passed -#' to [RtGam()]. -#' 4. Specifying `min_date` and `max_date`. Predictions are returned for each day -#' between these two dates (inclusive). -#' -#' @export -predict.RtGam <- function( - object, - type, - horizon = NULL, - min_date = NULL, - max_date = NULL, - n = 10, - mean_delay = NULL, - gi_pmf, - seed = 12345, - ...) { - rlang::arg_match(type, - values = c( - "obs_cases", - "incidence", - "r", - "Rt" - ), - call = rlang::caller_env() - ) - desired_dates <- parse_predict_dates( - object = object, - min_date = min_date, - max_date = max_date, - horizon = horizon - ) - timesteps <- prep_timesteps_for_pred( - type = type, - min_date = object[["min_date"]], - max_date = object[["max_date"]] -) - - if (type == "obs_cases") { - desired_dates <- seq.Date( - from = min_date, - to = max_date, - by = "day" - # TODO: does this handle doubles properly? How should they be handled? - ) - - timesteps <- dates_to_timesteps(desired_dates, - min_supplied_date = object[["min_date"]], - max_supplied_date = object[["max_date"]] - ) - newdata <- data.frame( - timestep = timesteps, - .row = seq_along(timesteps), - reference_date = desired_dates - ) - - fitted <- gratia::posterior_samples( - object[["model"]], - data = newdata, - unconditional = TRUE, - ... - ) - print(head(fitted)) - } else if (type == "incidence") { - # I'm breaking it out because we're going to want to also remove day of week effect I think. - if (!rlang::is_bare_numeric(mean_delay)) { - cli::cli_abort("{.arg mean_delay} is required when {.arg type} is {.val obs_incidence}") - } - - # Extract incidence at mean-shifted dates - desired_dates <- seq.Date( - from = min_date, - to = max_date, - by = "day" - # TODO: does this handle doubles properly? How should they be handled? - ) - - # Shift desired dates _forward_ to get the corresponding date in fitted cases - mean_shifted_dates <- desired_dates + mean_delay - - - timesteps <- dates_to_timesteps(mean_shifted_dates, - min_supplied_date = object[["min_date"]], - max_supplied_date = object[["max_date"]] - ) - - # And associate the timestep with the incidence date, not the case date to shift back from cases to incidence - newdata <- data.frame( - timestep = timesteps, - .row = seq_along(timesteps), - reference_date = desired_dates - ) - - # Use `posterior_samples()` over `fitted_samples()` to get response w/ obs uncertainty - fitted <- gratia::posterior_samples( - object[["model"]], - data = newdata, - unconditional = TRUE, - ... - ) - } else if (type == "r") { - if (!rlang::is_bare_numeric(mean_delay)) { - cli::cli_abort("{.arg mean_delay} is required when {.arg type} is {.val obs_incidence}") - } - - # Extract incidence at mean-shifted dates - desired_dates <- seq.Date( - from = min_date, - to = max_date, - iy = "day" - ) - - # Shift desired dates _forward_ to get the corresponding date in fitted cases - mean_shifted_dates <- desired_dates + mean_delay - - - timesteps <- dates_to_timesteps(mean_shifted_dates, - min_supplied_date = object[["min_date"]], - max_supplied_date = object[["max_date"]] - ) - - # Construct timesteps as with incidence and shift by delta for discrete derivative - newdata <- data.frame( - timestep = timesteps, - reference_date = desired_dates - ) - delta <- 1e-9 - timesteps_shifted <- timesteps + delta - print(head(timesteps)) - print(head(timesteps_shifted)) - all_timesteps <- c(rbind(timesteps, timesteps_shifted)) - print(head(all_timesteps)) - # `rbind()` interleaves vectors, so we can difference them - ds <- gratia::data_slice(fit[["model"]], timestep = all_timesteps) - - all_timesteps <- gratia::fitted_samples( - object[["model"]], - data = ds, - n = n, - seed = seed, - unconditional = TRUE, - scale = "response", - ... - ) - is_orig_timestep <- all_timesteps[[".row"]] %% 2 == 1 - differenced_draws <- all_timesteps[[".fitted"]][is_orig_timestep] - all_timesteps[[".fitted"]][!is_orig_timestep] - print(head(differenced_draws)) - growth_rate <- differenced_draws[is_orig_timestep] / delta - print(head(growth_rate)) - fitted <- data.frame( - .response = growth_rate, - .row = 1:length(growth_rate), - .draw = all_timesteps[[".draw"]][is_orig_timestep], - timestep = timesteps - ) - print(head(fitted)) - } else { - cli::cli_abort("Not implemented") - } - - merged <- merge(fitted, - newdata, - by = "timestep" - ) - - data.frame( - reference_date = merged[["reference_date"]], - .response = merged[[".response"]], - .draw = merged[[".draw"]] - ) -} - -#' Convert from user specification to necessary date range -#' -#' @inheritParams predict.RtGam -#' @param call The calling environment to be reflected in the error message -#' -#' @return List with two elements: min_date and max_date -parse_predict_dates <- function( - object, - min_date = NULL, - max_date = NULL, - horizon = NULL, - call = rlang::caller_env() -) { - if (!rlang::is_null(min_date)) check_date(min_date, call = call) - if (!rlang::is_null(max_date)) check_date(max_date, call = call) - if (!rlang::is_null(horizon)) check_integer(horizon, call = call) - - # Handle horizon to estimate dates if provided - if (!is.null(horizon)) { - rlang::check_exclusive(max_date, horizon, call = call) - min_date <- min_date %||% (object[["max_date"]] + 1) - max_date <- object[["max_date"]] + horizon + 1 - } else { - # Default to object's date range if not specified - min_date <- min_date %||% object[["min_date"]] - max_date <- max_date %||% object[["max_date"]] - } - - # Ensure min_date is before max_date - if (min_date >= max_date) { - cli::cli_warn(c( - "Swapping {.arg min_date} and {.arg max_date}", - "{.arg min_date} {.val {min_date}} is after {.arg max_date} {.val {max_date}}" - ), - class = "RtGam_predict_dates_backward", - call = call - ) - temp_var <- max_date - max_date <- min_date - min_date <- temp_var - } - - list(min_date = as.Date(min_date), max_date = as.Date(max_date)) -} - -#' Convert from user-specified dates to internal timesteps -#' -#' @inheritParams predict.RtGam -#' @return Double vector, the timesteps to predict -prep_timesteps_for_pred <- function( -type, -min_date, -max_date, -mean_delay -call = rlang::caller_env() -) { - - if (type == "incidence" || type == "growth_rate") { - # Shift cases up by mean delay to get projected incidence on day - min_date = min_date + mean_delay - max_date = max_date + mean_delay - } else if (type == "Rt") - # Shift up by mean delay to move to incidence scale and also pad by the - # GI on either side to prevent missing dates in the convolution - min_date = min_date + mean_delay - length(gi_pmf) - max_date = max_date + mean_delay + length(gi_pmf) - } - -} diff --git a/R/summary.R b/R/predict.R similarity index 52% rename from R/summary.R rename to R/predict.R index 226daab..e69b77e 100644 --- a/R/summary.R +++ b/R/predict.R @@ -1,17 +1,51 @@ -#' Stub docs: Predict method +#' Predict Method for RtGam Models #' -#' Supports specifying the the dates for prediction four ways: -#' 1. Passing fit and not specifying any date arguments. In this -#' predictions are returned for the dates in between the minimum and maximum -#' dates passed to [RtGam()]. -#' 2. Specifying `horizon`. A forecast of `horizon` days is returned, starting -#' from last day passed to [RtGam()]. To extract a 7 day-ahead forecast, pass -#' `predict(RtGamFit, horizon = 7)`. -#' 3. Specifying `min_date` and `horizon`. Predictions are returned for the -#' period from `min_date` through `horizon` days after last day passed -#' to [RtGam()]. -#' 4. Specifying `min_date` and `max_date`. Predictions are returned for each day -#' between these two dates (inclusive). +#' Generates predictions from an `RtGam` fit. Prediction dates can be specified +#' flexibly using multiple approaches. +#' +#' @param object An `RtGamFit` object from [RtGam()]. +#' @param type A string specifying the prediction type. Options are +#' `"obs_cases"`, `"incidence"`, `"r"`, and `"Rt"`. Currently, only +#' `"obs_cases"` is supported. Matching is enforced by [rlang::arg_match()]. +#' @param horizon Optional. Integer specifying forecast days from the last date +#' in the fit. For example, `horizon = 7` returns a 7-day forecast. +#' @param min_date Optional. A date-like object marking the start of the +#' prediction period. +#' @param max_date Optional. A date-like object marking the end of the +#' prediction period. +#' @param n Number of posterior samples to use. Default is 10. +#' @param mean_delay Optional. Numeric mean delay used in the prediction. +#' @param gi_pmf Optional. A vector representing the generation interval PMF. +#' @param seed Random seed for reproducibility. Default is 12345. +#' @param ... Additional arguments passed to lower-level functions. +#' +#' @details +#' Prediction dates can be set in four ways: +#' +#' 1. **Using Fit Object Alone**: Predictions span the full date range in the +#' original model fit. +#' 2. **Using `horizon`**: Forecasts extend `horizon` days from the fit’s last +#' date. +#' 3. **Using `min_date` and `horizon`**: Predictions start at `min_date` and +#' end `horizon` days after the fit’s last date. +#' 4. **Using `min_date` and `max_date`**: Predictions span all dates between +#' these two (inclusive). +#' +#' @return +#' A dataframe in [tidy format](https://www.jstatsoft.org/article/view/v059i10), +#' with each row representing a draw for a specific date: +#' +#' - `reference_date`: Date of the prediction. +#' - `.response`: Predicted value (e.g., observed cases). +#' - `.draw`: ID of the posterior draw. +#' +#' Example output: +#' ``` +#' reference_date .response .draw +#' 1 2023-01-01 18 1 +#' 2 2023-01-02 13 1 +#' 3 2023-01-03 21 1 +#' ``` #' #' @export predict.RtGam <- function( @@ -59,15 +93,15 @@ predict.RtGam <- function( ) timesteps <- prep_timesteps_for_pred( type = type, - desired_min_date = desired_dates[["min_date"]], - desired_max_date = desired_dates[["max_date"]], + desired_min_date = min(desired_dates), + desired_max_date = max(desired_dates), fit_min_date = object[["min_date"]], fit_max_date = object[["max_date"]], mean_delay = mean_delay ) if (type == "obs_cases") { - predict <- predict_obs_cases( + predict_obs_cases( object, desired_dates, timesteps, @@ -75,86 +109,35 @@ predict.RtGam <- function( seed = seed, ... ) - } else { cli::cli_abort("{.val {type}} not yet implemented}") } +} - newdata <- data.frame( - timestep = timesteps, - .row = seq_along(timesteps), - reference_date = desired_dates - ) - - fitted <- gratia::posterior_samples( - object[["model"]], - data = newdata, - unconditional = TRUE, - ... - ) - print(head(fitted)) - } else if (type == "incidence") { - # I'm breaking it out because we're going to want to also remove day of week effect I think. - - # And associate the timestep with the incidence date, not the case date to shift back from cases to incidence - newdata <- data.frame( - timestep = timesteps, - .row = seq_along(timesteps), - reference_date = desired_dates - ) - - # Use `posterior_samples()` over `fitted_samples()` to get response w/ obs uncertainty - fitted <- gratia::posterior_samples( - object[["model"]], - data = newdata, - unconditional = TRUE, - ... - ) - } else if (type == "r") { - # Construct timesteps as with incidence and shift by delta for discrete derivative - newdata <- data.frame( - timestep = timesteps, - reference_date = desired_dates - ) - delta <- 1e-9 - timesteps_shifted <- timesteps + delta - print(head(timesteps)) - print(head(timesteps_shifted)) - all_timesteps <- c(rbind(timesteps, timesteps_shifted)) - print(head(all_timesteps)) - # `rbind()` interleaves vectors, so we can difference them - ds <- gratia::data_slice(fit[["model"]], timestep = all_timesteps) +#' Posterior predicted cases +#' @noRd +predict_obs_cases <- function(object, desired_dates, timesteps, n, seed, ...) { + newdata <- data.frame( + timestep = timesteps, + .row = seq_along(timesteps), + reference_date = desired_dates + ) - all_timesteps <- gratia::fitted_samples( - object[["model"]], - data = ds, - n = n, - seed = seed, - unconditional = TRUE, - scale = "response", - ... - ) - is_orig_timestep <- all_timesteps[[".row"]] %% 2 == 1 - differenced_draws <- all_timesteps[[".fitted"]][is_orig_timestep] - all_timesteps[[".fitted"]][!is_orig_timestep] - print(head(differenced_draws)) - growth_rate <- differenced_draws[is_orig_timestep] / delta - print(head(growth_rate)) - fitted <- data.frame( - .response = growth_rate, - .row = 1:length(growth_rate), - .draw = all_timesteps[[".draw"]][is_orig_timestep], - timestep = timesteps - ) - print(head(fitted)) - } else { - cli::cli_abort("Not implemented") - } + # Use `posterior_samples()` over `fitted_samples()` to get response + # w/ obs uncertainty + fitted <- gratia::posterior_samples( + object[["model"]], + data = newdata, + unconditional = TRUE, + n = n, + seed = seed, + ... + ) merged <- merge(fitted, newdata, - by = "timestep" + by = ".row" ) - data.frame( reference_date = merged[["reference_date"]], .response = merged[[".response"]], @@ -168,6 +151,7 @@ predict.RtGam <- function( #' @param call The calling environment to be reflected in the error message #' #' @return List with two elements: min_date and max_date +#' @keyword internal parse_predict_dates <- function( object, min_date = NULL, @@ -191,20 +175,21 @@ parse_predict_dates <- function( # Ensure min_date is before max_date if (min_date >= max_date) { - cli::cli_warn( - c( - "Swapping {.arg min_date} and {.arg max_date}", - "{.arg min_date} {.val {min_date}} is after {.arg max_date} {.val {max_date}}" - ), - class = "RtGam_predict_dates_backward", - call = call - ) + cli::cli_alert_warning("Swapping {.arg min_date} and {.arg max_date}") + cli::cli_alert(c( + "{.arg min_date} {.val {min_date}} ", + "is after {.arg max_date} {.val {max_date}}" + )) temp_var <- max_date max_date <- min_date min_date <- temp_var } - list(min_date = as.Date(min_date), max_date = as.Date(max_date)) + seq.Date( + from = min_date, + to = max_date, + by = "day" + ) } #' Convert from user-specified dates to internal timesteps From 5611a4d5b44432c524c579762f7aacd434d79db8 Mon Sep 17 00:00:00 2001 From: Zachary Susswein Date: Tue, 15 Oct 2024 09:56:34 -0400 Subject: [PATCH 5/5] Fix missing rlang call --- NAMESPACE | 2 ++ R/predict.R | 9 +++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index b17c04c..e4fafe6 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -3,6 +3,7 @@ S3method(fit_model,RtGam_bam) S3method(fit_model,RtGam_gam) S3method(fit_model,default) +S3method(predict,RtGam) S3method(print,RtGam) export(RtGam) export(check_diagnostics) @@ -10,4 +11,5 @@ export(dataset_creator) export(dates_to_timesteps) export(penalty_dim_heuristic) export(smooth_dim_heuristic) +importFrom(rlang,"%||%") importFrom(rlang,abort) diff --git a/R/predict.R b/R/predict.R index e69b77e..a9e064a 100644 --- a/R/predict.R +++ b/R/predict.R @@ -152,6 +152,7 @@ predict_obs_cases <- function(object, desired_dates, timesteps, n, seed, ...) { #' #' @return List with two elements: min_date and max_date #' @keyword internal +#' @importFrom rlang %||% parse_predict_dates <- function( object, min_date = NULL, @@ -163,8 +164,12 @@ parse_predict_dates <- function( if (!rlang::is_null(horizon)) check_integer(horizon, call = call) # Handle horizon to estimate dates if provided - if (!is.null(horizon)) { - rlang::check_exclusive(max_date, horizon, call = call) + if (!rlang::is_null(horizon)) { + if (!rlang::is_null(max_date)) { + cli::cli_abort("Cannot specify both {.arg horizon} and {.arg max_date}", + call = call + ) + } min_date <- min_date %||% (object[["max_date"]] + 1) max_date <- object[["max_date"]] + horizon + 1 } else {