diff --git a/cpp/include/cudf/strings/string_view.cuh b/cpp/include/cudf/strings/string_view.cuh index a7559c7fbcb..238d55d580e 100644 --- a/cpp/include/cudf/strings/string_view.cuh +++ b/cpp/include/cudf/strings/string_view.cuh @@ -274,7 +274,8 @@ __device__ inline int string_view::compare(const char* data, size_type bytes) co size_type const len1 = size_bytes(); const unsigned char* ptr1 = reinterpret_cast(this->data()); const unsigned char* ptr2 = reinterpret_cast(data); - size_type idx = 0; + if ((ptr1 == ptr2) && (bytes == len1)) return 0; + size_type idx = 0; for (; (idx < len1) && (idx < bytes); ++idx) { if (*ptr1 != *ptr2) return static_cast(*ptr1) - static_cast(*ptr2); ++ptr1; diff --git a/cpp/src/reductions/scan/scan_inclusive.cu b/cpp/src/reductions/scan/scan_inclusive.cu index f729f812b28..1beb9ecb282 100644 --- a/cpp/src/reductions/scan/scan_inclusive.cu +++ b/cpp/src/reductions/scan/scan_inclusive.cu @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -67,7 +68,46 @@ rmm::device_buffer mask_scan(const column_view& input_view, namespace { /** - * @brief Dispatcher for running Scan operation on input column + * @brief Strings inclusive scan operator + * + * This was specifically created to workaround a thrust issue + * https://github.com/NVIDIA/thrust/issues/1479 + * where invalid values are passed to the operator. + * + * This operator will accept index values, check them and then + * run the `Op` operation on the individual string_view objects. + * The returned result is the appropriate index value. + */ +template +struct string_scan_operator { + column_device_view const col; ///< strings column device view + string_view const null_replacement{}; ///< value used when element is null + bool const has_nulls; ///< true if col has null elements + + string_scan_operator(column_device_view const& col, bool has_nulls = true) + : col{col}, null_replacement{Op::template identity()}, has_nulls{has_nulls} + { + CUDF_EXPECTS(type_id::STRING == col.type().id(), "the data type mismatch"); + // verify validity bitmask is non-null, otherwise, is_null_nocheck() will crash + if (has_nulls) CUDF_EXPECTS(col.nullable(), "column with nulls must have a validity bitmask"); + } + + CUDA_DEVICE_CALLABLE + size_type operator()(size_type lhs, size_type rhs) const + { + // thrust::inclusive_scan may pass us garbage values so we need to protect ourselves; + // in these cases the return value does not matter since the result is not used + if (lhs < 0 || rhs < 0 || lhs >= col.size() || rhs >= col.size()) return 0; + string_view d_lhs = + has_nulls && col.is_null_nocheck(lhs) ? null_replacement : col.element(lhs); + string_view d_rhs = + has_nulls && col.is_null_nocheck(rhs) ? null_replacement : col.element(rhs); + return Op{}(d_lhs, d_rhs) == d_lhs ? lhs : rhs; + } +}; + +/** + * @brief Dispatcher for running a Scan operation on an input column * * @tparam Op device binary operator */ @@ -117,22 +157,25 @@ struct scan_dispatcher { { auto d_input = column_device_view::create(input_view, stream); - rmm::device_uvector result(input_view.size(), stream); - auto begin = - make_null_replacement_iterator(*d_input, Op::template identity(), input_view.has_nulls()); - thrust::inclusive_scan( - rmm::exec_policy(stream), begin, begin + input_view.size(), result.data(), Op{}); - - CHECK_CUDA(stream.value()); - return cudf::make_strings_column(result, Op::template identity(), stream, mr); + // build indices of the scan operation results + rmm::device_uvector result(input_view.size(), stream); + thrust::inclusive_scan(rmm::exec_policy(stream), + thrust::counting_iterator(0), + thrust::counting_iterator(input_view.size()), + result.begin(), + string_scan_operator{*d_input, input_view.has_nulls()}); + + // call gather using the indices to build the output column + return cudf::strings::detail::gather( + strings_column_view(input_view), result.begin(), result.end(), false, stream, mr); } public: /** - * @brief creates new column from input column by applying scan operation + * @brief Creates a new column from the input column by applying the scan operation * - * @param input input column view - * @param inclusive inclusive or exclusive scan + * @param input Input column view + * @param null_handling How null row entries are to be processed * @param stream CUDA stream used for device memory operations and kernel launches. * @param mr Device memory resource used to allocate the returned column's device memory * @return diff --git a/cpp/tests/reductions/scan_tests.cpp b/cpp/tests/reductions/scan_tests.cpp index 92ba1f9e60f..ef5a66a2019 100644 --- a/cpp/tests/reductions/scan_tests.cpp +++ b/cpp/tests/reductions/scan_tests.cpp @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -394,3 +395,38 @@ TYPED_TEST(ScanTest, LeadingNulls) this->scan_test(v, b, cudf::make_min_aggregation(), scan_type::INCLUSIVE, null_policy::INCLUDE); this->scan_test(v, b, cudf::make_min_aggregation(), scan_type::EXCLUSIVE, null_policy::INCLUDE); } + +class ScanStringsTest : public ScanTest { +}; + +TEST_F(ScanStringsTest, MoreStringsMinMax) +{ + int row_count = 512; + + auto data_begin = cudf::detail::make_counting_transform_iterator(0, [](auto idx) { + char const s[] = {static_cast('a' + (idx % 26)), 0}; + return std::string(s); + }); + auto validity = cudf::detail::make_counting_transform_iterator( + 0, [](auto idx) -> bool { return (idx % 23) != 22; }); + cudf::test::strings_column_wrapper col(data_begin, data_begin + row_count, validity); + + thrust::host_vector v(data_begin, data_begin + row_count); + thrust::host_vector b(validity, validity + row_count); + + this->scan_test(v, {}, cudf::make_min_aggregation(), scan_type::INCLUSIVE); + this->scan_test(v, b, cudf::make_min_aggregation(), scan_type::INCLUSIVE); + this->scan_test(v, b, cudf::make_min_aggregation(), scan_type::INCLUSIVE, null_policy::EXCLUDE); + + this->scan_test(v, {}, cudf::make_min_aggregation(), scan_type::EXCLUSIVE); + this->scan_test(v, b, cudf::make_min_aggregation(), scan_type::EXCLUSIVE); + this->scan_test(v, b, cudf::make_min_aggregation(), scan_type::EXCLUSIVE, null_policy::EXCLUDE); + + this->scan_test(v, {}, cudf::make_max_aggregation(), scan_type::INCLUSIVE); + this->scan_test(v, b, cudf::make_max_aggregation(), scan_type::INCLUSIVE); + this->scan_test(v, b, cudf::make_max_aggregation(), scan_type::INCLUSIVE, null_policy::EXCLUDE); + + this->scan_test(v, {}, cudf::make_max_aggregation(), scan_type::EXCLUSIVE); + this->scan_test(v, b, cudf::make_max_aggregation(), scan_type::EXCLUSIVE); + this->scan_test(v, b, cudf::make_max_aggregation(), scan_type::EXCLUSIVE, null_policy::EXCLUDE); +}