Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move package depends to suggest #798

Merged
merged 12 commits into from
Sep 30, 2024
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,12 @@ Imports:
cli,
data.table,
futile.logger (>= 1.4),
future,
future.apply,
ggplot2,
lifecycle,
lubridate,
methods,
patchwork,
posterior,
progressr,
purrr,
R.utils (>= 2.0.0),
Rcpp (>= 0.12.0),
Expand All @@ -126,9 +123,12 @@ Imports:
Suggests:
cmdstanr,
covr,
future,
future.apply,
here,
knitr,
precommit,
progressr,
rmarkdown,
spelling,
testthat,
Expand Down
6 changes: 0 additions & 6 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,6 @@ importFrom(futile.logger,flog.threshold)
importFrom(futile.logger,flog.trace)
importFrom(futile.logger,flog.warn)
importFrom(futile.logger,ftry)
importFrom(future,availableCores)
importFrom(future,plan)
importFrom(future,tweak)
importFrom(future.apply,future_lapply)
importFrom(ggplot2,.data)
importFrom(ggplot2,aes)
importFrom(ggplot2,coord_cartesian)
Expand Down Expand Up @@ -205,8 +201,6 @@ importFrom(lubridate,days)
importFrom(lubridate,wday)
importFrom(patchwork,plot_layout)
importFrom(posterior,mcse_mean)
importFrom(progressr,progressor)
importFrom(progressr,with_progress)
importFrom(purrr,compact)
importFrom(purrr,flatten)
importFrom(purrr,keep)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

- `fix_dist()` has been renamed to `fix_parameters()` because it removes the uncertainty in a distribution's parameters. By @sbfnk in #733 and reviewed by @jamesmbaazam.
- `plot.dist_spec` now uses color instead of line types to display pmfs vs cmfs. By @jamesmbaazam in #788 and reviewed by @sbfnk.
- The use of the `{progressr}` package for displaying progress bars is now optional, as is the use of `{future}` and `{future.apply}` for parallelisation. By @sbfnk in #798 and reviewed by @seabbs.

## Bug fixes

Expand Down
1 change: 0 additions & 1 deletion R/deprecated.R
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ lognorm_dist_def <- function(mean, mean_sd,
#' @inheritParams estimate_infections
#' @inheritParams adjust_infection_to_report
#' @importFrom data.table data.table rbindlist
#' @importFrom future.apply future_lapply
report_cases <- function(case_estimates,
case_forecast = NULL,
delays,
Expand Down
18 changes: 10 additions & 8 deletions R/estimate_delay.R
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ dist_fit <- function(values = NULL, samples = 1000, cores = 1,
#'
#' @return A `<dist_spec>` object summarising the bootstrapped distribution
#' @importFrom purrr list_transpose
#' @importFrom future.apply future_lapply
#' @importFrom rstan extract
#' @importFrom data.table data.table rbindlist
#' @importFrom cli cli_abort col_blue
Expand Down Expand Up @@ -199,7 +198,7 @@ bootstrapped_dist_fit <- function(values, dist = "lognormal",
dist_samples <- get_single_dist(values, samples = samples)
} else {
## Fit each sub sample
dist_samples <- future.apply::future_lapply(1:bootstraps,
dist_samples <- lapply_func(1:bootstraps,
function(boot) {
get_single_dist(
sample(values,
Expand All @@ -209,12 +208,15 @@ bootstrapped_dist_fit <- function(values, dist = "lognormal",
samples = ceiling(samples / bootstraps)
)
},
future.scheduling = Inf,
future.globals = c(
"values", "bootstraps", "samples",
"bootstrap_samples", "get_single_dist"
),
future.packages = "data.table", future.seed = TRUE
future.opts = list(
future.scheduling = Inf,
future.globals = c(
"values", "bootstraps", "samples",
"bootstrap_samples", "get_single_dist"
),
future.packages = "data.table",
future.seed = TRUE
)
)


Expand Down
6 changes: 2 additions & 4 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#'
#' @importFrom futile.logger flog.debug flog.info flog.error
#' @importFrom R.utils withTimeout
#' @importFrom future.apply future_lapply
#' @importFrom purrr compact
#' @importFrom rstan sflist2stanfit sampling
#' @importFrom rlang abort cnd_muffle
Expand Down Expand Up @@ -103,12 +102,11 @@ fit_model_with_nuts <- function(args, future = FALSE, max_execution_time = Inf,
chains <- args$chains
args$chains <- 1
args$cores <- 1
fits <- future.apply::future_lapply(1:chains,
fits <- lapply_func(1:chains,
fit_chain,
stan_args = args,
max_time = max_execution_time,
catch = TRUE,
future.seed = TRUE
catch = TRUE
)
if (stuck_chains > 0) {
fits[1:stuck_chains] <- NULL
Expand Down
30 changes: 18 additions & 12 deletions R/regional_epinow.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#'
#' Regions can be estimated in parallel using the `{future}` package (see
#' [setup_future()]). The progress of producing estimates across multiple
#' regions is tracked using the `{progressr}` package. Modify this behaviour
#' regions can be tracked using the `{progressr}` package. Modify this behaviour
#' using [progressr::handlers()] and enable it in batch by setting
#' `R_PROGRESSR_ENABLE=TRUE` as an environment variable.
#'
Expand Down Expand Up @@ -54,13 +54,11 @@
#' @export
#' @seealso [epinow()] [estimate_infections()] [setup_future()]
#' [regional_summary()]
#' @importFrom future.apply future_lapply
#' @importFrom data.table as.data.table setDT copy setorder
#' @importFrom purrr safely map compact keep
#' @importFrom futile.logger flog.info flog.warn flog.trace
#' @importFrom R.utils withTimeout
#' @importFrom rlang cnd_muffle
#' @importFrom progressr with_progress progressor
#' @examples
#' \donttest{
#' # set number of cores to use
Expand Down Expand Up @@ -161,9 +159,8 @@ regional_epinow <- function(data,
" function"
)

progressr::with_progress({
progress_fn <- progressr::progressor(along = regions)
regional_out <- future.apply::future_lapply(regions, run_region,
run_regions <- function(progress_fn = NULL) {
lapply_func(regions, run_region,
generation_time = generation_time,
delays = delays,
truncation = truncation,
Expand All @@ -186,10 +183,19 @@ regional_epinow <- function(data,
progress_fn = progress_fn,
verbose = verbose,
...,
future.scheduling = Inf,
future.seed = TRUE
future.opts = list(
future.scheduling = Inf,
future.seed = TRUE
)
)
})
}
if (requireNamespace("progressr", quietly = TRUE)) {
progressr::with_progress({
regional_out <- run_regions(progressr::progressor(along = regions))
})
} else {
regional_out <- run_regions()
}

out <- process_regions(regional_out, regions)
regional_out <- out$all
Expand Down Expand Up @@ -313,7 +319,7 @@ clean_regions <- function(data, non_zero_points) {
#'
#' @param target_region Character string indicating the region being evaluated
#' @param progress_fn Function as returned by [progressr::progressor()]. Allows
#' the use of a progress bar.
#' the use of a progress bar. If NULL (default), no progress bar is used.
#'
#' @param complete_logger Character string indicating the logger to output
#' the completion of estimation to.
Expand Down Expand Up @@ -341,7 +347,7 @@ run_region <- function(target_region,
output,
complete_logger,
verbose,
progress_fn,
progress_fn = NULL,
...) {
futile.logger::flog.info("Initialising estimates for: %s", target_region,
name = "EpiNow2.epinow"
Expand Down Expand Up @@ -390,7 +396,7 @@ run_region <- function(target_region,
complete_logger
)

if (!missing(progress_fn)) {
if (!is.null(progress_fn)) {
progress_fn(sprintf("Region: %s", target_region))
}
return(out)
Expand Down
14 changes: 12 additions & 2 deletions R/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ setup_default_logging <- function(logs = tempdir(check = TRUE),
#' A utility function that aims to streamline the set up
#' of the required future backend with sensible defaults for most users of
#' [regional_epinow()]. More advanced users are recommended to setup their own
#' `{future}` backend based on their available resources.
#' `{future}` backend based on their available resources. Running this requires
#' the `{future}` package to be installed.
#'
#' @param strategies A vector length 1 to 2 of strategies to pass to
#' [future::plan()]. Nesting of parallelisation is from the top level down.
Expand All @@ -136,7 +137,6 @@ setup_default_logging <- function(logs = tempdir(check = TRUE),
#'
#' @inheritParams regional_epinow
#' @importFrom futile.logger flog.error flog.info flog.debug
#' @importFrom future availableCores plan tweak
#' @importFrom cli cli_abort
#' @export
#' @return Numeric number of cores to use per worker. If greater than 1 pass to
Expand All @@ -145,6 +145,16 @@ setup_default_logging <- function(logs = tempdir(check = TRUE),
setup_future <- function(data,
strategies = c("multisession", "multisession"),
min_cores_per_worker = 4) {
if (!requireNamespace("future", quietly = TRUE)) {
futile.logger::flog.error(
"The future package is required for parallelisation"
)
cli_abort(
c(
"!" = "The future package is required for parallelisation."
)
)
}
if (length(strategies) > 2 || length(strategies) == 0) {
futile.logger::flog.error("1 or 2 strategies should be used")
cli_abort(
Expand Down
48 changes: 25 additions & 23 deletions R/simulate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,11 @@ simulate_infections <- function(estimates, R, initial_infections,
#' simulate. May decrease run times due to reduced IO costs but this is still
#' being evaluated. If set to NULL then all simulations are done at once.
#'
#' @param verbose Logical defaults to [interactive()]. Should a progress bar
#' (from `progressr`) be shown.
#' @param verbose Logical defaults to [interactive()]. If the `progressr`
#' package is available, a progress bar will be shown.
#' @inheritParams stan_opts
#' @importFrom rstan extract sampling
#' @importFrom purrr list_transpose map safely compact
#' @importFrom future.apply future_lapply
#' @importFrom progressr with_progress progressor
#' @importFrom data.table rbindlist as.data.table
#' @importFrom lubridate days
#' @importFrom checkmate assert_class assert_names test_numeric test_data_frame
Expand Down Expand Up @@ -472,39 +470,43 @@ forecast_infections <- function(estimates,

safe_batch <- safely(batch_simulate)

if (backend == "cmdstanr") {
lapply_func <- lapply ## future_lapply can't handle cmdstanr
} else {
lapply_func <- function(...) future_lapply(future.seed = TRUE, ...)
}

## simulate in batches
with_progress({
if (verbose) {
p <- progressor(along = batches)
}
out <- lapply_func(batches,
process_batches <- function(p = NULL) {
lapply_func(batches,
function(batch) {
if (verbose) {
if (!is.null(p)) {
p()
}
safe_batch(
estimates, draws, model,
shift, dates, batch[[1]],
batch[[2]]
)[[1]]
}
},
future.opts = list(
future.seed = TRUE
),
backend = backend
)
})
}

## simulate in batches
if (verbose && requireNamespace("progressr", quietly = TRUE)) {
p <- progressr::progressor(along = batches)
progressr::with_progress({
regional_out <- process_batches(p)
})
} else {
regional_out <- process_batches()
}

## join batches
out <- compact(out)
out <- list_transpose(out, simplify = FALSE)
out <- map(out, rbindlist)
regional_out <- compact(regional_out)
regional_out <- list_transpose(regional_out, simplify = FALSE)
regional_out <- map(regional_out, rbindlist)

## format output
format_out <- format_fit(
posterior_samples = out,
posterior_samples = regional_out,
horizon = estimates$args$horizon,
shift = shift,
burn_in = 0,
Expand Down
15 changes: 15 additions & 0 deletions R/utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,21 @@ set_dt_single_thread <- function() {
)
}

#' Choose a parallel or sequential apply function
#'
#' Internal function that chooses an appropriate "apply"-type function (either
#' [lapply()] or [future.apply::future_lapply()])
#' @return A function that can be used to apply a function to a list
#' @keywords internal
#' @inheritParams stan_opts
lapply_func <- function(..., backend = "rstan", future.opts = list()) {
if (requireNamespace("future.apply", quietly = TRUE) && backend == "rstan") {
do.call(future.apply::future_lapply, c(list(...), future.opts))
} else {
lapply(...)
}
}

#' @importFrom stats glm median na.omit pexp pgamma plnorm quasipoisson rexp
#' @importFrom stats rlnorm rnorm rpois runif sd var rgamma pnorm
globalVariables(
Expand Down
4 changes: 2 additions & 2 deletions man/forecast_infections.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions man/lapply_func.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/regional_epinow.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading