Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize string sort for default collation UTF8MB4_BIN #5375

Merged
merged 19 commits into from
Jul 21, 2022
Merged
166 changes: 123 additions & 43 deletions dbms/src/Columns/ColumnString.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
#include <Columns/ColumnsCommon.h>
#include <Common/HashTable/Hash.h>
#include <DataStreams/ColumnGathererStream.h>
#include <Functions/CollationOperatorOptimized.h>
#include <fmt/core.h>


/// Used in the `reserve` method, when the number of rows is known, but sizes of elements are not.
#define APPROX_STRING_SIZE 64

Expand Down Expand Up @@ -321,54 +323,94 @@ int ColumnString::compareAtWithCollationImpl(size_t n, size_t m, const IColumn &
);
}


template <bool positive>
struct ColumnString::lessWithCollation
// Derived must implement function `int compare(const char *, size_t, const char *, size_t)`.
template <bool positive, typename Derived>
struct ColumnString::LessWithCollation
{
const ColumnString & parent;
const ICollator & collator;
const Derived & inner;

lessWithCollation(const ColumnString & parent_, const ICollator & collator_)
LessWithCollation(const ColumnString & parent_, const Derived & inner_)
: parent(parent_)
, collator(collator_)
, inner(inner_)
{}

bool operator()(size_t lhs, size_t rhs) const
FLATTEN_INLINE_PURE inline bool operator()(size_t lhs, size_t rhs) const
{
int res = collator.compare(
int res = inner.compare(
reinterpret_cast<const char *>(&parent.chars[parent.offsetAt(lhs)]),
parent.sizeAt(lhs) - 1, // Skip last zero byte.
reinterpret_cast<const char *>(&parent.chars[parent.offsetAt(rhs)]),
parent.sizeAt(rhs) - 1 // Skip last zero byte.
);

return positive ? (res < 0) : (res > 0);
if constexpr (positive)
{
return (res < 0);
}
else
{
return (res > 0);
}
}
};

void ColumnString::getPermutationWithCollationImpl(const ICollator & collator, bool reverse, size_t limit, Permutation & res) const
struct Utf8MB4BinCmp
{
size_t s = offsets.size();
res.resize(s);
for (size_t i = 0; i < s; ++i)
res[i] = i;

if (limit >= s)
limit = 0;
static FLATTEN_INLINE_PURE inline int compare(const char * s1, size_t length1, const char * s2, size_t length2)
{
return DB::BinCollatorCompare<true>(s1, length1, s2, length2);
}
};

if (limit)
// common util functions
template <>
struct ColumnString::LessWithCollation<false, void>
{
// `CollationCmpImpl` must implement function `int compare(const char *, size_t, const char *, size_t)`.
template <typename CollationCmpImpl>
static void getPermutationWithCollationImpl(const ColumnString & src, const CollationCmpImpl & collator_cmp_impl, bool reverse, size_t limit, Permutation & res)
{
if (reverse)
std::partial_sort(res.begin(), res.begin() + limit, res.end(), lessWithCollation<false>(*this, collator));
size_t s = src.offsets.size();
res.resize(s);
for (size_t i = 0; i < s; ++i)
res[i] = i;

if (limit >= s)
limit = 0;

if (limit)
{
if (reverse)
std::partial_sort(res.begin(), res.begin() + limit, res.end(), LessWithCollation<false, CollationCmpImpl>(src, collator_cmp_impl));
else
std::partial_sort(res.begin(), res.begin() + limit, res.end(), LessWithCollation<true, CollationCmpImpl>(src, collator_cmp_impl));
}
else
std::partial_sort(res.begin(), res.begin() + limit, res.end(), lessWithCollation<true>(*this, collator));
{
if (reverse)
std::sort(res.begin(), res.end(), LessWithCollation<false, CollationCmpImpl>(src, collator_cmp_impl));
else
std::sort(res.begin(), res.end(), LessWithCollation<true, CollationCmpImpl>(src, collator_cmp_impl));
}
}
else
};

void ColumnString::getPermutationWithCollationImpl(const ICollator & collator, bool reverse, size_t limit, Permutation & res) const
{
using PermutationWithCollationUtils = ColumnString::LessWithCollation<false, void>;

// optimize path for default collator `UTF8MB4_BIN`
if (TiDB::ITiDBCollator::getCollator(TiDB::ITiDBCollator::UTF8MB4_BIN) == &collator)
{
if (reverse)
std::sort(res.begin(), res.end(), lessWithCollation<false>(*this, collator));
else
std::sort(res.begin(), res.end(), lessWithCollation<true>(*this, collator));
Utf8MB4BinCmp cmp_impl;
PermutationWithCollationUtils::getPermutationWithCollationImpl(*this, cmp_impl, reverse, limit, res);
///
solotzg marked this conversation as resolved.
Show resolved Hide resolved
return;
}

{
PermutationWithCollationUtils::getPermutationWithCollationImpl(*this, collator, reverse, limit, res);
}
}

Expand All @@ -379,35 +421,73 @@ void ColumnString::updateWeakHash32(WeakHash32 & hash, const TiDB::TiDBCollatorP
if (hash.getData().size() != s)
throw Exception(fmt::format("Size of WeakHash32 does not match size of column: column size is {}, hash size is {}", s, hash.getData().size()), ErrorCodes::LOGICAL_ERROR);

const UInt8 * pos = chars.data();
UInt32 * hash_data = hash.getData().data();
Offset prev_offset = 0;

if (collator != nullptr)
{
for (const auto & offset : offsets)
if (collator->getCollatorId() == TiDB::ITiDBCollator::UTF8MB4_BIN)
{
auto str_size = offset - prev_offset;
/// Skip last zero byte.
auto sort_key = collator->sortKey(reinterpret_cast<const char *>(pos), str_size - 1, sort_key_container);
*hash_data = ::updateWeakHash32(reinterpret_cast<const UInt8 *>(sort_key.data), sort_key.size, *hash_data);

pos += str_size;
prev_offset = offset;
// Skip last zero byte.
LoopOneColumn(chars, offsets, offsets.size(), [&](const std::string_view & view, size_t) {
auto sort_key = BinCollatorSortKey<true>(view.data(), view.size());
*hash_data = ::updateWeakHash32(reinterpret_cast<const UInt8 *>(sort_key.data), sort_key.size, *hash_data);
++hash_data;
});
}
else
{
// Skip last zero byte.
LoopOneColumn(chars, offsets, offsets.size(), [&](const std::string_view & view, size_t) {
auto sort_key = collator->sortKey(view.data(), view.size(), sort_key_container);
*hash_data = ::updateWeakHash32(reinterpret_cast<const UInt8 *>(sort_key.data), sort_key.size, *hash_data);
++hash_data;
});
}
}
else
{
// Skip last zero byte.
LoopOneColumn(chars, offsets, offsets.size(), [&](const std::string_view & view, size_t) {
solotzg marked this conversation as resolved.
Show resolved Hide resolved
*hash_data = ::updateWeakHash32(reinterpret_cast<const UInt8 *>(view.data()), view.size(), *hash_data);
++hash_data;
});
}
}

void ColumnString::updateHashWithValues(IColumn::HashValues & hash_values, const TiDB::TiDBCollatorPtr & collator, String & sort_key_container) const
{
if (collator != nullptr)
{
if (collator->getCollatorId() == TiDB::ITiDBCollator::UTF8MB4_BIN)
{
// Skip last zero byte.
LoopOneColumn(chars, offsets, offsets.size(), [&hash_values](const std::string_view & view, size_t i) {
auto sort_key = BinCollatorSortKey<true>(view.data(), view.size());
size_t string_size = sort_key.size;
hash_values[i].update(reinterpret_cast<const char *>(&string_size), sizeof(string_size));
hash_values[i].update(sort_key.data, sort_key.size);
});
}
else
{
// Skip last zero byte.
LoopOneColumn(chars, offsets, offsets.size(), [&](const std::string_view & view, size_t i) {
auto sort_key = collator->sortKey(view.data(), view.size(), sort_key_container);
size_t string_size = sort_key.size;
hash_values[i].update(reinterpret_cast<const char *>(&string_size), sizeof(string_size));
hash_values[i].update(sort_key.data, sort_key.size);
});
}
}
else
{
for (const auto & offset : offsets)
for (size_t i = 0; i < offsets.size(); ++i)
{
auto str_size = offset - prev_offset;
/// Skip last zero byte.
*hash_data = ::updateWeakHash32(pos, str_size - 1, *hash_data);
size_t string_size = sizeAt(i);
size_t offset = offsetAt(i);

pos += str_size;
prev_offset = offset;
++hash_data;
hash_values[i].update(reinterpret_cast<const char *>(&string_size), sizeof(string_size));
hash_values[i].update(reinterpret_cast<const char *>(&chars[offset]), string_size);
}
}
}
Expand Down
41 changes: 7 additions & 34 deletions dbms/src/Columns/ColumnString.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class ColumnString final : public COWPtrHelper<IColumn, ColumnString>
template <bool positive>
struct less;

template <bool positive>
struct lessWithCollation;
template <bool positive, typename Derived>
struct LessWithCollation;

ColumnString() = default;

Expand Down Expand Up @@ -118,7 +118,7 @@ class ColumnString final : public COWPtrHelper<IColumn, ColumnString>

void insert(const Field & x) override
{
const String & s = DB::get<const String &>(x);
const auto & s = DB::get<const String &>(x);
const size_t old_size = chars.size();
const size_t size_to_append = s.size() + 1;
const size_t new_size = old_size + size_to_append;
Expand All @@ -134,7 +134,7 @@ class ColumnString final : public COWPtrHelper<IColumn, ColumnString>

void insertFrom(const IColumn & src_, size_t n) override
{
const ColumnString & src = static_cast<const ColumnString &>(src_);
const auto & src = static_cast<const ColumnString &>(src_);

if (n != 0)
{
Expand Down Expand Up @@ -213,7 +213,7 @@ class ColumnString final : public COWPtrHelper<IColumn, ColumnString>

if (collator != nullptr)
{
/// Skip last zero byte.
// Skip last zero byte.
auto sort_key = collator->sortKey(reinterpret_cast<const char *>(src), string_size - 1, sort_key_container);
string_size = sort_key.size;
src = sort_key.data;
Expand Down Expand Up @@ -259,34 +259,7 @@ class ColumnString final : public COWPtrHelper<IColumn, ColumnString>
}
}

void updateHashWithValues(IColumn::HashValues & hash_values, const TiDB::TiDBCollatorPtr & collator, String & sort_key_container) const override
{
if (collator != nullptr)
{
for (size_t i = 0; i < offsets.size(); ++i)
{
size_t string_size = sizeAt(i);
size_t offset = offsetAt(i);

/// Skip last zero byte.
auto sort_key = collator->sortKey(reinterpret_cast<const char *>(&chars[offset]), string_size - 1, sort_key_container);
string_size = sort_key.size;
hash_values[i].update(reinterpret_cast<const char *>(&string_size), sizeof(string_size));
hash_values[i].update(sort_key.data, sort_key.size);
}
}
else
{
for (size_t i = 0; i < offsets.size(); ++i)
{
size_t string_size = sizeAt(i);
size_t offset = offsetAt(i);

hash_values[i].update(reinterpret_cast<const char *>(&string_size), sizeof(string_size));
hash_values[i].update(reinterpret_cast<const char *>(&chars[offset]), string_size);
}
}
}
void updateHashWithValues(IColumn::HashValues & hash_values, const TiDB::TiDBCollatorPtr & collator, String & sort_key_container) const override;

void updateWeakHash32(WeakHash32 & hash, const TiDB::TiDBCollatorPtr &, String &) const override;

Expand All @@ -304,7 +277,7 @@ class ColumnString final : public COWPtrHelper<IColumn, ColumnString>

int compareAt(size_t n, size_t m, const IColumn & rhs_, int /*nan_direction_hint*/) const override
{
const ColumnString & rhs = static_cast<const ColumnString &>(rhs_);
const auto & rhs = static_cast<const ColumnString &>(rhs_);

const size_t size = sizeAt(n);
const size_t rhs_size = rhs.sizeAt(m);
Expand Down
Loading