Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for missing NAs in estimate_infection() model #528

Merged
merged 18 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* The functions `get_dist`, `get_generation_time`, `get_incubation_period` have been deprecated and replaced with examples. By @sbfnk in #481 and reviewed by @seabbs.
* The utility function `update_list()` has been deprecated in favour of `utils::modifyList()` because it comes with an installation of R. By @jamesmbaazam in #491 and reviewed by @seabbs.
* The `fixed` argument to `dist_spec` has been deprecated and replaced by a `fix_dist()` function. By @sbfnk in #503 and reviewed by @seabbs.
* Updated `estimate_infections()` so that rather than imputing missing data, it now skips these data points in the likelihood. This is a breaking change as it alters the behaviour of the model when dates are missing from a time series but are known to be zero. We recommend that users check their results when updating to this version but expect this to in most cases improve performance. By @seabbs in #528 and reviewed by @sbfnk.

## Documentation

Expand Down
76 changes: 48 additions & 28 deletions R/create.R
Original file line number Diff line number Diff line change
@@ -1,27 +1,35 @@
#' Create Clean Reported Cases
#' @description `r lifecycle::badge("stable")`
#' Cleans a data frame of reported cases by replacing missing dates with 0
#' cases and applies an optional threshold at which point 0 cases are replaced
#' with a moving average of observed cases. See `zero_threshold` for details.
#' Filters leading zeros, completes dates, and applies an optional threshold at
#' which point 0 cases are replaced with a user supplied value (defaults to
#' `NA`).
#'
#' @param filter_leading_zeros Logical, defaults to TRUE. Should zeros at the
#' start of the time series be filtered out.
#'
#' @param zero_threshold `r lifecycle::badge("experimental")` Numeric defaults
#' to Inf. Indicates if detected zero cases are meaningful by using a threshold
#' number of cases based on the 7 day average. If the average is above this
#' threshold then the zero is replaced with the backwards looking rolling
#' average. If set to infinity then no changes are made.
#' number of cases based on the 7-day average. If the average is above this
#' threshold then the zero is replaced using `fill`.
#'
#' @param fill Numeric, defaults to NA. Value to use to replace NA values or
sbfnk marked this conversation as resolved.
Show resolved Hide resolved
#' zeroes that are flagged because the 7-day average is above the
#' `zero_threshold`. If the default NA is used then dates with NA values or with
#' 7-day averages above the `zero_threshold` will be skipped in model fitting.
#' If this is set to 0 then the only effect is to replace NA values with 0.
#'
#' @inheritParams estimate_infections
#' @importFrom data.table copy merge.data.table setorder setDT frollsum
#' @return A cleaned data frame of reported cases
#' @author Sam Abbott
#' @author Lloyd Chapman
#' @export
#' @examples
#' create_clean_reported_cases(example_confirmed, 7)
create_clean_reported_cases <- function(reported_cases, horizon,
filter_leading_zeros = TRUE,
zero_threshold = Inf) {
zero_threshold = Inf,
fill = NA_integer_) {
reported_cases <- data.table::setDT(reported_cases)
reported_cases_grid <- data.table::copy(reported_cases)[,
.(date = seq(min(date), max(date) + horizon, by = "days"))
Expand All @@ -35,35 +43,35 @@ create_clean_reported_cases <- function(reported_cases, horizon,
if (is.null(reported_cases$breakpoint)) {
reported_cases$breakpoint <- 0
}
reported_cases <- reported_cases[
is.na(confirm), confirm := 0][, .(date = date, confirm, breakpoint)
]
reported_cases <- reported_cases[is.na(breakpoint), breakpoint := 0]
reported_cases[is.na(breakpoint), breakpoint := 0]
reported_cases <- data.table::setorder(reported_cases, date)
## Filter out 0 reported cases from the beginning of the data
if (filter_leading_zeros) {
reported_cases <- reported_cases[order(date)][
,
cum_cases := cumsum(confirm)
][cum_cases > 0][, cum_cases := NULL]
date >= min(date[confirm[!is.na(confirm)] > 0])
]
}

# Calculate `average_7_day` which for rows with `confirm == 0`
# (the only instance where this is being used) equates to the 7-day
# right-aligned moving average at the previous data point.
reported_cases <-
reported_cases[
,
`:=`(average_7_day = (
data.table::frollsum(confirm, n = 8, na.rm = TRUE)
sbfnk marked this conversation as resolved.
Show resolved Hide resolved
) / 7
)
]
# Check case counts preceding zero case counts and set to 7 day average if
# average over last 7 days is greater than a threshold
if (!is.infinite(zero_threshold)) {
reported_cases <-
reported_cases[
,
`:=`(average_7 = (data.table::frollsum(confirm, n = 8)) / 7)
]
reported_cases <- reported_cases[
confirm == 0 & average_7 > zero_threshold,
confirm := as.integer(average_7)
][
,
"average_7" := NULL
confirm == 0 & average_7_day > zero_threshold,
confirm := NA_integer_
]
}
reported_cases[is.na(confirm), confirm := fill]
reported_cases[, "average_7_day" := NULL]
return(reported_cases)
}

Expand Down Expand Up @@ -429,14 +437,26 @@ create_obs_model <- function(obs = obs_opts(), dates) {
#' @author Sam Abbott
#' @author Sebastian Funk
#' @export
#' @examples
#' create_stan_data(
#' example_confirmed, 7, rt_opts(), gp_opts(), obs_opts(), 7,
#' backcalc_opts(), create_shifted_cases(example_confirmed, 7, 14, 7)
#' )
create_stan_data <- function(reported_cases, seeding_time,
rt, gp, obs, horizon,
backcalc, shifted_cases) {

cases <- reported_cases[(seeding_time + 1):(.N - horizon)]$confirm
cases <- reported_cases[(seeding_time + 1):(.N - horizon)]
cases[, lookup := seq_len(.N)]
complete_cases <- cases[!is.na(cases$confirm)]
cases_time <- complete_cases$lookup
complete_cases <- complete_cases$confirm
cases <- cases$confirm

data <- list(
cases = cases,
cases = complete_cases,
cases_time = cases_time,
lt = length(cases_time),
shifted_cases = shifted_cases,
t = length(reported_cases$date),
horizon = horizon,
Expand All @@ -455,7 +475,7 @@ create_stan_data <- function(reported_cases, seeding_time,
first_week <- data.table::data.table(
confirm = cases[seq_len(min(7, length(cases)))],
t = seq_len(min(7, length(cases)))
)
)[!is.na(confirm)]
data$prior_infections <- log(mean(first_week$confirm, na.rm = TRUE))
data$prior_infections <- ifelse(
is.na(data$prior_infections) || is.null(data$prior_infections),
Expand Down
2 changes: 1 addition & 1 deletion R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ estimate_infections <- function(reported_cases,
name = "EpiNow2.epinow.estimate_infections"
)
}
# Make sure there are no missing dates and order cases
# Order cases
reported_cases <- create_clean_reported_cases(
reported_cases, horizon,
filter_leading_zeros = filter_leading_zeros,
Expand Down
4 changes: 2 additions & 2 deletions R/utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -444,9 +444,9 @@ globalVariables(
"New confirmed cases by infection date", "Data", "R", "reference",
".SD", "day_of_week", "forecast_type", "measure", "numeric_estimate",
"point", "strat", "estimate", "breakpoint", "variable", "value.V1",
"central_lower", "central_upper", "mean_sd", "sd_sd", "average_7",
"central_lower", "central_upper", "mean_sd", "sd_sd", "average_7_day",
"..lowers", "..upper_CrI", "..uppers", "timing", "dataset", "last_confirm",
"report_date", "secondary", "id", "conv", "meanlog", "primary", "scaled",
"scaling", "sdlog"
"scaling", "sdlog", "lookup"
)
)
4 changes: 3 additions & 1 deletion inst/stan/data/observations.stan
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
int t; // unobserved time
int lt; // timepoints in the likelihood
int seeding_time; // time period used for seeding and not observed
int horizon; // forecast horizon
int future_time; // time in future for Rt
array[t - horizon - seeding_time] int<lower = 0> cases; // observed cases
array[lt] int<lower = 0> cases; // observed cases
array[lt] int cases_time; // time of observed cases
vector<lower = 0>[t] shifted_cases; // prior infections (for backcalculation)
5 changes: 3 additions & 2 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ model {
// observed reports from mean of reports (update likelihood)
if (likelihood) {
report_lp(
cases, obs_reports, rep_phi, phi_mean, phi_sd, model_type, obs_weight
cases, obs_reports[cases_time], rep_phi, phi_mean, phi_sd, model_type,
obs_weight
);
}
}
Expand Down Expand Up @@ -191,7 +192,7 @@ generated quantities {
// log likelihood of model
if (return_likelihood) {
log_lik = report_log_lik(
cases, obs_reports, rep_phi, model_type, obs_weight
cases, obs_reports[cases_time], rep_phi, model_type, obs_weight
);
}
}
23 changes: 16 additions & 7 deletions man/create_clean_reported_cases.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions man/create_stan_data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions man/epinow.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions man/estimate_infections.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 34 additions & 0 deletions tests/testthat/test-create_clean_reported_cases.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@

test_that("create_clean_reported_cases runs without errors", {
expect_no_error(create_clean_reported_cases(example_confirmed, 7))
})

test_that("create_clean_reported_cases returns a data table", {
result <- create_clean_reported_cases(example_confirmed, 7)
expect_s3_class(result, "data.table")
})

test_that("create_clean_reported_cases filters leading zeros correctly", {
# Modify example_confirmed to have leading zeros
modified_data <- example_confirmed
modified_data[1:3, "confirm"] <- 0

result <- create_clean_reported_cases(modified_data, 7)
# Check if the first row with non-zero cases is retained
expect_equal(
result$date[1], min(modified_data$date[modified_data$confirm > 0])
)
})

test_that("create_clean_reported_cases replaces zero cases correctly", {
# Modify example_confirmed to have zero cases that should be replaced
modified_data <- example_confirmed
modified_data$confirm[10:16] <- 0
threshold <- 10

result <- create_clean_reported_cases(
modified_data, 0, zero_threshold = threshold
)
# Check if zero cases within the threshold are replaced
expect_equal(sum(result$confirm == 0, na.rm = TRUE), 0)
})
Empty file.
8 changes: 8 additions & 0 deletions tests/testthat/test-estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ test_that("estimate_infections successfully returns estimates using default sett
test_estimate_infections(reported_cases)
})

test_that("estimate_infections successfully returns estimates when passed NA values", {
skip_on_cran()
reported_cases_na <- data.table::copy(reported_cases)
reported_cases_na[sample(1:30, 5), confirm := NA]
test_estimate_infections(reported_cases_na)
})


test_that("estimate_infections successfully returns estimates using no delays", {
skip_on_cran()
test_estimate_infections(reported_cases, delay = FALSE)
Expand Down