Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

distribution interface to dist_spec #504

Merged
merged 136 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
136 commits
Select commit Hold shift + click to select a range
cbb9cdb
add distribution functions
sbfnk Oct 14, 2023
78ea512
deprecate "empty" distribution
sbfnk Nov 10, 2023
92f185d
make sd S3
sbfnk Nov 13, 2023
2053969
only generate samples if any params aren't natural
sbfnk Nov 13, 2023
ab1b175
update stan model with new dist interface
sbfnk Nov 17, 2023
524ebc2
update lognormal parameters
sbfnk Nov 17, 2023
dc0ce34
return mean function to previous functionality
sbfnk Nov 17, 2023
bbec631
update data
sbfnk Nov 17, 2023
7dd25c9
deprecate dist_def functions
sbfnk Nov 21, 2023
f26136c
use natural parametrisations in dist_def functions
sbfnk Nov 21, 2023
52c51ae
deprecate dist_spec
sbfnk Nov 21, 2023
7c9edb0
extract_single_dist function
sbfnk Nov 21, 2023
7e423fe
update fix_dist to work with compsosite dists
sbfnk Nov 21, 2023
aec50d0
extract squash
sbfnk Nov 21, 2023
46a9c92
update parameters to extract
sbfnk Nov 21, 2023
890e6b5
specify lower bounds in function
sbfnk Nov 21, 2023
7a58063
pass lower bounds to stan model
sbfnk Nov 21, 2023
10acdb7
update sample/report functions
sbfnk Nov 21, 2023
e2f62de
max squash adjust report
sbfnk Nov 21, 2023
6c0ec90
update dist functions to new syntax
sbfnk Nov 21, 2023
0155317
re-create data
sbfnk Nov 21, 2023
0e2ef97
update get_dist to new syntax
sbfnk Nov 21, 2023
0518157
fully deprecate get fnuctions
sbfnk Nov 21, 2023
bd9ba78
create delay inits separately
sbfnk Nov 21, 2023
9c65d77
max squash again
sbfnk Nov 21, 2023
c72efe9
return correct dist in estimate_truncation
sbfnk Nov 21, 2023
37524bb
few more examples/docs
sbfnk Nov 21, 2023
cad0dd2
fix tests
sbfnk Nov 21, 2023
19b2a91
add documentation to dist interface
sbfnk Nov 21, 2023
48388c7
add input checks
sbfnk Nov 21, 2023
4439db7
sd function to work with composite dists
sbfnk Nov 21, 2023
df6d08a
warn when not using natural parameters
sbfnk Nov 21, 2023
f2db571
ensure bounds are respected in stan
sbfnk Nov 22, 2023
50b59b7
add empty distribution for legacy reasons
sbfnk Nov 22, 2023
9e53614
add checks to dist_skel
sbfnk Nov 22, 2023
93356c6
use lapply for parameters
sbfnk Nov 22, 2023
3718c80
don't calculate sd if length 1
sbfnk Nov 22, 2023
b70c06f
use uncertain reporting in example
sbfnk Nov 22, 2023
58ba10b
don't add one to sd
sbfnk Nov 22, 2023
ac3c87d
return correct parameters
sbfnk Nov 22, 2023
89acc25
dist_skel: calculate rate everywhere
sbfnk Nov 22, 2023
8eaf13e
update dist_skel examples
sbfnk Nov 22, 2023
b1d33e1
add missing man file
sbfnk Nov 22, 2023
44f9717
don't run internal examples
sbfnk Nov 22, 2023
f05c8cc
demote warning to message
sbfnk Nov 22, 2023
a785f4b
update syntax everywhere
sbfnk Nov 22, 2023
c6d62be
add news item
sbfnk Nov 22, 2023
c95a296
turn sd into an internal function
sbfnk Nov 22, 2023
b3e9001
fix distribution documentation
sbfnk Nov 22, 2023
d7fcc60
remove obselete default
sbfnk Nov 22, 2023
2fe6f02
spell checking
sbfnk Nov 22, 2023
6403ff8
use correct sd function
sbfnk Nov 22, 2023
286108a
linting
sbfnk Nov 22, 2023
5636b01
remove obsolete tests
sbfnk Nov 22, 2023
a53d800
loop over all parameters
sbfnk Nov 22, 2023
5c1fe7a
update touchstone arguments
sbfnk Nov 22, 2023
65c86fa
linting
sbfnk Nov 22, 2023
502ceb8
fix regex search/replace gone wrong
sbfnk Nov 22, 2023
7c72b0f
remove obsolete space
sbfnk Nov 26, 2023
88277c5
update strategy for estimating uncertainty
sbfnk Nov 26, 2023
cac13ff
update uncertain parameter transformations
sbfnk Nov 26, 2023
d91defd
add missing sd to parameter sampling
sbfnk Nov 27, 2023
e23ac3d
update / recompile vignettes
sbfnk Nov 27, 2023
7e79e6e
update var names
sbfnk Nov 27, 2023
30f5aed
rename argument in docs
sbfnk Nov 27, 2023
0970260
update man pages
sbfnk Nov 27, 2023
322cdd0
update test result
sbfnk Nov 27, 2023
a17e85d
add reviewer
sbfnk Nov 30, 2023
136e303
base scaling on variance, not sd
sbfnk Nov 30, 2023
f3eeba8
re-render vignettes
sbfnk Nov 30, 2023
5f7d162
full text capitalisation of distributions
sbfnk Dec 8, 2023
ac5b200
separate dist_spec from stan model
sbfnk Jan 11, 2024
5daa83b
adjust tests/code for new dist_spec set up
sbfnk Jan 11, 2024
d5e4ddf
re-create examples
sbfnk Jan 11, 2024
2ddb915
re-doc
sbfnk Jan 11, 2024
fff4c99
update tests
sbfnk Jan 11, 2024
64f8e21
new dist_spec in estimate_truncation example
sbfnk Jan 11, 2024
9a301ec
update get_seeding_time with updated dist_spec
sbfnk Jan 11, 2024
ee008db
estimate_truncation and seeding time tests
sbfnk Jan 11, 2024
ae59fdb
update truncation dist in estimate_truncation
sbfnk Jan 11, 2024
4d3e8d9
remove more uses of old dist_spec
sbfnk Jan 11, 2024
1187201
SD explicitly to zero for fixed
sbfnk Jan 11, 2024
5173586
give names
sbfnk Jan 11, 2024
96a1277
fix typo
sbfnk Jan 11, 2024
c92fb21
fix indent
sbfnk Jan 11, 2024
e03f365
fix another typo
sbfnk Jan 11, 2024
9c4f45a
squash bugs highlighted by tests
sbfnk Jan 11, 2024
52eeaf1
remove missing variable
sbfnk Jan 12, 2024
a475499
linting
sbfnk Jan 12, 2024
f4f43d4
add missing docs
sbfnk Jan 12, 2024
6ec6394
import transpose
sbfnk Jan 12, 2024
a35933a
ensure sd is positive
sbfnk Jan 12, 2024
8b95ccf
fix estimate_truncation example
sbfnk Jan 12, 2024
9a69221
make tolerance user-settable
sbfnk Jan 12, 2024
4c8626d
use purrr::map instead of lapply
sbfnk Jan 12, 2024
46e1f8c
fix stan dist test
sbfnk Jan 12, 2024
2becf78
fix plotting
sbfnk Jan 16, 2024
748a62c
Apply suggestions from code review
sbfnk Feb 20, 2024
57dce2b
rate and scale examples for Gamma
sbfnk Feb 20, 2024
7a7d29e
capitalise gamma and lognormal
sbfnk Feb 20, 2024
e01b874
change to single hash
sbfnk Feb 20, 2024
9374acd
use bar in normal_cdf
sbfnk Feb 20, 2024
2a3495f
remove estraneous backticks
sbfnk Feb 20, 2024
204feff
remove space before left parenthesis
sbfnk Feb 20, 2024
ca47727
split up dist.R
sbfnk Feb 21, 2024
4cb0108
move deprecated `dist_spec` function
sbfnk Feb 21, 2024
e8e6623
add examples
sbfnk Feb 21, 2024
3a5f45a
initial design sketch
sbfnk Feb 26, 2024
ca56d2c
make parameter conversion more flexible
sbfnk Feb 28, 2024
c6884e3
add test for alternative gama params
sbfnk Feb 28, 2024
81e8bdd
Merge remote-tracking branch 'origin/main' into dist-interface
sbfnk Feb 28, 2024
bb861dc
update syntax in simulate_infections
sbfnk Feb 28, 2024
52855bb
add missing tag
sbfnk Feb 28, 2024
29db428
update man pages
sbfnk Feb 28, 2024
14a2c17
update estimate_secondary tests
sbfnk Feb 28, 2024
df2253a
update simulate_infections for new interface
sbfnk Feb 28, 2024
a16c237
udpate snapshots
sbfnk Feb 28, 2024
19eb1eb
get_dist deprecation test with natural params
sbfnk Feb 28, 2024
12460c2
update phi syntax
sbfnk Feb 28, 2024
5b57b56
hide internal example
sbfnk Feb 28, 2024
ebe9032
Merge remote-tracking branch 'origin/main' into dist-interface
sbfnk Feb 28, 2024
68b0463
update deprecations
sbfnk Feb 28, 2024
9f8d86a
use toString
sbfnk Feb 28, 2024
3ced453
pmf -> NonParametric
sbfnk Feb 28, 2024
78bd559
Merge branch 'main' into dist-interface
seabbs Feb 28, 2024
a7729ed
Merge remote-tracking branch 'origin/main' into dist-interface
sbfnk Feb 29, 2024
a630d69
add american spelling
sbfnk Feb 29, 2024
b4c04ca
fix gamma deprecation
sbfnk Mar 4, 2024
0d88b4a
Merge branch 'main' into dist-interface
sbfnk Mar 4, 2024
c39b28e
add new functions to pkgdown
sbfnk Mar 6, 2024
ac2d854
Merge branch 'main' into dist-interface
sbfnk Mar 6, 2024
d1c3cf7
update vignette
sbfnk Mar 6, 2024
4ffd533
recompile vignettes
sbfnk Mar 6, 2024
75ce0c2
Merge branch 'main' into dist-interface
seabbs Mar 12, 2024
eac039f
Apply suggestions from code review
sbfnk Mar 12, 2024
33caa9c
link to design doc
sbfnk Mar 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Generated by roxygen2: do not edit by hand

S3method("+",dist_spec)
S3method(c,dist_spec)
S3method(max,dist_spec)
S3method(mean,dist_spec)
S3method(plot,dist_spec)
S3method(plot,epinow)
Expand All @@ -10,16 +12,23 @@ S3method(plot,estimate_truncation)
S3method(print,dist_spec)
S3method(summary,epinow)
S3method(summary,estimate_infections)
export(Fixed)
export(Gamma)
export(LogNormal)
export(NonParametric)
export(Normal)
export(R_to_growth)
export(add_day_of_week)
export(adjust_infection_to_report)
export(apply_tolerance)
export(backcalc_opts)
export(bootstrapped_dist_fit)
export(calc_CrI)
export(calc_CrIs)
export(calc_summary_measures)
export(calc_summary_stats)
export(clean_nowcasts)
export(collapse)
export(construct_output)
export(convert_to_logmean)
export(convert_to_logsd)
Expand All @@ -35,6 +44,8 @@ export(create_shifted_cases)
export(create_stan_args)
export(create_stan_data)
export(delay_opts)
export(discretise)
export(discretize)
export(dist_fit)
export(dist_skel)
export(dist_spec)
Expand Down Expand Up @@ -195,16 +206,19 @@ importFrom(posterior,mcse_mean)
importFrom(progressr,progressor)
importFrom(progressr,with_progress)
importFrom(purrr,compact)
importFrom(purrr,flatten)
importFrom(purrr,keep)
importFrom(purrr,list_transpose)
importFrom(purrr,map)
importFrom(purrr,map2_dbl)
importFrom(purrr,map_chr)
importFrom(purrr,map_dbl)
importFrom(purrr,map_dfc)
importFrom(purrr,pmap_dbl)
importFrom(purrr,quietly)
importFrom(purrr,reduce)
importFrom(purrr,safely)
importFrom(purrr,transpose)
importFrom(purrr,walk)
importFrom(rlang,abort)
importFrom(rlang,arg_match)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
* `simulate_infections` has been renamed to `forecast_infections` in line with `simulate_secondary` and `forecast_secondary`. The terminology is: a forecast is done from a fit to existing data, a simulation from first principles. By @sbfnk in #544 and reviewed by @seabbs.
* A new `simulate_infections` function has been added that can be used to simulate from the model from given initial conditions and parameters. By @sbfnk in #557 and reviewed by @jamesmbaazam.
* The function `init_cumulative_fit()` has been deprecated. By @jamesmbaazam in #541 and reviewed by @sbfnk.
* The interface to generating delay distributions has been completely overhauled. Instead of calling `dist_spec()` users now specify distributions using functions that represent the available distributions, i.e. `LogNormal()`, `Gamma()` and `Fixed()`. Uncertainty is specified using calls of the same nature, to `Normal()`. More information on the underlying design can be found in `inst/dev/design_dist.md` By @sbfnk in #504 and reviewed by @seabbs.

## Documentation

Expand Down
232 changes: 227 additions & 5 deletions R/adjust.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ adjust_infection_to_report <- function(infections, delay_defs,
# Reset DT Defaults on Exit
set_dt_single_thread()

## deprecated
sample_single_dist <- function(input, delay_def) {
## Define sample delay fn
sample_delay_fn <- function(n, ...) {
Expand All @@ -111,14 +112,50 @@ adjust_infection_to_report <- function(infections, delay_defs,
return(out)
}

report <- sample_single_dist(infections, delay_defs[[1]])

if (length(delay_defs) > 1) {
for (def in 2:length(delay_defs)) {
report <- sample_single_dist(report, delay_defs[[def]])
sample_dist_spec <- function(input, delay_def) {
## Define sample delay fn
sample_delay_fn <- function(n, dist, cum, ...) {
fixed_dist <- discretise(fix_dist(delay_def, strategy = "sample"))
if (dist) {
fixed_dist[[1]]$pmf[n + 1]
} else {
sample(seq_along(fixed_dist[[1]]$pmf) - 1, size = n, replace = TRUE)
}
}

## Infection to onset
out <- EpiNow2::sample_approx_dist(
cases = input,
dist_fn = sample_delay_fn,
max_value = max(delay_def),
direction = "forwards",
type = type,
truncate_future = FALSE
)

return(out)
}

if (is(delay_defs, "dist_spec")) {
report <- sample_dist_spec(infections, extract_single_dist(delay_defs, 1))
if (length(delay_defs) > 1) {
for (def in seq(2, length(delay_defs))) {
report <- sample_dist_spec(report, extract_single_dist(delay_defs, def))
}
}
} else {
deprecate_warn(
"1.5.0",
"adjust_infection_to_report(delay_defs = 'should be a dist_spec')",
details = "Specifying this as a list of data tables is deprecated."
)
report <- sample_single_dist(infections, delay_defs[[1]])
if (length(delay_defs) > 1) {
for (def in 2:length(delay_defs)) {
report <- sample_single_dist(report, delay_defs[[def]])
}
}
}
## Add a weekly reporting effect if present
if (!missing(reporting_effect)) {
reporting_effect <- data.table::data.table(
Expand Down Expand Up @@ -146,3 +183,188 @@ adjust_infection_to_report <- function(infections, delay_defs,
}
return(report)
}

#' Approximate Sampling a Distribution using Counts
#'
#' @description `r lifecycle::badge("soft-deprecated")`
#' Convolves cases by a PMF function. This function will soon be removed or
#' replaced with a more robust stan implementation.
#'
#' @param cases A `<data.frame>` of cases (in date order) with the following
#' variables: `date` and `cases`.
#'
#' @param max_value Numeric, maximum value to allow. Defaults to 120 days
#'
#' @param direction Character string, defato "backwards". Direction in which to
#' map cases. Supports either "backwards" or "forwards".
#'
#' @param dist_fn Function that takes two arguments with the first being
#' numeric and the second being logical (and defined as `dist`). Should return
#' the probability density or a sample from the defined distribution. See
#' the examples for more.
#'
#' @param earliest_allowed_mapped A character string representing a date
#' ("2020-01-01"). Indicates the earliest allowed mapped value.
#'
#' @param type Character string indicating the method to use to transform
#' counts. Supports either "sample" which approximates sampling or "median"
#' would shift by the median of the distribution.
#'
#' @param truncate_future Logical, should cases be truncated if they occur
#' after the first date reported in the data. Defaults to `TRUE`.
#'
#' @return A `<data.table>` of cases by date of onset
#' @export
#' @importFrom purrr map_dfc
#' @importFrom data.table data.table setorder
#' @importFrom lubridate days
#' @examples
#' \donttest{
#' cases <- example_confirmed
#' cases <- cases[, cases := as.integer(confirm)]
#' print(cases)
#'
#' # total cases
#' sum(cases$cases)
#'
#' delay_fn <- function(n, dist, cum) {
#' if (dist) {
#' pgamma(n + 0.9999, 2, 1) - pgamma(n - 1e-5, 2, 1)
#' } else {
#' as.integer(rgamma(n, 2, 1))
#' }
#' }
#'
#' onsets <- sample_approx_dist(
#' cases = cases,
#' dist_fn = delay_fn
#' )
#'
#' # estimated onset distribution
#' print(onsets)
#'
#' # check that sum is equal to reported cases
#' total_onsets <- median(
#' purrr::map_dbl(
#' 1:10,
#' ~ sum(sample_approx_dist(
#' cases = cases,
#' dist_fn = delay_fn
#' )$cases)
#' )
#' )
#' total_onsets
#'
#'
#' # map from onset cases to reported
#' reports <- sample_approx_dist(
#' cases = cases,
#' dist_fn = delay_fn,
#' direction = "forwards"
#' )
#'
#'
#' # map from onset cases to reported using a mean shift
#' reports <- sample_approx_dist(
#' cases = cases,
#' dist_fn = delay_fn,
#' direction = "forwards",
#' type = "median"
#' )
#' }
sample_approx_dist <- function(cases = NULL,
dist_fn = NULL,
max_value = 120,
earliest_allowed_mapped = NULL,
direction = "backwards",
type = "sample",
truncate_future = TRUE) {
if (type == "sample") {
if (direction == "backwards") {
direction_fn <- rev
} else if (direction == "forwards") {
direction_fn <- function(x) {
x
}
}
# reverse cases so starts with current first
reversed_cases <- direction_fn(cases$cases)
reversed_cases[is.na(reversed_cases)] <- 0
# draw from the density fn of the dist
draw <- dist_fn(0:max_value, dist = TRUE, cum = FALSE)

# approximate cases
mapped_cases <- do.call(cbind, purrr::map(
seq_along(reversed_cases),
~ c(
rep(0, . - 1),
stats::rbinom(
length(draw),
rep(reversed_cases[.], length(draw)),
draw
),
rep(0, length(reversed_cases) - .)
)
))


# set dates order based on direction mapping
if (direction == "backwards") {
dates <- seq(min(cases$date) - lubridate::days(length(draw) - 1),
max(cases$date),
by = "days"
)
} else if (direction == "forwards") {
dates <- seq(min(cases$date),
max(cases$date) + lubridate::days(length(draw) - 1),
by = "days"
)
}

# summarises movements and sample for placement of non-integer cases
case_sum <- direction_fn(rowSums(mapped_cases))
floor_case_sum <- floor(case_sum)
sample_cases <- floor_case_sum +
as.numeric((runif(seq_along(case_sum)) < (case_sum - floor_case_sum)))

# summarise imputed onsets and build output data.table
mapped_cases <- data.table::data.table(
date = dates,
cases = sample_cases
)

# filter out all zero cases until first recorded case
mapped_cases <- data.table::setorder(mapped_cases, date)
mapped_cases <- mapped_cases[
,
cum_cases := cumsum(cases)
][cum_cases != 0][, cum_cases := NULL]
} else if (type == "median") {
shift <- as.integer(
median(as.integer(dist_fn(1000, dist = FALSE)), na.rm = TRUE)
)

if (direction == "backwards") {
mapped_cases <- data.table::copy(cases)[
,
date := date - lubridate::days(shift)
]
} else if (direction == "forwards") {
mapped_cases <- data.table::copy(cases)[
,
date := date + lubridate::days(shift)
]
}
}

if (!is.null(earliest_allowed_mapped)) {
mapped_cases <- mapped_cases[date >= as.Date(earliest_allowed_mapped)]
}

# filter out future cases
if (direction == "forwards" && truncate_future) {
max_date <- max(cases$date)
mapped_cases <- mapped_cases[date <= max_date]
}
return(mapped_cases)
}
50 changes: 50 additions & 0 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,53 @@ check_reports_valid <- function(reports, model) {
assert_numeric(reports$confirm, lower = 0)
}
}

#' Validate probability distribution for passing to stan
#'
#' @description
#' `check_stan_delay()` checks that the supplied data is a `<dist_spec>`,
#' that it is a supported distribution, and that is has a finite maximum.
#'
#' @param dist A `dist_spec` object.`
#' @importFrom checkmate assert_class
#' @importFrom rlang arg_match
#' @return Called for its side effects.
#' @keywords internal
check_stan_delay <- function(dist) {
# Check that `dist` is a `dist_spec`
assert_class(dist, "dist_spec")
# Check that `dist` is lognormal or gamma or nonparametric
distributions <- vapply(dist, function(x) x$distribution, character(1))
if (
!all(distributions %in% c("lognormal", "gamma", "fixed", "nonparametric"))
) {
stop(
"Distributions passed to the model need to be lognormal, gamma, fixed ",
"or nonparametric."
)
}
# Check that `dist` has parameters that are either numeric or normal
# distributions with numeric parameters and infinite maximum
numeric_parameters <- vapply(dist$parameters, is.numeric, logical(1))
normal_parameters <- vapply(
dist$parameters,
function(x) {
is(x, "dist_spec") &&
x$distribution == "normal" &&
all(vapply(x$parameters, is.numeric, logical(1))) &&
is.infinite(x$max)
},
logical(1)
)
if (!all(numeric_parameters | normal_parameters)) {
stop(
"Delay distributions passed to the model need to have parameters that ",
"are either numeric or normally distributed with numeric parameters ",
"and infinite maximum."
)
}
# Check that `dist` has a finite maximum
if (any(is.infinite(max(dist)))) {
stop("All distribution passed to the model need to have a finite maximum")
}
}
Loading
Loading