Skip to content

Commit

Permalink
Replace div_two trick in invNTT.
Browse files Browse the repository at this point in the history
1. I found the div_two trick in invNTT is slow, which introduces O(NlgN) more reductions.
   It seems better to merge the n^{-1} step with the last layer of invNTT.
   Actually, the modulo multiplication with a fixed value is fast via Shoup's accleration.
  • Loading branch information
fionser committed Mar 24, 2020
1 parent 1561711 commit 82f595a
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 155 deletions.
211 changes: 94 additions & 117 deletions native/src/seal/util/smallntt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ namespace seal
scaled_root_powers_.release();
inv_root_powers_.release();
scaled_inv_root_powers_.release();
inv_root_powers_div_two_.release();
scaled_inv_root_powers_div_two_.release();
inv_degree_modulo_ = 0;
coeff_count_power_ = 0;
coeff_count_ = 0;
Expand All @@ -68,8 +66,6 @@ namespace seal
inv_root_powers_ = allocate_uint(coeff_count_, pool_);
scaled_root_powers_ = allocate_uint(coeff_count_, pool_);
scaled_inv_root_powers_ = allocate_uint(coeff_count_, pool_);
inv_root_powers_div_two_ = allocate_uint(coeff_count_, pool_);
scaled_inv_root_powers_div_two_ = allocate_uint(coeff_count_, pool_);
modulus_ = modulus;

// We defer parameter checking to try_minimal_primitive_root(...)
Expand Down Expand Up @@ -98,16 +94,6 @@ namespace seal
ntt_scale_powers_of_primitive_root(inv_root_powers_.get(),
scaled_inv_root_powers_.get());

// Populate the tables storing (scaled version of ) 2 times
// powers of roots^-1 mod q in bit-scrambled order.
for (size_t i = 0; i < coeff_count_; i++)
{
inv_root_powers_div_two_[i] =
div2_uint_mod(inv_root_powers_[i], modulus_);
}
ntt_scale_powers_of_primitive_root(inv_root_powers_div_two_.get(),
scaled_inv_root_powers_div_two_.get());

// Reordering inv_root_powers_ so that the access pattern at inverse NTT is sequential.
std::vector<uint64_t> tmp(coeff_count_);
uint64_t *ptr = tmp.data() + 1;
Expand Down Expand Up @@ -153,18 +139,72 @@ namespace seal

// compute floor ( input * beta /q ), where beta is a 64k power of 2
// and 0 < q < beta.
static inline uint64_t precompute_mulmod(uint64_t y, uint64_t p) {
uint64_t wide_quotient[2]{ 0, 0 };
uint64_t wide_coeff[2]{ 0, y };
divide_uint128_uint64_inplace(wide_coeff, p, wide_quotient);
return wide_quotient[0];
}

void SmallNTTTables::ntt_scale_powers_of_primitive_root(
const uint64_t *input, uint64_t *destination) const
{
for (size_t i = 0; i < coeff_count_; i++, input++, destination++)
{
uint64_t wide_quotient[2]{ 0, 0 };
uint64_t wide_coeff[2]{ 0, *input };
divide_uint128_uint64_inplace(wide_coeff, modulus_.value(), wide_quotient);
*destination = wide_quotient[0];
*destination = precompute_mulmod(*input, modulus_.value());
}
}

struct ntt_body {
const uint64_t modulus, two_times_modulus;
ntt_body(uint64_t modulus) : modulus(modulus), two_times_modulus(modulus << 1) {}

// x0' <- x0 + w * x1
// x1' <- x0 - w * x1
inline void forward(uint64_t *x0, uint64_t *x1, uint64_t W, uint64_t Wprime) const {
uint64_t u = *x0;
uint64_t v = mulmod_lazy(*x1, W, Wprime);

u -= select(two_times_modulus, u < two_times_modulus);
*x0 = u + v;
*x1 = u - v + two_times_modulus;
}

// x0' <- x0 + x1
// x1' <- x0 - w * x1
inline void backward(uint64_t *x0, uint64_t *x1, uint64_t W, uint64_t Wprime) const {
uint64_t u = *x0;
uint64_t v = *x1;
uint64_t t = u + v;
t -= select(two_times_modulus, t < two_times_modulus);

*x0 = t;
*x1 = mulmod_lazy(u - v + two_times_modulus, W, Wprime);
}

inline void backward_last(uint64_t *x0, uint64_t *x1, uint64_t inv_N, uint64_t inv_Nprime, uint64_t inv_N_W, uint64_t inv_N_Wprime) const {
uint64_t u = *x0;
uint64_t v = *x1;
uint64_t t = u + v;
t -= select(two_times_modulus, t < two_times_modulus);

*x0 = mulmod_lazy(t, inv_N, inv_Nprime);
*x1 = mulmod_lazy(u - v + two_times_modulus, inv_N_W, inv_N_Wprime);
}

// x * y mod p using Shoup's trick, i.e., yprime = floor(2^64 * y / p)
inline uint64_t mulmod_lazy(uint64_t x, uint64_t y, uint64_t yprime) const {
unsigned long long q;
multiply_uint64_hw64(x, yprime, &q);
return x * y - q * modulus;
}

// return 0 if cond = true, else return b if cond = false
inline uint64_t select(uint64_t b, bool cond) const {
return (b & -(uint64_t) cond) ^ b;
}
};

/**
This function computes in-place the negacyclic NTT. The input is
a polynomial a of degree n in R_q, where n is assumed to be a power of
Expand All @@ -178,10 +218,8 @@ namespace seal
void ntt_negacyclic_harvey_lazy(uint64_t *operand,
const SmallNTTTables &tables)
{
uint64_t modulus = tables.modulus().value();
uint64_t two_times_modulus = modulus * 2;
ntt_body ntt(tables.modulus().value());

// Return the NTT in scrambled order
size_t n = size_t(1) << tables.coeff_count_power();
size_t t = n >> 1;
for (size_t m = 1; m < n; m <<= 1)
Expand All @@ -197,33 +235,12 @@ namespace seal

uint64_t *X = operand + j1;
uint64_t *Y = X + t;
uint64_t currX;
unsigned long long Q;
for (size_t j = j1; j < j2; j += 4)
{
currX = *X - (two_times_modulus & static_cast<uint64_t>(-static_cast<int64_t>(*X >= two_times_modulus)));
multiply_uint64_hw64(Wprime, *Y, &Q);
Q = *Y * W - Q * modulus;
*X++ = currX + Q;
*Y++ = currX + (two_times_modulus - Q);

currX = *X - (two_times_modulus & static_cast<uint64_t>(-static_cast<int64_t>(*X >= two_times_modulus)));
multiply_uint64_hw64(Wprime, *Y, &Q);
Q = *Y * W - Q * modulus;
*X++ = currX + Q;
*Y++ = currX + (two_times_modulus - Q);

currX = *X - (two_times_modulus & static_cast<uint64_t>(-static_cast<int64_t>(*X >= two_times_modulus)));
multiply_uint64_hw64(Wprime, *Y, &Q);
Q = *Y * W - Q * modulus;
*X++ = currX + Q;
*Y++ = currX + (two_times_modulus - Q);

currX = *X - (two_times_modulus & static_cast<uint64_t>(-static_cast<int64_t>(*X >= two_times_modulus)));
multiply_uint64_hw64(Wprime, *Y, &Q);
Q = *Y * W - Q * modulus;
*X++ = currX + Q;
*Y++ = currX + (two_times_modulus - Q);
ntt.forward(X++, Y++, W, Wprime);
ntt.forward(X++, Y++, W, Wprime);
ntt.forward(X++, Y++, W, Wprime);
ntt.forward(X++, Y++, W, Wprime);
}
}
}
Expand All @@ -238,17 +255,9 @@ namespace seal

uint64_t *X = operand + j1;
uint64_t *Y = X + t;
uint64_t currX;
unsigned long long Q;
for (size_t j = j1; j < j2; j++)
{
// The Harvey butterfly: assume X, Y in [0, 2p), and return X', Y' in [0, 4p).
// X', Y' = X + WY, X - WY (mod p).
currX = *X - (two_times_modulus & static_cast<uint64_t>(-static_cast<int64_t>(*X >= two_times_modulus)));
multiply_uint64_hw64(Wprime, *Y, &Q);
Q = W * *Y - Q * modulus;
*X++ = currX + Q;
*Y++ = currX + (two_times_modulus - Q);
ntt.forward(X++, Y++, W, Wprime);
}
}
}
Expand All @@ -259,103 +268,71 @@ namespace seal
// Inverse negacyclic NTT using Harvey's butterfly. (See Patrick Longa and Michael Naehrig).
void inverse_ntt_negacyclic_harvey_lazy(uint64_t *operand, const SmallNTTTables &tables)
{
uint64_t modulus = tables.modulus().value();
uint64_t two_times_modulus = modulus * 2;
ntt_body ntt(tables.modulus().value());

// return the bit-reversed order of NTT.
size_t n = size_t(1) << tables.coeff_count_power();
const size_t n = size_t(1) << tables.coeff_count_power();
size_t t = 1;

for (size_t m = n; m > 1; m >>= 1)
size_t inv_root_index = 1;
// m > 2 to skip the last layer
for (size_t m = n; m > 2; m >>= 1)
{
size_t j1 = 0;
size_t h = m >> 1;
if (t >= 4)
{
for (size_t i = 0; i < h; i++)
for (size_t i = 0; i < h; i++, ++inv_root_index)
{
size_t j2 = j1 + t;
// Need the powers of phi^{-1} in bit-reversed order
const uint64_t W = tables.get_from_inv_root_powers_div_two(h + i);
const uint64_t Wprime = tables.get_from_scaled_inv_root_powers_div_two(h + i);
const uint64_t W = tables.get_from_inv_root_powers(inv_root_index);
const uint64_t Wprime = tables.get_from_scaled_inv_root_powers(inv_root_index);

uint64_t *U = operand + j1;
uint64_t *V = U + t;
uint64_t currU;
uint64_t T;
unsigned long long H;
for (size_t j = j1; j < j2; j += 4)
{
T = two_times_modulus - *V + *U;
currU = *U + *V - (two_times_modulus & static_cast<uint64_t>(-static_cast<int64_t>((*U << 1) >= T)));
*U++ = (currU + (modulus & static_cast<uint64_t>(-static_cast<int64_t>(T & 1)))) >> 1;
multiply_uint64_hw64(Wprime, T, &H);
*V++ = T * W - H * modulus;

T = two_times_modulus - *V + *U;
currU = *U + *V - (two_times_modulus & static_cast<uint64_t>(-static_cast<int64_t>((*U << 1) >= T)));
*U++ = (currU + (modulus & static_cast<uint64_t>(-static_cast<int64_t>(T & 1)))) >> 1;
multiply_uint64_hw64(Wprime, T, &H);
*V++ = T * W - H * modulus;

T = two_times_modulus - *V + *U;
currU = *U + *V - (two_times_modulus & static_cast<uint64_t>(-static_cast<int64_t>((*U << 1) >= T)));
*U++ = (currU + (modulus & static_cast<uint64_t>(-static_cast<int64_t>(T & 1)))) >> 1;
multiply_uint64_hw64(Wprime, T, &H);
*V++ = T * W - H * modulus;

T = two_times_modulus - *V + *U;
currU = *U + *V - (two_times_modulus & static_cast<uint64_t>(-static_cast<int64_t>((*U << 1) >= T)));
*U++ = (currU + (modulus & static_cast<uint64_t>(-static_cast<int64_t>(T & 1)))) >> 1;
multiply_uint64_hw64(Wprime, T, &H);
*V++ = T * W - H * modulus;
ntt.backward(U++, V++, W, Wprime);
ntt.backward(U++, V++, W, Wprime);
ntt.backward(U++, V++, W, Wprime);
ntt.backward(U++, V++, W, Wprime);
}
j1 += (t << 1);
}
}
else
{
for (size_t i = 0; i < h; i++)
for (size_t i = 0; i < h; i++, ++inv_root_index)
{
size_t j2 = j1 + t;
// Need the powers of phi^{-1} in bit-reversed order
const uint64_t W = tables.get_from_inv_root_powers_div_two(h + i);
const uint64_t Wprime = tables.get_from_scaled_inv_root_powers_div_two(h + i);
const uint64_t W = tables.get_from_inv_root_powers(inv_root_index);
const uint64_t Wprime = tables.get_from_scaled_inv_root_powers(inv_root_index);

uint64_t *U = operand + j1;
uint64_t *V = U + t;
uint64_t currU;
uint64_t T;
unsigned long long H;
for (size_t j = j1; j < j2; j++)
{
// U = x[i], V = x[i+m]

// Compute U - V + 2q
T = two_times_modulus - *V + *U;

// Cleverly check whether currU + currV >= two_times_modulus
currU = *U + *V - (two_times_modulus & static_cast<uint64_t>(-static_cast<int64_t>((*U << 1) >= T)));

// Need to make it so that div2_uint_mod takes values that are > q.
//div2_uint_mod(U, modulusptr, coeff_uint64_count, U);
// We use also the fact that parity of currU is same as parity of T.
// Since our modulus is always so small that currU + masked_modulus < 2^64,
// we never need to worry about wrapping around when adding masked_modulus.
//uint64_t masked_modulus = modulus & static_cast<uint64_t>(-static_cast<int64_t>(T & 1));
//uint64_t carry = add_uint64(currU, masked_modulus, 0, &currU);
//currU += modulus & static_cast<uint64_t>(-static_cast<int64_t>(T & 1));
*U++ = (currU + (modulus & static_cast<uint64_t>(-static_cast<int64_t>(T & 1)))) >> 1;

multiply_uint64_hw64(Wprime, T, &H);
// effectively, the next two multiply perform multiply modulo beta = 2**wordsize.
*V++ = W * T - H * modulus;
ntt.backward(U++, V++, W, Wprime);
}
j1 += (t << 1);
}
}
t <<= 1;
}

// merge n^{-1} with the last layer of invNTT
const uint64_t W = tables.get_from_inv_root_powers(inv_root_index);
const uint64_t inv_N = *(tables.get_inv_degree_modulo());
const uint64_t inv_N_W = multiply_uint_uint_mod(inv_N, W, tables.modulus());
const uint64_t inv_Nprime = precompute_mulmod(inv_N, tables.modulus().value());
const uint64_t inv_N_Wprime = precompute_mulmod(inv_N_W, tables.modulus().value());

uint64_t *U = operand;
uint64_t *V = U + (n / 2);
for (size_t j = n / 2; j < n; j++)
{
ntt.backward_last(U++, V++, inv_N, inv_Nprime, inv_N_W, inv_N_Wprime);
}
}
}
}
38 changes: 0 additions & 38 deletions native/src/seal/util/smallntt.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,38 +113,6 @@ namespace seal
return scaled_inv_root_powers_[index];
}

SEAL_NODISCARD inline auto get_from_inv_root_powers_div_two(
std::size_t index) const -> std::uint64_t
{
#ifdef SEAL_DEBUG
if (index >= coeff_count_)
{
throw std::out_of_range("index");
}
if (!generated_)
{
throw std::logic_error("tables are not generated");
}
#endif
return inv_root_powers_div_two_[index];
}

SEAL_NODISCARD inline auto get_from_scaled_inv_root_powers_div_two(
std::size_t index) const -> std::uint64_t
{
#ifdef SEAL_DEBUG
if (index >= coeff_count_)
{
throw std::out_of_range("index");
}
if (!generated_)
{
throw std::logic_error("tables are not generated");
}
#endif
return scaled_inv_root_powers_div_two_[index];
}

SEAL_NODISCARD inline auto get_inv_degree_modulo() const
-> const std::uint64_t*
{
Expand Down Expand Up @@ -203,12 +171,6 @@ namespace seal
// Size coeff_count_
Pointer<decltype(root_)> scaled_root_powers_;

// Size coeff_count_
Pointer<decltype(root_)> inv_root_powers_div_two_;

// Size coeff_count_
Pointer<decltype(root_)> scaled_inv_root_powers_div_two_;

int coeff_count_power_ = 0;

std::size_t coeff_count_ = 0;
Expand Down

0 comments on commit 82f595a

Please sign in to comment.