Skip to content

Commit

Permalink
fix some bugs in simd Jaro/JaroWinkler implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
maxbachmann committed Oct 8, 2023
1 parent f121d28 commit 3a637d3
Show file tree
Hide file tree
Showing 13 changed files with 100 additions and 92 deletions.
91 changes: 47 additions & 44 deletions extras/rapidfuzz_amalgamated.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Licensed under the MIT License <http://opensource.org/licenses/MIT>.
// SPDX-License-Identifier: MIT
// RapidFuzz v1.0.2
// Generated: 2023-10-08 12:45:00.456286
// Generated: 2023-10-08 18:06:53.104764
// ----------------------------------------------------------
// This file is an amalgamation of multiple different files.
// You probably shouldn't edit it directly.
Expand Down Expand Up @@ -1375,7 +1375,7 @@ T bit_mask_lsb(int n)
{
T mask = static_cast<T>(-1);
if (n < static_cast<int>(sizeof(T) * 8)) {
mask += static_cast<T>(1) << n;
mask += static_cast<T>(static_cast<T>(1) << n);
}
return mask;
}
Expand Down Expand Up @@ -2345,7 +2345,7 @@ static inline native_simd<uint16_t> operator>(const native_simd<uint16_t>& a,
static inline native_simd<uint32_t> operator>(const native_simd<uint32_t>& a,
const native_simd<uint32_t>& b) noexcept
{
__m256i signbit = _mm256_set1_epi32(0x80000000);
__m256i signbit = _mm256_set1_epi32(static_cast<int32_t>(0x80000000));
__m256i a1 = _mm256_xor_si256(a, signbit);
__m256i b1 = _mm256_xor_si256(b, signbit);
return _mm256_cmpgt_epi32(a1, b1); // signed compare
Expand Down Expand Up @@ -2934,7 +2934,7 @@ static inline native_simd<uint16_t> operator>(const native_simd<uint16_t>& a,
static inline native_simd<uint32_t> operator>(const native_simd<uint32_t>& a,
const native_simd<uint32_t>& b) noexcept
{
__m128i signbit = _mm_set1_epi32(0x80000000);
__m128i signbit = _mm_set1_epi32(static_cast<int32_t>(0x80000000));
__m128i a1 = _mm_xor_si128(a, signbit);
__m128i b1 = _mm_xor_si128(b, signbit);
return _mm_cmpgt_epi32(a1, b1); // signed compare
Expand All @@ -2943,7 +2943,7 @@ static inline native_simd<uint32_t> operator>(const native_simd<uint32_t>& a,
static inline native_simd<uint64_t> operator>(const native_simd<uint64_t>& a,
const native_simd<uint64_t>& b) noexcept
{
__m128i sign32 = _mm_set1_epi32(0x80000000); // sign bit of each dword
__m128i sign32 = _mm_set1_epi32(static_cast<int32_t>(0x80000000)); // sign bit of each dword
__m128i aflip = _mm_xor_si128(a, sign32); // a with sign bits flipped to use signed compare
__m128i bflip = _mm_xor_si128(b, sign32); // b with sign bits flipped to use signed compare
__m128i equal = _mm_cmpeq_epi32(a, b); // a == b, dwords
Expand Down Expand Up @@ -3322,7 +3322,7 @@ struct CachedSimilarityBase : public CachedNormalizedMetricBase<T> {
friend T;
};

template <typename T>
template <typename T, typename ResType>
struct MultiNormalizedMetricBase {
template <typename InputIt2>
void normalized_distance(double* scores, size_t score_count, InputIt2 first2, InputIt2 last2,
Expand Down Expand Up @@ -3362,23 +3362,23 @@ struct MultiNormalizedMetricBase {
throw std::invalid_argument("scores has to have >= result_count() elements");

// reinterpretation only works when the types have the same size
int64_t* scores_i64 = nullptr;
if constexpr (sizeof(double) == sizeof(int64_t))
scores_i64 = reinterpret_cast<int64_t*>(scores);
ResType* scores_orig = nullptr;
if constexpr (sizeof(double) == sizeof(ResType))
scores_orig = reinterpret_cast<ResType*>(scores);
else
scores_i64 = new int64_t[derived.result_count()];
scores_orig = new ResType[derived.result_count()];

Range s2_(s2);
derived.distance(scores_i64, derived.result_count(), s2_);
derived.distance(scores_orig, derived.result_count(), s2_);

for (size_t i = 0; i < derived.get_input_count(); ++i) {
auto maximum = derived.maximum(i, s2);
double norm_dist =
(maximum != 0) ? static_cast<double>(scores_i64[i]) / static_cast<double>(maximum) : 0.0;
(maximum != 0) ? static_cast<double>(scores_orig[i]) / static_cast<double>(maximum) : 0.0;
scores[i] = (norm_dist <= score_cutoff) ? norm_dist : 1.0;
}

if constexpr (sizeof(double) != sizeof(int64_t)) delete[] scores_i64;
if constexpr (sizeof(double) != sizeof(ResType)) delete[] scores_orig;
}

template <typename InputIt2>
Expand All @@ -3400,7 +3400,7 @@ struct MultiNormalizedMetricBase {
};

template <typename T, typename ResType, int64_t WorstSimilarity, int64_t WorstDistance>
struct MultiDistanceBase : public MultiNormalizedMetricBase<T> {
struct MultiDistanceBase : public MultiNormalizedMetricBase<T, ResType> {
template <typename InputIt2>
void distance(ResType* scores, size_t score_count, InputIt2 first2, InputIt2 last2,
ResType score_cutoff = static_cast<ResType>(WorstDistance)) const
Expand Down Expand Up @@ -3451,7 +3451,7 @@ struct MultiDistanceBase : public MultiNormalizedMetricBase<T> {
};

template <typename T, typename ResType, int64_t WorstSimilarity, int64_t WorstDistance>
struct MultiSimilarityBase : public MultiNormalizedMetricBase<T> {
struct MultiSimilarityBase : public MultiNormalizedMetricBase<T, ResType> {
template <typename InputIt2>
void distance(ResType* scores, size_t score_count, InputIt2 first2, InputIt2 last2,
ResType score_cutoff = static_cast<ResType>(WorstDistance)) const
Expand Down Expand Up @@ -4726,7 +4726,7 @@ struct MultiLCSseq : public detail::MultiSimilarityBase<MultiLCSseq<MaxLen>, int
std::numeric_limits<int64_t>::max()> {
private:
friend detail::MultiSimilarityBase<MultiLCSseq<MaxLen>, int64_t, 0, std::numeric_limits<int64_t>::max()>;
friend detail::MultiNormalizedMetricBase<MultiLCSseq<MaxLen>>;
friend detail::MultiNormalizedMetricBase<MultiLCSseq<MaxLen>, int64_t>;

constexpr static size_t get_vec_size()
{
Expand Down Expand Up @@ -5012,7 +5012,7 @@ struct MultiIndel
: public detail::MultiDistanceBase<MultiIndel<MaxLen>, int64_t, 0, std::numeric_limits<int64_t>::max()> {
private:
friend detail::MultiDistanceBase<MultiIndel<MaxLen>, int64_t, 0, std::numeric_limits<int64_t>::max()>;
friend detail::MultiNormalizedMetricBase<MultiIndel<MaxLen>>;
friend detail::MultiNormalizedMetricBase<MultiIndel<MaxLen>, int64_t>;

public:
MultiIndel(size_t count) : scorer(count)
Expand Down Expand Up @@ -5570,7 +5570,7 @@ double jaro_similarity(const BlockPatternMatchVector& PM, Range<InputIt1> P, Ran
#ifdef RAPIDFUZZ_SIMD
template <typename VecType, typename InputIt, int _lto_hack = RAPIDFUZZ_LTO_HACK>
void jaro_similarity_simd(Range<double*> scores, const detail::BlockPatternMatchVector& block,
const std::vector<size_t>& s1_lengths, Range<InputIt> s2,
const std::vector<int64_t>& s1_lengths, Range<InputIt> s2,
double score_cutoff) noexcept
{
# ifdef RAPIDFUZZ_AVX2
Expand All @@ -5588,15 +5588,15 @@ void jaro_similarity_simd(Range<double*> scores, const detail::BlockPatternMatch
size_t result_index = 0;

if (score_cutoff > 1.0) {
for (int64_t i = 0; i < s1_lengths.size(); i++)
for (int64_t i = 0; i < static_cast<int64_t>(s1_lengths.size()); i++)
scores[i] = 0.0;

return;
}

if (s2.empty()) {
for (int64_t i = 0; i < s1_lengths.size(); i++)
scores[i] = s1_lengths[i] ? 0.0 : 1.0;
for (size_t i = 0; i < s1_lengths.size(); i++)
scores[static_cast<int64_t>(i)] = s1_lengths[i] ? 0.0 : 1.0;

return;
}
Expand All @@ -5617,8 +5617,8 @@ void jaro_similarity_simd(Range<double*> scores, const detail::BlockPatternMatch

if (Bound > maxBound) maxBound = Bound;

boundMaskSize_[i] = bit_mask_lsb<VecType>(2 * Bound);
boundMask_[i] = bit_mask_lsb<VecType>(Bound + 1);
boundMaskSize_[i] = bit_mask_lsb<VecType>(static_cast<int>(2 * Bound));
boundMask_[i] = bit_mask_lsb<VecType>(static_cast<int>(Bound + 1));
});

if (s2_cur.size() > lastRelevantChar) s2_cur.remove_suffix(s2_cur.size() - lastRelevantChar);
Expand Down Expand Up @@ -5649,7 +5649,7 @@ void jaro_similarity_simd(Range<double*> scores, const detail::BlockPatternMatch
P_flag.store(P_flags.data());
alignas(32) std::array<VecType, vec_width> T_flags;
T_flag.store(T_flags.data());
for (int64_t i = 0; i < vec_width; ++i) {
for (size_t i = 0; i < vec_width; ++i) {
VecType CommonChars = counts[i];
if (!jaro_common_char_filter(s1_lengths[result_index], s2.size(), CommonChars, score_cutoff)) {
scores[static_cast<int64_t>(result_index)] = 0.0;
Expand All @@ -5661,20 +5661,22 @@ void jaro_similarity_simd(Range<double*> scores, const detail::BlockPatternMatch
VecType T_flag_cur = T_flags[i];
size_t Transpositions = 0;

int64_t cur_block = i / 4;
int64_t offset = 8 * (i % 4);
static constexpr size_t vecs_per_word = vec_width / vecs;
size_t cur_block = i / vecs_per_word;
int64_t offset = static_cast<int64_t>(sizeof(VecType) * 8 * (i % vecs_per_word));
while (T_flag_cur) {
uint64_t PatternFlagMask = blsi(P_flag_cur);
VecType PatternFlagMask = blsi(P_flag_cur);

Transpositions +=
!(block.get(cur_block, s2[countr_zero(T_flag_cur)]) & (PatternFlagMask << offset));
uint64_t PM_j = block.get(cur_block, s2[countr_zero(T_flag_cur)]);
Transpositions += !(PM_j & (static_cast<uint64_t>(PatternFlagMask) << offset));

T_flag_cur = blsr(T_flag_cur);
P_flag_cur ^= PatternFlagMask;
}

double Sim =
jaro_calculate_similarity(s1_lengths[result_index], s2.size(), CommonChars, Transpositions);

scores[static_cast<int64_t>(result_index)] = (Sim >= score_cutoff) ? Sim : 0;
result_index++;
}
Expand Down Expand Up @@ -5763,7 +5765,7 @@ struct MultiJaro : public detail::MultiSimilarityBase<MultiJaro<MaxLen>, double,

private:
friend detail::MultiSimilarityBase<MultiJaro<MaxLen>, double, 0, 1>;
friend detail::MultiNormalizedMetricBase<MultiJaro<MaxLen>>;
friend detail::MultiNormalizedMetricBase<MultiJaro<MaxLen>, double>;

constexpr static size_t get_vec_size()
{
Expand Down Expand Up @@ -5829,7 +5831,7 @@ struct MultiJaro : public detail::MultiSimilarityBase<MultiJaro<MaxLen>, double,

if (pos >= input_count) throw std::invalid_argument("out of bounds insert");

str_lens[pos] = static_cast<size_t>(len);
str_lens[pos] = len;
for (; first1 != last1; ++first1) {
PM.insert(block, *first1, block_pos);
block_pos++;
Expand Down Expand Up @@ -5857,7 +5859,7 @@ struct MultiJaro : public detail::MultiSimilarityBase<MultiJaro<MaxLen>, double,
}

template <typename InputIt2>
double maximum(size_t s1_idx, detail::Range<InputIt2>) const
double maximum([[maybe_unused]] size_t s1_idx, detail::Range<InputIt2>) const
{
return 1.0;
}
Expand All @@ -5870,7 +5872,7 @@ struct MultiJaro : public detail::MultiSimilarityBase<MultiJaro<MaxLen>, double,
size_t input_count;
size_t pos = 0;
detail::BlockPatternMatchVector PM;
std::vector<size_t> str_lens;
std::vector<int64_t> str_lens;
};

} /* namespace experimental */
Expand Down Expand Up @@ -6070,7 +6072,7 @@ struct MultiJaroWinkler : public detail::MultiSimilarityBase<MultiJaroWinkler<Ma

private:
friend detail::MultiSimilarityBase<MultiJaroWinkler<MaxLen>, double, 0, 1>;
friend detail::MultiNormalizedMetricBase<MultiJaroWinkler<MaxLen>>;
friend detail::MultiNormalizedMetricBase<MultiJaroWinkler<MaxLen>, double>;

public:
MultiJaroWinkler(size_t count, double prefix_weight_) : scorer(count), prefix_weight(prefix_weight_)
Expand Down Expand Up @@ -6102,8 +6104,8 @@ struct MultiJaroWinkler : public detail::MultiSimilarityBase<MultiJaroWinkler<Ma
scorer.insert(first1, last1);
size_t len = static_cast<size_t>(std::distance(first1, last1));
std::array<uint64_t, 4> prefix;
for (size_t i = 0; i < std::min<int64_t>(len, 4); ++i)
prefix[i] = (uint64_t)first1[i];
for (size_t i = 0; i < std::min(len, size_t(4)); ++i)
prefix[i] = static_cast<uint64_t>(first1[static_cast<ptrdiff_t>(i)]);

str_lens.push_back(len);
prefixes.push_back(prefix);
Expand All @@ -6117,15 +6119,16 @@ struct MultiJaroWinkler : public detail::MultiSimilarityBase<MultiJaroWinkler<Ma
if (score_count < result_count())
throw std::invalid_argument("scores has to have >= result_count() elements");

scorer.similarity(scores, score_count, s2, score_cutoff);
scorer.similarity(scores, score_count, s2, std::min(0.7, score_cutoff));

for (size_t i = 0; i < get_input_count(); ++i) {
if (scores[i] > 0.7) {
int64_t min_len = std::min<int64_t>(s2.size(), str_lens[i]);
int64_t max_prefix = std::min<int64_t>(min_len, 4);
int64_t prefix = 0;
size_t min_len = std::min(static_cast<size_t>(s2.size()), str_lens[i]);
size_t max_prefix = std::min(min_len, size_t(4));
size_t prefix = 0;
for (; prefix < max_prefix; ++prefix)
if (s2[prefix] != prefixes[i][prefix]) break;
if (static_cast<uint64_t>(s2[static_cast<ptrdiff_t>(prefix)]) != prefixes[i][prefix])
break;

scores[i] += static_cast<double>(prefix) * prefix_weight * (1.0 - scores[i]);
}
Expand All @@ -6135,7 +6138,7 @@ struct MultiJaroWinkler : public detail::MultiSimilarityBase<MultiJaroWinkler<Ma
}

template <typename InputIt2>
double maximum(size_t s1_idx, detail::Range<InputIt2>) const
double maximum([[maybe_unused]] size_t s1_idx, detail::Range<InputIt2>) const
{
return 1.0;
}
Expand Down Expand Up @@ -7674,7 +7677,7 @@ struct MultiLevenshtein : public detail::MultiDistanceBase<MultiLevenshtein<MaxL
private:
friend detail::MultiDistanceBase<MultiLevenshtein<MaxLen>, int64_t, 0,
std::numeric_limits<int64_t>::max()>;
friend detail::MultiNormalizedMetricBase<MultiLevenshtein<MaxLen>>;
friend detail::MultiNormalizedMetricBase<MultiLevenshtein<MaxLen>, int64_t>;

constexpr static size_t get_vec_size()
{
Expand Down Expand Up @@ -8240,7 +8243,7 @@ struct MultiOSA
: public detail::MultiDistanceBase<MultiOSA<MaxLen>, int64_t, 0, std::numeric_limits<int64_t>::max()> {
private:
friend detail::MultiDistanceBase<MultiOSA<MaxLen>, int64_t, 0, std::numeric_limits<int64_t>::max()>;
friend detail::MultiNormalizedMetricBase<MultiOSA<MaxLen>>;
friend detail::MultiNormalizedMetricBase<MultiOSA<MaxLen>, int64_t>;

constexpr static size_t get_vec_size()
{
Expand Down
20 changes: 10 additions & 10 deletions rapidfuzz/details/distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ struct CachedSimilarityBase : public CachedNormalizedMetricBase<T> {
friend T;
};

template <typename T>
template <typename T, typename ResType>
struct MultiNormalizedMetricBase {
template <typename InputIt2>
void normalized_distance(double* scores, size_t score_count, InputIt2 first2, InputIt2 last2,
Expand Down Expand Up @@ -402,23 +402,23 @@ struct MultiNormalizedMetricBase {
throw std::invalid_argument("scores has to have >= result_count() elements");

// reinterpretation only works when the types have the same size
int64_t* scores_i64 = nullptr;
if constexpr (sizeof(double) == sizeof(int64_t))
scores_i64 = reinterpret_cast<int64_t*>(scores);
ResType* scores_orig = nullptr;
if constexpr (sizeof(double) == sizeof(ResType))
scores_orig = reinterpret_cast<ResType*>(scores);
else
scores_i64 = new int64_t[derived.result_count()];
scores_orig = new ResType[derived.result_count()];

Range s2_(s2);
derived.distance(scores_i64, derived.result_count(), s2_);
derived.distance(scores_orig, derived.result_count(), s2_);

for (size_t i = 0; i < derived.get_input_count(); ++i) {
auto maximum = derived.maximum(i, s2);
double norm_dist =
(maximum != 0) ? static_cast<double>(scores_i64[i]) / static_cast<double>(maximum) : 0.0;
(maximum != 0) ? static_cast<double>(scores_orig[i]) / static_cast<double>(maximum) : 0.0;
scores[i] = (norm_dist <= score_cutoff) ? norm_dist : 1.0;
}

if constexpr (sizeof(double) != sizeof(int64_t)) delete[] scores_i64;
if constexpr (sizeof(double) != sizeof(ResType)) delete[] scores_orig;
}

template <typename InputIt2>
Expand All @@ -440,7 +440,7 @@ struct MultiNormalizedMetricBase {
};

template <typename T, typename ResType, int64_t WorstSimilarity, int64_t WorstDistance>
struct MultiDistanceBase : public MultiNormalizedMetricBase<T> {
struct MultiDistanceBase : public MultiNormalizedMetricBase<T, ResType> {
template <typename InputIt2>
void distance(ResType* scores, size_t score_count, InputIt2 first2, InputIt2 last2,
ResType score_cutoff = static_cast<ResType>(WorstDistance)) const
Expand Down Expand Up @@ -491,7 +491,7 @@ struct MultiDistanceBase : public MultiNormalizedMetricBase<T> {
};

template <typename T, typename ResType, int64_t WorstSimilarity, int64_t WorstDistance>
struct MultiSimilarityBase : public MultiNormalizedMetricBase<T> {
struct MultiSimilarityBase : public MultiNormalizedMetricBase<T, ResType> {
template <typename InputIt2>
void distance(ResType* scores, size_t score_count, InputIt2 first2, InputIt2 last2,
ResType score_cutoff = static_cast<ResType>(WorstDistance)) const
Expand Down
2 changes: 1 addition & 1 deletion rapidfuzz/details/intrinsics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ T bit_mask_lsb(int n)
{
T mask = static_cast<T>(-1);
if (n < static_cast<int>(sizeof(T) * 8)) {
mask += static_cast<T>(1) << n;
mask += static_cast<T>(static_cast<T>(1) << n);
}
return mask;
}
Expand Down
2 changes: 1 addition & 1 deletion rapidfuzz/details/simd_avx2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ static inline native_simd<uint16_t> operator>(const native_simd<uint16_t>& a,
static inline native_simd<uint32_t> operator>(const native_simd<uint32_t>& a,
const native_simd<uint32_t>& b) noexcept
{
__m256i signbit = _mm256_set1_epi32(0x80000000);
__m256i signbit = _mm256_set1_epi32(static_cast<int32_t>(0x80000000));
__m256i a1 = _mm256_xor_si256(a, signbit);
__m256i b1 = _mm256_xor_si256(b, signbit);
return _mm256_cmpgt_epi32(a1, b1); // signed compare
Expand Down
4 changes: 2 additions & 2 deletions rapidfuzz/details/simd_sse2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ static inline native_simd<uint16_t> operator>(const native_simd<uint16_t>& a,
static inline native_simd<uint32_t> operator>(const native_simd<uint32_t>& a,
const native_simd<uint32_t>& b) noexcept
{
__m128i signbit = _mm_set1_epi32(0x80000000);
__m128i signbit = _mm_set1_epi32(static_cast<int32_t>(0x80000000));
__m128i a1 = _mm_xor_si128(a, signbit);
__m128i b1 = _mm_xor_si128(b, signbit);
return _mm_cmpgt_epi32(a1, b1); // signed compare
Expand All @@ -571,7 +571,7 @@ static inline native_simd<uint32_t> operator>(const native_simd<uint32_t>& a,
static inline native_simd<uint64_t> operator>(const native_simd<uint64_t>& a,
const native_simd<uint64_t>& b) noexcept
{
__m128i sign32 = _mm_set1_epi32(0x80000000); // sign bit of each dword
__m128i sign32 = _mm_set1_epi32(static_cast<int32_t>(0x80000000)); // sign bit of each dword
__m128i aflip = _mm_xor_si128(a, sign32); // a with sign bits flipped to use signed compare
__m128i bflip = _mm_xor_si128(b, sign32); // b with sign bits flipped to use signed compare
__m128i equal = _mm_cmpeq_epi32(a, b); // a == b, dwords
Expand Down
Loading

0 comments on commit 3a637d3

Please sign in to comment.