Skip to content

Commit

Permalink
Add newdata arg to simulate.sdmTMB()
Browse files Browse the repository at this point in the history
  • Loading branch information
seananderson committed Sep 27, 2024
1 parent b3c0ac7 commit 7b8569c
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 4 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Type: Package
Package: sdmTMB
Title: Spatial and Spatiotemporal SPDE-Based GLMMs with 'TMB'
Version: 0.6.0.9011
Version: 0.6.0.9012
Authors@R: c(
person(c("Sean", "C."), "Anderson", , "sean@seananderson.ca",
role = c("aut", "cre"),
Expand Down
7 changes: 7 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# sdmTMB (development version)

* Add `newdata` argument to `simulate.sdmTMB()`. This enables simulating on
a new data frame similar to how one would predict on new data.

* Add `mle_mvn_samples` argument to `simulate.sdmTMB()`. Defaults to "single".
If "multiple", then a sample from the random effects is taken for each
simulation iteration.

* Add `project()` experimental function.

* Add print method for `sdmTMB_cv()` output. #319
Expand Down
34 changes: 31 additions & 3 deletions R/tmb-sim.R
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,12 @@ sdmTMB_simulate <- function(formula,
#' effects (this only simulates observation error). `~0` or `NA` to simulate
#' new random affects (smoothers, which internally are random effects, will
#' not be simulated as new).
#' @param mle_mvn_samples Applies if `type = "mle-mvn"`. If `"single"`, take
#' a single MVN draw from the random effects. If `"multiple"`, take an MVN
#' draw from the random effects for each of the `nsim`.
#' @param model If a delta/hurdle model, which model to simulate from?
#' `NA` = combined, `1` = first model, `2` = second mdoel.
#' @param newdata Optional new data frame from which to simulate.
#' @param mcmc_samples An optional matrix of MCMC samples. See `extract_mcmc()`
#' in the \href{https://github.com/pbs-assess/sdmTMBextra}{sdmTMBextra}
#' package.
Expand Down Expand Up @@ -381,10 +385,16 @@ sdmTMB_simulate <- function(formula,
simulate.sdmTMB <- function(object, nsim = 1L, seed = sample.int(1e6, 1L),
type = c("mle-eb", "mle-mvn"),
model = c(NA, 1, 2),
re_form = NULL, mcmc_samples = NULL, silent = FALSE, ...) {
newdata = NULL,
re_form = NULL,
mle_mvn_samples = c("single", "multiple"),
mcmc_samples = NULL,
silent = FALSE,
...) {
set.seed(seed)
type <- tolower(type)
type <- match.arg(type)
mle_mvn_samples <- match.arg(mle_mvn_samples)
assert_that(as.integer(model[[1]]) %in% c(NA_integer_, 1L, 2L))

# need to re-attach environment if in fresh session
Expand All @@ -403,6 +413,16 @@ simulate.sdmTMB <- function(object, nsim = 1L, seed = sample.int(1e6, 1L),
stopifnot(length(object$tmb_data$sim_re) == 6L) # in case this gets changed
tmb_dat$sim_re <- c(rep(1L, 5L), 0L) # last is smoothers; don't simulate them
}

if (!is.null(newdata)) {
# generate prediction TMB data list
p <- predict(object, newdata = newdata, return_tmb_data = TRUE, ...)
# move data elements over
p <- move_proj_to_tmbdat(p, object, newdata)
p$sim_re <- tmb_dat$sim_re
tmb_dat <- p
}

newobj <- TMB::MakeADFun(
data = tmb_dat, map = object$tmb_map,
random = object$tmb_random, parameters = object$tmb_obj$env$parList(), DLL = "sdmTMB"
Expand All @@ -411,9 +431,17 @@ simulate.sdmTMB <- function(object, nsim = 1L, seed = sample.int(1e6, 1L),
# params MLE/MVN stuff
if (is.null(mcmc_samples)) {
if (type == "mle-mvn") {
new_par <- .one_sample_posterior(object)
if (mle_mvn_samples == "single") {
new_par <- .one_sample_posterior(object)
new_par <- replicate(nsim, new_par)
} else {
new_par <- lapply(seq_len(nsim), \(i) .one_sample_posterior(object))
new_par <- do.call(cbind, new_par)
}
} else if (type == "mle-eb") {
new_par <- object$tmb_obj$env$last.par.best
new_par <- lapply(seq_len(nsim), \(i) new_par)
new_par <- do.call(cbind, new_par)
} else {
cli_abort("`type` type not defined")
}
Expand All @@ -432,7 +460,7 @@ simulate.sdmTMB <- function(object, nsim = 1L, seed = sample.int(1e6, 1L),
} else {
for (i in seq_len(nsim)) {
if (!silent) cli::cli_progress_update()
ret[[i]] <- newobj$simulate(par = new_par, complete = FALSE)$y_i
ret[[i]] <- newobj$simulate(par = new_par[, i, drop = TRUE], complete = FALSE)$y_i
}
}
if (!silent) cli::cli_progress_done()
Expand Down
8 changes: 8 additions & 0 deletions man/simulate.sdmTMB.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 45 additions & 0 deletions tests/testthat/test-6-tmb-simulation.R
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,51 @@ test_that("simulate() behaves OK with or without random effects across types", {
expect_length(s, 969)
})

test_that("simulate() method works with newdata", {
skip_on_cran()
fit <- sdmTMB(
present ~ 1,
time = "year",
data = pcod_2011, spatial = "on",
spatiotemporal = "iid",
family = binomial(),
mesh = pcod_mesh_2011
)
s <- simulate(fit)
expect_true(nrow(s) == nrow(pcod_2011))
g <- replicate_df(qcs_grid, "year", unique(pcod_2011$year))
s <- simulate(fit, newdata = g)
expect_true(nrow(s) == nrow(g))
s <- simulate(fit, newdata = subset(g, year == 2011))
nrow(s)
expect_true(nrow(s) == nrow(subset(g, year == 2011)))

gg <- subset(g, year == 2011)
set.seed(1)
s <- simulate(fit, newdata = gg, nsim = 400L)
a <- apply(s, 1, mean)
p <- predict(fit, newdata = gg)
plot(a, plogis(p$est))
expect_gt(cor(a, plogis(p$est)), 0.98)

set.seed(1)
s1 <- simulate(fit, type = "mle-mvn", mle_mvn_samples = "single", nsim = 100)
set.seed(1)
s2 <- simulate(fit, type = "mle-mvn", mle_mvn_samples = "multiple", nsim = 100)
set.seed(1)
s3 <- simulate(fit, type = "mle-eb", nsim = 100)

expect_false(identical(s1, s2))
expect_false(identical(s1, s3))

sd1 <- apply(s1, 1, sd)
sd2 <- apply(s2, 1, sd)
sd3 <- apply(s3, 1, sd)

expect_lt(mean(sd1), mean(sd2))
})


# test_that("TMB Delta simulation works", {
# skip_on_cran()
# skip_if_not_installed("INLA")
Expand Down

0 comments on commit 7b8569c

Please sign in to comment.