Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance improvement for cudf::strings::like #13594

Merged
merged 3 commits into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 39 additions & 20 deletions cpp/benchmarks/string/like.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,41 @@

#include <cudf/copying.hpp>
#include <cudf/filling.hpp>
#include <cudf/strings/combine.hpp>
#include <cudf/strings/contains.hpp>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/utilities/default_stream.hpp>

#include <nvbench/nvbench.cuh>

namespace {
std::unique_ptr<cudf::column> build_input_column(cudf::size_type n_rows, int32_t hit_rate)
std::unique_ptr<cudf::column> build_input_column(cudf::size_type n_rows,
cudf::size_type row_width,
int32_t hit_rate)
{
// build input table using the following data
auto data = cudf::test::strings_column_wrapper({
"123 abc 4567890 DEFGHI 0987 5W43", // matches always;
"012345 6789 01234 56789 0123 456", // the rest do not match
"abc 4567890 DEFGHI 0987 Wxyz 123",
"abcdefghijklmnopqrstuvwxyz 01234",
"",
"AbcéDEFGHIJKLMNOPQRSTUVWXYZ 01",
"9876543210,abcdefghijklmnopqrstU",
"9876543210,abcdefghijklmnopqrstU",
"123 édf 4567890 DéFG 0987 X5",
"1",
});
auto data_view = cudf::column_view(data);
auto raw_data = cudf::test::strings_column_wrapper(
{
"123 abc 4567890 DEFGHI 0987 5W43", // matches always;
"012345 6789 01234 56789 0123 456", // the rest do not match
"abc 4567890 DEFGHI 0987 Wxyz 123",
"abcdefghijklmnopqrstuvwxyz 01234",
"",
"AbcéDEFGHIJKLMNOPQRSTUVWXYZ 01",
"9876543210,abcdefghijklmnopqrstU",
"9876543210,abcdefghijklmnopqrstU",
"123 édf 4567890 DéFG 0987 X5",
"1",
})
.release();
if (row_width / 32 > 1) {
std::vector<cudf::column_view> columns;
for (int i = 0; i < row_width / 32; ++i) {
columns.push_back(raw_data->view());
}
raw_data = cudf::strings::concatenate(cudf::table_view(columns));
}
auto data_view = raw_data->view();

// compute number of rows in n_rows that should match
auto matches = static_cast<int32_t>(n_rows * hit_rate) / 100;
Expand Down Expand Up @@ -71,14 +83,20 @@ std::unique_ptr<cudf::column> build_input_column(cudf::size_type n_rows, int32_t

static void bench_like(nvbench::state& state)
{
auto const n_rows = static_cast<cudf::size_type>(state.get_int64("num_rows"));
auto const hit_rate = static_cast<int32_t>(state.get_int64("hit_rate"));
auto const n_rows = static_cast<cudf::size_type>(state.get_int64("num_rows"));
auto const row_width = static_cast<cudf::size_type>(state.get_int64("row_width"));
auto const hit_rate = static_cast<int32_t>(state.get_int64("hit_rate"));

auto col = build_input_column(n_rows, hit_rate);
if (static_cast<std::size_t>(n_rows) * static_cast<std::size_t>(row_width) >=
static_cast<std::size_t>(std::numeric_limits<cudf::size_type>::max())) {
state.skip("Skip benchmarks greater than size_type limit");
}

auto col = build_input_column(n_rows, row_width, hit_rate);
auto input = cudf::strings_column_view(col->view());

// This pattern forces reading the entire target string (when matched expected)
auto pattern = std::string("% 5W4_"); // regex equivalent: ".* 5W4."
auto pattern = std::string("% 5W4_"); // regex equivalent: ".* 5W4.$"

state.set_cuda_stream(nvbench::make_cuda_stream_view(cudf::get_default_stream().value()));
// gather some throughput statistics as well
Expand All @@ -93,5 +111,6 @@ static void bench_like(nvbench::state& state)

NVBENCH_BENCH(bench_like)
.set_name("strings_like")
.add_int64_axis("num_rows", {4096, 32768, 262144, 2097152, 16777216})
.add_int64_axis("hit_rate", {1, 5, 10, 25, 70, 100});
.add_int64_axis("row_width", {32, 64, 128, 256, 512})
.add_int64_axis("num_rows", {32768, 262144, 2097152, 16777216})
.add_int64_axis("hit_rate", {10, 25, 70, 100});
27 changes: 18 additions & 9 deletions cpp/src/strings/like.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ struct like_fn {
auto const d_str = d_strings.element<string_view>(idx);
auto const d_pattern = patterns_itr[idx];

// using only iterators to better handle UTF-8 characters
auto target_itr = d_str.begin();
// incrementing by bytes instead of character improves performance 10-20%
auto target_itr = d_str.data();
auto pattern_itr = d_pattern.begin();

auto const target_end = d_str.end();
auto const target_end = target_itr + d_str.size_bytes();
auto const pattern_end = d_pattern.end();
auto const esc_char = d_escape.empty() ? 0 : d_escape[0];

Expand All @@ -75,12 +75,20 @@ struct like_fn {
escaped && (pattern_itr + 1 < pattern_end) ? *(++pattern_itr) : *pattern_itr;

if (escaped || (pattern_char != multi_wildcard)) {
size_type char_width = 0;
// check match with the current character
result = ((target_itr != target_end) && ((!escaped && pattern_char == single_wildcard) ||
(pattern_char == *target_itr)));
result = (target_itr != target_end);
if (result) {
if (escaped || pattern_char != single_wildcard) {
char_utf8 target_char = 0;
// retrieve the target character to compare with the current pattern_char
char_width = to_char_utf8(target_itr, target_char);
result = (pattern_char == target_char);
}
}
if (!result) { break; }
++target_itr;
++pattern_itr;
target_itr += char_width ? char_width : bytes_in_utf8_byte(*target_itr);
} else {
// process wildcard '%'
result = true;
Expand All @@ -92,8 +100,8 @@ struct like_fn {
// save positions
last_pattern_itr = pattern_itr;
last_target_itr = target_itr;
}
} // next pattern character
} // next pattern character
}

if (result && (target_itr == target_end)) { break; } // success

Expand All @@ -103,7 +111,8 @@ struct like_fn {

// restore saved positions
pattern_itr = last_pattern_itr;
target_itr = ++last_target_itr;
last_target_itr += bytes_in_utf8_byte(*last_target_itr);
target_itr = last_target_itr;
}
return result;
}
Expand Down