Skip to content

Commit

Permalink
Merge pull request #214 from tidymodels/fingerprint-attribute
Browse files Browse the repository at this point in the history
Add fingerprinting hash to rset attributes
  • Loading branch information
topepo authored Feb 1, 2021
2 parents 68b6c57 + 8b20a09 commit ae94bbc
Show file tree
Hide file tree
Showing 10 changed files with 96 additions and 99 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: rsample
Title: General Resampling Infrastructure
Version: 0.0.8.9000
Version: 0.0.8.9001
Authors@R: c(
person(given = "Max", family = "Kuhn", email = "max@rstudio.com", role = c("aut", "cre")),
person(given = "Fanny", family = "Chow", email = "fannybchow@gmail.com", role = c("aut")),
Expand Down
4 changes: 3 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

S3method("[",rset)
S3method("names<-",rset)
S3method(.get_fingerprint,default)
S3method(.get_fingerprint,rset)
S3method(as.data.frame,rsplit)
S3method(as.integer,rsplit)
S3method(complement,apparent_split)
Expand Down Expand Up @@ -202,6 +204,7 @@ S3method(vec_restore,sliding_period)
S3method(vec_restore,sliding_window)
S3method(vec_restore,validation_split)
S3method(vec_restore,vfold_cv)
export(.get_fingerprint)
export(add_resample_id)
export(all_of)
export(analysis)
Expand All @@ -215,7 +218,6 @@ export(contains)
export(default_complement)
export(ends_with)
export(everything)
export(fingerprint)
export(form_pred)
export(gather)
export(gather.rset)
Expand Down
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

* Fixed an issue where empty assessment sets couldn't be created by `make_splits()` (#188).

* A `fingerprint()` function was added to create a hash value that can be used to compare `rset` objects.
* `rset` objects now contain a "fingerprint" attribute that can be used to check to see if the same object uses the same resamples.

* The `reg_intervals()` function is a convenience function for `lm()`, `glm()`, `survreg()`, and `coxph()` models.

Expand Down
70 changes: 34 additions & 36 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,53 +58,51 @@ split_unnamed <- function(x, f) {
unname(out)
}


## -----------------------------------------------------------------------------

#' Create a cryptographical hash value for `rset` objects.
#'
#' This function uses the distinct rows in the data set and the column(s) for the
#' resample identifier and the splits to produce a character string that can be
#' used to determine if another object shares the same splits.
#' Obtain a identifier for the resamples
#'
#' The comparison is based on the unique contents of the `id` and `splits`
#' columns. Attributes are not used in the comparison.
#' @param x An `rset` object.
#' This function returns a hash (or NA) for an attribute that is created when
#' the `rset` was initially constructed. This can be used to compare with other
#' resampling objects to see if they are the same.
#' @param x An `rset` or `tune_results` object.
#' @param ... Not currently used.
#' @return A character string.
#' @return A character value or `NA_character_` if the object was created prior
#' to `rsample` version 0.1.0.
#' @rdname get_fingerprint
#' @aliases .get_fingerprint
#' @examples
#' set.seed(1)
#' fingerprint(vfold_cv(mtcars))
#' .get_fingerprint(vfold_cv(mtcars))
#'
#' set.seed(1)
#' fingerprint(vfold_cv(mtcars))
#' .get_fingerprint(vfold_cv(mtcars))
#'
#' set.seed(2)
#' fingerprint(vfold_cv(mtcars))
#' .get_fingerprint(vfold_cv(mtcars))
#'
#' set.seed(1)
#' fingerprint(vfold_cv(mtcars, repeats = 2))
#' .get_fingerprint(vfold_cv(mtcars, repeats = 2))
#' @export
fingerprint <- function(x, ...) {
# For iterative models, the splits are replicated multiple times. Get the
# unique id values and has those rows
is_id_var <- col_starts_with_id(names(x))
id_vars <- names(x)[is_id_var]
if (length(id_vars) == 0) {
rlang::abort("No ID columns were found.")
}
if (!any(names(x) == "splits")) {
rlang::abort("The 'split' column was not found.")
}
.get_fingerprint <- function(x, ...) {
UseMethod(".get_fingerprint")
}

x <-
dplyr::select(x, splits, dplyr::all_of(id_vars)) %>%
dplyr::distinct() %>%
dplyr::arrange(!!!id_vars) %>%
tibble::as_tibble()
attrib <- attributes(x)
attrib <- attrib[names(attrib) %in% c("row.names", "names", "class")]
attributes(x) <- attrib
rlang::hash(x)
#' @export
#' @rdname get_fingerprint
.get_fingerprint.default <- function(x, ...) {
cls <- paste("'", class(x), "'", sep = ", ")
rlang::abort(
paste("No `.get_fingerprint()` method for this class(es)", cls)
)
}

#' @export
#' @rdname get_fingerprint
.get_fingerprint.rset <- function(x, ...) {
att <- attributes(x)
if (any(names(att) == "fingerprint")) {
res <- att$fingerprint
} else {
res <- NA_character_
}
res
}
6 changes: 6 additions & 0 deletions R/rset.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
#' @param attrib An optional named list of attributes to add to the object.
#' @param subclass A character vector of subclasses to add.
#' @return An `rset` object.
#' @details Once the new `rset` is constructed, an additional attribute called
#' "fingerprint" is added that is a hash of the `rset`. This can be used to
#' make sure other objects have the exact same resamples.
#' @keywords internal
#' @export
new_rset <- function(splits, ids, attrib = NULL,
Expand Down Expand Up @@ -71,6 +74,9 @@ new_rset <- function(splits, ids, attrib = NULL,
res <- add_class(res, cls = subclass, at_end = FALSE)
}

fingerprint <- rlang::hash(res)
attr(res, "fingerprint") <- fingerprint

res
}

Expand Down
38 changes: 0 additions & 38 deletions man/fingerprint.Rd

This file was deleted.

41 changes: 41 additions & 0 deletions man/get_fingerprint.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions man/new_rset.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 3 additions & 20 deletions tests/testthat/test_fingerprint.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,17 @@
test_that("fingerprinting", {
set.seed(1)
rs_1 <- vfold_cv(mtcars)
fp_1 <- fingerprint(rs_1)
fp_1 <- .get_fingerprint(rs_1)

set.seed(1)
fp_2 <- fingerprint(vfold_cv(mtcars))
fp_2 <- .get_fingerprint(vfold_cv(mtcars))

set.seed(1)
fp_3 <- fingerprint(vfold_cv(mtcars, repeats = 2))
fp_3 <- .get_fingerprint(vfold_cv(mtcars, repeats = 2))

expect_true(class(fp_1) == "character")
expect_true(class(fp_2) == "character")
expect_true(class(fp_3) == "character")
expect_equal(fp_1, fp_2)
expect_false(fp_1 == fp_3)

expect_error(
fingerprint(vfold_cv(mtcars) %>% dplyr::select(-id)),
"No ID columns were found"
)
expect_error(
fingerprint(vfold_cv(mtcars) %>% dplyr::select(-splits)),
"The 'split' column was not found"
)

# test cases where the rows of the rset are expaned (e.g. in tune_bayes())
set.seed(1)
rs_2 <- vfold_cv(mtcars)
rs_3 <- rs_2[rep(1:10, 3), ]
fp_4 <- fingerprint(rs_3)
expect_equal(fp_1, fp_4)

})
4 changes: 2 additions & 2 deletions tests/testthat/test_rset.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
context("Rset constructor")
context("rset constructor")

library(testthat)
library(rsample)
Expand Down Expand Up @@ -37,7 +37,7 @@ test_that('rset with attributes', {
attrib = args
)
expect_equal(sort(names(attributes(res3))),
c("class", "names", "row.names", "value"))
c("class", "fingerprint", "names", "row.names", "value"))
expect_equal(attr(res3, "value"), "potato")
})

Expand Down

0 comments on commit ae94bbc

Please sign in to comment.