diff --git a/cpp/src/join/distinct_hash_join.cu b/cpp/src/join/distinct_hash_join.cu index 7c834d1a96b..981a7bf0dea 100644 --- a/cpp/src/join/distinct_hash_join.cu +++ b/cpp/src/join/distinct_hash_join.cu @@ -205,18 +205,14 @@ CUDF_KERNEL void distinct_join_probe_kernel(Iter iter, cudf::size_type buffer_size = 0; while (idx - block.thread_rank() < n) { // the whole thread block falls into the same iteration - cudf::size_type thread_count{0}; - cudf::size_type build_idx{0}; - if (idx < n) { - auto const found = hash_table.find(*(iter + idx)); - thread_count = found != hash_table.end(); - build_idx = static_cast(found->second); - } + auto const found = idx < n ? hash_table.find(*(iter + idx)) : hash_table.end(); + auto const has_match = found != hash_table.end(); // Use a whole-block scan to calculate the output location cudf::size_type offset; cudf::size_type block_count; - block_scan(block_scan_temp_storage).ExclusiveSum(thread_count, offset, block_count); + block_scan(block_scan_temp_storage) + .ExclusiveSum(static_cast(has_match), offset, block_count); if (buffer_size + block_count > buffer_capacity) { flush_buffer(block, buffer_size, buffer, counter, build_indices, probe_indices); @@ -224,8 +220,9 @@ CUDF_KERNEL void distinct_join_probe_kernel(Iter iter, buffer_size = 0; } - if (thread_count == 1) { - buffer[buffer_size + offset] = cuco::pair{build_idx, static_cast(idx)}; + if (has_match) { + buffer[buffer_size + offset] = cuco::pair{static_cast(found->second), + static_cast(idx)}; } buffer_size += block_count; block.sync();