From b6d4523dc1291bb3f8238d7d85a464cf00a12782 Mon Sep 17 00:00:00 2001 From: wlandau Date: Thu, 12 Sep 2019 22:32:31 -0400 Subject: [PATCH] Fix #1008 --- NEWS.md | 1 + R/transform_plan.R | 18 ++++- R/utils.R | 4 + tests/testthat/test-dsl.R | 159 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 181 insertions(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index cb25db69d..251d45999 100644 --- a/NEWS.md +++ b/NEWS.md @@ -23,6 +23,7 @@ - Remove the now-superfluous vignette. - Wrap up console and text file logging functionality into a reference class (#964). - Deprecate the `verbose` argument in various caching functions. The location of the cache is now only printed in `make()`. This made the previous feature easier to implement. +- Carry forward nested grouping variables in `combine()` (#1008). # Version 7.6.1 diff --git a/R/transform_plan.R b/R/transform_plan.R index e766e8c2f..d790a416f 100644 --- a/R/transform_plan.R +++ b/R/transform_plan.R @@ -633,7 +633,7 @@ map_by <- function(.x, .by, .f, ...) { out } ) - do.call(what = rbind, args = out) + do.call(what = drake_bind_rows, args = out) } split_by <- function(.x, .by = character(0)) { @@ -735,6 +735,10 @@ combine_step <- function(plan, row, transform, old_cols) { out[[col]] <- row[[col]] } } + groupings <- combine_tagalongs(plan, transform, old_cols) + if (nrow(groupings) == 1L) { + out <- cbind(out, groupings) + } out } @@ -794,6 +798,18 @@ splice_inner <- function(x, replacements) { } } +combine_tagalongs <- function(plan, transform, old_cols) { + combined_plan <- plan[, dsl_combine(transform), drop = FALSE] + out <- plan[complete.cases(combined_plan),, drop = FALSE] # nolint + drop <- c(old_cols, dsl_combine(transform), dsl_by(transform)) + keep <- setdiff(colnames(out), drop) + out <- out[, keep, drop = FALSE] + keep <- !vapply(out, anyNA, FUN.VALUE = logical(1)) + out <- out[, keep, drop = FALSE] + keep <- vapply(out, num_unique, FUN.VALUE = integer(1)) == 1L + utils::head(out[, keep, drop = FALSE], n = 1) +} + dsl_deps <- function(transform) UseMethod("dsl_deps") dsl_deps.map <- function(transform) { diff --git a/R/utils.R b/R/utils.R index 27b45fb77..cb49f7ab0 100644 --- a/R/utils.R +++ b/R/utils.R @@ -132,6 +132,10 @@ longest_match <- function(choices, against) { matches[which.max(nchar(matches))] } +num_unique <- function(x) { + length(unique(x)) +} + random_string <- function(exclude) { key <- NULL while (is.null(key) || (key %in% exclude)) { diff --git a/tests/testthat/test-dsl.R b/tests/testthat/test-dsl.R index 828597925..a12c1934a 100644 --- a/tests/testthat/test-dsl.R +++ b/tests/testthat/test-dsl.R @@ -1302,10 +1302,14 @@ test_with_dir("trace has correct provenance", { ), i = target( command = list(e_c_b_a_1_3, e_c_b_a_1_3_2), + x = "1", + y = "3", i = "i" ), j = target( command = list(f_c_b_a_1_3, f_c_b_a_1_3_2), + x = "1", + y = "3", j = "j" ) ) @@ -2858,3 +2862,158 @@ test_with_dir("eliminate partial tagalong grouping vars (#1009)", { ) equivalent_plans(out, exp) }) + +test_with_dir("keep nested grouping vars in combine() (#1008)", { + out <- drake_plan( + i = target(p, transform = map(p = !!(1:2))), + a = target(x * i, transform = cross(i, x = !!(1:2))), + b = target(a * y, transform = cross(a, y = !!(1:2), .id = c(p, x))), + d = target(c(b), transform = combine(b, .by = c(a))), + trace = TRUE + ) + exp <- drake_plan( + i_1L = target( + command = 1L, + p = "1L", + i = "i_1L" + ), + i_2L = target( + command = 2L, + p = "2L", + i = "i_2L" + ), + a_1L_i_1L = target( + command = 1L * i_1L, + p = "1L", + i = "i_1L", + x = "1L", + a = "a_1L_i_1L" + ), + a_2L_i_1L = target( + command = 2L * i_1L, + p = "1L", + i = "i_1L", + x = "2L", + a = "a_2L_i_1L" + ), + a_1L_i_2L = target( + command = 1L * i_2L, + p = "2L", + i = "i_2L", + x = "1L", + a = "a_1L_i_2L" + ), + a_2L_i_2L = target( + command = 2L * i_2L, + p = "2L", + i = "i_2L", + x = "2L", + a = "a_2L_i_2L" + ), + b_1L_1L = target( + command = a_1L_i_1L * 1L, + p = "1L", + i = "i_1L", + x = "1L", + a = "a_1L_i_1L", + y = "1L", + b = "b_1L_1L" + ), + b_1L_1L_2 = target( + command = a_1L_i_1L * 2L, + p = "1L", + i = "i_1L", + x = "1L", + a = "a_1L_i_1L", + y = "2L", + b = "b_1L_1L_2" + ), + b_1L_2L = target( + command = a_2L_i_1L * 1L, + p = "1L", + i = "i_1L", + x = "2L", + a = "a_2L_i_1L", + y = "1L", + b = "b_1L_2L" + ), + b_1L_2L_2 = target( + command = a_2L_i_1L * 2L, + p = "1L", + i = "i_1L", + x = "2L", + a = "a_2L_i_1L", + y = "2L", + b = "b_1L_2L_2" + ), + b_2L_1L = target( + command = a_1L_i_2L * 1L, + p = "2L", + i = "i_2L", + x = "1L", + a = "a_1L_i_2L", + y = "1L", + b = "b_2L_1L" + ), + b_2L_1L_2 = target( + command = a_1L_i_2L * 2L, + p = "2L", + i = "i_2L", + x = "1L", + a = "a_1L_i_2L", + y = "2L", + b = "b_2L_1L_2" + ), + b_2L_2L = target( + command = a_2L_i_2L * 1L, + p = "2L", + i = "i_2L", + x = "2L", + a = "a_2L_i_2L", + y = "1L", + b = "b_2L_2L" + ), + b_2L_2L_2 = target( + command = a_2L_i_2L * 2L, + p = "2L", + i = "i_2L", + x = "2L", + a = "a_2L_i_2L", + y = "2L", + b = "b_2L_2L_2" + ), + d_a_1L_i_1L = target( + command = c(b_1L_1L, b_1L_1L_2), + p = "1L", + i = "i_1L", + x = "1L", + a = "a_1L_i_1L", + d = "d_a_1L_i_1L" + ), + d_a_1L_i_2L = target( + command = c(b_2L_1L, b_2L_1L_2), + p = "2L", + i = "i_2L", + x = "1L", + a = "a_1L_i_2L", + d = "d_a_1L_i_2L" + ), + d_a_2L_i_1L = target( + command = c(b_1L_2L, b_1L_2L_2), + p = "1L", + i = "i_1L", + x = "2L", + a = "a_2L_i_1L", + d = "d_a_2L_i_1L" + ), + d_a_2L_i_2L = target( + command = c(b_2L_2L, b_2L_2L_2), + p = "2L", + i = "i_2L", + x = "2L", + a = "a_2L_i_2L", + d = "d_a_2L_i_2L" + ) + ) + equivalent_plans(out, exp) +})