diff --git a/cpp/include/cudf/utilities/traits.hpp b/cpp/include/cudf/utilities/traits.hpp index f4e7e3e2a6d..40a833112e1 100644 --- a/cpp/include/cudf/utilities/traits.hpp +++ b/cpp/include/cudf/utilities/traits.hpp @@ -142,6 +142,31 @@ constexpr inline bool is_equality_comparable() return detail::is_equality_comparable_impl::value; } +namespace detail { +/** + * @brief Helper functor to check if a specified type `T` supports equality comparisons. + */ +struct unary_equality_comparable_functor { + template + bool operator()() const + { + return cudf::is_equality_comparable(); + } +}; +} // namespace detail + +/** + * @brief Checks whether `data_type` `type` supports equality comparisons. + * + * @param type Data_type for comparison. + * @return true If `type` supports equality comparisons. + * @return false If `type` does not support equality comparisons. + */ +inline bool is_equality_comparable(data_type type) +{ + return cudf::type_dispatcher(type, detail::unary_equality_comparable_functor{}); +} + /** * @brief Indicates whether the type `T` is a numeric type. * diff --git a/cpp/src/groupby/common/utils.hpp b/cpp/src/groupby/common/utils.hpp index 3da20fb9af3..2804dea576e 100644 --- a/cpp/src/groupby/common/utils.hpp +++ b/cpp/src/groupby/common/utils.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-20, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/src/groupby/groupby.cu b/cpp/src/groupby/groupby.cu index 533f193d692..bdaccba38dc 100644 --- a/cpp/src/groupby/groupby.cu +++ b/cpp/src/groupby/groupby.cu @@ -27,10 +27,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include @@ -76,6 +78,9 @@ std::pair, std::vector> groupby::disp // Optionally flatten nested key columns. auto [flattened_keys, _, __, ___] = flatten_nested_columns(_keys, {}, {}, column_nullability::FORCE); + auto is_supported_key_type = [](auto col) { return cudf::is_equality_comparable(col.type()); }; + CUDF_EXPECTS(std::all_of(flattened_keys.begin(), flattened_keys.end(), is_supported_key_type), + "Unsupported groupby key type does not support equality comparison"); auto [grouped_keys, results] = detail::hash::groupby(flattened_keys, requests, _include_null_keys, stream, mr); return std::make_pair(unflatten_nested_columns(std::move(grouped_keys), _keys), diff --git a/cpp/src/groupby/sort/sort_helper.cu b/cpp/src/groupby/sort/sort_helper.cu index 69d68f7b6bc..c4905b86ab9 100644 --- a/cpp/src/groupby/sort/sort_helper.cu +++ b/cpp/src/groupby/sort/sort_helper.cu @@ -23,8 +23,10 @@ #include #include #include +#include #include #include +#include #include #include @@ -102,6 +104,9 @@ sort_groupby_helper::sort_groupby_helper(table_view const& keys, auto [flattened_keys, _, __, struct_null_vectors] = flatten_nested_columns(keys, {}, {}, column_nullability::FORCE); + auto is_supported_key_type = [](auto col) { return cudf::is_equality_comparable(col.type()); }; + CUDF_EXPECTS(std::all_of(flattened_keys.begin(), flattened_keys.end(), is_supported_key_type), + "Unsupported groupby key type does not support equality comparison"); _struct_null_vectors = std::move(struct_null_vectors); _keys = flattened_keys; diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index d9553d463ab..03f7967cee0 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -62,6 +62,7 @@ ConfigureTest(GROUPBY_TEST groupby/count_tests.cpp groupby/groups_tests.cpp groupby/keys_tests.cpp + groupby/lists_tests.cpp groupby/m2_tests.cpp groupby/min_tests.cpp groupby/max_scan_tests.cpp diff --git a/cpp/tests/groupby/lists_tests.cpp b/cpp/tests/groupby/lists_tests.cpp new file mode 100644 index 00000000000..11b8ffa92b9 --- /dev/null +++ b/cpp/tests/groupby/lists_tests.cpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include + +#include + +namespace cudf { +namespace test { + +template +struct groupby_lists_test : public cudf::test::BaseFixture { +}; + +TYPED_TEST_SUITE(groupby_lists_test, cudf::test::FixedWidthTypes); + +namespace { +// Checking with a single aggregation, and aggregation column. +// This test is orthogonal to the aggregation type; it focuses on testing the grouping +// with LISTS keys. +auto sum_agg() { return cudf::make_sum_aggregation(); } + +void test_sort_based_sum_agg(column_view const& keys, column_view const& values) +{ + test_single_agg( + keys, values, keys, values, sum_agg(), force_use_sort_impl::YES, null_policy::INCLUDE); +} + +void test_hash_based_sum_agg(column_view const& keys, column_view const& values) +{ + test_single_agg( + keys, values, keys, values, sum_agg(), force_use_sort_impl::NO, null_policy::INCLUDE); +} + +} // namespace + +TYPED_TEST(groupby_lists_test, top_level_lists_are_unsupported) +{ + // Test that grouping on LISTS columns fails visibly. + + // clang-format off + auto keys = lists_column_wrapper { {1,1}, {2,2}, {3,3}, {1,1}, {2,2} }; + auto values = fixed_width_column_wrapper { 0, 1, 2, 3, 4 }; + // clang-format on + + EXPECT_THROW(test_sort_based_sum_agg(keys, values), cudf::logic_error); + EXPECT_THROW(test_hash_based_sum_agg(keys, values), cudf::logic_error); +} + +} // namespace test +} // namespace cudf diff --git a/cpp/tests/groupby/structs_tests.cpp b/cpp/tests/groupby/structs_tests.cpp index 00126a4a5a0..3715ba8d17b 100644 --- a/cpp/tests/groupby/structs_tests.cpp +++ b/cpp/tests/groupby/structs_tests.cpp @@ -22,8 +22,6 @@ #include #include -#include "cudf/aggregation.hpp" -#include "cudf/types.hpp" using namespace cudf::test::iterators; @@ -34,7 +32,7 @@ template struct groupby_structs_test : public cudf::test::BaseFixture { }; -TYPED_TEST_CASE(groupby_structs_test, cudf::test::FixedWidthTypes); +TYPED_TEST_SUITE(groupby_structs_test, cudf::test::FixedWidthTypes); using V = int32_t; // Type of Aggregation Column. using M0 = int32_t; // Type of STRUCT's first (i.e. 0th) member. @@ -79,27 +77,43 @@ void print_agg_results(column_view const& keys, column_view const& vals) } } -void test_sum_agg(column_view const& keys, - column_view const& values, - column_view const& expected_keys, - column_view const& expected_values) +void test_sort_based_sum_agg(column_view const& keys, + column_view const& values, + column_view const& expected_keys, + column_view const& expected_values) { test_single_agg(keys, values, expected_keys, expected_values, sum_agg(), - force_use_sort_impl::NO, + force_use_sort_impl::YES, null_policy::INCLUDE); +} + +void test_hash_based_sum_agg(column_view const& keys, + column_view const& values, + column_view const& expected_keys, + column_view const& expected_values) +{ test_single_agg(keys, values, expected_keys, expected_values, sum_agg(), - force_use_sort_impl::YES, + force_use_sort_impl::NO, null_policy::INCLUDE); } +void test_sum_agg(column_view const& keys, + column_view const& values, + column_view const& expected_keys, + column_view const& expected_values) +{ + test_sort_based_sum_agg(keys, values, expected_keys, expected_values); + test_hash_based_sum_agg(keys, values, expected_keys, expected_values); +} + } // namespace TYPED_TEST(groupby_structs_test, basic) @@ -312,7 +326,8 @@ TYPED_TEST(groupby_structs_test, lists_are_unsupported) // clang-format on auto keys = structs{{member_0, member_1}}; - EXPECT_THROW(test_sum_agg(keys, values, keys, values), cudf::logic_error); + EXPECT_THROW(test_sort_based_sum_agg(keys, values, keys, values), cudf::logic_error); + EXPECT_THROW(test_hash_based_sum_agg(keys, values, keys, values), cudf::logic_error); } } // namespace test