Skip to content

Commit

Permalink
evmmax: Make ModArith header-only and constexpr (#964)
Browse files Browse the repository at this point in the history
Move all ModArith's methods to the class definition and make them
constexpr.

Change evmmax library type to INTERFACE.

Co-authored-by: Paweł Bylica <pawel@ethereum.org>
  • Loading branch information
rodiazet and chfast authored Aug 7, 2024
1 parent 391bd64 commit b2d0672
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 139 deletions.
90 changes: 84 additions & 6 deletions include/evmmax/evmmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,112 @@ class ModArith
/// The modulus inversion, i.e. the number N' such that mod⋅N' = 2⁶⁴-1.
const uint64_t m_mod_inv;

/// Compute the modulus inverse for Montgomery multiplication, i.e. N': mod⋅N' = 2⁶⁴-1.
///
/// @param mod0 The least significant word of the modulus.
static constexpr uint64_t compute_mod_inv(uint64_t mod0) noexcept
{
// TODO: Find what is this algorithm and why it works.
uint64_t base = 0 - mod0;
uint64_t result = 1;
for (auto i = 0; i < 64; ++i)
{
result *= base;
base *= base;
}
return result;
}

/// Compute R² % mod.
static constexpr UintT compute_r_squared(const UintT& mod) noexcept
{
// R is 2^num_bits, R² is 2^(2*num_bits) and needs 2*num_bits+1 bits to represent,
// rounded to 2*num_bits+64) for intx requirements.
constexpr auto r2 = intx::uint<UintT::num_bits * 2 + 64>{1} << (UintT::num_bits * 2);
return intx::udivrem(r2, mod).rem;
}

static constexpr std::pair<uint64_t, uint64_t> addmul(
uint64_t t, uint64_t a, uint64_t b, uint64_t c) noexcept
{
const auto p = intx::umul(a, b) + t + c;
return {p[1], p[0]};
}

public:
explicit ModArith(const UintT& modulus) noexcept;
constexpr explicit ModArith(const UintT& modulus) noexcept
: mod{modulus},
m_r_squared{compute_r_squared(modulus)},
m_mod_inv{compute_mod_inv(modulus[0])}
{}

/// Converts a value to Montgomery form.
///
/// This is done by using Montgomery multiplication mul(x, R²)
/// what gives aR²R⁻¹ % mod = aR % mod.
UintT to_mont(const UintT& x) const noexcept;
constexpr UintT to_mont(const UintT& x) const noexcept { return mul(x, m_r_squared); }

/// Converts a value in Montgomery form back to normal value.
///
/// Given the x is the Montgomery form x = aR, the conversion is done by using
/// Montgomery multiplication mul(x, 1) what gives aRR⁻¹ % mod = a % mod.
UintT from_mont(const UintT& x) const noexcept;
constexpr UintT from_mont(const UintT& x) const noexcept { return mul(x, 1); }

/// Performs a Montgomery modular multiplication.
///
/// Inputs must be in Montgomery form: x = aR, y = bR.
/// This computes Montgomery multiplication xyR⁻¹ % mod what gives aRbRR⁻¹ % mod = abR % mod.
/// The result (abR) is in Montgomery form.
UintT mul(const UintT& x, const UintT& y) const noexcept;
constexpr UintT mul(const UintT& x, const UintT& y) const noexcept
{
// Coarsely Integrated Operand Scanning (CIOS) Method
// Based on 2.3.2 from
// High-Speed Algorithms & Architectures For Number-Theoretic Cryptosystems
// https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf

constexpr auto S = UintT::num_words; // TODO(C++23): Make it static

intx::uint<UintT::num_bits + 64> t;
for (size_t i = 0; i != S; ++i)
{
uint64_t c = 0;
for (size_t j = 0; j != S; ++j)
std::tie(c, t[j]) = addmul(t[j], x[j], y[i], c);
auto tmp = intx::addc(t[S], c);
t[S] = tmp.value;
const auto d = tmp.carry; // TODO: Carry is 0 for sparse modulus.

const auto m = t[0] * m_mod_inv;
std::tie(c, std::ignore) = addmul(t[0], m, mod[0], 0);
for (size_t j = 1; j != S; ++j)
std::tie(c, t[j - 1]) = addmul(t[j], m, mod[j], c);
tmp = intx::addc(t[S], c);
t[S - 1] = tmp.value;
t[S] = d + tmp.carry; // TODO: Carry is 0 for sparse modulus.
}

if (t >= mod)
t -= mod;

return static_cast<UintT>(t);
}

/// Performs a modular addition. It is required that x < mod and y < mod, but x and y may be
/// but are not required to be in Montgomery form.
UintT add(const UintT& x, const UintT& y) const noexcept;
constexpr UintT add(const UintT& x, const UintT& y) const noexcept
{
const auto s = addc(x, y); // TODO: cannot overflow if modulus is sparse (e.g. 255 bits).
const auto d = subc(s.value, mod);
return (!s.carry && d.carry) ? s.value : d.value;
}

/// Performs a modular subtraction. It is required that x < mod and y < mod, but x and y may be
/// but are not required to be in Montgomery form.
UintT sub(const UintT& x, const UintT& y) const noexcept;
constexpr UintT sub(const UintT& x, const UintT& y) const noexcept
{
const auto d = subc(x, y);
const auto s = d.value + mod;
return (d.carry) ? s : d.value;
}
};
} // namespace evmmax
22 changes: 13 additions & 9 deletions lib/evmmax/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
# Copyright 2023 The evmone Authors.
# SPDX-License-Identifier: Apache-2.0

add_library(evmmax STATIC)
add_library(evmmax INTERFACE)
add_library(evmone::evmmax ALIAS evmmax)
target_compile_features(evmmax PUBLIC cxx_std_20)
target_include_directories(evmmax PUBLIC ${PROJECT_SOURCE_DIR}/include)
target_link_libraries(evmmax PUBLIC intx::intx PRIVATE evmc::evmc_cpp)
target_sources(
evmmax PRIVATE
${PROJECT_SOURCE_DIR}/include/evmmax/evmmax.hpp
evmmax.cpp
)
target_compile_features(evmmax INTERFACE cxx_std_20)
target_include_directories(evmmax INTERFACE ${PROJECT_SOURCE_DIR}/include)
target_link_libraries(evmmax INTERFACE intx::intx)

if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.19)
# We want to add the header file to the library for IDEs.
# However, cmake 3.18 does not support PRIVATE scope for INTERFACE libraries.
target_sources(
evmmax PRIVATE
${PROJECT_SOURCE_DIR}/include/evmmax/evmmax.hpp
)
endif()
117 changes: 0 additions & 117 deletions lib/evmmax/evmmax.cpp

This file was deleted.

7 changes: 3 additions & 4 deletions lib/evmone_precompiles/bn254.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@

namespace evmmax::bn254
{

namespace
{
const ModArith<uint256> Fp{FieldPrime};
const auto B = Fp.to_mont(3);
const auto B3 = Fp.to_mont(3 * 3);
constexpr ModArith Fp{FieldPrime};
constexpr auto B = Fp.to_mont(3);
constexpr auto B3 = Fp.to_mont(3 * 3);
} // namespace

bool validate(const Point& pt) noexcept
Expand Down
6 changes: 3 additions & 3 deletions lib/evmone_precompiles/secp256k1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ namespace evmmax::secp256k1
{
namespace
{
const ModArith<uint256> Fp{FieldPrime};
const auto B = Fp.to_mont(7);
const auto B3 = Fp.to_mont(7 * 3);
constexpr ModArith Fp{FieldPrime};
constexpr auto B = Fp.to_mont(7);
constexpr auto B3 = Fp.to_mont(7 * 3);

constexpr Point G{0x79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798_u256,
0x483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8_u256};
Expand Down
13 changes: 13 additions & 0 deletions test/unittests/evmmax_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,19 @@ static auto get_test_values(const Mod& m) noexcept
};
}

[[maybe_unused]] static void constexpr_test()
{
// Make sure ModArith works in constexpr.
static constexpr ModArith m{BN254Mod};
static_assert(m.mod == BN254Mod);

static constexpr auto a = m.to_mont(3);
static constexpr auto b = m.to_mont(11);
static_assert(m.add(a, b) == m.to_mont(14));
static_assert(m.sub(a, b) == m.to_mont(BN254Mod - 8));
static_assert(m.mul(a, b) == m.to_mont(33));
}

TYPED_TEST(evmmax_test, add)
{
const TypeParam m;
Expand Down

0 comments on commit b2d0672

Please sign in to comment.