Skip to content

Commit

Permalink
add support for bidirectional iterators
Browse files Browse the repository at this point in the history
  • Loading branch information
maxbachmann committed Sep 21, 2023
1 parent 05946d7 commit cff24a0
Show file tree
Hide file tree
Showing 19 changed files with 285 additions and 85 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
## Changelog

## [2.1.0] -
### Changed
- add support for bidirectional iterators


### [2.0.0] - 2023-06-02
#### Changed
- added argument ``pad`` to Hamming distance. This controls whether sequences of different
Expand Down
6 changes: 3 additions & 3 deletions rapidfuzz/details/PatternMatchVector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,11 @@ struct BlockPatternMatchVector {
template <typename InputIt>
void insert(Range<InputIt> s) noexcept
{
auto len = s.size();
uint64_t mask = 1;
for (ptrdiff_t i = 0; i < len; ++i) {
ptrdiff_t i = 0;
for (auto iter = s.begin(); iter != s.end(); ++iter,++i) {
size_t block = static_cast<size_t>(i) / 64;
insert_mask(block, s[i], mask);
insert_mask(block, *iter, mask);
mask = rotl(mask, 1);
}
}
Expand Down
17 changes: 15 additions & 2 deletions rapidfuzz/details/Range.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,31 @@ class Range {
{
return !empty();
}

template <
typename... Dummy, typename IterCopy = Iter,
typename = std::enable_if_t<std::is_base_of_v<
std::random_access_iterator_tag, typename std::iterator_traits<IterCopy>::iterator_category>>>
constexpr decltype(auto) operator[](ptrdiff_t n) const
{
return _first[n];
}

constexpr void remove_prefix(ptrdiff_t n)
{
_first += n;
if constexpr (std::is_base_of_v<std::random_access_iterator_tag, typename std::iterator_traits<Iter>::iterator_category>)
_first += n;
else
for (ptrdiff_t i = 0; i < n; ++i)
_first++;
}
constexpr void remove_suffix(ptrdiff_t n)
{
_last -= n;
if constexpr (std::is_base_of_v<std::random_access_iterator_tag, typename std::iterator_traits<Iter>::iterator_category>)
_last -= n;
else
for (ptrdiff_t i = 0; i < n; ++i)
_last--;
}

constexpr Range subseq(ptrdiff_t pos = 0, ptrdiff_t count = std::numeric_limits<ptrdiff_t>::max())
Expand Down
12 changes: 8 additions & 4 deletions rapidfuzz/distance/DamerauLevenshtein_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,26 +55,28 @@ int64_t damerau_levenshtein_distance_zhao(Range<InputIt1> s1, Range<InputIt2> s2
IntType* R1 = &R1_arr[1];
IntType* FR = &FR_arr[1];

auto iter_s1 = s1.begin();
for (IntType i = 1; i <= len1; i++) {
std::swap(R, R1);
IntType last_col_id = -1;
IntType last_i2l1 = R[0];
R[0] = i;
IntType T = maxVal;

auto iter_s2 = s2.begin();
for (IntType j = 1; j <= len2; j++) {
ptrdiff_t diag = R1[j - 1] + static_cast<IntType>(s1[i - 1] != s2[j - 1]);
ptrdiff_t diag = R1[j - 1] + static_cast<IntType>(*iter_s1 != *iter_s2);
ptrdiff_t left = R[j - 1] + 1;
ptrdiff_t up = R1[j] + 1;
ptrdiff_t temp = std::min({diag, left, up});

if (s1[i - 1] == s2[j - 1]) {
if (*iter_s1 == *iter_s2) {
last_col_id = j; // last occurence of s1_i
FR[j] = R1[j - 2]; // save H_k-1,j-2
T = last_i2l1; // save H_i-2,l-1
}
else {
ptrdiff_t k = last_row_id.get(static_cast<uint64_t>(s2[j - 1])).val;
ptrdiff_t k = last_row_id.get(static_cast<uint64_t>(*iter_s2)).val;
ptrdiff_t l = last_col_id;

if ((j - l) == 1) {
Expand All @@ -89,8 +91,10 @@ int64_t damerau_levenshtein_distance_zhao(Range<InputIt1> s1, Range<InputIt2> s2

last_i2l1 = R[j];
R[j] = static_cast<IntType>(temp);
iter_s2++;
}
last_row_id[s1[i - 1]].val = i;
last_row_id[*iter_s1].val = i;
iter_s1++;
}

int64_t dist = R[s2.size()];
Expand Down
4 changes: 3 additions & 1 deletion rapidfuzz/distance/Hamming_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ class Hamming : public DistanceBase<Hamming, int64_t, 0, std::numeric_limits<int

ptrdiff_t min_len = std::min(s1.size(), s2.size());
int64_t dist = std::max(s1.size(), s2.size());
auto iter_s1 = s1.begin();
auto iter_s2 = s2.begin();
for (ptrdiff_t i = 0; i < min_len; ++i)
dist -= bool(s1[i] == s2[i]);
dist -= bool(*(iter_s1++) == *(iter_s2++));

return (dist <= score_cutoff) ? dist : score_cutoff + 1;
}
Expand Down
24 changes: 14 additions & 10 deletions rapidfuzz/distance/LCSseq_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,17 @@ int64_t lcs_seq_mbleven2018(Range<InputIt1> s1, Range<InputIt2> s2, int64_t scor
int64_t max_len = 0;

for (uint8_t ops : possible_ops) {
ptrdiff_t s1_pos = 0;
ptrdiff_t s2_pos = 0;
auto iter_s1 = s1.begin();
auto iter_s2 = s2.begin();
int64_t cur_len = 0;

while (s1_pos < len1 && s2_pos < len2) {
if (s1[s1_pos] != s2[s2_pos]) {
while (iter_s1 != s1.end() && iter_s2 != s2.end()) {
if (*iter_s1 != *iter_s2) {
if (!ops) break;
if (ops & 1)
s1_pos++;
iter_s1++;
else if (ops & 2)
s2_pos++;
iter_s2++;
#if defined(__GNUC__) && !defined(__clang__) && !defined(__ICC) && __GNUC__ < 10
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wconversion"
Expand All @@ -108,8 +108,8 @@ int64_t lcs_seq_mbleven2018(Range<InputIt1> s1, Range<InputIt2> s2, int64_t scor
}
else {
cur_len++;
s1_pos++;
s2_pos++;
iter_s1++;
iter_s2++;
}
}

Expand Down Expand Up @@ -171,16 +171,18 @@ auto lcs_unroll(const PMV& block, Range<InputIt1>, Range<InputIt2> s2, int64_t s
LCSseqResult<RecordMatrix> res;
if constexpr (RecordMatrix) res.S = ShiftedBitMatrix<uint64_t>(s2.size(), N, ~UINT64_C(0));

auto iter_s2 = s2.begin();
for (ptrdiff_t i = 0; i < s2.size(); ++i) {
uint64_t carry = 0;
unroll<size_t, N>([&](size_t word) {
uint64_t Matches = block.get(word, s2[i]);
uint64_t Matches = block.get(word, *iter_s2);
uint64_t u = S[word] & Matches;
uint64_t x = addc64(S[word], u, carry, &carry);
S[word] = x | (S[word] - u);

if constexpr (RecordMatrix) res.S[i][word] = S[word];
});
iter_s2++;
}

res.sim = 0;
Expand All @@ -201,10 +203,11 @@ auto lcs_blockwise(const PMV& block, Range<InputIt1>, Range<InputIt2> s2, int64_
LCSseqResult<RecordMatrix> res;
if constexpr (RecordMatrix) res.S = ShiftedBitMatrix<uint64_t>(s2.size(), words, ~UINT64_C(0));

auto iter_s2 = s2.begin();
for (ptrdiff_t i = 0; i < s2.size(); ++i) {
uint64_t carry = 0;
for (size_t word = 0; word < words; ++word) {
const uint64_t Matches = block.get(word, s2[i]);
const uint64_t Matches = block.get(word, *iter_s2);
uint64_t Stemp = S[word];

uint64_t u = Stemp & Matches;
Expand All @@ -214,6 +217,7 @@ auto lcs_blockwise(const PMV& block, Range<InputIt1>, Range<InputIt2> s2, int64_

if constexpr (RecordMatrix) res.S[i][word] = S[word];
}
iter_s2++;
}

res.sim = 0;
Expand Down
49 changes: 27 additions & 22 deletions rapidfuzz/distance/Levenshtein_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ int64_t levenshtein_mbleven2018(Range<InputIt1> s1, Range<InputIt2> s2, int64_t
assert(len1 > 0);
assert(len2 > 0);
assert(*s1.begin() != *s2.begin());
assert(*(s1.end() - 1) != *(s2.end() - 1));
assert(*(--s1.end()) != *(--s2.end()));

if (len1 < len2) return levenshtein_mbleven2018(s2, s1, max);

Expand All @@ -175,15 +175,16 @@ int64_t levenshtein_mbleven2018(Range<InputIt1> s1, Range<InputIt2> s2, int64_t
int64_t dist = max + 1;

for (uint8_t ops : possible_ops) {
ptrdiff_t s1_pos = 0;
ptrdiff_t s2_pos = 0;
auto iter_s1 = s1.begin();
auto iter_s2 = s2.begin();
int64_t cur_dist = 0;
while (s1_pos < len1 && s2_pos < len2) {
if (s1[s1_pos] != s2[s2_pos]) {

while (iter_s1 != s1.end() && iter_s2 != s2.end()) {
if (*iter_s1 != *iter_s2) {
cur_dist++;
if (!ops) break;
if (ops & 1) s1_pos++;
if (ops & 2) s2_pos++;
if (ops & 1) iter_s1++;
if (ops & 2) iter_s2++;
#if defined(__GNUC__) && !defined(__clang__) && !defined(__ICC) && __GNUC__ < 10
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wconversion"
Expand All @@ -194,11 +195,11 @@ int64_t levenshtein_mbleven2018(Range<InputIt1> s1, Range<InputIt2> s2, int64_t
#endif
}
else {
s1_pos++;
s2_pos++;
iter_s1++;
iter_s2++;
}
}
cur_dist += (len1 - s1_pos) + (len2 - s2_pos);
cur_dist += std::distance(iter_s1, s1.end()) + std::distance(iter_s2, s2.end());
dist = std::min(dist, cur_dist);
}

Expand Down Expand Up @@ -245,9 +246,10 @@ auto levenshtein_hyrroe2003(const PM_Vec& PM, Range<InputIt1> s1, Range<InputIt2
uint64_t mask = UINT64_C(1) << (s1.size() - 1);

/* Searching */
for (ptrdiff_t i = 0; i < s2.size(); ++i) {
auto iter_s2 = s2.begin();
for (ptrdiff_t i = 0; iter_s2 != s2.end(); ++iter_s2,++i) {
/* Step 1: Computing D0 */
uint64_t PM_j = PM.get(0, s2[i]);
uint64_t PM_j = PM.get(0, *iter_s2);
uint64_t X = PM_j;
uint64_t D0 = (((X & VP) + VP) ^ VP) | X | VN;

Expand Down Expand Up @@ -494,25 +496,27 @@ auto levenshtein_hyrroe2003_small_band(Range<InputIt1> s1, Range<InputIt2> s2, i
int64_t break_score = max + s2.size() - (s1.size() - max);
HybridGrowingHashmap<typename Range<InputIt1>::value_type, std::pair<ptrdiff_t, uint64_t>> PM;

for (ptrdiff_t j = -max; j < 0; ++j) {
auto& x = PM[s1[j + max]];
auto iter_s1 = s1.begin();
for (ptrdiff_t j = -max; j < 0; ++iter_s1,++j) {
auto& x = PM[*iter_s1];
x.second = shr64(x.second, j - x.first) | (UINT64_C(1) << 63);
x.first = j;
}

/* Searching */
ptrdiff_t i = 0;
for (; i < s1.size() - max; ++i) {
auto iter_s2 = s2.begin();
for (; i < s1.size() - max; ++iter_s2,++iter_s1,++i) {
/* Step 1: Computing D0 */
/* update bitmasks online */
uint64_t PM_j = 0;
if (i + max < s1.size()) {
auto& x = PM[s1[i + max]];
auto& x = PM[*iter_s1];
x.second = shr64(x.second, i - x.first) | (UINT64_C(1) << 63);
x.first = i;
}
{
auto x = PM.get(s2[i]);
auto x = PM.get(*iter_s2);
PM_j = shr64(x.second, i - x.first);
}

Expand Down Expand Up @@ -541,17 +545,17 @@ auto levenshtein_hyrroe2003_small_band(Range<InputIt1> s1, Range<InputIt2> s2, i
}
}

for (; i < s2.size(); ++i) {
for (; i < s2.size(); ++iter_s2,++iter_s1,++i) {
/* Step 1: Computing D0 */
/* update bitmasks online */
uint64_t PM_j = 0;
if (i + max < s1.size()) {
auto& x = PM[s1[i + max]];
auto& x = PM[*iter_s1];
x.second = shr64(x.second, i - x.first) | (UINT64_C(1) << 63);
x.first = i;
}
{
auto x = PM.get(s2[i]);
auto x = PM.get(*iter_s2);
PM_j = shr64(x.second, i - x.first);
}

Expand Down Expand Up @@ -631,7 +635,8 @@ auto levenshtein_hyrroe2003_block(const BlockPatternMatchVector& PM, Range<Input
1;

/* Searching */
for (ptrdiff_t row = 0; row < s2.size(); ++row) {
auto iter_s2 = s2.begin();
for (ptrdiff_t row = 0; row < s2.size(); ++iter_s2,++row) {
uint64_t HP_carry = 1;
uint64_t HN_carry = 0;

Expand All @@ -642,7 +647,7 @@ auto levenshtein_hyrroe2003_block(const BlockPatternMatchVector& PM, Range<Input

auto advance_block = [&](size_t word) {
/* Step 1: Computing D0 */
uint64_t PM_j = PM.get(word, s2[row]);
uint64_t PM_j = PM.get(word, *iter_s2);
uint64_t VN = vecs[word].VN;
uint64_t VP = vecs[word].VP;

Expand Down
5 changes: 3 additions & 2 deletions rapidfuzz/distance/OSA_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ int64_t osa_hyrroe2003_block(const BlockPatternMatchVector& PM, Range<InputIt1>
std::vector<Row> new_vecs(words + 1);

/* Searching */
for (ptrdiff_t row = 0; row < s2.size(); ++row) {
auto iter_s2 = s2.begin();
for (ptrdiff_t row = 0; row < s2.size(); ++iter_s2,++row) {
uint64_t HP_carry = 1;
uint64_t HN_carry = 0;

Expand All @@ -210,7 +211,7 @@ int64_t osa_hyrroe2003_block(const BlockPatternMatchVector& PM, Range<InputIt1>
/* PM of last word */
uint64_t PM_last = new_vecs[word].PM;

uint64_t PM_j = PM.get(word, s2[row]);
uint64_t PM_j = PM.get(word, *iter_s2);
uint64_t X = PM_j;
uint64_t TR = ((((~D0) & X) << 1) | (((~D0_last) & PM_last) >> 63)) & PM_j_old;

Expand Down
2 changes: 1 addition & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ else()
FetchContent_Declare(
Catch2
GIT_REPOSITORY https://github.com/catchorg/Catch2.git
GIT_TAG v3.0.1
GIT_TAG v3.4.0
)
FetchContent_MakeAvailable(Catch2)
endif()
Expand Down
Loading

0 comments on commit cff24a0

Please sign in to comment.