Skip to content

Commit

Permalink
initial commit. biggroup objects track whether they are points at inf…
Browse files Browse the repository at this point in the history
…inity, and have +/- methods that correctly handle points at infinity
  • Loading branch information
zac-williamson committed Apr 29, 2024
1 parent d4cb410 commit 9f6b4ef
Show file tree
Hide file tree
Showing 15 changed files with 517 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,12 @@ template <typename Builder, typename T> class bigfield {

bigfield conditional_negate(const bool_t<Builder>& predicate) const;
bigfield conditional_select(const bigfield& other, const bool_t<Builder>& predicate) const;
static bigfield conditional_assign(const bool_t<Builder>& predicate, const bigfield& lhs, const bigfield& rhs)
{
return rhs.conditional_select(lhs, predicate);
}

bool_t<Builder> operator==(const bigfield& other) const;

void assert_is_in_field() const;
void assert_less_than(const uint256_t upper_limit) const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,45 @@ template <typename Builder> class stdlib_bigfield : public testing::Test {
fq_ct ret = fq_ct::div_check_denominator_nonzero({}, a_ct);
EXPECT_NE(ret.get_context(), nullptr);
}

static void test_assert_equal_not_equal()
{
auto builder = Builder();
size_t num_repetitions = 10;
for (size_t i = 0; i < num_repetitions; ++i) {
fq inputs[4]{ fq::random_element(), fq::random_element(), fq::random_element(), fq::random_element() };

fq_ct a(witness_ct(&builder, fr(uint256_t(inputs[0]).slice(0, fq_ct::NUM_LIMB_BITS * 2))),
witness_ct(&builder,
fr(uint256_t(inputs[0]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4))));
fq_ct b(witness_ct(&builder, fr(uint256_t(inputs[1]).slice(0, fq_ct::NUM_LIMB_BITS * 2))),
witness_ct(&builder,
fr(uint256_t(inputs[1]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4))));
fq_ct c(witness_ct(&builder, fr(uint256_t(inputs[2]).slice(0, fq_ct::NUM_LIMB_BITS * 2))),
witness_ct(&builder,
fr(uint256_t(inputs[2]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4))));
fq_ct d(witness_ct(&builder, fr(uint256_t(inputs[3]).slice(0, fq_ct::NUM_LIMB_BITS * 2))),
witness_ct(&builder,
fr(uint256_t(inputs[3]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4))));

fq_ct two(witness_ct(&builder, fr(2)),
witness_ct(&builder, fr(0)),
witness_ct(&builder, fr(0)),
witness_ct(&builder, fr(0)));
fq_ct t0 = a + a;
fq_ct t1 = a * two;

t0.assert_equal(t1);
t0.assert_is_not_equal(c);
t0.assert_is_not_equal(d);
stdlib::bool_t<Builder> is_equal_a = t0 == t1;
stdlib::bool_t<Builder> is_equal_b = t0 == c;
EXPECT_TRUE(is_equal_a.get_value());
EXPECT_FALSE(is_equal_b.get_value());
}
bool result = CircuitChecker::check(builder);
EXPECT_EQ(result, true);
}
};

// Define types for which the above tests will be constructed.
Expand Down Expand Up @@ -930,6 +969,11 @@ TYPED_TEST(stdlib_bigfield, division_context)
TestFixture::test_division_context();
}

TYPED_TEST(stdlib_bigfield, assert_equal_not_equal)
{
TestFixture::test_assert_equal_not_equal();
}

// // This test was disabled before the refactor to use TYPED_TEST's/
// TEST(stdlib_bigfield, DISABLED_test_div_against_constants)
// {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1562,6 +1562,57 @@ bigfield<Builder, T> bigfield<Builder, T>::conditional_select(const bigfield& ot
return result;
}

/**
* @brief Validate whether two bigfield elements are equal to each other
* @details To evaluate whether `(a == b)`, we use result boolean `r` to evaluate the following logic:
* (n.b all algebra involving bigfield elements is done in the bigfield)
* 1. If `r == 1` , `a - b == 0`
* 2. If `r == 0`, `a - b` posesses an inverse `I` i.e. `(a - b) * I - 1 == 0`
* We efficiently evaluate this logic by evaluating a single expression `(a - b)*X = Y`
* We use conditional assignment logic to define `X, Y` to be the following:
* If `r == 1` then `X = 1, Y = 0`
* If `r == 0` then `X = I, Y = 1`
* This allows us to evaluate `operator==` using only 1 bigfield multiplication operation.
* We can check the product equals 0 or 1 by directly evaluating the binary basis/prime basis limbs of Y.
* i.e. if `r == 1` then `(a - b)*X` should have 0 for all limb values
* if `r == 0` then `(a - b)*X` should have 1 in the least significant binary basis limb and 0 elsewhere
* @tparam Builder
* @tparam T
* @param other
* @return bool_t<Builder>
*/
template <typename Builder, typename T> bool_t<Builder> bigfield<Builder, T>::operator==(const bigfield& other) const
{
Builder* ctx = context ? context : other.get_context();
auto lhs = get_value() % modulus_u512;
auto rhs = other.get_value() % modulus_u512;
bool is_equal_raw = (lhs == rhs);
bool_t<Builder> is_equal = witness_t<Builder>(ctx, is_equal_raw);

bigfield diff = (*this) - other;

// TODO: get native values efficiently (i.e. if u512 value fits in a u256, subtract off modulus until u256 fits
// into finite field)
native diff_native = native((diff.get_value() % modulus_u512).lo);
native inverse_native = is_equal_raw ? 0 : diff_native.invert();

bigfield inverse = bigfield::from_witness(ctx, inverse_native);

bigfield multiplicand = bigfield::conditional_assign(is_equal, one(), inverse);

bigfield product = diff * multiplicand;

field_t result = field_t<Builder>::conditional_assign(is_equal, 0, 1);

product.prime_basis_limb.assert_equal(result);
product.binary_basis_limbs[0].element.assert_equal(result);
product.binary_basis_limbs[1].element.assert_equal(0);
product.binary_basis_limbs[2].element.assert_equal(0);
product.binary_basis_limbs[3].element.assert_equal(0);

return is_equal;
}

/**
* REDUCTION CHECK
*
Expand Down Expand Up @@ -1747,6 +1798,7 @@ template <typename Builder, typename T> void bigfield<Builder, T>::assert_equal(
std::cerr << "bigfield: calling assert equal on 2 CONSTANT bigfield elements...is this intended?" << std::endl;
return;
} else if (other.is_constant()) {
// TODO: wtf?
// evaluate a strict equality - make sure *this is reduced first, or an honest prover
// might not be able to satisfy these constraints.
field_t<Builder> t0 = (binary_basis_limbs[0].element - other.binary_basis_limbs[0].element);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ namespace bb::stdlib {
// ( ͡° ͜ʖ ͡°)
template <class Builder, class Fq, class Fr, class NativeGroup> class element {
public:
using bool_t = stdlib::bool_t<Builder>;

struct secp256k1_wnaf {
std::vector<field_t<Builder>> wnaf;
field_t<Builder> positive_skew;
Expand All @@ -38,27 +40,41 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {
element(const Fq& x, const Fq& y);

element(const element& other);
element(element&& other);
element(element&& other) noexcept;

static element from_witness(Builder* ctx, const typename NativeGroup::affine_element& input)
{
Fq x = Fq::from_witness(ctx, input.x);
Fq y = Fq::from_witness(ctx, input.y);
element out(x, y);
element out;
if (input.is_point_at_infinity()) {
Fq x = Fq::from_witness(ctx, NativeGroup::affine_one.x);
Fq y = Fq::from_witness(ctx, NativeGroup::affine_one.y);
out.x = x;
out.y = y;
} else {
Fq x = Fq::from_witness(ctx, input.x);
Fq y = Fq::from_witness(ctx, input.y);
out.x = x;
out.y = y;
}
out.set_point_at_infinity(witness_t<Builder>(ctx, input.is_point_at_infinity()));
out.validate_on_curve();
return out;
}

void validate_on_curve() const
{
Fq b(get_context(), uint256_t(NativeGroup::curve_b));
Fq _b = Fq::conditional_assign(is_point_at_infinity(), Fq::zero(), b);
Fq _x = Fq::conditional_assign(is_point_at_infinity(), Fq::zero(), x);
Fq _y = Fq::conditional_assign(is_point_at_infinity(), Fq::zero(), y);
if constexpr (!NativeGroup::has_a) {
// we validate y^2 = x^3 + b by setting "fix_remainder_zero = true" when calling mult_madd
Fq::mult_madd({ x.sqr(), y }, { x, -y }, { b }, true);
Fq::mult_madd({ _x.sqr(), _y }, { _x, -_y }, { _b }, true);
} else {
Fq a(get_context(), uint256_t(NativeGroup::curve_a));
Fq _a = Fq::conditional_assign(is_point_at_infinity(), Fq::zero(), a);
// we validate y^2 = x^3 + ax + b by setting "fix_remainder_zero = true" when calling mult_madd
Fq::mult_madd({ x.sqr(), x, y }, { x, a, -y }, { b }, true);
Fq::mult_madd({ _x.sqr(), _x, _y }, { _x, _a, -_y }, { _b }, true);
}
}

Expand All @@ -72,7 +88,7 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {
}

element& operator=(const element& other);
element& operator=(element&& other);
element& operator=(element&& other) noexcept;

byte_array<Builder> to_byte_array() const
{
Expand All @@ -82,6 +98,9 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {
return result;
}

element checked_unconditional_add(const element& other) const;
element checked_unconditional_subtract(const element& other) const;

element operator+(const element& other) const;
element operator-(const element& other) const;
element operator-() const
Expand All @@ -100,11 +119,11 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {
*this = *this - other;
return *this;
}
std::array<element, 2> add_sub(const element& other) const;
std::array<element, 2> checked_unconditional_add_sub(const element& other) const;

element operator*(const Fr& other) const;

element conditional_negate(const bool_t<Builder>& predicate) const
element conditional_negate(const bool_t& predicate) const
{
element result(*this);
result.y = result.y.conditional_negate(predicate);
Expand Down Expand Up @@ -176,9 +195,13 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {

typename NativeGroup::affine_element get_value() const
{
uint512_t x_val = x.get_value();
uint512_t y_val = y.get_value();
return typename NativeGroup::affine_element(x_val.lo, y_val.lo);
uint512_t x_val = x.get_value() % Fq::modulus_u512;
uint512_t y_val = y.get_value() % Fq::modulus_u512;
auto result = typename NativeGroup::affine_element(x_val.lo, y_val.lo);
if (is_point_at_infinity().get_value()) {
result.self_set_infinity();
}
return result;
}

// compute a multi-scalar-multiplication by creating a precomputed lookup table for each point,
Expand Down Expand Up @@ -229,7 +252,7 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {
template <typename X = NativeGroup, typename = typename std::enable_if_t<std::is_same<X, secp256k1::g1>::value>>
static element secp256k1_ecdsa_mul(const element& pubkey, const Fr& u1, const Fr& u2);

static std::vector<bool_t<Builder>> compute_naf(const Fr& scalar, const size_t max_num_bits = 0);
static std::vector<bool_t> compute_naf(const Fr& scalar, const size_t max_num_bits = 0);

template <size_t max_num_bits = 0, size_t WNAF_SIZE = 4>
static std::vector<field_t<Builder>> compute_wnaf(const Fr& scalar);
Expand Down Expand Up @@ -265,10 +288,15 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {
return nullptr;
}

bool_t is_point_at_infinity() const { return _is_infinity; }
void set_point_at_infinity(const bool_t& is_infinity) { _is_infinity = is_infinity; }

Fq x;
Fq y;

private:
bool_t _is_infinity;

template <size_t num_elements, typename = typename std::enable_if<HasPlookup<Builder>>>
static std::array<twin_rom_table<Builder>, 5> create_group_element_rom_tables(
const std::array<element, num_elements>& elements, std::array<uint256_t, 8>& limb_max);
Expand Down Expand Up @@ -367,7 +395,7 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {
lookup_table_base(const lookup_table_base& other) = default;
lookup_table_base& operator=(const lookup_table_base& other) = default;

element get(const std::array<bool_t<Builder>, length>& bits) const;
element get(const std::array<bool_t, length>& bits) const;

element operator[](const size_t idx) const { return element_table[idx]; }

Expand Down Expand Up @@ -397,7 +425,7 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {
lookup_table_plookup(const lookup_table_plookup& other) = default;
lookup_table_plookup& operator=(const lookup_table_plookup& other) = default;

element get(const std::array<bool_t<Builder>, length>& bits) const;
element get(const std::array<bool_t, length>& bits) const;

element operator[](const size_t idx) const { return element_table[idx]; }

Expand Down Expand Up @@ -608,7 +636,7 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {
return chain_add_accumulator(add_accumulator[0]);
}

element::chain_add_accumulator get_chain_add_accumulator(std::vector<bool_t<Builder>>& naf_entries) const
element::chain_add_accumulator get_chain_add_accumulator(std::vector<bool_t>& naf_entries) const
{
std::vector<element> round_accumulator;
for (size_t j = 0; j < num_sixes; ++j) {
Expand Down Expand Up @@ -660,7 +688,7 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {
return (accumulator);
}

element get(std::vector<bool_t<Builder>>& naf_entries) const
element get(std::vector<bool_t>& naf_entries) const
{
std::vector<element> round_accumulator;
for (size_t j = 0; j < num_sixes; ++j) {
Expand Down Expand Up @@ -812,21 +840,21 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {
return chain_add_accumulator(add_accumulator[0]);
}

element::chain_add_accumulator get_chain_add_accumulator(std::vector<bool_t<Builder>>& naf_entries) const
element::chain_add_accumulator get_chain_add_accumulator(std::vector<bool_t>& naf_entries) const
{
std::vector<element> round_accumulator;
for (size_t j = 0; j < num_quads; ++j) {
round_accumulator.push_back(quad_tables[j].get(std::array<bool_t<Builder>, 4>{
round_accumulator.push_back(quad_tables[j].get(std::array<bool_t, 4>{
naf_entries[4 * j], naf_entries[4 * j + 1], naf_entries[4 * j + 2], naf_entries[4 * j + 3] }));
}

if (has_triple) {
round_accumulator.push_back(triple_tables[0].get(std::array<bool_t<Builder>, 3>{
round_accumulator.push_back(triple_tables[0].get(std::array<bool_t, 3>{
naf_entries[num_quads * 4], naf_entries[num_quads * 4 + 1], naf_entries[num_quads * 4 + 2] }));
}
if (has_twin) {
round_accumulator.push_back(twin_tables[0].get(
std::array<bool_t<Builder>, 2>{ naf_entries[num_quads * 4], naf_entries[num_quads * 4 + 1] }));
std::array<bool_t, 2>{ naf_entries[num_quads * 4], naf_entries[num_quads * 4 + 1] }));
}
if (has_singleton) {
round_accumulator.push_back(singletons[0].conditional_negate(naf_entries[num_points - 1]));
Expand All @@ -849,7 +877,7 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {
return (accumulator);
}

element get(std::vector<bool_t<Builder>>& naf_entries) const
element get(std::vector<bool_t>& naf_entries) const
{
std::vector<element> round_accumulator;
for (size_t j = 0; j < num_quads; ++j) {
Expand All @@ -858,7 +886,7 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {
}

if (has_triple) {
round_accumulator.push_back(triple_tables[0].get(std::array<bool_t<Builder>, 3>{
round_accumulator.push_back(triple_tables[0].get(std::array<bool_t, 3>{
naf_entries[num_quads * 4], naf_entries[num_quads * 4 + 1], naf_entries[num_quads * 4 + 2] }));
}
if (has_twin) {
Expand Down
Loading

0 comments on commit 9f6b4ef

Please sign in to comment.