From 9672ff091482c32ad94ebcc28267c06b7793a4ce Mon Sep 17 00:00:00 2001 From: Mike Mahoney Date: Fri, 1 Jul 2022 11:52:23 -0400 Subject: [PATCH 1/9] First stab at reshuffling --- DESCRIPTION | 3 +- NAMESPACE | 1 + R/boot.R | 4 +- R/compat-vctrs-helpers.R | 35 +++-- R/misc.R | 35 ++++- R/slide.R | 4 + man/reshuffle_rset.Rd | 24 +++ tests/testthat/_snaps/misc.md | 272 ++++++++++++++++++++++++++++++++++ tests/testthat/test-misc.R | 9 ++ 9 files changed, 368 insertions(+), 19 deletions(-) create mode 100644 man/reshuffle_rset.Rd diff --git a/DESCRIPTION b/DESCRIPTION index 4848574d..bf4351e5 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -31,7 +31,8 @@ Imports: tibble, tidyr, tidyselect, - vctrs (>= 0.3.0) + vctrs (>= 0.3.0), + withr Suggests: broom, covr, diff --git a/NAMESPACE b/NAMESPACE index 12cb3db2..9f4332ec 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -269,6 +269,7 @@ export(num_range) export(permutations) export(populate) export(reg_intervals) +export(reshuffle_rset) export(reverse_splits) export(rolling_origin) export(rsample2caret) diff --git a/R/boot.R b/R/boot.R index 0318d704..a7f83472 100644 --- a/R/boot.R +++ b/R/boot.R @@ -91,7 +91,9 @@ bootstraps <- boot_att <- list( times = times, apparent = apparent, - strata = !is.null(strata) + strata = !is.null(strata), + breaks = breaks, + pool = pool ) new_rset( diff --git a/R/compat-vctrs-helpers.R b/R/compat-vctrs-helpers.R index c26988ed..9659386e 100644 --- a/R/compat-vctrs-helpers.R +++ b/R/compat-vctrs-helpers.R @@ -118,22 +118,25 @@ test_data <- function() { # Delay assignment because we are creating this directly in the R script # and not all of the required helpers might have been sourced yet. delayedAssign("rset_subclasses", { - list( - bootstraps = bootstraps(test_data()), - group_bootstraps = group_bootstraps(test_data(), y), - vfold_cv = vfold_cv(test_data(), v = 10, repeats = 2), - group_vfold_cv = group_vfold_cv(test_data(), y), - loo_cv = loo_cv(test_data()), - mc_cv = mc_cv(test_data()), - group_mc_cv = group_mc_cv(test_data(), y), - nested_cv = nested_cv(test_data(), outside = vfold_cv(v = 3), inside = bootstraps(times = 5)), - validation_split = validation_split(test_data()), - rolling_origin = rolling_origin(test_data()), - sliding_window = sliding_window(test_data()), - sliding_index = sliding_index(test_data(), index), - sliding_period = sliding_period(test_data(), index, "week"), - manual_rset = manual_rset(bootstraps(test_data())$splits[1:2], c("ID1", "ID2")), - apparent = apparent(test_data()) + withr::with_seed( + 123, + list( + bootstraps = bootstraps(test_data()), + group_bootstraps = group_bootstraps(test_data(), y), + vfold_cv = vfold_cv(test_data(), v = 10, repeats = 2), + group_vfold_cv = group_vfold_cv(test_data(), y), + loo_cv = loo_cv(test_data()), + mc_cv = mc_cv(test_data()), + group_mc_cv = group_mc_cv(test_data(), y), + nested_cv = nested_cv(test_data(), outside = vfold_cv(v = 3), inside = bootstraps(times = 5)), + validation_split = validation_split(test_data()), + rolling_origin = rolling_origin(test_data()), + sliding_window = sliding_window(test_data()), + sliding_index = sliding_index(test_data(), index), + sliding_period = sliding_period(test_data(), index, "week"), + manual_rset = manual_rset(bootstraps(test_data())$splits[1:2], c("ID1", "ID2")), + apparent = apparent(test_data()) + ) ) }) diff --git a/R/misc.R b/R/misc.R index 7164167b..078c2b1c 100644 --- a/R/misc.R +++ b/R/misc.R @@ -220,7 +220,40 @@ reverse_splits.rset <- function(x, ...) { x } +#' "Reshuffle" an rset to re-generate a new rset with the same parameters +#' +#' This function re-generates an rset object, using the same arguments as used +#' to generate the original. +#' +#' @param rset The `rset` object to be reshuffled +#' +#' @return An rset of the same class as `rset`. +#' +#' @examples +#' set.seed(123) +#' (starting_splits <- group_vfold_cv(mtcars, cyl, v = 3)) +#' reshuffle_rset(starting_splits) +#' +#' @export +reshuffle_rset <- function(rset) { + if (!inherits(rset, "rset")) { + rlang::abort("`rset` must be an rset object") + } + if (inherits(rset, "manual_rset")) { + rlang::abort("`manual_rset` objects cannot be reshuffled") + } + arguments <- attributes(rset) + useful_arguments <- names(formals(arguments$class[[1]])) + useful_arguments <- arguments[useful_arguments] + useful_arguments <- useful_arguments[!is.na(names(useful_arguments))] + if (identical(useful_arguments$strata, FALSE)) { + useful_arguments$strata <- NULL + } - + do.call( + arguments$class[[1]], + c(list(data = rset$splits[[1]]$data), useful_arguments) + ) +} diff --git a/R/slide.R b/R/slide.R index 0293dc98..20f5cab5 100644 --- a/R/slide.R +++ b/R/slide.R @@ -312,6 +312,7 @@ sliding_index <- function(data, rlang::abort("`index` must specify exactly one column in `data`.") } + index_attrib <- index index <- data[[loc]] seq <- vctrs::vec_seq_along(data) @@ -352,6 +353,7 @@ sliding_index <- function(data, ids <- names0(length(indices), prefix = "Slice") attrib <- list( + index = index_attrib, lookback = lookback, assess_start = assess_start, assess_stop = assess_stop, @@ -406,6 +408,7 @@ sliding_period <- function(data, rlang::abort("`index` must specify exactly one column in `data`.") } + index_attrib <- index index <- data[[loc]] seq <- vctrs::vec_seq_along(data) @@ -452,6 +455,7 @@ sliding_period <- function(data, ids <- names0(length(indices), prefix = "Slice") attrib <- list( + index = index_attrib, period = period, lookback = lookback, assess_start = assess_start, diff --git a/man/reshuffle_rset.Rd b/man/reshuffle_rset.Rd new file mode 100644 index 00000000..75cb2aa0 --- /dev/null +++ b/man/reshuffle_rset.Rd @@ -0,0 +1,24 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/misc.R +\name{reshuffle_rset} +\alias{reshuffle_rset} +\title{"Reshuffle" an rset to re-generate a new rset with the same parameters} +\usage{ +reshuffle_rset(rset) +} +\arguments{ +\item{rset}{The \code{rset} object to be reshuffled} +} +\value{ +An rset of the same class as \code{rset}. +} +\description{ +This function re-generates an rset object, using the same arguments as used +to generate the original. +} +\examples{ +set.seed(123) +(starting_splits <- group_vfold_cv(mtcars, cyl, v = 3)) +reshuffle_rset(starting_splits) + +} diff --git a/tests/testthat/_snaps/misc.md b/tests/testthat/_snaps/misc.md index f0b5fb49..0a4b5d14 100644 --- a/tests/testthat/_snaps/misc.md +++ b/tests/testthat/_snaps/misc.md @@ -22,3 +22,275 @@ Error in `reverse_splits()`: ! Permutations cannot have their splits reversed +# reshuffle_rset is working + + Code + out + Output + # Bootstrap sampling + # A tibble: 25 x 2 + splits id + + 1 Bootstrap01 + 2 Bootstrap02 + 3 Bootstrap03 + 4 Bootstrap04 + 5 Bootstrap05 + 6 Bootstrap06 + 7 Bootstrap07 + 8 Bootstrap08 + 9 Bootstrap09 + 10 Bootstrap10 + # ... with 15 more rows + +--- + + Code + out + Output + # Bootstrap sampling + # A tibble: 25 x 2 + splits id + + 1 Bootstrap01 + 2 Bootstrap02 + 3 Bootstrap03 + 4 Bootstrap04 + 5 Bootstrap05 + 6 Bootstrap06 + 7 Bootstrap07 + 8 Bootstrap08 + 9 Bootstrap09 + 10 Bootstrap10 + # ... with 15 more rows + +--- + + Code + out + Output + # 10-fold cross-validation repeated 2 times + # A tibble: 20 x 3 + splits id id2 + + 1 Repeat1 Fold01 + 2 Repeat1 Fold02 + 3 Repeat1 Fold03 + 4 Repeat1 Fold04 + 5 Repeat1 Fold05 + 6 Repeat1 Fold06 + 7 Repeat1 Fold07 + 8 Repeat1 Fold08 + 9 Repeat1 Fold09 + 10 Repeat1 Fold10 + 11 Repeat2 Fold01 + 12 Repeat2 Fold02 + 13 Repeat2 Fold03 + 14 Repeat2 Fold04 + 15 Repeat2 Fold05 + 16 Repeat2 Fold06 + 17 Repeat2 Fold07 + 18 Repeat2 Fold08 + 19 Repeat2 Fold09 + 20 Repeat2 Fold10 + +--- + + Code + out + Output + # Group 10-fold cross-validation + # A tibble: 10 x 2 + splits id + + 1 Resample01 + 2 Resample02 + 3 Resample03 + 4 Resample04 + 5 Resample05 + 6 Resample06 + 7 Resample07 + 8 Resample08 + 9 Resample09 + 10 Resample10 + +--- + + Code + out + Output + # Leave-one-out cross-validation + # A tibble: 50 x 2 + splits id + + 1 Resample1 + 2 Resample2 + 3 Resample3 + 4 Resample4 + 5 Resample5 + 6 Resample6 + 7 Resample7 + 8 Resample8 + 9 Resample9 + 10 Resample10 + # ... with 40 more rows + +--- + + Code + out + Output + # Monte Carlo cross-validation (0.75/0.25) with 25 resamples + # A tibble: 25 x 2 + splits id + + 1 Resample01 + 2 Resample02 + 3 Resample03 + 4 Resample04 + 5 Resample05 + 6 Resample06 + 7 Resample07 + 8 Resample08 + 9 Resample09 + 10 Resample10 + # ... with 15 more rows + +--- + + Code + out + Output + # Grouped Monte Carlo cross-validation (0.75/0.25) with 25 resamples + # A tibble: 25 x 2 + splits id + + 1 Resample01 + 2 Resample02 + 3 Resample03 + 4 Resample04 + 5 Resample05 + 6 Resample06 + 7 Resample07 + 8 Resample08 + 9 Resample09 + 10 Resample10 + # ... with 15 more rows + +--- + + Code + out + Output + # Nested resampling: + # outer: 3-fold cross-validation + # inner: Bootstrap sampling + # A tibble: 3 x 3 + splits id inner_resamples + + 1 Fold1 + 2 Fold2 + 3 Fold3 + +--- + + Code + out + Output + # Validation Set Split (0.75/0.25) + # A tibble: 1 x 2 + splits id + + 1 validation + +--- + + Code + out + Output + # Rolling origin forecast resampling + # A tibble: 45 x 2 + splits id + + 1 Slice01 + 2 Slice02 + 3 Slice03 + 4 Slice04 + 5 Slice05 + 6 Slice06 + 7 Slice07 + 8 Slice08 + 9 Slice09 + 10 Slice10 + # ... with 35 more rows + +--- + + Code + out + Output + # Sliding window resampling + # A tibble: 49 x 2 + splits id + + 1 Slice01 + 2 Slice02 + 3 Slice03 + 4 Slice04 + 5 Slice05 + 6 Slice06 + 7 Slice07 + 8 Slice08 + 9 Slice09 + 10 Slice10 + # ... with 39 more rows + +--- + + Code + out + Output + # Sliding index resampling + # A tibble: 49 x 2 + splits id + + 1 Slice01 + 2 Slice02 + 3 Slice03 + 4 Slice04 + 5 Slice05 + 6 Slice06 + 7 Slice07 + 8 Slice08 + 9 Slice09 + 10 Slice10 + # ... with 39 more rows + +--- + + Code + out + Output + # Sliding period resampling + # A tibble: 7 x 2 + splits id + + 1 Slice1 + 2 Slice2 + 3 Slice3 + 4 Slice4 + 5 Slice5 + 6 Slice6 + 7 Slice7 + +--- + + Code + out + Output + # Apparent sampling + # A tibble: 1 x 2 + splits id + + 1 Apparent + diff --git a/tests/testthat/test-misc.R b/tests/testthat/test-misc.R index 34b2d00e..e2ac2ab0 100644 --- a/tests/testthat/test-misc.R +++ b/tests/testthat/test-misc.R @@ -28,3 +28,12 @@ test_that("reverse_splits is working", { ) }) + +test_that("reshuffle_rset is working", { + for (x in rset_subclasses) { + if (inherits(x, "manual_rset")) next + withr::with_seed(123, out <- reshuffle_rset(x)) + expect_snapshot(out) + } +}) + From 1ee58b4a209c81fe39e1dedcb1f2bbe78f2863a0 Mon Sep 17 00:00:00 2001 From: Mike Mahoney Date: Fri, 1 Jul 2022 12:29:03 -0400 Subject: [PATCH 2/9] Fix tests --- NAMESPACE | 1 + R/rsample-package.R | 1 + tests/testthat/_snaps/misc.md | 272 ---------------------------------- tests/testthat/test-misc.R | 25 +++- 4 files changed, 23 insertions(+), 276 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 9f4332ec..aa047689 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -348,3 +348,4 @@ importFrom(vctrs,vec_count) importFrom(vctrs,vec_in) importFrom(vctrs,vec_slice) importFrom(vctrs,vec_unique_count) +importFrom(withr,with_seed) diff --git a/R/rsample-package.R b/R/rsample-package.R index 5d9a01df..39c3e72f 100644 --- a/R/rsample-package.R +++ b/R/rsample-package.R @@ -21,6 +21,7 @@ NULL #' @importFrom tidyselect vars_select one_of #' @importFrom furrr future_map_dfr #' @importFrom tidyr gather +#' @importFrom withr with_seed #------------------------------------------------------------------------------# diff --git a/tests/testthat/_snaps/misc.md b/tests/testthat/_snaps/misc.md index 0a4b5d14..f0b5fb49 100644 --- a/tests/testthat/_snaps/misc.md +++ b/tests/testthat/_snaps/misc.md @@ -22,275 +22,3 @@ Error in `reverse_splits()`: ! Permutations cannot have their splits reversed -# reshuffle_rset is working - - Code - out - Output - # Bootstrap sampling - # A tibble: 25 x 2 - splits id - - 1 Bootstrap01 - 2 Bootstrap02 - 3 Bootstrap03 - 4 Bootstrap04 - 5 Bootstrap05 - 6 Bootstrap06 - 7 Bootstrap07 - 8 Bootstrap08 - 9 Bootstrap09 - 10 Bootstrap10 - # ... with 15 more rows - ---- - - Code - out - Output - # Bootstrap sampling - # A tibble: 25 x 2 - splits id - - 1 Bootstrap01 - 2 Bootstrap02 - 3 Bootstrap03 - 4 Bootstrap04 - 5 Bootstrap05 - 6 Bootstrap06 - 7 Bootstrap07 - 8 Bootstrap08 - 9 Bootstrap09 - 10 Bootstrap10 - # ... with 15 more rows - ---- - - Code - out - Output - # 10-fold cross-validation repeated 2 times - # A tibble: 20 x 3 - splits id id2 - - 1 Repeat1 Fold01 - 2 Repeat1 Fold02 - 3 Repeat1 Fold03 - 4 Repeat1 Fold04 - 5 Repeat1 Fold05 - 6 Repeat1 Fold06 - 7 Repeat1 Fold07 - 8 Repeat1 Fold08 - 9 Repeat1 Fold09 - 10 Repeat1 Fold10 - 11 Repeat2 Fold01 - 12 Repeat2 Fold02 - 13 Repeat2 Fold03 - 14 Repeat2 Fold04 - 15 Repeat2 Fold05 - 16 Repeat2 Fold06 - 17 Repeat2 Fold07 - 18 Repeat2 Fold08 - 19 Repeat2 Fold09 - 20 Repeat2 Fold10 - ---- - - Code - out - Output - # Group 10-fold cross-validation - # A tibble: 10 x 2 - splits id - - 1 Resample01 - 2 Resample02 - 3 Resample03 - 4 Resample04 - 5 Resample05 - 6 Resample06 - 7 Resample07 - 8 Resample08 - 9 Resample09 - 10 Resample10 - ---- - - Code - out - Output - # Leave-one-out cross-validation - # A tibble: 50 x 2 - splits id - - 1 Resample1 - 2 Resample2 - 3 Resample3 - 4 Resample4 - 5 Resample5 - 6 Resample6 - 7 Resample7 - 8 Resample8 - 9 Resample9 - 10 Resample10 - # ... with 40 more rows - ---- - - Code - out - Output - # Monte Carlo cross-validation (0.75/0.25) with 25 resamples - # A tibble: 25 x 2 - splits id - - 1 Resample01 - 2 Resample02 - 3 Resample03 - 4 Resample04 - 5 Resample05 - 6 Resample06 - 7 Resample07 - 8 Resample08 - 9 Resample09 - 10 Resample10 - # ... with 15 more rows - ---- - - Code - out - Output - # Grouped Monte Carlo cross-validation (0.75/0.25) with 25 resamples - # A tibble: 25 x 2 - splits id - - 1 Resample01 - 2 Resample02 - 3 Resample03 - 4 Resample04 - 5 Resample05 - 6 Resample06 - 7 Resample07 - 8 Resample08 - 9 Resample09 - 10 Resample10 - # ... with 15 more rows - ---- - - Code - out - Output - # Nested resampling: - # outer: 3-fold cross-validation - # inner: Bootstrap sampling - # A tibble: 3 x 3 - splits id inner_resamples - - 1 Fold1 - 2 Fold2 - 3 Fold3 - ---- - - Code - out - Output - # Validation Set Split (0.75/0.25) - # A tibble: 1 x 2 - splits id - - 1 validation - ---- - - Code - out - Output - # Rolling origin forecast resampling - # A tibble: 45 x 2 - splits id - - 1 Slice01 - 2 Slice02 - 3 Slice03 - 4 Slice04 - 5 Slice05 - 6 Slice06 - 7 Slice07 - 8 Slice08 - 9 Slice09 - 10 Slice10 - # ... with 35 more rows - ---- - - Code - out - Output - # Sliding window resampling - # A tibble: 49 x 2 - splits id - - 1 Slice01 - 2 Slice02 - 3 Slice03 - 4 Slice04 - 5 Slice05 - 6 Slice06 - 7 Slice07 - 8 Slice08 - 9 Slice09 - 10 Slice10 - # ... with 39 more rows - ---- - - Code - out - Output - # Sliding index resampling - # A tibble: 49 x 2 - splits id - - 1 Slice01 - 2 Slice02 - 3 Slice03 - 4 Slice04 - 5 Slice05 - 6 Slice06 - 7 Slice07 - 8 Slice08 - 9 Slice09 - 10 Slice10 - # ... with 39 more rows - ---- - - Code - out - Output - # Sliding period resampling - # A tibble: 7 x 2 - splits id - - 1 Slice1 - 2 Slice2 - 3 Slice3 - 4 Slice4 - 5 Slice5 - 6 Slice6 - 7 Slice7 - ---- - - Code - out - Output - # Apparent sampling - # A tibble: 1 x 2 - splits id - - 1 Apparent - diff --git a/tests/testthat/test-misc.R b/tests/testthat/test-misc.R index e2ac2ab0..97c25db7 100644 --- a/tests/testthat/test-misc.R +++ b/tests/testthat/test-misc.R @@ -30,10 +30,27 @@ test_that("reverse_splits is working", { }) test_that("reshuffle_rset is working", { - for (x in rset_subclasses) { - if (inherits(x, "manual_rset")) next - withr::with_seed(123, out <- reshuffle_rset(x)) - expect_snapshot(out) + + supported_subclasses <- rset_subclasses[ + setdiff(names(rset_subclasses), "manual_rset") + ] + + # Reshuffling with the same seed, in the same order, + # should recreate the same objects + out <- withr::with_seed( + 123, + lapply( + supported_subclasses, + reshuffle_rset + ) + ) + + for (i in seq_along(supported_subclasses)) { + expect_identical( + out[[i]], + supported_subclasses[[i]] + ) } + }) From afb0c45bf8d5190796476c1878d54f9fa91dfd30 Mon Sep 17 00:00:00 2001 From: Mike Mahoney Date: Fri, 1 Jul 2022 12:56:13 -0400 Subject: [PATCH 3/9] pkgdown --- _pkgdown.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/_pkgdown.yml b/_pkgdown.yml index 70ab1945..79e11fc3 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -65,6 +65,7 @@ reference: - make_splits - make_strata - populate + - reshuffle_rset - reverse_splits - rsample2caret - rset_reconstruct From 48df2989b8371d70bfced696fa06ae8593b411a2 Mon Sep 17 00:00:00 2001 From: Mike Mahoney Date: Tue, 5 Jul 2022 17:39:46 -0400 Subject: [PATCH 4/9] Test strata when applicable --- tests/testthat/test-misc.R | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/tests/testthat/test-misc.R b/tests/testthat/test-misc.R index f656ea09..586ad231 100644 --- a/tests/testthat/test-misc.R +++ b/tests/testthat/test-misc.R @@ -32,6 +32,7 @@ test_that("reverse_splits is working", { test_that("reshuffle_rset is working", { + skip_if_not(rlang::is_installed("withr")) supported_subclasses <- rset_subclasses[ setdiff(names(rset_subclasses), "manual_rset") ] @@ -53,5 +54,25 @@ test_that("reshuffle_rset is working", { ) } + supports_strata <- purrr::map_lgl( + names(supported_subclasses), + ~ any(names(formals(.x)) == "strata") + ) + supports_strata <- names(supported_subclasses)[supports_strata] + supports_strata <- supported_subclasses[supports_strata] + for (i in seq_along(supports_strata)) { + set.seed(123) + resample <- do.call( + names(supports_strata)[i], + list( + data = test_data(), + strata = "y", + breaks = 2, + pool = 0.2 + ) + ) + set.seed(123) + reshuffled_resample <- reshuffle_rset(resample) + expect_identical(resample, reshuffled_resample) + } }) - From 366576228b6a6c059e3d440dbac0c0b08c2123fc Mon Sep 17 00:00:00 2001 From: Mike Mahoney Date: Tue, 5 Jul 2022 18:34:28 -0400 Subject: [PATCH 5/9] Support strata --- DESCRIPTION | 3 +-- NAMESPACE | 1 - NEWS.md | 6 ++++++ R/boot.R | 5 ++++- R/compat-vctrs-helpers.R | 2 +- R/mc.R | 7 ++++++- R/misc.R | 5 +++++ R/permutations.R | 3 ++- R/printing.R | 12 ++++++------ R/rsample-package.R | 1 - R/validation_split.R | 7 ++++++- R/vfold.R | 10 +++++++++- tests/testthat/test-misc.R | 4 ++-- 13 files changed, 48 insertions(+), 18 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index d8b20c51..81b6bf08 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -31,8 +31,7 @@ Imports: tibble, tidyr, tidyselect, - vctrs (>= 0.3.0), - withr + vctrs (>= 0.3.0) Suggests: broom, covr, diff --git a/NAMESPACE b/NAMESPACE index d5b47024..1c7de961 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -374,4 +374,3 @@ importFrom(vctrs,vec_count) importFrom(vctrs,vec_in) importFrom(vctrs,vec_slice) importFrom(vctrs,vec_unique_count) -importFrom(withr,with_seed) diff --git a/NEWS.md b/NEWS.md index 8ffb9606..ee3ad220 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,11 @@ # rsample (development version) +* Added `break` and `pool` as attributes to all functions which support stratification (#329). + +* Changed the "strata" attribute on rset objects so that it now is either a character vector identifying the column used to stratify the data, or `FALSE` if stratification was not used. (#329) + +* Added a new function, `reshuffle_rset()`, which takes an `rset` object and generates a new version of it using the same arguments but the current random seed. (#79, #329) + * Added arguments to control how `group_vfold_cv()` combines groups. Use `balance = "groups"` to assign (roughly) the same number of groups to each fold, or `balance = "observations"` to assign (roughly) the same number of observations to each fold. * Added a `repeats` argument to `group_vfold_cv()` (#330). diff --git a/R/boot.R b/R/boot.R index ab3c54af..83b57807 100644 --- a/R/boot.R +++ b/R/boot.R @@ -88,10 +88,13 @@ bootstraps <- split_objs <- bind_rows(split_objs, apparent(data)) } + if (is.null(strata)) strata <- FALSE + names(strata) <- NULL + boot_att <- list( times = times, apparent = apparent, - strata = !is.null(strata), + strata = strata, breaks = breaks, pool = pool ) diff --git a/R/compat-vctrs-helpers.R b/R/compat-vctrs-helpers.R index 3a260376..bbcc8289 100644 --- a/R/compat-vctrs-helpers.R +++ b/R/compat-vctrs-helpers.R @@ -136,7 +136,7 @@ delayedAssign("rset_subclasses", { sliding_window = sliding_window(test_data()), sliding_index = sliding_index(test_data(), index), sliding_period = sliding_period(test_data(), index, "week"), - manual_rset = manual_rset(bootstraps(test_data())$splits[1:2], c("ID1", "ID2")), + manual_rset = manual_rset(list(initial_time_split(test_data()), initial_time_split(test_data())), c("ID1", "ID2")), apparent = apparent(test_data()), permutations = permutations(test_data(), y) ) diff --git a/R/mc.R b/R/mc.R index b49f1b06..645e9337 100644 --- a/R/mc.R +++ b/R/mc.R @@ -73,10 +73,15 @@ mc_cv <- function(data, prop = 3 / 4, times = 25, split_objs$splits <- map(split_objs$splits, rm_out) + if (is.null(strata)) strata <- FALSE + names(strata) <- NULL + mc_att <- list( prop = prop, times = times, - strata = !is.null(strata) + strata = strata, + breaks = breaks, + pool = pool ) new_rset( diff --git a/R/misc.R b/R/misc.R index 078c2b1c..29929cf4 100644 --- a/R/misc.R +++ b/R/misc.R @@ -250,6 +250,11 @@ reshuffle_rset <- function(rset) { useful_arguments <- useful_arguments[!is.na(names(useful_arguments))] if (identical(useful_arguments$strata, FALSE)) { useful_arguments$strata <- NULL + } else if (identical(useful_arguments$strata, TRUE)) { + rlang::abort( + "Cannot reshuffle this rset (`attr(rset, 'strata')` is `TRUE`, not a column identifier)", + i = "If the original object was created with an older version of rsample, try recreating it with the newest version of the package" + ) } do.call( diff --git a/R/permutations.R b/R/permutations.R index 28ad6849..d1723bf2 100644 --- a/R/permutations.R +++ b/R/permutations.R @@ -76,7 +76,8 @@ permutations <- function(data, perm_att <- list( times = times, apparent = apparent, - col_id = col_id + col_id = col_id, + permute = names(col_id) ) new_rset( diff --git a/R/printing.R b/R/printing.R index f9794dff..dbeeb742 100644 --- a/R/printing.R +++ b/R/printing.R @@ -9,7 +9,7 @@ pretty.vfold_cv <- function(x, ...) { if (details$repeats > 1) { res <- paste(res, "repeated", details$repeats, "times") } - if (details$strata) { + if (!identical(details$strata, FALSE)) { res <- paste(res, "using stratification") } res @@ -57,7 +57,7 @@ pretty.mc_cv <- function(x, ...) { details$times, " resamples " ) - if (details$strata) { + if (!identical(details$strata, FALSE)) { res <- paste(res, "using stratification") } res @@ -73,7 +73,7 @@ pretty.validation_split <- function(x, ...) { signif(1 - details$prop, 2), ") " ) - if (details$strata) { + if (!identical(details$strata, FALSE)) { res <- paste(res, "using stratification") } res @@ -89,7 +89,7 @@ pretty.group_validation_split <- function(x, ...) { signif(1 - details$prop, 2), ") " ) - if (details$strata) { + if (!identical(details$strata, FALSE)) { res <- paste(res, "using stratification") } res @@ -124,7 +124,7 @@ pretty.nested_cv <- function(x, ...) { pretty.bootstraps <- function(x, ...) { details <- attributes(x) res <- "Bootstrap sampling" - if (details$strata) { + if (!identical(details$strata, FALSE)) { res <- paste(res, "using stratification") } if (details$apparent) { @@ -137,7 +137,7 @@ pretty.bootstraps <- function(x, ...) { pretty.group_bootstraps <- function(x, ...) { details <- attributes(x) res <- "Group bootstrap sampling" - if (details$strata) { + if (!identical(details$strata, FALSE)) { res <- paste(res, "using stratification") } if (details$apparent) { diff --git a/R/rsample-package.R b/R/rsample-package.R index 39c3e72f..5d9a01df 100644 --- a/R/rsample-package.R +++ b/R/rsample-package.R @@ -21,7 +21,6 @@ NULL #' @importFrom tidyselect vars_select one_of #' @importFrom furrr future_map_dfr #' @importFrom tidyr gather -#' @importFrom withr with_seed #------------------------------------------------------------------------------# diff --git a/R/validation_split.R b/R/validation_split.R index 1e254b49..7afbf34d 100644 --- a/R/validation_split.R +++ b/R/validation_split.R @@ -52,9 +52,14 @@ validation_split <- function(data, prop = 3 / 4, split_objs$splits <- map(split_objs$splits, rm_out) class(split_objs$splits[[1]]) <- c("val_split", "rsplit") + if (is.null(strata)) strata <- FALSE + names(strata) <- NULL + val_att <- list( prop = prop, - strata = !is.null(strata) + strata = strata, + breaks = breaks, + pool = pool ) new_rset( diff --git a/R/vfold.R b/R/vfold.R index 6b3f7a0f..7bc688a7 100644 --- a/R/vfold.R +++ b/R/vfold.R @@ -102,7 +102,15 @@ vfold_cv <- function(data, v = 10, repeats = 1, ## Save some overall information - cv_att <- list(v = v, repeats = repeats, strata = !is.null(strata)) + if (is.null(strata)) strata <- FALSE + names(strata) <- NULL + cv_att <- list( + v = v, + repeats = repeats, + strata = strata, + breaks = breaks, + pool = pool + ) new_rset( splits = split_objs$splits, diff --git a/tests/testthat/test-misc.R b/tests/testthat/test-misc.R index d6634e22..d9c9afbf 100644 --- a/tests/testthat/test-misc.R +++ b/tests/testthat/test-misc.R @@ -37,7 +37,7 @@ test_that("reshuffle_rset is working", { skip_if_not(rlang::is_installed("withr")) supported_subclasses <- rset_subclasses[ - setdiff(names(rset_subclasses), "manual_rset") + setdiff(names(rset_subclasses), c("manual_rset")) ] # Reshuffling with the same seed, in the same order, @@ -71,7 +71,7 @@ test_that("reshuffle_rset is working", { data = test_data(), strata = "y", breaks = 2, - pool = 0.2 + pool = 0.1 ) ) set.seed(123) From 3533894d3d3f8426054448bd9a8c8889f4ca7064 Mon Sep 17 00:00:00 2001 From: Mike Mahoney Date: Tue, 5 Jul 2022 18:48:40 -0400 Subject: [PATCH 6/9] Test a bit more thoroughly --- tests/testthat/_snaps/misc.md | 36 ++++++++++++ tests/testthat/test-misc.R | 106 +++++++++++++++++++++++++++++++++- 2 files changed, 141 insertions(+), 1 deletion(-) diff --git a/tests/testthat/_snaps/misc.md b/tests/testthat/_snaps/misc.md index f0b5fb49..98e5fb30 100644 --- a/tests/testthat/_snaps/misc.md +++ b/tests/testthat/_snaps/misc.md @@ -22,3 +22,39 @@ Error in `reverse_splits()`: ! Permutations cannot have their splits reversed +--- + + Code + reverse_splits(1) + Condition + Error in `reverse_splits()`: + ! `x` must be either an `rsplit` or an `rset` object + +--- + + Code + reverse_splits(permutes) + Condition + Error in `reverse_splits()`: + ! Permutations cannot have their splits reversed + +--- + + Code + reverse_splits(permutes$splits[[1]]) + Condition + Error in `reverse_splits()`: + ! Permutations cannot have their splits reversed + +# reshuffle_rset is working + + Cannot reshuffle this rset (`attr(rset, 'strata')` is `TRUE`, not a column identifier) + +--- + + `manual_rset` objects cannot be reshuffled + +--- + + `rset` must be an rset object + diff --git a/tests/testthat/test-misc.R b/tests/testthat/test-misc.R index d9c9afbf..2943d698 100644 --- a/tests/testthat/test-misc.R +++ b/tests/testthat/test-misc.R @@ -57,12 +57,18 @@ test_that("reshuffle_rset is working", { ) } + # Check to make sure that stratification, + # with non-default arguments, + # is supported by reshuffled_resample + + # Select any function in rset_subclasses with a strata argument supports_strata <- purrr::map_lgl( names(supported_subclasses), ~ any(names(formals(.x)) == "strata") ) supports_strata <- names(supported_subclasses)[supports_strata] supports_strata <- supported_subclasses[supports_strata] + for (i in seq_along(supports_strata)) { set.seed(123) resample <- do.call( @@ -71,7 +77,7 @@ test_that("reshuffle_rset is working", { data = test_data(), strata = "y", breaks = 2, - pool = 0.1 + pool = 0.2 ) ) set.seed(123) @@ -79,3 +85,101 @@ test_that("reshuffle_rset is working", { expect_identical(resample, reshuffled_resample) } }) +test_that("reverse_splits is working", { + skip_if_not(rlang::is_installed("withr")) + + reversable_subclasses <- setdiff(names(rset_subclasses), "permutations") + reversable_subclasses <- rset_subclasses[reversable_subclasses] + for (x in reversable_subclasses) { + + set.seed(123) + rev_x <- reverse_splits(x) + expect_identical(analysis(x$splits[[1]]), assessment(rev_x$splits[[1]])) + expect_identical(assessment(x$splits[[1]]), analysis(rev_x$splits[[1]])) + expect_identical(class(x), class(rev_x)) + expect_identical(class(x$splits[[1]]), class(rev_x$splits[[1]])) + + } + + expect_snapshot( + reverse_splits(1), + error = TRUE + ) + + permutes <- permutations(mtcars, cyl) + + expect_snapshot( + reverse_splits(permutes), + error = TRUE + ) + + expect_snapshot( + reverse_splits(permutes$splits[[1]]), + error = TRUE + ) + +}) + +test_that("reshuffle_rset is working", { + + skip_if_not(rlang::is_installed("withr")) + supported_subclasses <- rset_subclasses[ + setdiff(names(rset_subclasses), c("manual_rset")) + ] + + # Reshuffling with the same seed, in the same order, + # should recreate the same objects + out <- withr::with_seed( + 123, + lapply( + supported_subclasses, + reshuffle_rset + ) + ) + + for (i in seq_along(supported_subclasses)) { + expect_identical( + out[[i]], + supported_subclasses[[i]] + ) + } + + # Check to make sure that stratification, + # with non-default arguments, + # is supported by reshuffled_resample + + # Select any function in rset_subclasses with a strata argument: + supports_strata <- purrr::map_lgl( + names(supported_subclasses), + ~ any(names(formals(.x)) == "strata") + ) + supports_strata <- names(supported_subclasses)[supports_strata] + + for (i in seq_along(supports_strata)) { + # Fit those functions with non-default arguments: + set.seed(123) + resample <- do.call( + supports_strata[i], + list( + data = test_data(), + strata = "y", + breaks = 2, + pool = 0.2 + ) + ) + # Reshuffle them under the same seed to ensure they're identical + set.seed(123) + reshuffled_resample <- reshuffle_rset(resample) + expect_identical(resample, reshuffled_resample) + } + + resample <- vfold_cv(mtcars, strata = cyl) + attr(resample, "strata") <- TRUE + + expect_snapshot_error(reshuffle_rset(resample)) + + expect_snapshot_error(reshuffle_rset(rset_subclasses[["manual_rset"]])) + + expect_snapshot_error(reshuffle_rset(rset_subclasses[["manual_rset"]]$splits[[1]])) + +}) From de9ccae6e4cd01e45f5e57cb8992c058cd3a97d2 Mon Sep 17 00:00:00 2001 From: Mike Mahoney Date: Wed, 6 Jul 2022 09:44:54 -0400 Subject: [PATCH 7/9] Update NEWS --- NEWS.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/NEWS.md b/NEWS.md index ee3ad220..53eab459 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,12 @@ # rsample (development version) +* rset objects should now always have all parameters used to create them as +attributes (#329). + +* Objects returned by sliding functions now have an `index` attribute, where appropriate, containing the column name used as an index (#329). + +* Objects returned by `permutations()` now have a `permutes` attribute containing the column name used for permutation (#329). + * Added `break` and `pool` as attributes to all functions which support stratification (#329). * Changed the "strata" attribute on rset objects so that it now is either a character vector identifying the column used to stratify the data, or `FALSE` if stratification was not used. (#329) From 996afc15a9317bee17650e64cb18057f9afc15dc Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Wed, 6 Jul 2022 17:40:43 -0600 Subject: [PATCH 8/9] Tiny tiny edits --- NEWS.md | 5 ++--- R/misc.R | 2 +- man/reshuffle_rset.Rd | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/NEWS.md b/NEWS.md index 53eab459..fbe7211c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,7 +1,6 @@ # rsample (development version) -* rset objects should now always have all parameters used to create them as -attributes (#329). +* rset objects now include all parameters used to create them as attributes (#329). * Objects returned by sliding functions now have an `index` attribute, where appropriate, containing the column name used as an index (#329). @@ -128,7 +127,7 @@ attributes (#329). # `rsample` 0.0.4 -Small maintenence release. +Small maintenance release. ## Minor improvements and fixes diff --git a/R/misc.R b/R/misc.R index 29929cf4..a77be434 100644 --- a/R/misc.R +++ b/R/misc.R @@ -222,7 +222,7 @@ reverse_splits.rset <- function(x, ...) { #' "Reshuffle" an rset to re-generate a new rset with the same parameters #' -#' This function re-generates an rset object, using the same arguments as used +#' This function re-generates an rset object, using the same arguments used #' to generate the original. #' #' @param rset The `rset` object to be reshuffled diff --git a/man/reshuffle_rset.Rd b/man/reshuffle_rset.Rd index 75cb2aa0..b3ebc6d6 100644 --- a/man/reshuffle_rset.Rd +++ b/man/reshuffle_rset.Rd @@ -13,7 +13,7 @@ reshuffle_rset(rset) An rset of the same class as \code{rset}. } \description{ -This function re-generates an rset object, using the same arguments as used +This function re-generates an rset object, using the same arguments used to generate the original. } \examples{ From c26dffc5a78e321b04dfb0668df9ed73b92b3191 Mon Sep 17 00:00:00 2001 From: Mike Mahoney Date: Thu, 7 Jul 2022 07:24:36 -0400 Subject: [PATCH 9/9] Set strata to NULL if missing --- .gitignore | 1 + NEWS.md | 2 +- R/boot.R | 4 +--- R/mc.R | 4 +--- R/printing.R | 16 ++++++++++------ R/validation_split.R | 4 +--- R/vfold.R | 3 +-- 7 files changed, 16 insertions(+), 18 deletions(-) diff --git a/.gitignore b/.gitignore index 4d06fed3..f89762e0 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ revdep/checks.noindex revdep/library.noindex revdep/cloud.noindex docs +inst/doc diff --git a/NEWS.md b/NEWS.md index 53eab459..ee00c5a3 100644 --- a/NEWS.md +++ b/NEWS.md @@ -9,7 +9,7 @@ attributes (#329). * Added `break` and `pool` as attributes to all functions which support stratification (#329). -* Changed the "strata" attribute on rset objects so that it now is either a character vector identifying the column used to stratify the data, or `FALSE` if stratification was not used. (#329) +* Changed the "strata" attribute on rset objects so that it now is either a character vector identifying the column used to stratify the data, and is not present (set to `NULL`) if stratification was not used. (#329) * Added a new function, `reshuffle_rset()`, which takes an `rset` object and generates a new version of it using the same arguments but the current random seed. (#79, #329) diff --git a/R/boot.R b/R/boot.R index 83b57807..46256d0a 100644 --- a/R/boot.R +++ b/R/boot.R @@ -88,9 +88,7 @@ bootstraps <- split_objs <- bind_rows(split_objs, apparent(data)) } - if (is.null(strata)) strata <- FALSE - names(strata) <- NULL - + if (!is.null(strata)) names(strata) <- NULL boot_att <- list( times = times, apparent = apparent, diff --git a/R/mc.R b/R/mc.R index 645e9337..477776ed 100644 --- a/R/mc.R +++ b/R/mc.R @@ -73,9 +73,7 @@ mc_cv <- function(data, prop = 3 / 4, times = 25, split_objs$splits <- map(split_objs$splits, rm_out) - if (is.null(strata)) strata <- FALSE - names(strata) <- NULL - + if (!is.null(strata)) names(strata) <- NULL mc_att <- list( prop = prop, times = times, diff --git a/R/printing.R b/R/printing.R index dbeeb742..ac34286e 100644 --- a/R/printing.R +++ b/R/printing.R @@ -9,7 +9,7 @@ pretty.vfold_cv <- function(x, ...) { if (details$repeats > 1) { res <- paste(res, "repeated", details$repeats, "times") } - if (!identical(details$strata, FALSE)) { + if (has_strata(details)) { res <- paste(res, "using stratification") } res @@ -57,7 +57,7 @@ pretty.mc_cv <- function(x, ...) { details$times, " resamples " ) - if (!identical(details$strata, FALSE)) { + if (has_strata(details)) { res <- paste(res, "using stratification") } res @@ -73,7 +73,7 @@ pretty.validation_split <- function(x, ...) { signif(1 - details$prop, 2), ") " ) - if (!identical(details$strata, FALSE)) { + if (has_strata(details)) { res <- paste(res, "using stratification") } res @@ -89,7 +89,7 @@ pretty.group_validation_split <- function(x, ...) { signif(1 - details$prop, 2), ") " ) - if (!identical(details$strata, FALSE)) { + if (has_strata(details)) { res <- paste(res, "using stratification") } res @@ -124,7 +124,7 @@ pretty.nested_cv <- function(x, ...) { pretty.bootstraps <- function(x, ...) { details <- attributes(x) res <- "Bootstrap sampling" - if (!identical(details$strata, FALSE)) { + if (has_strata(details)) { res <- paste(res, "using stratification") } if (details$apparent) { @@ -137,7 +137,7 @@ pretty.bootstraps <- function(x, ...) { pretty.group_bootstraps <- function(x, ...) { details <- attributes(x) res <- "Group bootstrap sampling" - if (!identical(details$strata, FALSE)) { + if (has_strata(details)) { res <- paste(res, "using stratification") } if (details$apparent) { @@ -356,3 +356,7 @@ print.vfold_cv <- function(x, ...) { class(x) <- class(x)[!(class(x) %in% c("vfold_cv", "rset"))] print(x, ...) } + +has_strata <- function(x) { + !is.null(x$strata) && !identical(x$strata, FALSE) +} diff --git a/R/validation_split.R b/R/validation_split.R index 7afbf34d..7fd70636 100644 --- a/R/validation_split.R +++ b/R/validation_split.R @@ -52,9 +52,7 @@ validation_split <- function(data, prop = 3 / 4, split_objs$splits <- map(split_objs$splits, rm_out) class(split_objs$splits[[1]]) <- c("val_split", "rsplit") - if (is.null(strata)) strata <- FALSE - names(strata) <- NULL - + if (!is.null(strata)) names(strata) <- NULL val_att <- list( prop = prop, strata = strata, diff --git a/R/vfold.R b/R/vfold.R index 7bc688a7..6a5602c8 100644 --- a/R/vfold.R +++ b/R/vfold.R @@ -102,8 +102,7 @@ vfold_cv <- function(data, v = 10, repeats = 1, ## Save some overall information - if (is.null(strata)) strata <- FALSE - names(strata) <- NULL + if (!is.null(strata)) names(strata) <- NULL cv_att <- list( v = v, repeats = repeats,