Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reshuffle_rset #329

Merged
merged 13 commits into from
Jul 7, 2022
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ export(num_range)
export(permutations)
export(populate)
export(reg_intervals)
export(reshuffle_rset)
export(reverse_splits)
export(rolling_origin)
export(rsample2caret)
Expand Down
13 changes: 13 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
# rsample (development version)

* rset objects should now always have all parameters used to create them as
attributes (#329).

* Objects returned by sliding functions now have an `index` attribute, where appropriate, containing the column name used as an index (#329).

* Objects returned by `permutations()` now have a `permutes` attribute containing the column name used for permutation (#329).

* Added `break` and `pool` as attributes to all functions which support stratification (#329).

* Changed the "strata" attribute on rset objects so that it now is either a character vector identifying the column used to stratify the data, or `FALSE` if stratification was not used. (#329)

* Added a new function, `reshuffle_rset()`, which takes an `rset` object and generates a new version of it using the same arguments but the current random seed. (#79, #329)

* Added arguments to control how `group_vfold_cv()` combines groups. Use `balance = "groups"` to assign (roughly) the same number of groups to each fold, or `balance = "observations"` to assign (roughly) the same number of observations to each fold.

* Added a `repeats` argument to `group_vfold_cv()` (#330).
Expand Down
7 changes: 6 additions & 1 deletion R/boot.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,15 @@ bootstraps <-
split_objs <- bind_rows(split_objs, apparent(data))
}

if (is.null(strata)) strata <- FALSE
names(strata) <- NULL

boot_att <- list(
times = times,
apparent = apparent,
strata = !is.null(strata)
strata = strata,
breaks = breaks,
pool = pool
)

new_rset(
Expand Down
2 changes: 1 addition & 1 deletion R/compat-vctrs-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ delayedAssign("rset_subclasses", {
sliding_window = sliding_window(test_data()),
sliding_index = sliding_index(test_data(), index),
sliding_period = sliding_period(test_data(), index, "week"),
manual_rset = manual_rset(bootstraps(test_data())$splits[1:2], c("ID1", "ID2")),
manual_rset = manual_rset(list(initial_time_split(test_data()), initial_time_split(test_data())), c("ID1", "ID2")),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using bootstraps here means that manual_rset "spent" randomness, which had knock on effects in testing (because we exclude manual_rset from things like reshuffling, but then don't have the same seed active by the time we're rebuilding permutations()). Changing this to initial_time_split() avoids the issue and didn't need any changes in testing.

apparent = apparent(test_data()),
permutations = permutations(test_data(), y)
)
Expand Down
7 changes: 6 additions & 1 deletion R/mc.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,15 @@ mc_cv <- function(data, prop = 3 / 4, times = 25,

split_objs$splits <- map(split_objs$splits, rm_out)

if (is.null(strata)) strata <- FALSE
names(strata) <- NULL

mc_att <- list(
prop = prop,
times = times,
strata = !is.null(strata)
strata = strata,
breaks = breaks,
pool = pool
)

new_rset(
Expand Down
43 changes: 43 additions & 0 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,46 @@ reverse_splits.rset <- function(x, ...) {

x
}

#' "Reshuffle" an rset to re-generate a new rset with the same parameters
#'
#' This function re-generates an rset object, using the same arguments as used
#' to generate the original.
#'
#' @param rset The `rset` object to be reshuffled
#'
#' @return An rset of the same class as `rset`.
#'
#' @examples
#' set.seed(123)
#' (starting_splits <- group_vfold_cv(mtcars, cyl, v = 3))
#' reshuffle_rset(starting_splits)
#'
#' @export
reshuffle_rset <- function(rset) {
if (!inherits(rset, "rset")) {
rlang::abort("`rset` must be an rset object")
}

if (inherits(rset, "manual_rset")) {
rlang::abort("`manual_rset` objects cannot be reshuffled")
}

arguments <- attributes(rset)
useful_arguments <- names(formals(arguments$class[[1]]))
useful_arguments <- arguments[useful_arguments]
useful_arguments <- useful_arguments[!is.na(names(useful_arguments))]
if (identical(useful_arguments$strata, FALSE)) {
useful_arguments$strata <- NULL
} else if (identical(useful_arguments$strata, TRUE)) {
rlang::abort(
"Cannot reshuffle this rset (`attr(rset, 'strata')` is `TRUE`, not a column identifier)",
i = "If the original object was created with an older version of rsample, try recreating it with the newest version of the package"
)
}

do.call(
arguments$class[[1]],
c(list(data = rset$splits[[1]]$data), useful_arguments)
)
}
3 changes: 2 additions & 1 deletion R/permutations.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ permutations <- function(data,
perm_att <- list(
times = times,
apparent = apparent,
col_id = col_id
col_id = col_id,
permute = names(col_id)
)

new_rset(
Expand Down
12 changes: 6 additions & 6 deletions R/printing.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pretty.vfold_cv <- function(x, ...) {
if (details$repeats > 1) {
res <- paste(res, "repeated", details$repeats, "times")
}
if (details$strata) {
if (!identical(details$strata, FALSE)) {
res <- paste(res, "using stratification")
}
res
Expand Down Expand Up @@ -57,7 +57,7 @@ pretty.mc_cv <- function(x, ...) {
details$times,
" resamples "
)
if (details$strata) {
if (!identical(details$strata, FALSE)) {
res <- paste(res, "using stratification")
}
res
Expand All @@ -73,7 +73,7 @@ pretty.validation_split <- function(x, ...) {
signif(1 - details$prop, 2),
") "
)
if (details$strata) {
if (!identical(details$strata, FALSE)) {
res <- paste(res, "using stratification")
}
res
Expand All @@ -89,7 +89,7 @@ pretty.group_validation_split <- function(x, ...) {
signif(1 - details$prop, 2),
") "
)
if (details$strata) {
if (!identical(details$strata, FALSE)) {
res <- paste(res, "using stratification")
}
res
Expand Down Expand Up @@ -124,7 +124,7 @@ pretty.nested_cv <- function(x, ...) {
pretty.bootstraps <- function(x, ...) {
details <- attributes(x)
res <- "Bootstrap sampling"
if (details$strata) {
if (!identical(details$strata, FALSE)) {
res <- paste(res, "using stratification")
}
if (details$apparent) {
Expand All @@ -137,7 +137,7 @@ pretty.bootstraps <- function(x, ...) {
pretty.group_bootstraps <- function(x, ...) {
details <- attributes(x)
res <- "Group bootstrap sampling"
if (details$strata) {
if (!identical(details$strata, FALSE)) {
res <- paste(res, "using stratification")
}
if (details$apparent) {
Expand Down
4 changes: 4 additions & 0 deletions R/slide.R
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ sliding_index <- function(data,
rlang::abort("`index` must specify exactly one column in `data`.")
}

index_attrib <- index
index <- data[[loc]]

seq <- vctrs::vec_seq_along(data)
Expand Down Expand Up @@ -352,6 +353,7 @@ sliding_index <- function(data,
ids <- names0(length(indices), prefix = "Slice")

attrib <- list(
index = index_attrib,
lookback = lookback,
assess_start = assess_start,
assess_stop = assess_stop,
Expand Down Expand Up @@ -406,6 +408,7 @@ sliding_period <- function(data,
rlang::abort("`index` must specify exactly one column in `data`.")
}

index_attrib <- index
index <- data[[loc]]

seq <- vctrs::vec_seq_along(data)
Expand Down Expand Up @@ -452,6 +455,7 @@ sliding_period <- function(data,
ids <- names0(length(indices), prefix = "Slice")

attrib <- list(
index = index_attrib,
period = period,
lookback = lookback,
assess_start = assess_start,
Expand Down
7 changes: 6 additions & 1 deletion R/validation_split.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,14 @@ validation_split <- function(data, prop = 3 / 4,
split_objs$splits <- map(split_objs$splits, rm_out)
class(split_objs$splits[[1]]) <- c("val_split", "rsplit")

if (is.null(strata)) strata <- FALSE
names(strata) <- NULL

val_att <- list(
prop = prop,
strata = !is.null(strata)
strata = strata,
breaks = breaks,
pool = pool
)

new_rset(
Expand Down
10 changes: 9 additions & 1 deletion R/vfold.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,15 @@ vfold_cv <- function(data, v = 10, repeats = 1,

## Save some overall information

cv_att <- list(v = v, repeats = repeats, strata = !is.null(strata))
if (is.null(strata)) strata <- FALSE
names(strata) <- NULL
cv_att <- list(
v = v,
repeats = repeats,
strata = strata,
breaks = breaks,
pool = pool
)

new_rset(
splits = split_objs$splits,
Expand Down
1 change: 1 addition & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ reference:
- make_splits
- make_strata
- populate
- reshuffle_rset
- reverse_splits
- rsample2caret
- rset_reconstruct
Expand Down
24 changes: 24 additions & 0 deletions man/reshuffle_rset.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 36 additions & 0 deletions tests/testthat/_snaps/misc.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,39 @@
Error in `reverse_splits()`:
! Permutations cannot have their splits reversed

---

Code
reverse_splits(1)
Condition
Error in `reverse_splits()`:
! `x` must be either an `rsplit` or an `rset` object

---

Code
reverse_splits(permutes)
Condition
Error in `reverse_splits()`:
! Permutations cannot have their splits reversed

---

Code
reverse_splits(permutes$splits[[1]])
Condition
Error in `reverse_splits()`:
! Permutations cannot have their splits reversed

# reshuffle_rset is working

Cannot reshuffle this rset (`attr(rset, 'strata')` is `TRUE`, not a column identifier)

---

`manual_rset` objects cannot be reshuffled

---

`rset` must be an rset object

Loading