Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement n_distinct() for multiple arguments using duckdb structs #122

Merged
merged 7 commits into from
Apr 29, 2024
24 changes: 22 additions & 2 deletions R/backend-dbplyr__duckdb_connection.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ duckdb_grepl <- function(pattern, x, ignore.case = FALSE, perl = FALSE, fixed =
}
}

duckdb_n_distinct <- function(..., na.rm = FALSE) {
sql <- pkg_method("sql", "dbplyr")

if (!identical(na.rm, FALSE)) {
stop("Parameter `na.rm = TRUE` in n_distinct() is currently not supported in DuckDB backend.", call. = FALSE)
}

# https://duckdb.org/docs/sql/data_types/struct.html#creating-structs-with-the-row-function
str_struct <- paste0("row(", paste0(list(...), collapse = ", "), ")")

sql(paste0("COUNT(DISTINCT ", str_struct, ")"))
Comment on lines +88 to +90
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgirlich: Can you please help me review this part of the code?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically looks good to me. This also works for diverse user input, e.g. n_distinct(x, a / (b - 1)) or n_distinct(!!x).
One minor thing: you might want to check for empty dots.

}

# Customized translation functions for DuckDB SQL
# @param con A \code{\link{dbConnect}} object, as returned by \code{dbConnect()}
Expand Down Expand Up @@ -316,7 +328,8 @@ sql_translation.duckdb_connection <- function(con) {
any = sql_aggregate("BOOL_OR", "any"),
str_flatten = function(x, collapse) sql_expr(STRING_AGG(!!x, !!collapse)),
first = sql_prefix("FIRST", 1),
last = sql_prefix("LAST", 1)
last = sql_prefix("LAST", 1),
n_distinct = duckdb_n_distinct
),
sql_translator(
.parent = base_win,
Expand All @@ -333,7 +346,14 @@ sql_translation.duckdb_connection <- function(con) {
partition = win_current_group(),
order = win_current_order()
)
}
},
n_distinct =
function(..., na.rm = FALSE) {
win_over(
duckdb_n_distinct(..., na.rm = na.rm),
partition = win_current_group()
)
}
)
)
}
Expand Down
49 changes: 49 additions & 0 deletions tests/testthat/test-backend-dbplyr__duckdb_connection.R
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ test_that("aggregators translated correctly", {

expect_equal(translate(str_flatten(x, ","), window = FALSE), sql(r"{STRING_AGG(x, ',')}"))
expect_equal(translate(str_flatten(x, ","), window = TRUE), sql(r"{STRING_AGG(x, ',') OVER ()}"))

expect_equal(translate(n_distinct(x), window = FALSE), sql(r"{COUNT(DISTINCT row(x))}"))
expect_equal(translate(n_distinct(x), window = TRUE), sql(r"{COUNT(DISTINCT row(x)) OVER ()}"))
})

test_that("two variable aggregates are translated correctly", {
Expand All @@ -218,8 +221,54 @@ test_that("two variable aggregates are translated correctly", {

expect_equal(translate(cor(x, y), window = FALSE), sql(r"{CORR(x, y)}"))
expect_equal(translate(cor(x, y), window = TRUE), sql(r"{CORR(x, y) OVER ()}"))

expect_equal(translate(n_distinct(x, y), window = FALSE), sql(r"{COUNT(DISTINCT row(x, y))}"))
expect_equal(translate(n_distinct(x, y), window = TRUE), sql(r"{COUNT(DISTINCT row(x, y)) OVER ()}"))
})

test_that("n_distinct() computations are correct", {
skip_if_no_R4()
skip_if_not_installed("dplyr")
skip_if_not_installed("dbplyr")
con <- dbConnect(duckdb())
on.exit(dbDisconnect(con, shutdown = TRUE))
tbl <- dplyr::tbl
summarize <- dplyr::summarize
pull <- dplyr::pull

duckdb_register(con, "df", data.frame(x = c(1, 1, 2, 2), y = c(1, 2, 2, 2)))
duckdb_register(con, "df_na", data.frame(x = c(1, 1, 2, NA, NA), y = c(1, 2, NA, 2, NA)))

df <- tbl(con, "df")
df_na <- tbl(con, "df_na")

expect_error(
pull(summarize(df, n = n_distinct(x, na.rm = TRUE)), n)
)

# single column is working as usual
expect_equal(
pull(summarize(df, n = n_distinct(x)), n),
2
)

expect_equal(
pull(summarize(df_na, n = n_distinct(x)), n),
3
)

# two columns return correct results
expect_equal(
pull(summarize(df, n = n_distinct(x, y)), n),
3
)

# two columns containing NAs return correct results
expect_equal(
pull(summarize(df_na, n = n_distinct(x, y)), n),
5
)
})

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might want to add a test that uses a window.




Expand Down
Loading