Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Knuth's division #131

Merged
merged 6 commits into from
Mar 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 62 additions & 35 deletions lib/intx/div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,28 @@
// Licensed under the Apache License, Version 2.0.

#include "div.hpp"
#include <cassert>
#include <tuple>

#if defined(_MSC_VER)
#define UNREACHABLE __assume(0)
#else
#define UNREACHABLE __builtin_unreachable()
#endif

#if defined(_MSC_VER)
#define UNLIKELY(EXPR) EXPR
#else
#define UNLIKELY(EXPR) __builtin_expect((bool)(EXPR), false)
#endif

#if defined(NDEBUG)
#define REQUIRE(X) \
if (!(X)) \
UNREACHABLE
#else
#define REQUIRE assert
#endif

namespace intx
{
Expand Down Expand Up @@ -54,6 +76,37 @@ inline uint128 udivrem_by2(uint64_t u[], int m, uint128 d) noexcept
return r;
}

/// s = x + y.
inline bool add(uint64_t s[], const uint64_t x[], const uint64_t y[], int len) noexcept
{
// OPT: Add MinLen template parameter and unroll first loop iterations.
REQUIRE(len >= 3);

bool carry = false;
for (int i = 0; i < len; ++i)
std::tie(s[i], carry) = add_with_carry(x[i], y[i], carry);
return carry;
}

/// r = x - multiplier * y.
inline uint64_t submul(
uint64_t r[], const uint64_t x[], const uint64_t y[], int len, uint64_t multiplier) noexcept
{
// OPT: Add MinLen template parameter and unroll first loop iterations.
REQUIRE(len >= 3);

uint64_t borrow = 0;
for (int i = 0; i < len; ++i)
{
const auto s = sub_with_carry(x[i], borrow);
const auto p = umul(y[i], multiplier);
const auto t = sub_with_carry(s.value, p.lo);
r[i] = t.value;
borrow = p.hi + s.carry + t.carry;
}
return borrow;
}

void udivrem_knuth(uint64_t q[], uint64_t un[], int m, const uint64_t vn[], int n) noexcept
{
const auto divisor = uint128{vn[n - 1], vn[n - 2]};
Expand All @@ -67,67 +120,41 @@ void udivrem_knuth(uint64_t q[], uint64_t un[], int m, const uint64_t vn[], int
uint64_t qhat;
uint128 rhat;
const auto dividend = uint128{u2, u1};
if (dividend.hi >= divisor.hi) // Will overflow:
if (UNLIKELY(dividend.hi >= divisor.hi)) // Division overflows.
{
qhat = ~uint64_t{0};
rhat = dividend - uint128{divisor.hi, 0};
rhat += divisor.hi;

// Adjustment.
// OPT: This is not needed but helps avoiding negative case.
// Adjustment (not needed for correctness, but helps avoiding "add back" case).
if (rhat.hi == 0 && umul(qhat, divisor.lo) > uint128{rhat.lo, u0})
--qhat;
}
else
{
auto res = udivrem_2by1(dividend, divisor.hi, reciprocal);
const auto res = udivrem_2by1(dividend, divisor.hi, reciprocal);
qhat = res.quot;
rhat = res.rem;

if (umul(qhat, divisor.lo) > uint128{rhat.lo, u0})
const auto p = umul(qhat, divisor.lo);
if (p > uint128{rhat.lo, u0})
{
--qhat;
rhat += divisor.hi;

// Adjustment.
// OPT: This is not needed but helps avoiding negative case.
if (rhat.hi == 0 && umul(qhat, divisor.lo) > uint128{rhat.lo, u0})
// Adjustment (not needed for correctness, but helps avoiding "add back" case).
if (rhat.hi == 0 && (p - divisor.lo) > uint128{rhat.lo, u0})
--qhat;
}
}

// Multiply and subtract.
uint64_t borrow = 0;
for (int i = 0; i < n; ++i)
{
const auto p = umul(qhat, vn[i]);
const auto s = uint128{un[i + j]} - borrow - p.lo;
un[i + j] = s.lo;
borrow = p.hi - s.hi;
}

const auto borrow = submul(&un[j], &un[j], vn, n, qhat);
un[j + n] = u2 - borrow;
if (u2 < borrow) // Too much subtracted, add back.
{
--qhat;

uint64_t carry = 0;
for (int i = 0; i < n; ++i)
{
auto s = uint128{un[i + j]} + vn[i] + carry;
un[i + j] = s.lo;
carry = s.hi;
}
un[j + n] += carry;

// TODO: Consider this alternative implementation:
// bool k = false;
// for (int i = 0; i < n; ++i) {
// auto limit = std::min(un[j+i],vn[i]);
// un[i + j] += vn[i] + k;
// k = un[i + j] < limit || (k && un[i + j] == limit);
// }
// un[j+n] += k;
un[j + n] += add(&un[j], &un[j], vn, n);
}

q[j] = qhat; // Store quotient digit.
Expand Down
12 changes: 12 additions & 0 deletions test/unittests/test_div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,18 @@ static div_test_case<uint512> div_test_cases[] = {
0x10000000000000001_u512,
0x7fffffffffffffff80000000000000000000000000000000_u512,
},
{
0x00e8e8e8e2000100000009ea02000000000000ff3ffffff800000010002200000000000000000000000000000000000000000000000000000000000000000000_u512,
0x00e8e8e8e2000100000009ea02000000000000ff3ffffff800000010002280ff0000000000000000000000000000000000000000000000000000000000000000_u512,
0,
0x00e8e8e8e2000100000009ea02000000000000ff3ffffff800000010002200000000000000000000000000000000000000000000000000000000000000000000_u512,
},
{
0x000000c9700000000000000000023f00c00014ff0000000000000000223008050000000000000000000000000000000000000000000000000000000000000000_u512,
0x00000000c9700000000000000000023f00c00014ff002c0000000000002231080000000000000000000000000000000000000000000000000000000000000000_u512,
0xff,
0x00000000c9700000000000000000023f00c00014fed42c00000000000021310d0000000000000000000000000000000000000000000000000000000000000000_u512,
},
};

TEST(div, udivrem_512)
Expand Down