-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* benchmarking skeleton * adapt options * update gitignore * add action * add more triggers * add name * fetch main * update message * use remote branch as there is no tracking one * update documentation * try with braces * suppress messages * use two cores * rename dir * collect range as well as 90% * update formatting * increase number of iterations * only trigger when no changes (I think) * fix typo
- Loading branch information
Showing
7 changed files
with
231 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
name: stan model benchmark | ||
|
||
on: | ||
workflow_dispatch: | ||
pull_request: | ||
branches: main | ||
paths: | ||
- inst/stan/** | ||
- .github/workflows/stan-model-benchmark.yaml | ||
- inst/dev/benchmark*.R | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.event.number }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
stan-model-benchmark: | ||
runs-on: ubuntu-latest | ||
env: | ||
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} | ||
steps: | ||
- uses: actions/checkout@v4 | ||
|
||
- uses: r-lib/actions/setup-r@v2 | ||
with: | ||
use-public-rspm: true | ||
|
||
- uses: r-lib/actions/setup-r-dependencies@v2 | ||
with: | ||
packages: | | ||
local::. | ||
here | ||
purrr | ||
stan-dev/cmdstanr | ||
- name: Install cmdstan | ||
run: | | ||
cmdstanr::check_cmdstan_toolchain(fix = TRUE) | ||
cmdstanr::install_cmdstan(cores = 2, quiet = TRUE) | ||
shell: Rscript {0} | ||
|
||
- name: Checkout main branch in parallel and move to separate dir | ||
run: | | ||
mv inst/stan inst/stan-branch | ||
git -c protocol.version=2 fetch --no-tags --prune --no-recurse-submodules --depth=1 origin main | ||
git checkout origin/main inst/stan | ||
mv inst/stan inst/stan-main | ||
mv inst/stan-branch inst/stan | ||
- name: Benchmark | ||
run: | | ||
Rscript inst/dev/benchmark.R | ||
- id: output | ||
name: Output to environment variable | ||
if: ${{ hashFiles('inst/dev/benchmark-results.md') != '' }} | ||
run: | | ||
echo 'BENCHMARK<<EOF' > $GITHUB_OUTPUT | ||
cat inst/dev/benchmark-results.md >> $GITHUB_OUTPUT | ||
echo 'EOF' >> $GITHUB_OUTPUT | ||
- name: Post comment | ||
if: ${{ hashFiles('inst/dev/benchmark-results.md') != '' }} | ||
uses: actions/github-script@v7 | ||
env: | ||
BENCHMARK: ${{ steps.output.outputs.BENCHMARK }} | ||
with: | ||
script: | | ||
github.rest.issues.createComment({ | ||
issue_number: context.issue.number, | ||
owner: context.repo.owner, | ||
repo: context.repo.repo, | ||
body: process.env.BENCHMARK | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
##' Create a benchmark profile | ||
##' | ||
##' This runs the `estimate_infections` function using a given stan model file | ||
##' multiple times with given seeds and extracts the `cmdstanr` profiling | ||
##' information each time. | ||
##' @param dir directory that contains the stan model file | ||
##' @param seeds a vector of random seeds to use; this determines how often | ||
##' `estimate_infections` is run | ||
##' @return a data.table of profile informations, with the run id given as | ||
##' `iter` | ||
create_profiles <- function(dir = file.path("inst", "stan"), | ||
seeds = sample(.Machine$integer.max, 1)) { | ||
compiled_model <- EpiNow2:::package_model(dir = dir) | ||
profiles <- suppressMessages(purrr::map(seeds, \(x) { | ||
set.seed(x) | ||
fit <- estimate_infections( | ||
reported_cases = reported_cases, | ||
generation_time = generation_time_opts(fixed_generation_time), | ||
delays = delay_opts(delays), | ||
rt = rt_opts(prior = list(mean = 2, sd = 0.2)), | ||
stan = stan_opts( | ||
samples = 1000, chains = 2, object = compiled_model, | ||
cores = 2 | ||
), | ||
verbose = FALSE | ||
) | ||
df <- as.data.table(rbindlist(fit$fit$profiles(), idcol = "chain")) | ||
|
||
return(df) | ||
})) | ||
return(data.table::rbindlist(profiles, idcol = "iter")) | ||
} | ||
##' Calculate bootstrap mean and credible intervals | ||
##' | ||
##' Credible intervals are calculated from resampled quantiles | ||
##' @param x numeric vector | ||
##' @param n_boot number of bootstrap iterations; if NULL (default) will take | ||
##' length of x | ||
##' @return a `data.table` with one row, containing the mean, 50% credible | ||
##' intervals (`low`/`high`) and 90% credible intervals (`lower`/`higher`) | ||
bootci <- function(x, n_boot = NULL) { | ||
if (is.null(n_boot)) n_boot <- length(x) | ||
m <- matrix(sample(x, n_boot * length(x), replace = TRUE), n_boot, length(x)) | ||
means <- apply(m, 1, mean) | ||
dt <- data.table::data.table( | ||
mean = mean(x), | ||
low = quantile(means, 0.25), | ||
high = quantile(means, 0.75), | ||
lower = quantile(means, 0.05), | ||
higher = quantile(means, 0.95), | ||
lowest = range(means)[1], | ||
highest = range(means)[2] | ||
) | ||
return(list(dt)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
library("data.table") | ||
library("EpiNow2") | ||
library("knitr") | ||
|
||
## number of times to run estimate_infections | ||
n_iter <- 15 | ||
|
||
## random seeds | ||
seeds <- sample(.Machine$integer.max, n_iter) | ||
|
||
source(file.path("touchstone", "setup.R")) | ||
source(file.path("inst", "dev", "benchmark-functions.R")) | ||
|
||
profiles <- list() | ||
|
||
## generate profiles for different versions of the stan model | ||
profiles[["branch"]] <- create_profiles(file.path("inst", "stan"), seeds) | ||
profiles[["main"]] <- create_profiles(file.path("inst", "stan-main"), seeds) | ||
|
||
## merge profiles from the two chains into one, round total time and only keep | ||
## name and total_time columns; then average across chains; sort by time from | ||
## longest to shortest | ||
summary <- rbindlist(profiles, idcol = "branch") | ||
setnames(summary, "name", "operation") | ||
|
||
summary_ci <- summary[, rbindlist(bootci(total_time)), by = .(branch, operation)] | ||
setorder(summary_ci, -mean) | ||
summary_ci <- summary_ci[, | ||
lapply(.SD, signif, 2), by = .(branch, operation) | ||
] | ||
summary_ci[, lapply(.SD, as.character)] | ||
summary_means <- dcast( | ||
summary_ci, operation ~ branch, value.var = "mean" | ||
) | ||
|
||
## now calculate the percentage change | ||
format_summary <- dcast( | ||
summary, operation + iter + chain ~ branch, value.var = "total_time" | ||
) | ||
changes <- format_summary[, | ||
change := round((branch - main) / branch * 100) | ||
][, list( | ||
mean = round(mean(change)), | ||
min = min(change), | ||
max = max(change) | ||
), by = "operation" | ||
][, list( | ||
mean = mean, | ||
range = paste0("(", min, ", ", max, ")"), | ||
trend = fcase( | ||
min > 0, "slowdown", | ||
max < 0, "speedup", | ||
default = "no change" | ||
) | ||
), by = "operation" | ||
] | ||
|
||
if (any(changes$trend != "no change")) { | ||
format_summary <- merge( | ||
summary_means, | ||
changes[, .(operation, mean, range, trend)], | ||
by = "operation" | ||
) | ||
|
||
setorder(format_summary, -main) | ||
format_summary <- format_summary[, lapply(.SD, as.character)] | ||
setnames(format_summary, "mean", "% change") | ||
|
||
sink(file = file.path("inst", "dev", "benchmark-results.md")) | ||
knitr::kable( | ||
format_summary, align = "lrrrr", | ||
caption = "Benchmarking results (mean time in seconds)." | ||
) | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.