Skip to content

Commit

Permalink
cmdstanr benchmarking (#616)
Browse files Browse the repository at this point in the history
* benchmarking skeleton

* adapt options

* update gitignore

* add action

* add more triggers

* add name

* fetch main

* update message

* use remote branch as there is no tracking one

* update documentation

* try with braces

* suppress messages

* use two cores

* rename dir

* collect range as well as 90%

* update formatting

* increase number of iterations

* only trigger when no changes (I think)

* fix typo
  • Loading branch information
sbfnk authored Mar 25, 2024
1 parent a2600ab commit 5a27186
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 13 deletions.
74 changes: 74 additions & 0 deletions .github/workflows/stan-model-benchmark.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
name: stan model benchmark

on:
workflow_dispatch:
pull_request:
branches: main
paths:
- inst/stan/**
- .github/workflows/stan-model-benchmark.yaml
- inst/dev/benchmark*.R

concurrency:
group: ${{ github.workflow }}-${{ github.event.number }}
cancel-in-progress: true

jobs:
stan-model-benchmark:
runs-on: ubuntu-latest
env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
steps:
- uses: actions/checkout@v4

- uses: r-lib/actions/setup-r@v2
with:
use-public-rspm: true

- uses: r-lib/actions/setup-r-dependencies@v2
with:
packages: |
local::.
here
purrr
stan-dev/cmdstanr
- name: Install cmdstan
run: |
cmdstanr::check_cmdstan_toolchain(fix = TRUE)
cmdstanr::install_cmdstan(cores = 2, quiet = TRUE)
shell: Rscript {0}

- name: Checkout main branch in parallel and move to separate dir
run: |
mv inst/stan inst/stan-branch
git -c protocol.version=2 fetch --no-tags --prune --no-recurse-submodules --depth=1 origin main
git checkout origin/main inst/stan
mv inst/stan inst/stan-main
mv inst/stan-branch inst/stan
- name: Benchmark
run: |
Rscript inst/dev/benchmark.R
- id: output
name: Output to environment variable
if: ${{ hashFiles('inst/dev/benchmark-results.md') != '' }}
run: |
echo 'BENCHMARK<<EOF' > $GITHUB_OUTPUT
cat inst/dev/benchmark-results.md >> $GITHUB_OUTPUT
echo 'EOF' >> $GITHUB_OUTPUT
- name: Post comment
if: ${{ hashFiles('inst/dev/benchmark-results.md') != '' }}
uses: actions/github-script@v7
env:
BENCHMARK: ${{ steps.output.outputs.BENCHMARK }}
with:
script: |
github.rest.issues.createComment({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
body: process.env.BENCHMARK
})
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,7 @@ man/figures/*.png
inst/stan/*
!inst/stan/*/
!inst/stan/*.stan

# benchmarking
inst/stan-main
inst/dev/benchmark-results.md
19 changes: 15 additions & 4 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -894,11 +894,22 @@ stan_opts <- function(object = NULL,
)
}
opts <- list()
if (!is.null(object) && !missing(backend)) {
warning(
"`backend` option will be ignored as a stan model object has been passed."
)
if (!is.null(object)) {
if (!missing(backend)) {
warning(
"`backend` option will be ignored as a stan model object has been ",
"passed."
)
}
if (inherits(object, "stanmodel")) {
backend <- "rstan"
} else if (inherits(object, "CmdStanModel")) {
backend <- "cmdstanr"
} else {
stop("`object` must be a stan model object")
}
} else {
backend <- arg_match(backend, values = c("rstan", "cmdstanr"))
opts <- c(opts, list(backend = backend))
}
opts <- c(opts, list(
Expand Down
13 changes: 6 additions & 7 deletions R/stan.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#' "estimate_infections" (default), "simulate_infections", "estimate_secondary",
#' "simulate_secondary", "estimate_truncation" or "dist_fit".
#'
#' @param include A character string specifying the path to any stan
#' @param dir A character string specifying the path to any stan
#' files to include in the model. If missing the package default is used.
#'
#' @param verbose Logical, defaults to `TRUE`. Should verbose
Expand All @@ -23,17 +23,16 @@ package_model <- function(model = c(
"estimate_secondary", "simulate_secondary",
"estimate_truncation", "dist_fit"
),
include = system.file("stan", package = "EpiNow2"),
dir = system.file("stan", package = "EpiNow2"),
verbose = FALSE,
...) {
model <- arg_match(model)
model_file <- system.file(
"stan", paste0(model, ".stan"),
package = "EpiNow2"
model_file <- file.path(
dir, paste0(model, ".stan")
)
if (verbose) {
message(sprintf("Using model %s.", model))
message(sprintf("include is %s.", toString(include)))
message(sprintf("dir is %s.", toString(dir)))
}

monitor <- suppressMessages
Expand All @@ -44,7 +43,7 @@ package_model <- function(model = c(
}
model <- monitor(cmdstanr::cmdstan_model(
model_file,
include_paths = include,
include_paths = dir,
...
))
return(model)
Expand Down
55 changes: 55 additions & 0 deletions inst/dev/benchmark-functions.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
##' Create a benchmark profile
##'
##' This runs the `estimate_infections` function using a given stan model file
##' multiple times with given seeds and extracts the `cmdstanr` profiling
##' information each time.
##' @param dir directory that contains the stan model file
##' @param seeds a vector of random seeds to use; this determines how often
##' `estimate_infections` is run
##' @return a data.table of profile informations, with the run id given as
##' `iter`
create_profiles <- function(dir = file.path("inst", "stan"),
seeds = sample(.Machine$integer.max, 1)) {
compiled_model <- EpiNow2:::package_model(dir = dir)
profiles <- suppressMessages(purrr::map(seeds, \(x) {
set.seed(x)
fit <- estimate_infections(
reported_cases = reported_cases,
generation_time = generation_time_opts(fixed_generation_time),
delays = delay_opts(delays),
rt = rt_opts(prior = list(mean = 2, sd = 0.2)),
stan = stan_opts(
samples = 1000, chains = 2, object = compiled_model,
cores = 2
),
verbose = FALSE
)
df <- as.data.table(rbindlist(fit$fit$profiles(), idcol = "chain"))

return(df)
}))
return(data.table::rbindlist(profiles, idcol = "iter"))
}
##' Calculate bootstrap mean and credible intervals
##'
##' Credible intervals are calculated from resampled quantiles
##' @param x numeric vector
##' @param n_boot number of bootstrap iterations; if NULL (default) will take
##' length of x
##' @return a `data.table` with one row, containing the mean, 50% credible
##' intervals (`low`/`high`) and 90% credible intervals (`lower`/`higher`)
bootci <- function(x, n_boot = NULL) {
if (is.null(n_boot)) n_boot <- length(x)
m <- matrix(sample(x, n_boot * length(x), replace = TRUE), n_boot, length(x))
means <- apply(m, 1, mean)
dt <- data.table::data.table(
mean = mean(x),
low = quantile(means, 0.25),
high = quantile(means, 0.75),
lower = quantile(means, 0.05),
higher = quantile(means, 0.95),
lowest = range(means)[1],
highest = range(means)[2]
)
return(list(dt))
}
75 changes: 75 additions & 0 deletions inst/dev/benchmark.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
library("data.table")
library("EpiNow2")
library("knitr")

## number of times to run estimate_infections
n_iter <- 15

## random seeds
seeds <- sample(.Machine$integer.max, n_iter)

source(file.path("touchstone", "setup.R"))
source(file.path("inst", "dev", "benchmark-functions.R"))

profiles <- list()

## generate profiles for different versions of the stan model
profiles[["branch"]] <- create_profiles(file.path("inst", "stan"), seeds)
profiles[["main"]] <- create_profiles(file.path("inst", "stan-main"), seeds)

## merge profiles from the two chains into one, round total time and only keep
## name and total_time columns; then average across chains; sort by time from
## longest to shortest
summary <- rbindlist(profiles, idcol = "branch")
setnames(summary, "name", "operation")

summary_ci <- summary[, rbindlist(bootci(total_time)), by = .(branch, operation)]
setorder(summary_ci, -mean)
## print
summary_ci <- summary_ci[,
lapply(.SD, signif, 2), by = .(branch, operation)
]
summary_ci[, lapply(.SD, as.character)]
summary_means <- dcast(
summary_ci, operation ~ branch, value.var = "mean"
)

## now calculate the percentage change
format_summary <- dcast(
summary, operation + iter + chain ~ branch, value.var = "total_time"
)
changes <- format_summary[,
change := round((branch - main) / branch * 100)
][, list(
mean = round(mean(change)),
min = min(change),
max = max(change)
), by = "operation"
][, list(
mean = mean,
range = paste0("(", min, ", ", max, ")"),
trend = fcase(
min > 0, "slowdown",
max < 0, "speedup",
default = "no change"
)
), by = "operation"
]

if (any(changes$trend != "no change")) {
format_summary <- merge(
summary_means,
changes[, .(operation, mean, range, trend)],
by = "operation"
)

setorder(format_summary, -main)
format_summary <- format_summary[, lapply(.SD, as.character)]
setnames(format_summary, "mean", "% change")

sink(file = file.path("inst", "dev", "benchmark-results.md"))
knitr::kable(
format_summary, align = "lrrrr",
caption = "Benchmarking results (mean time in seconds)."
)
}
4 changes: 2 additions & 2 deletions man/package_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 5a27186

Please sign in to comment.