diff --git a/NEWS.md b/NEWS.md index fc0488e67..94aa9e32e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -41,6 +41,7 @@ * Replaced descriptions and plot labels to be more general and clearer. By @sbfnk in #621 and reviewed by @jamesmbaazam. * Argument choices have been moved into default arguments. By @sbfnk in #622 and reviewed by @seabbs. * `simulate_infections()` gained the argument `seeding_time` to change the seeding time. Additionally, the documentation was improved. By @sbfnk in #627 and reviewed by @jamesmbaazam. +* The model-specific `weigh_delay_priors` argument has been deprecated in favour of delay-specific prior weighting using `weight_priors`. See `generation_time_opts()`, `delay_opts()`, and `trunc_opts()`. By @sbfnk in #630 and reviewed by @jamesmbaazam. ## Model changes diff --git a/R/create.R b/R/create.R index d0fb3ed2f..78abe8d10 100644 --- a/R/create.R +++ b/R/create.R @@ -718,10 +718,12 @@ create_stan_args <- function(stan = stan_opts(), ##' Create delay variables for stan ##' ##' @param ... Named delay distributions. The names are assigned to IDs -##' @param weight Numeric, weight associated with delay priors; default: 1 +##' @param time_points Integer, the number of time points in the data; +##' determines weight associated with weighted delay priors; default: 1 ##' @return A list of variables as expected by the stan model ##' @importFrom purrr transpose map flatten -create_stan_delays <- function(..., weight = 1) { +create_stan_delays <- function(..., time_points = 1L) { + delays <- list(...) ## discretise delays <- map(list(...), discretise) ## convolve where appropriate @@ -739,23 +741,23 @@ create_stan_delays <- function(..., weight = 1) { ids[type_n > 0] <- seq_len(sum(type_n > 0)) names(ids) <- paste(names(type_n), "id", sep = "_") - delays <- flatten(delays) - parametric <- unname( - vapply(delays, function(x) x$distribution != "nonparametric", logical(1)) - ) - param_length <- unname(vapply(delays[parametric], function(x) { + flat_delays <- flatten(delays) + parametric <- unname(vapply( + flat_delays, function(x) x$distribution != "nonparametric", logical(1) + )) + param_length <- unname(vapply(flat_delays[parametric], function(x) { length(x$parameters) }, numeric(1))) - nonparam_length <- unname(vapply(delays[!parametric], function(x) { + nonparam_length <- unname(vapply(flat_delays[!parametric], function(x) { length(x$pmf) }, numeric(1))) distributions <- unname(as.character( - map(delays[parametric], ~ .x$distribution) + map(flat_delays[parametric], ~ .x$distribution) )) ## create stan object ret <- list( - n = length(delays), + n = length(flat_delays), n_p = sum(parametric), n_np = sum(!parametric), types = sum(type_n > 0), @@ -771,15 +773,15 @@ create_stan_delays <- function(..., weight = 1) { ret$types_groups <- array(c(0, cumsum(unname(type_n[type_n > 0]))) + 1) ret$params_mean <- array(unname(as.numeric( - map(flatten(map(delays[parametric], ~ .x$parameters)), mean) + map(flatten(map(flat_delays[parametric], ~ .x$parameters)), mean) ))) ret$params_sd <- array(unname(as.numeric( - map(flatten(map(delays[parametric], ~ .x$parameters)), sd_dist) + map(flatten(map(flat_delays[parametric], ~ .x$parameters)), sd_dist) ))) ret$max <- array(max_delay[parametric]) ret$np_pmf <- array(unname(as.numeric( - flatten(map(delays[!parametric], ~ .x$pmf)) + flatten(map(flat_delays[!parametric], ~ .x$pmf)) ))) ## get non zero length delay pmf lengths ret$np_pmf_groups <- array(c(0, cumsum(nonparam_length)) + 1) @@ -791,12 +793,16 @@ create_stan_delays <- function(..., weight = 1) { ret$params_length <- sum(param_length) ## set lower bounds ret$params_lower <- array(unname(as.numeric(flatten( - map(delays[parametric], function(x) { + map(flat_delays[parametric], function(x) { lower_bounds(x$distribution)[names(x$parameters)] }) )))) ## assign prior weights - ret$weight <- array(rep(weight, ret$n_p)) + weight_priors <- vapply( + delays[parametric], attr, "weight_prior", FUN.VALUE = logical(1) + ) + ret$weight <- array(rep(1, ret$n_p)) + ret$weight[weight_priors] <- time_points ## assign distribution ret$dist <- array(match(distributions, c("lognormal", "gamma")) - 1L) diff --git a/R/dist_spec.R b/R/dist_spec.R index df5e5fa66..8e8ae7362 100644 --- a/R/dist_spec.R +++ b/R/dist_spec.R @@ -454,7 +454,8 @@ discretise <- function(x, silent = TRUE) { } } }) - attr(ret, "class") <- c("dist_spec", "list") + ## preserve attributes + attributes(ret) <- attributes(x) return(ret) } #' @rdname discretise @@ -525,7 +526,7 @@ apply_tolerance <- function(x, tolerance) { if (!is(x, "dist_spec")) { stop("Can only apply tolerance to distributions in a .") } - x <- lapply(x, function(x) { + y <- lapply(x, function(x) { if (x$distribution == "nonparametric") { cmf <- cumsum(x$pmf) new_pmf <- x$pmf[c(TRUE, (1 - cmf[-length(cmf)]) >= tolerance)] @@ -536,8 +537,9 @@ apply_tolerance <- function(x, tolerance) { } }) - attr(x, "class") <- c("dist_spec", "list") - return(x) + ## preserve attributes + attributes(y) <- attributes(x) + return(y) } #' Prints the parameters of one or more delay distributions diff --git a/R/estimate_infections.R b/R/estimate_infections.R index e326c87b7..23ce1ecb1 100644 --- a/R/estimate_infections.R +++ b/R/estimate_infections.R @@ -203,9 +203,7 @@ estimate_infections <- function(reported_cases, gt = generation_time, delay = delays, trunc = truncation, - weight = ifelse( - weigh_delay_priors, data$t - data$seeding_time - data$horizon, 1 - ) + time_points = data$t - data$seeding_time - data$horizon )) # Set up default settings diff --git a/R/estimate_secondary.R b/R/estimate_secondary.R index 6213ca6a0..2af4b893d 100644 --- a/R/estimate_secondary.R +++ b/R/estimate_secondary.R @@ -140,7 +140,7 @@ estimate_secondary <- function(reports, meanlog = Normal(2.5, 0.5), sdlog = Normal(0.47, 0.25), max = 30 - ) + ), weight_prior = FALSE ), truncation = trunc_opts(), obs = obs_opts(), @@ -209,7 +209,7 @@ estimate_secondary <- function(reports, data <- c(data, create_stan_delays( delay = delays, trunc = truncation, - weight = ifelse(weigh_delay_priors, data$t, 1) + time_points = data$t )) # observation model data diff --git a/R/estimate_truncation.R b/R/estimate_truncation.R index 92c0f4aef..53f172e4d 100644 --- a/R/estimate_truncation.R +++ b/R/estimate_truncation.R @@ -47,12 +47,8 @@ #' @param model A compiled stan model to override the default model. May be #' useful for package developers or those developing extensions. #' -#' @param weigh_delay_priors Logical. If TRUE, all delay distribution priors -#' will be weighted by the number of observation data points, in doing so -#' approximately placing an independent prior at each time step and usually -#' preventing the posteriors from shifting. If FALSE (default), no weight will -#' be applied, i.e. delay distributions will be treated as a single -#' parameters. +#' @param weigh_delay_priors Deprecated; use the `weight_prior` option in +#' [trunc_opts()] instead. #' #' @param verbose Logical, should model fitting progress be returned. #' @@ -123,7 +119,15 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10, "estimate_truncation(stan)" ) } - # Validate inputs + if (!missing(weigh_delay_priors)) { + lifecycle::deprecate_warn( + "1.5.0", + "estimate_truncation(weigh_delay_priors)", + "trunc_opts(weight_prior)", + detail = "This argument will be removed completely in version 2.0.0" + ) + } + # Validate inputs walk(obs, check_reports_valid, model = "estimate_truncation") assert_class(truncation, "dist_spec") assert_class(model, "stanfit", null.ok = TRUE) @@ -233,7 +237,7 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10, data <- c(data, create_stan_delays( trunc = truncation, - weight = ifelse(weigh_delay_priors, data$t, 1) + time_points = data$t )) # initial conditions diff --git a/R/opts.R b/R/opts.R index 7059ee239..1555b3074 100644 --- a/R/opts.R +++ b/R/opts.R @@ -12,8 +12,13 @@ #' @param max deprecated; use `dist` instead #' @param fixed deprecated; use `dist` instead #' @param prior_weight deprecated; prior weights are now specified as a -#' model option. Use the `weigh_delay_priors` argument of -#' [estimate_infections()] instead. +#' model option. Use the `weight_prior` argument instead +#' @param weight_prior Logical; if TRUE (default), any priors given in `dist` +#' will be weighted by the number of observation data points, in doing so +#' approximately placing an independent prior at each time step and usually +#' preventing the posteriors from shifting. If FALSE, no weight will be +#' applied, i.e. any parameters in `dist` will be treated as a single +#' parameters. #' @inheritParams apply_tolerance #' @return A `` object summarising the input delay #' distributions. @@ -40,7 +45,8 @@ #' generation_time_opts(example_generation_time) generation_time_opts <- function(dist = Fixed(1), ..., disease, source, max = 14, fixed = FALSE, - prior_weight, tolerance = 0.001) { + prior_weight, tolerance = 0.001, + weight_prior = TRUE) { deprecated_options_given <- FALSE dot_options <- list(...) @@ -82,7 +88,7 @@ generation_time_opts <- function(dist = Fixed(1), ..., if (!missing(prior_weight)) { deprecate_warn( "1.4.0", "generation_time_opts(prior_weight)", - "estimate_infections(weigh_delay_prior)", + "generation_time_opts(weight_prior)", "This argument will be removed in version 2.0.0." ) } @@ -107,6 +113,7 @@ generation_time_opts <- function(dist = Fixed(1), ..., } check_stan_delay(dist) attr(dist, "tolerance") <- tolerance + attr(dist, "weight_prior") <- weight_prior attr(dist, "class") <- c("generation_time_opts", class(dist)) return(dist) } @@ -189,6 +196,7 @@ secondary_opts <- function(type = c("incidence", "prevalence"), ...) { #' @param ... deprecated; use `dist` instead #' @param fixed deprecated; use `dist` instead #' @inheritParams apply_tolerance +#' @inheritParams generation_time_opts #' @return A `` object summarising the input delay distributions. #' @seealso [convert_to_logmean()] [convert_to_logsd()] #' [bootstrapped_dist_fit()] [dist_spec()] @@ -207,7 +215,8 @@ secondary_opts <- function(type = c("incidence", "prevalence"), ...) { #' #' # Multiple delays (in this case twice the same) #' delay_opts(delay + delay) -delay_opts <- function(dist = Fixed(0), ..., fixed = FALSE, tolerance = 0.001) { +delay_opts <- function(dist = Fixed(0), ..., fixed = FALSE, tolerance = 0.001, + weight_prior = TRUE) { dot_options <- list(...) if (!is(dist, "dist_spec")) { ## could be old syntax if (is.list(dist)) { @@ -240,6 +249,7 @@ delay_opts <- function(dist = Fixed(0), ..., fixed = FALSE, tolerance = 0.001) { } check_stan_delay(dist) attr(dist, "tolerance") <- tolerance + attr(dist, "weight_prior") <- weight_prior attr(dist, "class") <- c("delay_opts", class(dist)) return(dist) } @@ -254,6 +264,12 @@ delay_opts <- function(dist = Fixed(0), ..., fixed = FALSE, tolerance = 0.001) { #' @param dist A delay distribution or series of delay distributions reflecting #' the truncation generated using [dist_spec()] or [estimate_truncation()]. #' Default is fixed distribution with maximum 0, i.e. no truncation +#' @param weight_prior Logical; if TRUE, the truncation prior will be weighted +#' by the number of observation data points, in doing so approximately placing +#' an independent prior at each time step and usually preventing the +#' posteriors from shifting. If FALSE (default), no weight will be applied, +#' i.e. the truncation distribution will be treated as a single parameter. +#' #' @inheritParams apply_tolerance #' @return A `` object summarising the input truncation #' distribution. @@ -267,7 +283,8 @@ delay_opts <- function(dist = Fixed(0), ..., fixed = FALSE, tolerance = 0.001) { #' #' # truncation dist #' trunc_opts(dist = LogNormal(mean = 3, sd = 2, max = 10)) -trunc_opts <- function(dist = Fixed(0), tolerance = 0.001) { +trunc_opts <- function(dist = Fixed(0), tolerance = 0.001, + weight_prior = FALSE) { if (!is(dist, "dist_spec")) { if (is.list(dist)) { dist <- do.call(dist_spec, dist) @@ -285,6 +302,7 @@ trunc_opts <- function(dist = Fixed(0), tolerance = 0.001) { } check_stan_delay(dist) attr(dist, "tolerance") <- tolerance + attr(dist, "weight_prior") <- weight_prior attr(dist, "class") <- c("trunc_opts", class(dist)) return(dist) } diff --git a/man/create_stan_delays.Rd b/man/create_stan_delays.Rd index ceec9a27d..35e017c08 100644 --- a/man/create_stan_delays.Rd +++ b/man/create_stan_delays.Rd @@ -4,12 +4,13 @@ \alias{create_stan_delays} \title{Create delay variables for stan} \usage{ -create_stan_delays(..., weight = 1) +create_stan_delays(..., time_points = 1L) } \arguments{ \item{...}{Named delay distributions. The names are assigned to IDs} -\item{weight}{Numeric, weight associated with delay priors; default: 1} +\item{time_points}{Integer, the number of time points in the data; +determines weight associated with weighted delay priors; default: 1} } \value{ A list of variables as expected by the stan model diff --git a/man/delay_opts.Rd b/man/delay_opts.Rd index 46ffb13a6..ea5c8dccd 100644 --- a/man/delay_opts.Rd +++ b/man/delay_opts.Rd @@ -4,7 +4,13 @@ \alias{delay_opts} \title{Delay Distribution Options} \usage{ -delay_opts(dist = Fixed(0), ..., fixed = FALSE, tolerance = 0.001) +delay_opts( + dist = Fixed(0), + ..., + fixed = FALSE, + tolerance = 0.001, + weight_prior = TRUE +) } \arguments{ \item{dist}{A delay distribution or series of delay distributions. Default is @@ -15,6 +21,13 @@ a fixed distribution with all mass at 0, i.e. no delay.} \item{fixed}{deprecated; use \code{dist} instead} \item{tolerance}{Numeric; the desired tolerance level.} + +\item{weight_prior}{Logical; if TRUE (default), any priors given in \code{dist} +will be weighted by the number of observation data points, in doing so +approximately placing an independent prior at each time step and usually +preventing the posteriors from shifting. If FALSE, no weight will be +applied, i.e. any parameters in \code{dist} will be treated as a single +parameters.} } \value{ A \verb{} object summarising the input delay distributions. diff --git a/man/estimate_secondary.Rd b/man/estimate_secondary.Rd index e8ecfd100..d83798cc2 100644 --- a/man/estimate_secondary.Rd +++ b/man/estimate_secondary.Rd @@ -8,7 +8,7 @@ estimate_secondary( reports, secondary = secondary_opts(), delays = delay_opts(LogNormal(meanlog = Normal(2.5, 0.5), sdlog = Normal(0.47, 0.25), - max = 30)), + max = 30), weight_prior = FALSE), truncation = trunc_opts(), obs = obs_opts(), stan = stan_opts(), diff --git a/man/estimate_truncation.Rd b/man/estimate_truncation.Rd index 884d280f7..e8da5d958 100644 --- a/man/estimate_truncation.Rd +++ b/man/estimate_truncation.Rd @@ -54,12 +54,8 @@ to Inf. Indicates if detected zero cases are meaningful by using a threshold number of cases based on the 7-day average. If the average is above this threshold then the zero is replaced using \code{fill}.} -\item{weigh_delay_priors}{Logical. If TRUE, all delay distribution priors -will be weighted by the number of observation data points, in doing so -approximately placing an independent prior at each time step and usually -preventing the posteriors from shifting. If FALSE (default), no weight will -be applied, i.e. delay distributions will be treated as a single -parameters.} +\item{weigh_delay_priors}{Deprecated; use the \code{weight_prior} option in +\code{\link[=trunc_opts]{trunc_opts()}} instead.} \item{verbose}{Logical, should model fitting progress be returned.} diff --git a/man/generation_time_opts.Rd b/man/generation_time_opts.Rd index 95f8b2e3f..a687024d6 100644 --- a/man/generation_time_opts.Rd +++ b/man/generation_time_opts.Rd @@ -12,7 +12,8 @@ generation_time_opts( max = 14, fixed = FALSE, prior_weight, - tolerance = 0.001 + tolerance = 0.001, + weight_prior = TRUE ) } \arguments{ @@ -30,10 +31,16 @@ distribution is given a fixed generation time of 1 will be assumed.} \item{fixed}{deprecated; use \code{dist} instead} \item{prior_weight}{deprecated; prior weights are now specified as a -model option. Use the \code{weigh_delay_priors} argument of -\code{\link[=estimate_infections]{estimate_infections()}} instead.} +model option. Use the \code{weight_prior} argument instead} \item{tolerance}{Numeric; the desired tolerance level.} + +\item{weight_prior}{Logical; if TRUE (default), any priors given in \code{dist} +will be weighted by the number of observation data points, in doing so +approximately placing an independent prior at each time step and usually +preventing the posteriors from shifting. If FALSE, no weight will be +applied, i.e. any parameters in \code{dist} will be treated as a single +parameters.} } \value{ A \verb{} object summarising the input delay diff --git a/man/trunc_opts.Rd b/man/trunc_opts.Rd index 1f9b0dedc..81ceb25fa 100644 --- a/man/trunc_opts.Rd +++ b/man/trunc_opts.Rd @@ -4,7 +4,7 @@ \alias{trunc_opts} \title{Truncation Distribution Options} \usage{ -trunc_opts(dist = Fixed(0), tolerance = 0.001) +trunc_opts(dist = Fixed(0), tolerance = 0.001, weight_prior = FALSE) } \arguments{ \item{dist}{A delay distribution or series of delay distributions reflecting @@ -12,6 +12,12 @@ the truncation generated using \code{\link[=dist_spec]{dist_spec()}} or \code{\l Default is fixed distribution with maximum 0, i.e. no truncation} \item{tolerance}{Numeric; the desired tolerance level.} + +\item{weight_prior}{Logical; if TRUE, the truncation prior will be weighted +by the number of observation data points, in doing so approximately placing +an independent prior at each time step and usually preventing the +posteriors from shifting. If FALSE (default), no weight will be applied, +i.e. the truncation distribution will be treated as a single parameter.} } \value{ A \verb{} object summarising the input truncation diff --git a/tests/testthat/_snaps/simulate-secondary.md b/tests/testthat/_snaps/simulate-secondary.md index 5fce9e7f7..deef060da 100644 --- a/tests/testthat/_snaps/simulate-secondary.md +++ b/tests/testthat/_snaps/simulate-secondary.md @@ -27,6 +27,6 @@ 126: 2020-06-26 315 127: 2020-06-27 157 128: 2020-06-28 237 - 129: 2020-06-29 259 + 129: 2020-06-29 260 130: 2020-06-30 234 diff --git a/tests/testthat/test-delays.R b/tests/testthat/test-delays.R index 046d483d4..2480166d9 100644 --- a/tests/testthat/test-delays.R +++ b/tests/testthat/test-delays.R @@ -6,7 +6,7 @@ test_stan_delays <- function(generation_time = generation_time_opts(Fixed(1)), generation_time = generation_time, delays = delays, truncation = truncation, - weight = 10 + time_points = 10 ) return(unlist(unname(data[params]))) } diff --git a/tests/testthat/test-epinow.R b/tests/testthat/test-epinow.R index dc644afbe..2e6dfb08d 100644 --- a/tests/testthat/test-epinow.R +++ b/tests/testthat/test-epinow.R @@ -125,7 +125,9 @@ test_that("epinow runs without error when saving to disk", { test_that("epinow can produce partial output as specified", { out <- suppressWarnings(epinow( reported_cases = reported_cases, - generation_time = generation_time_opts(example_generation_time), + generation_time = generation_time_opts( + example_generation_time, weight_prior = FALSE + ), delays = delay_opts(example_incubation_period + reporting_delay), stan = stan_opts( samples = 25, warmup = 25,