Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fit EpiNow2 model #26

Merged
merged 11 commits into from
Sep 12, 2024
Merged
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Description: Add logging, metadata handling, and data handling
fitting hundreds of models in parallel.
License: Apache License (>= 2)
Encoding: UTF-8
Remotes: github::epiforecasts/EpiNow2@bcf297cf36a93cc56123bc3c9e8cebfb1421a962
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.2
Suggests:
Expand All @@ -28,6 +29,7 @@ Imports:
cli,
DBI,
duckdb,
EpiNow2 (>= 1.4.0),
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
rlang
URL: https://cdcgov.github.io/cfa-epinow2-pipeline/
Depends:
Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ export(apply_exclusions)
export(download_from_azure_blob)
export(fetch_blob_container)
export(fetch_credential_from_env_var)
export(fit_model)
export(format_delay_interval)
export(format_generation_interval)
export(format_right_truncation)
export(read_data)
export(read_disease_parameters)
export(read_exclusions)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# CFAEpiNow2Pipeline (development version)

* Fit EpiNow2 model using params and fixed seed
* Removed `.vscode` folder from repo
* Read and apply exclusions to case data
* Data reader and processor
Expand Down
181 changes: 181 additions & 0 deletions R/fit_model.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
#' Fit an EpiNow2 model
#'
#' @param data, in the format returned by [CFAEpiNow2Pipeline::read_data()]
#' @param parameters As returned from
#' [CFAEpiNow2Pipeline::read_disease_parameters()]
#' @param seed The random seed, used for both initialization by EpiNow2 in R and
#' sampling in Stan
#' @param horizon The number of days, as an integer, to forecast
#' @param priors A list of lists. The first level should contain the key `rt`
#' with elements `mean` and `sd` and the key `gp` with element `alpha_sd`.
#' @param sampler_opts A list. The Stan sampler options to be passed through
#' EpiNow2. It has required keys: `cores`, `chains`, `iter_warmup`,
#' `iter_sampling`, `max_treedepth`, and `adapt_delta`.
#'
#' @return A fitted model object of class `epinow` or, if model fitting fails,
#' an NA is returned with a warning
#' @export
fit_model <- function(
data,
parameters,
seed,
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
horizon,
priors,
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
sampler_opts) {
# Priors ------------------------------------------------------------------
rt <- EpiNow2::rt_opts(
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
list(
mean = priors[["rt"]][["mean"]],
sd = priors[["rt"]][["sd"]]
)
)
gp <- EpiNow2::gp_opts(
alpha_sd = priors[["gp"]][["alpha_sd"]]
)

# Distributions -----------------------------------------------------------
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
generation_time <- format_generation_interval(
parameters[["generation_interval"]]
)
delays <- format_delay_interval(
parameters[["delay_interval"]]
)
truncation <- format_right_truncation(
parameters[["right_truncation"]],
data
)

# Stan sampler ------------------------------------------------------------
stan <- EpiNow2::stan_opts(
cores = sampler_opts[["cores"]],
chains = sampler_opts[["chains"]],
# NOTE: seed gets used twice -- as the seed passed here to the Stan sampler
# and below as the R PRNG seed for EpiNow2 initialization
seed = seed,
warmup = sampler_opts[["iter_warmup"]],
samples = sampler_opts[["iter_samples"]],
control = list(
adapt_delta = sampler_opts[["adapt_delta"]],
max_treedepth = sampler_opts[["max_treedepth"]]
)
)

df <- data.frame(
confirm = data[["confirm"]],
date = as.Date(data[["reference_date"]])
)
rlang::try_fetch(
withr::with_seed(seed, {
EpiNow2::epinow(
df,
generation_time = generation_time,
delays = delays,
truncation = truncation,
horizon = horizon,
rt = rt,
gp = gp,
stan = stan,
verbose = interactive()
)
}),
# Downgrade model erroring out to a warning so we can catch and return
error = function(cnd) {
cli::cli_warn(
"Model fitting failed. Returning NA.",
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
parent = cnd,
class = "failing_fit"
)
NA
}
)
}

#' Format PMFs for EpiNow2
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
#'
#' Opinionated wrappers around EpiNow2::generation_time_opts(),
#' EpiNow2::delay_opts(), or EpiNow2::dist_spec() that formats the generation
#' interval, delay, or right truncation parameters as an object ready for input
#' to EpiNow2.
#'
#' Delays or right truncation are optional and can be skipped by passing `pmf =
#' NA`.
#'
#' @param pmf As returned by [CFAEpiNow2Pipeline::read_disease_parameters()]. A
#' PMF vector or an NA, if not applying the PMF to the model fit.
#'
#' @return An EpiNow2::*_opts() formatted object or NA with a message
#' @name opts_formatter
NULL

#' @rdname opts_formatter
#' @export
format_generation_interval <- function(pmf) {
if (
rlang::is_na(pmf) || rlang::is_null(pmf)
) {
cli::cli_abort("No generation time PMF specified but is required",
class = "Missing_GI"
)
}

suppressWarnings({
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
EpiNow2::generation_time_opts(
dist = EpiNow2::dist_spec(
pmf = pmf
)
)
})
}

#' @rdname opts_formatter
#' @export
format_delay_interval <- function(pmf) {
if (
rlang::is_na(pmf) || rlang::is_null(pmf)
) {
cli::cli_alert("Not adjusting for infection to case delay")
EpiNow2::delay_opts()
} else {
suppressWarnings({
EpiNow2::delay_opts(
dist = EpiNow2::dist_spec(
pmf = pmf
)
)
})
}
}

#' @inheritParams fit_model
#' @rdname opts_formatter
#' @export
format_right_truncation <- function(pmf, data) {
if (
rlang::is_na(pmf) || rlang::is_null(pmf)
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
) {
cli::cli_alert("Not adjusting for right truncation")
EpiNow2::trunc_opts()
} else if (length(pmf) > nrow(data)) {
# Nasty bug we ran into where **left-hand** side of the PMF was being
# silently removed if length of the PMF was longer than the data,
# effectively eliminating the right-truncation correction
athowes marked this conversation as resolved.
Show resolved Hide resolved

cli::cli_abort(
c(
"Right truncation PMF longer than the data",
"PMF length: {.val {length(pmf)}}",
"Data length: {.val {nrow(data)}}",
"PMF can only be up to length as the data"
),
class = "right_trunc_too_long"
)
} else {
suppressWarnings({
EpiNow2::trunc_opts(
dist = EpiNow2::dist_spec(
pmf = pmf
)
)
})
}
}
12 changes: 8 additions & 4 deletions man/apply_exclusions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 33 additions & 0 deletions man/fit_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 33 additions & 0 deletions man/opts_formatter.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 3 additions & 4 deletions man/read_exclusions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading