Skip to content

Commit

Permalink
feat: faster square roots (#2694)
Browse files Browse the repository at this point in the history
We use the Tonelli-Shanks square root algorithm to perform square roots,
required for deriving generators on the Grumpkin curve.

Our existing implementation uses a slow algorithm that requires ~1,000
field multiplications per square root.

This PR implements a newer algorithm by Bernstein that uses precomputed
lookup tables to increase performance

https://cr.yp.to/papers/sqroot-20011123-retypeset20220327.pdf
  • Loading branch information
zac-williamson authored Oct 30, 2024
1 parent cf2ca2e commit 722ec5c
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,17 @@ TEST(secp256k1, TestSqr)
}
}

TEST(secp256k1, SqrtRandom)
{
size_t n = 1;
for (size_t i = 0; i < n; ++i) {
secp256k1::fq input = secp256k1::fq::random_element().sqr();
auto [is_sqr, root] = input.sqrt();
secp256k1::fq root_test = root.sqr();
EXPECT_EQ(root_test, input);
}
}

TEST(secp256k1, TestArithmetic)
{
secp256k1::fq a = secp256k1::fq::random_element();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,10 @@ template <class Params_> struct alignas(32) field {
*
* @return <true, root> if the element is a quadratic remainder, <false, 0> if it's not
*/
constexpr std::pair<bool, field> sqrt() const noexcept;

constexpr std::pair<bool, field> sqrt() const noexcept
requires((Params_::modulus_0 & 0x3UL) == 0x3UL);
constexpr std::pair<bool, field> sqrt() const noexcept
requires((Params_::modulus_0 & 0x3UL) != 0x3UL);
BB_INLINE constexpr void self_neg() & noexcept;

BB_INLINE constexpr void self_to_montgomery_form() & noexcept;
Expand Down
236 changes: 161 additions & 75 deletions barretenberg/cpp/src/barretenberg/ecc/fields/field_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,102 +452,187 @@ template <class T> void field<T>::batch_invert(std::span<field> coeffs) noexcept
}
}

/**
* @brief Implements an optimised variant of Tonelli-Shanks via lookup tables.
* Algorithm taken from https://cr.yp.to/papers/sqroot-20011123-retypeset20220327.pdf
* "FASTER SQUARE ROOTS IN ANNOYING FINITE FIELDS" by D. Bernstein
* Page 5 "Accelerated Discrete Logarithm"
* @tparam T
* @return constexpr field<T>
*/
template <class T> constexpr field<T> field<T>::tonelli_shanks_sqrt() const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::tonelli_shanks_sqrt");
// Tonelli-shanks algorithm begins by finding a field element Q and integer S,
// such that (p - 1) = Q.2^{s}

// We can compute the square root of a, by considering a^{(Q + 1) / 2} = R
// Once we have found such an R, we have
// R^{2} = a^{Q + 1} = a^{Q}a
// If a^{Q} = 1, we have found our square root.
// Otherwise, we have a^{Q} = t, where t is a 2^{s-1}'th root of unity.
// This is because t^{2^{s-1}} = a^{Q.2^{s-1}}.
// We know that (p - 1) = Q.w^{s}, therefore t^{2^{s-1}} = a^{(p - 1) / 2}
// From Euler's criterion, if a is a quadratic residue, a^{(p - 1) / 2} = 1
// i.e. t^{2^{s-1}} = 1

// To proceed with computing our square root, we want to transform t into a smaller subgroup,
// specifically, the (s-2)'th roots of unity.
// We do this by finding some value b,such that
// (t.b^2)^{2^{s-2}} = 1 and R' = R.b
// Finding such a b is trivial, because from Euler's criterion, we know that,
// for any quadratic non-residue z, z^{(p - 1) / 2} = -1
// i.e. z^{Q.2^{s-1}} = -1
// => z^Q is a 2^{s-1}'th root of -1
// => z^{Q^2} is a 2^{s-2}'th root of -1
// Since t^{2^{s-1}} = 1, we know that t^{2^{s - 2}} = -1
// => t.z^{Q^2} is a 2^{s - 2}'th root of unity.

// We can iteratively transform t into ever smaller subgroups, until t = 1.
// At each iteration, we need to find a new value for b, which we can obtain
// by repeatedly squaring z^{Q}
constexpr uint256_t Q = (modulus - 1) >> static_cast<uint64_t>(primitive_root_log_size() - 1);
constexpr uint256_t Q_minus_one_over_two = (Q - 1) >> 2;

// __to_montgomery_form(Q_minus_one_over_two, Q_minus_one_over_two);
field z = coset_generator(0); // the generator is a non-residue
field b = pow(Q_minus_one_over_two);
field r = operator*(b); // r = a^{(Q + 1) / 2}
field t = r * b; // t = a^{(Q - 1) / 2 + (Q + 1) / 2} = a^{Q}
// We can determine s by counting the least significant set bit of `p - 1`
// We pick elements `r, g` such that g = r^Q and r is not a square.
// (the coset generators are all nonresidues and satisfy this condition)
//
// To find the square root of `u`, consider `v = u^(Q - 1 / 2)`
// There exists an integer `e` where uv^2 = g^e (see Theorem 3.1 in paper).
// If `u` is a square, `e` is even and (uvg^{−e/2})^2 = u^2v^2g^e = u^{Q+1}g^{-e} = u
//
// The goal of the algorithm is two fold:
// 1. find `e` given `u`
// 2. compute `sqrt(u) = uvg^{−e/2}`
constexpr uint256_t Q = (modulus - 1) >> static_cast<uint64_t>(primitive_root_log_size());
constexpr uint256_t Q_minus_one_over_two = (Q - 1) >> 1;
field v = pow(Q_minus_one_over_two);
field uv = operator*(v); // uv = u^{(Q + 1) / 2}
// uvv = g^e for some unknown e. Goal is to find e.
field uvv = uv * v; // uvv = u^{(Q - 1) / 2 + (Q + 1) / 2} = u^{Q}

// check if t is a square with euler's criterion
// if not, we don't have a quadratic residue and a has no square root!
field check = t;
field check = uvv;
for (size_t i = 0; i < primitive_root_log_size() - 1; ++i) {
check.self_sqr();
}
if (check != one()) {
return zero();
if (check != 1) {
return 0;
}
field t1 = z.pow(Q_minus_one_over_two);
field t2 = t1 * z;
field c = t2 * t1; // z^Q

size_t m = primitive_root_log_size();
constexpr field g = coset_generator(0).pow(Q);
constexpr field g_inv = coset_generator(0).pow(modulus - 1 - Q);
constexpr size_t root_bits = primitive_root_log_size();
constexpr size_t table_bits = 6;
constexpr size_t num_tables = root_bits / table_bits + (root_bits % table_bits != 0 ? 1 : 0);
constexpr size_t num_offset_tables = num_tables - 1;
constexpr size_t table_size = static_cast<size_t>(1UL) << table_bits;

using GTable = std::array<field, table_size>;
constexpr auto get_g_table = [&](const field& h) {
GTable result;
result[0] = 1;
for (size_t i = 1; i < table_size; ++i) {
result[i] = result[i - 1] * h;
}
return result;
};
constexpr std::array<GTable, num_tables> g_tables = [&]() {
field working_base = g_inv;
std::array<GTable, num_tables> result;
for (size_t i = 0; i < num_tables; ++i) {
result[i] = get_g_table(working_base);
for (size_t j = 0; j < table_bits; ++j) {
working_base.self_sqr();
}
}
return result;
}();
constexpr std::array<GTable, num_offset_tables> offset_g_tables = [&]() {
field working_base = g_inv;
for (size_t i = 0; i < root_bits % table_bits; ++i) {
working_base.self_sqr();
}
std::array<GTable, num_offset_tables> result;
for (size_t i = 0; i < num_offset_tables; ++i) {
result[i] = get_g_table(working_base);
for (size_t j = 0; j < table_bits; ++j) {
working_base.self_sqr();
}
}
return result;
}();

constexpr GTable root_table_a = get_g_table(g.pow(1UL << ((num_tables - 1) * table_bits)));
constexpr GTable root_table_b = get_g_table(g.pow(1UL << (root_bits - table_bits)));
// compute uvv^{2^table_bits}, uvv^{2^{table_bits*2}}, ..., uvv^{2^{table_bits*num_tables}}
std::array<field, num_tables> uvv_powers;
field base = uvv;
for (size_t i = 0; i < num_tables - 1; ++i) {
uvv_powers[i] = base;
for (size_t j = 0; j < table_bits; ++j) {
base.self_sqr();
}
}
uvv_powers[num_tables - 1] = base;
std::array<size_t, num_tables> e_slices;
for (size_t i = 0; i < num_tables; ++i) {
size_t table_index = num_tables - 1 - i;
field target = uvv_powers[table_index];
for (size_t j = 0; j < i; ++j) {
size_t e_idx = num_tables - 1 - (i - 1) + j;
size_t g_idx = num_tables - 2 - j;

field g_lookup;
if (j != i - 1) {
g_lookup = offset_g_tables[g_idx - 1][e_slices[e_idx]]; // e1
} else {
g_lookup = g_tables[g_idx][e_slices[e_idx]];
}
target *= g_lookup;
}
size_t count = 0;

if (i == 0) {
for (auto& x : root_table_a) {
if (x == target) {
break;
}
count += 1;
}
} else {
for (auto& x : root_table_b) {
if (x == target) {
break;
}
count += 1;
}
}

while (t != one()) {
size_t i = 0;
field t2m = t;
ASSERT(count != table_size);
e_slices[table_index] = count;
}

// find the smallest value of m, such that t^{2^m} = 1
while (t2m != one()) {
t2m.self_sqr();
i += 1;
// We want to compute g^{-e/2} which requires computing `e/2` via our slice representation
for (size_t i = 0; i < num_tables; ++i) {
auto& e_slice = e_slices[num_tables - 1 - i];
// e_slices[num_tables - 1] is always even.
// From theorem 3.1 (https://cr.yp.to/papers/sqroot-20011123-retypeset20220327.pdf)
// if slice is odd, propagate the downshifted bit into previous slice value
if ((e_slice & 1UL) == 1UL) {
size_t borrow_value = (i == 1) ? 1UL << ((root_bits % table_bits) - 1) : (1UL << (table_bits - 1));
e_slices[num_tables - i] += borrow_value;
}
e_slice >>= 1;
}

size_t j = m - i - 1;
b = c;
while (j > 0) {
b.self_sqr();
--j;
} // b = z^2^(m-i-1)

c = b.sqr();
t = t * c;
r = r * b;
m = i;
field g_pow_minus_e_over_2 = 1;
for (size_t i = 0; i < num_tables; ++i) {
if (i == 0) {
g_pow_minus_e_over_2 *= g_tables[i][e_slices[num_tables - 1 - i]];
} else {
g_pow_minus_e_over_2 *= offset_g_tables[i - 1][e_slices[num_tables - 1 - i]];
}
}
return r;
return uv * g_pow_minus_e_over_2;
}

template <class T> constexpr std::pair<bool, field<T>> field<T>::sqrt() const noexcept
template <class T>
constexpr std::pair<bool, field<T>> field<T>::sqrt() const noexcept
requires((T::modulus_0 & 0x3UL) == 0x3UL)
{
BB_OP_COUNT_TRACK_NAME("fr::sqrt");
field root;
if constexpr ((T::modulus_0 & 0x3UL) == 0x3UL) {
constexpr uint256_t sqrt_exponent = (modulus + uint256_t(1)) >> 2;
root = pow(sqrt_exponent);
} else {
root = tonelli_shanks_sqrt();
}
constexpr uint256_t sqrt_exponent = (modulus + uint256_t(1)) >> 2;
field root = pow(sqrt_exponent);
if ((root * root) == (*this)) {
return std::pair<bool, field>(true, root);
}
return std::pair<bool, field>(false, field::zero());
}

} // namespace bb;
template <class T>
constexpr std::pair<bool, field<T>> field<T>::sqrt() const noexcept
requires((T::modulus_0 & 0x3UL) != 0x3UL)
{
field root = tonelli_shanks_sqrt();
if ((root * root) == (*this)) {
return std::pair<bool, field>(true, root);
}
return std::pair<bool, field>(false, field::zero());
}

template <class T> constexpr field<T> field<T>::operator/(const field& other) const noexcept
{
Expand Down Expand Up @@ -634,8 +719,8 @@ constexpr std::array<field<T>, field<T>::COSET_GENERATOR_SIZE> field<T>::compute

size_t count = 1;
while (count < n) {
// work_variable contains a new field element, and we need to test that, for all previous vector elements,
// result[i] / work_variable is not a member of our subgroup
// work_variable contains a new field element, and we need to test that, for all previous vector
// elements, result[i] / work_variable is not a member of our subgroup
field work_inverse = work_variable.invert();
bool valid = true;
for (size_t j = 0; j < count; ++j) {
Expand Down Expand Up @@ -674,8 +759,9 @@ template <class Params> void field<Params>::msgpack_pack(auto& packer) const
// The field is first converted from Montgomery form, similar to how the old format did it.
auto adjusted = from_montgomery_form();

// The data is then converted to big endian format using htonll, which stands for "host to network long long".
// This is necessary because the data will be written to a raw msgpack buffer, which requires big endian format.
// The data is then converted to big endian format using htonll, which stands for "host to network long
// long". This is necessary because the data will be written to a raw msgpack buffer, which requires big
// endian format.
uint64_t bin_data[4] = {
htonll(adjusted.data[3]), htonll(adjusted.data[2]), htonll(adjusted.data[1]), htonll(adjusted.data[0])
};
Expand All @@ -693,8 +779,8 @@ template <class Params> void field<Params>::msgpack_unpack(auto o)
// The binary data is first extracted from the msgpack object.
std::array<uint8_t, sizeof(data)> raw_data = o;

// The binary data is then read as big endian uint64_t's. This is done by casting the raw data to uint64_t* and then
// using ntohll ("network to host long long") to correct the endianness to the host's endianness.
// The binary data is then read as big endian uint64_t's. This is done by casting the raw data to uint64_t*
// and then using ntohll ("network to host long long") to correct the endianness to the host's endianness.
uint64_t* cast_data = (uint64_t*)&raw_data[0]; // NOLINT
uint64_t reversed[] = { ntohll(cast_data[3]), ntohll(cast_data[2]), ntohll(cast_data[1]), ntohll(cast_data[0]) };

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export class BarretenbergWasmMain extends BarretenbergWasmBase {
module: WebAssembly.Module,
threads = Math.min(getNumCpu(), BarretenbergWasmMain.MAX_THREADS),
logger: (msg: string) => void = debug,
initial = 30,
initial = 31,
maximum = 2 ** 16,
) {
this.logger = logger;
Expand Down
2 changes: 1 addition & 1 deletion yarn-project/foundation/src/wasm/wasm_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ export class WasmModule implements IWasmModule {
* @param initMethod - Defaults to calling '_initialize'.
* @param maximum - 8192 maximum by default. 512mb.
*/
public async init(initial = 30, maximum = 8192, initMethod: string | null = '_initialize') {
public async init(initial = 31, maximum = 8192, initMethod: string | null = '_initialize') {
this.debug(
`initial mem: ${initial} pages, ${(initial * 2 ** 16) / (1024 * 1024)}mb. max mem: ${maximum} pages, ${
(maximum * 2 ** 16) / (1024 * 1024)
Expand Down

0 comments on commit 722ec5c

Please sign in to comment.