Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix large strings handling in nvtext::character_tokenize #15829

Merged
merged 15 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions cpp/benchmarks/text/tokenize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ static void bench_tokenize(nvbench::state& state)
state.skip("Skip benchmarks greater than size_type limit");
}

data_profile const profile = data_profile_builder().distribution(
cudf::type_id::STRING, distribution_id::NORMAL, 0, row_width);
data_profile const profile =
data_profile_builder()
.distribution(cudf::type_id::STRING, distribution_id::NORMAL, 0, row_width)
.no_validity();
auto const column = create_random_column(cudf::type_id::STRING, row_count{num_rows}, profile);
cudf::strings_column_view input(column->view());

Expand Down
3 changes: 2 additions & 1 deletion cpp/include/nvtext/tokenize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ std::unique_ptr<cudf::column> count_tokens(
* t is now ["h","e","l","l","o"," ","w","o","r","l","d","g","o","o","d","b","y","e"]
* @endcode
*
* All null row entries are ignored and the output contains all valid rows.
* @throw std::invalid_argument if `input` contains nulls
* @throw std::overflow_error if the output would produce more than max size_type rows
*
* @param input Strings column to tokenize
* @param stream CUDA stream used for device memory operations and kernel launches
Expand Down
66 changes: 51 additions & 15 deletions cpp/src/text/tokenize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/detail/utilities/algorithm.cuh>
#include <cudf/detail/utilities/integer_utils.hpp>
#include <cudf/strings/detail/strings_column_factories.cuh>
#include <cudf/strings/string_view.cuh>
#include <cudf/strings/strings_column_view.hpp>
Expand All @@ -35,6 +36,7 @@
#include <rmm/exec_policy.hpp>
#include <rmm/resource_ref.hpp>

#include <cuda/atomic>
#include <thrust/copy.h>
#include <thrust/count.h>
#include <thrust/for_each.h>
Expand Down Expand Up @@ -99,6 +101,31 @@ std::unique_ptr<cudf::column> tokenize_fn(cudf::size_type strings_count,
return cudf::strings::detail::make_strings_column(tokens.begin(), tokens.end(), stream, mr);
}

constexpr int64_t block_size = 512; // number of threads per block
constexpr int64_t bytes_per_thread = 4; // bytes processed per thread

CUDF_KERNEL void count_characters(uint8_t const* d_chars, int64_t chars_bytes, int64_t* d_output)
{
auto const idx = cudf::detail::grid_1d::global_thread_id();
auto const byte_idx = static_cast<int64_t>(idx) * bytes_per_thread;
auto const lane_idx = static_cast<cudf::size_type>(threadIdx.x);

using block_reduce = cub::BlockReduce<int64_t, block_size>;
__shared__ typename block_reduce::TempStorage temp_storage;

int64_t count = 0;
// each thread processes multiple bytes
for (auto i = byte_idx; (i < (byte_idx + bytes_per_thread)) && (i < chars_bytes); ++i) {
count += cudf::strings::detail::is_begin_utf8_char(d_chars[i]);
}
auto const total = block_reduce(temp_storage).Reduce(count, cub::Sum());

if ((lane_idx == 0) && (total > 0)) {
cuda::atomic_ref<int64_t, cuda::thread_scope_block> ref{*d_output};
ref.fetch_add(total, cuda::std::memory_order_relaxed);
}
}

} // namespace

// detail APIs
Expand Down Expand Up @@ -176,35 +203,44 @@ std::unique_ptr<cudf::column> character_tokenize(cudf::strings_column_view const
return cudf::make_empty_column(cudf::data_type{cudf::type_id::STRING});
}

auto offsets = strings_column.offsets();
auto offset = cudf::strings::detail::get_offset_value(offsets, strings_column.offset(), stream);
auto chars_bytes = cudf::strings::detail::get_offset_value(
offsets, strings_column.offset() + strings_count, stream) -
offset;
CUDF_EXPECTS(
strings_column.null_count() == 0, "input must not contain nulls", std::invalid_argument);

auto const offsets = strings_column.offsets();
auto const offset =
cudf::strings::detail::get_offset_value(offsets, strings_column.offset(), stream);
auto const chars_bytes = cudf::strings::detail::get_offset_value(
offsets, strings_column.offset() + strings_count, stream) -
offset;
// no bytes -- this could happen in an all-empty column
if (chars_bytes == 0) { return cudf::make_empty_column(cudf::type_id::STRING); }
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
auto d_chars =
strings_column.parent().data<uint8_t>(); // unsigned is necessary for checking bits
d_chars += offset;

// To minimize memory, count the number of characters so we can
// build the output offsets without an intermediate buffer.
// In the worst case each byte is a character so the output is 4x the input.
cudf::size_type num_characters = thrust::count_if(
rmm::exec_policy(stream), d_chars, d_chars + chars_bytes, [] __device__(uint8_t byte) {
return cudf::strings::detail::is_begin_utf8_char(byte);
});
rmm::device_scalar<int64_t> d_count(0, stream);
auto const num_blocks = cudf::util::div_rounding_up_safe(
cudf::util::div_rounding_up_safe(chars_bytes, static_cast<int64_t>(bytes_per_thread)),
block_size);
count_characters<<<num_blocks, block_size, 0, stream.value()>>>(
d_chars, chars_bytes, d_count.data());
auto const num_characters = d_count.value(stream);

// no characters check -- this could happen in all-empty or all-null strings column
if (num_characters == 0) {
return cudf::make_empty_column(cudf::data_type{cudf::type_id::STRING});
}
// number of characters becomes the number of rows so need to check the row limit
CUDF_EXPECTS(
num_characters + 1 < static_cast<int64_t>(std::numeric_limits<cudf::size_type>::max()),
"output exceeds the column size limit",
std::overflow_error);

// create output offsets column
// -- conditionally copy a counting iterator where
// the first byte of each character is located
auto offsets_column = cudf::make_numeric_column(
offsets.type(), num_characters + 1, cudf::mask_state::UNALLOCATED, stream, mr);
auto d_new_offsets =
cudf::detail::offsetalator_factory::make_output_iterator(offsets_column->mutable_view());
// offsets are at the beginning byte of each character
cudf::detail::copy_if_safe(
thrust::counting_iterator<int64_t>(0),
thrust::counting_iterator<int64_t>(chars_bytes + 1),
Expand Down
10 changes: 2 additions & 8 deletions cpp/tests/text/tokenize_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,13 @@ TEST_F(TextTokenizeTest, TokenizeErrorTest)

TEST_F(TextTokenizeTest, CharacterTokenize)
{
std::vector<char const*> h_strings{"the mousé ate the cheese", nullptr, ""};
cudf::test::strings_column_wrapper strings(
h_strings.begin(),
h_strings.end(),
thrust::make_transform_iterator(h_strings.begin(), [](auto str) { return str != nullptr; }));
cudf::test::strings_column_wrapper input({"the mousé ate the cheese", ""});

cudf::test::strings_column_wrapper expected{"t", "h", "e", " ", "m", "o", "u", "s",
"é", " ", "a", "t", "e", " ", "t", "h",
"e", " ", "c", "h", "e", "e", "s", "e"};

auto results = nvtext::character_tokenize(cudf::strings_column_view(strings));
auto results = nvtext::character_tokenize(cudf::strings_column_view(input));
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected);
}

Expand Down Expand Up @@ -151,8 +147,6 @@ TEST_F(TextTokenizeTest, TokenizeEmptyTest)
EXPECT_EQ(results->size(), 0);
results = nvtext::character_tokenize(all_empty);
EXPECT_EQ(results->size(), 0);
results = nvtext::character_tokenize(all_null);
EXPECT_EQ(results->size(), 0);
auto const delimiter = cudf::string_scalar{""};
results = nvtext::tokenize_with_vocabulary(view, all_empty, delimiter);
EXPECT_EQ(results->size(), 0);
Expand Down
13 changes: 7 additions & 6 deletions python/cudf/cudf/core/column/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,16 +552,17 @@ def join(
return self._return_or_inplace(data)

def _split_by_character(self):
result_col = libstrings.character_tokenize(self._column)
col = self._column.fillna("") # sanitize nulls
result_col = libstrings.character_tokenize(col)

offset_col = self._column.children[0]
offset_col = col.children[0]

return cudf.core.column.ListColumn(
size=len(self._column),
dtype=cudf.ListDtype(self._column.dtype),
mask=self._column.mask,
size=len(col),
dtype=cudf.ListDtype(col.dtype),
mask=col.mask,
offset=0,
null_count=self._column.null_count,
null_count=0,
children=(offset_col, result_col),
)

Expand Down
2 changes: 0 additions & 2 deletions python/cudf/cudf/tests/text/test_text_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,6 @@ def test_character_tokenize_series():
[
"hello world",
"sdf",
None,
(
"goodbye, one-two:three~four+five_six@sev"
"en#eight^nine heŒŽ‘•™œ$µ¾ŤƠé DŽ"
Expand Down Expand Up @@ -543,7 +542,6 @@ def test_character_tokenize_index():
[
"hello world",
"sdf",
None,
(
"goodbye, one-two:three~four+five_six@sev"
"en#eight^nine heŒŽ‘•™œ$µ¾ŤƠé DŽ"
Expand Down
Loading