Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First draft of predict method #48

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ RoxygenNote: 7.3.2
Imports:
cli,
glue,
gratia,
mgcv,
rlang
Depends:
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
S3method(fit_model,RtGam_bam)
S3method(fit_model,RtGam_gam)
S3method(fit_model,default)
S3method(predict,RtGam)
S3method(print,RtGam)
export(RtGam)
export(check_diagnostics)
export(dataset_creator)
export(dates_to_timesteps)
export(penalty_dim_heuristic)
export(smooth_dim_heuristic)
importFrom(rlang,"%||%")
importFrom(rlang,abort)
2 changes: 1 addition & 1 deletion R/checkers.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ check_no_missingness <- function(x, arg = "x", call = rlang::caller_env()) {

check_elements_below_max <- function(x, arg, max, call = rlang::caller_env()) {
# Greater than or equal to 0 or is NA
is_below_max <- all((x <= max) | is.na(x))
is_below_max <- (x <= max) | is.na(x)
if (!all(is_below_max)) {
cli::cli_abort(
c("{.arg {arg}} has elements larger than {.val {max}}",
Expand Down
233 changes: 233 additions & 0 deletions R/predict.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
#' Predict Method for RtGam Models
#'
#' Generates predictions from an `RtGam` fit. Prediction dates can be specified
#' flexibly using multiple approaches.
#'
#' @param object An `RtGamFit` object from [RtGam()].
#' @param type A string specifying the prediction type. Options are
#' `"obs_cases"`, `"incidence"`, `"r"`, and `"Rt"`. Currently, only
#' `"obs_cases"` is supported. Matching is enforced by [rlang::arg_match()].
#' @param horizon Optional. Integer specifying forecast days from the last date
#' in the fit. For example, `horizon = 7` returns a 7-day forecast.
#' @param min_date Optional. A date-like object marking the start of the
#' prediction period.
#' @param max_date Optional. A date-like object marking the end of the
#' prediction period.
#' @param n Number of posterior samples to use. Default is 10.
#' @param mean_delay Optional. Numeric mean delay used in the prediction.
#' @param gi_pmf Optional. A vector representing the generation interval PMF.
#' @param seed Random seed for reproducibility. Default is 12345.
#' @param ... Additional arguments passed to lower-level functions.
#'
#' @details
#' Prediction dates can be set in four ways:
#'
#' 1. **Using Fit Object Alone**: Predictions span the full date range in the
#' original model fit.
#' 2. **Using `horizon`**: Forecasts extend `horizon` days from the fit’s last
#' date.
#' 3. **Using `min_date` and `horizon`**: Predictions start at `min_date` and
#' end `horizon` days after the fit’s last date.
#' 4. **Using `min_date` and `max_date`**: Predictions span all dates between
#' these two (inclusive).
#'
#' @return
#' A dataframe in [tidy format](https://www.jstatsoft.org/article/view/v059i10),
#' with each row representing a draw for a specific date:
#'
#' - `reference_date`: Date of the prediction.
#' - `.response`: Predicted value (e.g., observed cases).
#' - `.draw`: ID of the posterior draw.
#'
#' Example output:
#' ```
#' reference_date .response .draw
#' 1 2023-01-01 18 1
#' 2 2023-01-02 13 1
#' 3 2023-01-03 21 1
#' ```
#'
#' @export
predict.RtGam <- function(
object,
type,
horizon = NULL,
min_date = NULL,
max_date = NULL,
n = 10,
mean_delay = NULL,
gi_pmf = NULL,
seed = 12345,
...) {
rlang::arg_match(type,
values = c(
"obs_cases",
"incidence",
"r",
"Rt"
),
call = rlang::caller_env()
)
if (type != "obs_cases") {
if (rlang::is_null(mean_delay)) {
cli::cli_abort("{.arg mean_delay} is required when type is {.val {type}}")
}
check_integer(mean_delay, "gi_pmf")
}
if (type == "Rt") {
if (rlang::is_null(gi_pmf)) {
cli::cli_abort("{.arg gi_pmf} is required when type is {.val Rt}")
}
check_vector(gi_pmf, "gi_pmf")
check_no_missingness(gi_pmf, "gi_pmf")
check_elements_above_min(gi_pmf, "gi_pmf", 0)
check_elements_below_max(gi_pmf, "gi_pmf", 1)
check_sums_to_one(gi_pmf, "gi_pmf")
}

desired_dates <- parse_predict_dates(
object = object,
min_date = min_date,
max_date = max_date,
horizon = horizon
)
timesteps <- prep_timesteps_for_pred(
type = type,
desired_min_date = min(desired_dates),
desired_max_date = max(desired_dates),
fit_min_date = object[["min_date"]],
fit_max_date = object[["max_date"]],
mean_delay = mean_delay
)

if (type == "obs_cases") {
predict_obs_cases(
object,
desired_dates,
timesteps,
n = n,
seed = seed,
...
)
} else {
cli::cli_abort("{.val {type}} not yet implemented}")
}
}

#' Posterior predicted cases
#' @noRd
predict_obs_cases <- function(object, desired_dates, timesteps, n, seed, ...) {
newdata <- data.frame(
timestep = timesteps,
.row = seq_along(timesteps),
reference_date = desired_dates
)

# Use `posterior_samples()` over `fitted_samples()` to get response
# w/ obs uncertainty
fitted <- gratia::posterior_samples(
object[["model"]],
data = newdata,
unconditional = TRUE,
n = n,
seed = seed,
...
)

merged <- merge(fitted,
newdata,
by = ".row"
)
data.frame(
reference_date = merged[["reference_date"]],
.response = merged[[".response"]],
.draw = merged[[".draw"]]
)
}

#' Convert from user specification to necessary date range
#'
#' @inheritParams predict.RtGam
#' @param call The calling environment to be reflected in the error message
#'
#' @return List with two elements: min_date and max_date
#' @keyword internal
#' @importFrom rlang %||%
parse_predict_dates <- function(
object,
min_date = NULL,
max_date = NULL,
horizon = NULL,
call = rlang::caller_env()) {
if (!rlang::is_null(min_date)) check_date(min_date, call = call)
if (!rlang::is_null(max_date)) check_date(max_date, call = call)
if (!rlang::is_null(horizon)) check_integer(horizon, call = call)

# Handle horizon to estimate dates if provided
if (!rlang::is_null(horizon)) {
if (!rlang::is_null(max_date)) {
cli::cli_abort("Cannot specify both {.arg horizon} and {.arg max_date}",
call = call
)
}
min_date <- min_date %||% (object[["max_date"]] + 1)
max_date <- object[["max_date"]] + horizon + 1
} else {
# Default to object's date range if not specified
min_date <- min_date %||% object[["min_date"]]
max_date <- max_date %||% object[["max_date"]]
}

# Ensure min_date is before max_date
if (min_date >= max_date) {
cli::cli_alert_warning("Swapping {.arg min_date} and {.arg max_date}")
cli::cli_alert(c(
"{.arg min_date} {.val {min_date}} ",
"is after {.arg max_date} {.val {max_date}}"
))
temp_var <- max_date
max_date <- min_date
min_date <- temp_var
}

seq.Date(
from = min_date,
to = max_date,
by = "day"
)
}

#' Convert from user-specified dates to internal timesteps
#'
#' @inheritParams predict.RtGam
#' @return Double vector, the timesteps to predict
prep_timesteps_for_pred <- function(
type,
fit_min_date,
fit_max_date,
desired_min_date,
desired_max_date,
mean_delay,
call = rlang::caller_env()) {
if (type == "incidence" || type == "growth_rate") {
# Shift cases up by mean delay to get projected incidence on day
desired_min_date <- desired_min_date + mean_delay
desired_max_date <- desired_max_date + mean_delay
} else if (type == "Rt") {
# Shift up by mean delay to move to incidence scale and also pad by the
# GI on either side to prevent missing dates in the convolution
desired_min_date <- desired_min_date + desired_mean_delay - length(gi_pmf)
desired_max_date <- desired_max_date + mean_delay + length(gi_pmf)
}

dates <- seq.Date(
from = desired_min_date,
to = desired_max_date,
by = "day"
)
dates_to_timesteps(
dates,
min_supplied_date = fit_min_date,
max_supplied_date = fit_max_date
)
}
109 changes: 109 additions & 0 deletions tests/testthat/test-predict.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
test_that("Dates are parsed correctly", {
# Passing just `object` gives the same min and max dates
expected_object_dates <- list(
min_date = as.Date("2023-01-01"),
max_date = as.Date("2023-01-15")
)
actual_object_dates <- parse_predict_dates(
object = expected_object_dates
)
expect_equal(
actual_object_dates,
expected_object_dates
)

# Passing horizon gives you n day ahead forcast
object_dates <- list(
min_date = as.Date("2023-01-01"),
max_date = as.Date("2023-01-15")
)
horizon <- 5
expected_horizon_dates <- list(
min_date = object_dates$max_date + 1,
max_date = object_dates$max_date + 1 + horizon
)

actual_horizon_dates <- parse_predict_dates(
object = object_dates,
horizon = horizon
)
expect_equal(actual_horizon_dates, expected_horizon_dates)

# Horizon + min_date gives period from min_date to horizon
object_dates <- list(
min_date = as.Date("2023-01-01"),
max_date = as.Date("2023-01-15")
)
min_date <- as.Date("2023-01-10")
horizon <- 3
expected_min_horiz_dates <- list(
min_date = min_date,
max_date = object_dates$max_date + 1 + horizon
)
actual_min_horiz_dates <- parse_predict_dates(
object = object_dates,
horizon = horizon,
min_date = min_date
)
expect_equal(actual_min_horiz_dates, expected_min_horiz_dates)

# Specifying min and max dates returns from min to max
object_dates <- list(
min_date = as.Date("2023-01-01"),
max_date = as.Date("2023-01-15")
)
min_date <- as.Date("2023-01-02")
max_date <- as.Date("2023-01-20")
expected_min_max_dates <- list(
min_date = min_date,
max_date = max_date
)
actual_min_max_dates <- parse_predict_dates(
object = object,
min_date = min_date,
max_date = max_date
)
expect_equal(actual_min_max_dates, expected_min_max_dates)
})

test_that("Bad dates throw appropriate status messages", {
object <- list(
min_date = as.Date("2023-01-01"),
max_date = as.Date("2023-01-15")
)
# Type errors
expect_error(
parse_predict_dates(object, max_date = "potato"),
class = "RtGam_type_error"
)
expect_error(
parse_predict_dates(object, min_date = "potato"),
class = "RtGam_type_error"
)
expect_error(
parse_predict_dates(object, horizon = "potato"),
class = "RtGam_type_error"
)

# Specifying both horizon and max date
expect_error(
parse_predict_dates(list(), horizon = 15, max_date = as.Date("2023-01-01"))
)

# min_date after max_date
min_date <- as.Date("2023-01-01")
max_date <- min_date + 1
expect_warning(
actual <- parse_predict_dates(
list(),
# Swap min and max to invoke warning
min_date = max_date,
max_date = min_date
),
class = "RtGam_predict_dates_backward"
)
expect_equal(
actual,
list(min_date = min_date, max_date = max_date)
)
})
Loading