Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: set denominator to 1 during verification of dsl/big-field division #5188

Merged
merged 9 commits into from
Mar 19, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,12 @@ void build_constraints(Builder& builder, AcirFormat const& constraint_system, bo

// Add big_int constraints
DSLBigInts<Builder> dsl_bigints;
dsl_bigints.set_builder(&builder);
for (const auto& constraint : constraint_system.bigint_from_le_bytes_constraints) {
create_bigint_from_le_bytes_constraint(builder, constraint, dsl_bigints);
}
for (const auto& constraint : constraint_system.bigint_operations) {
create_bigint_operations_constraint<Builder>(constraint, dsl_bigints);
create_bigint_operations_constraint<Builder>(constraint, dsl_bigints, has_valid_witness_assignments);
}
for (const auto& constraint : constraint_system.bigint_to_le_bytes_constraints) {
create_bigint_to_le_bytes_constraint(builder, constraint, dsl_bigints);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "barretenberg/common/assert.hpp"
#include "barretenberg/dsl/types.hpp"
#include "barretenberg/numeric/uint256/uint256.hpp"
#include "barretenberg/numeric/uintx/uintx.hpp"
#include "barretenberg/stdlib/primitives/bigfield/bigfield.hpp"
#include <cstddef>
#include <cstdint>
Expand Down Expand Up @@ -34,14 +35,16 @@ ModulusId modulus_param_to_id(ModulusParam param)
secp256r1::FrParams::modulus_2 == param.modulus_2 && secp256r1::FrParams::modulus_3 == param.modulus_3) {
return ModulusId::SECP256R1_FR;
}

return ModulusId::UNKNOWN;
}

template void create_bigint_operations_constraint<UltraCircuitBuilder>(const BigIntOperation& input,
DSLBigInts<UltraCircuitBuilder>& dsl_bigint);
DSLBigInts<UltraCircuitBuilder>& dsl_bigint,
bool has_valid_witness_assignments);
template void create_bigint_operations_constraint<GoblinUltraCircuitBuilder>(
const BigIntOperation& input, DSLBigInts<GoblinUltraCircuitBuilder>& dsl_bigint);
const BigIntOperation& input,
DSLBigInts<GoblinUltraCircuitBuilder>& dsl_bigint,
bool has_valid_witness_assignments);
template void create_bigint_addition_constraint<UltraCircuitBuilder>(const BigIntOperation& input,
DSLBigInts<UltraCircuitBuilder>& dsl_bigint);
template void create_bigint_addition_constraint<GoblinUltraCircuitBuilder>(
Expand All @@ -55,9 +58,11 @@ template void create_bigint_mul_constraint<UltraCircuitBuilder>(const BigIntOper
template void create_bigint_mul_constraint<GoblinUltraCircuitBuilder>(
const BigIntOperation& input, DSLBigInts<GoblinUltraCircuitBuilder>& dsl_bigint);
template void create_bigint_div_constraint<UltraCircuitBuilder>(const BigIntOperation& input,
DSLBigInts<UltraCircuitBuilder>& dsl_bigint);
template void create_bigint_div_constraint<GoblinUltraCircuitBuilder>(
const BigIntOperation& input, DSLBigInts<GoblinUltraCircuitBuilder>& dsl_bigint);
DSLBigInts<UltraCircuitBuilder>& dsl_bigint,
bool has_valid_witness_assignments);
template void create_bigint_div_constraint<GoblinUltraCircuitBuilder>(const BigIntOperation& input,
DSLBigInts<GoblinUltraCircuitBuilder>& dsl_bigint,
bool has_valid_witness_assignments);

template <typename Builder>
void create_bigint_addition_constraint(const BigIntOperation& input, DSLBigInts<Builder>& dsl_bigint)
Expand Down Expand Up @@ -198,8 +203,18 @@ void create_bigint_mul_constraint(const BigIntOperation& input, DSLBigInts<Build
}

template <typename Builder>
void create_bigint_div_constraint(const BigIntOperation& input, DSLBigInts<Builder>& dsl_bigint)
void create_bigint_div_constraint(const BigIntOperation& input,
DSLBigInts<Builder>& dsl_bigint,
bool has_valid_witness_assignments)
{
if (!has_valid_witness_assignments) {
// Asserts catch the case where the divisor is zero, so we need to provide a different value (1) to avoid the
// assert
std::array<uint32_t, 5> limbs_idx;
dsl_bigint.get_witness_idx_of_limbs(input.rhs, limbs_idx);
dsl_bigint.set_value(1, limbs_idx);
}

switch (dsl_bigint.get_modulus_id(input.lhs)) {
case ModulusId::BN254_FR: {
auto lhs = dsl_bigint.bn254_fr(input.lhs);
Expand Down Expand Up @@ -244,7 +259,9 @@ void create_bigint_div_constraint(const BigIntOperation& input, DSLBigInts<Build
}

template <typename Builder>
void create_bigint_operations_constraint(const BigIntOperation& input, DSLBigInts<Builder>& dsl_bigint)
void create_bigint_operations_constraint(const BigIntOperation& input,
DSLBigInts<Builder>& dsl_bigint,
bool has_valid_witness_assignments)
{
switch (input.opcode) {
case BigIntOperationType::Add: {
Expand All @@ -260,7 +277,7 @@ void create_bigint_operations_constraint(const BigIntOperation& input, DSLBigInt
break;
}
case BigIntOperationType::Div: {
create_bigint_div_constraint<Builder>(input, dsl_bigint);
create_bigint_div_constraint<Builder>(input, dsl_bigint, has_valid_witness_assignments);
break;
}
default: {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#pragma once
#include "barretenberg/dsl/types.hpp"
#include "barretenberg/serialize/msgpack.hpp"

#include <array>
#include <cstdint>
#include <vector>

Expand Down Expand Up @@ -77,9 +79,13 @@ template <typename Builder> class DSLBigInts {
std::map<uint32_t, big_secp256r1_fq> m_secp256r1_fq;
std::map<uint32_t, big_secp256r1_fr> m_secp256r1_fr;

Builder* builder;

public:
DSLBigInts() = default;

void set_builder(Builder* ctx) { builder = ctx; }

ModulusId get_modulus_id(uint32_t bigint_id)
{
if (this->m_bn254_fq.contains(bigint_id)) {
Expand All @@ -104,6 +110,62 @@ template <typename Builder> class DSLBigInts {
return ModulusId::UNKNOWN;
}

/// Set value of the witnesses representing the bigfield element
/// so that the bigfield value is the input value.
/// The input value is decomposed into the binary basis for the binary limbs
/// The input array must be:
/// the 4 witness index of the binary limbs, and the index of the prime limb
void set_value(uint256_t value, const std::array<uint32_t, 5> limbs_idx)
{
uint256_t limb_modulus = uint256_t(1) << big_bn254_fq::NUM_LIMB_BITS;
builder->variables[limbs_idx[4]] = value;
for (uint32_t i = 0; i < 4; i++) {
uint256_t limb = value % limb_modulus;
value = (value - limb) / limb_modulus;
builder->variables[limbs_idx[i]] = limb;
}
}

/// Utility function that retrieve the witness indexes of a bigfield element
/// for use in set_value()
void get_witness_idx_of_limbs(uint32_t bigint_id, std::array<uint32_t, 5>& limbs_idx)
{
if (m_bn254_fr.contains(bigint_id)) {
for (uint32_t i = 0; i < 4; i++) {
limbs_idx[i] = m_bn254_fr[bigint_id].binary_basis_limbs[i].element.witness_index;
}
limbs_idx[4] = m_bn254_fr[bigint_id].prime_basis_limb.witness_index;
} else if (m_bn254_fq.contains(bigint_id)) {
for (uint32_t i = 0; i < 4; i++) {
limbs_idx[i] = m_bn254_fq[bigint_id].binary_basis_limbs[i].element.witness_index;
}
limbs_idx[4] = m_bn254_fq[bigint_id].prime_basis_limb.witness_index;
} else if (m_secp256k1_fq.contains(bigint_id)) {
auto big_field = m_secp256k1_fq[bigint_id];
for (uint32_t i = 0; i < 4; i++) {
limbs_idx[i] = big_field.binary_basis_limbs[i].element.witness_index;
}
limbs_idx[4] = big_field.prime_basis_limb.witness_index;
} else if (m_secp256k1_fr.contains(bigint_id)) {
auto big_field = m_secp256k1_fr[bigint_id];
for (uint32_t i = 0; i < 4; i++) {
limbs_idx[i] = big_field.binary_basis_limbs[i].element.witness_index;
}
limbs_idx[4] = big_field.prime_basis_limb.witness_index;
} else if (m_secp256r1_fr.contains(bigint_id)) {
auto big_field = m_secp256r1_fr[bigint_id];
for (uint32_t i = 0; i < 4; i++) {
limbs_idx[i] = big_field.binary_basis_limbs[i].element.witness_index;
}
limbs_idx[4] = big_field.prime_basis_limb.witness_index;
} else if (m_secp256r1_fq.contains(bigint_id)) {
auto big_field = m_secp256r1_fq[bigint_id];
for (uint32_t i = 0; i < 4; i++) {
limbs_idx[i] = big_field.binary_basis_limbs[i].element.witness_index;
}
limbs_idx[4] = big_field.prime_basis_limb.witness_index;
}
}
big_bn254_fr bn254_fr(uint32_t bigint_id)
{
if (this->m_bn254_fr.contains(bigint_id)) {
Expand Down Expand Up @@ -192,14 +254,14 @@ void create_bigint_to_le_bytes_constraint(Builder& builder,
DSLBigInts<Builder>& dsl_bigints);

template <typename Builder>
void create_bigint_operations_constraint(const BigIntOperation& input, DSLBigInts<Builder>& dsl_bigints);
void create_bigint_operations_constraint(const BigIntOperation& input, DSLBigInts<Builder>& dsl_bigints, bool);
template <typename Builder>
void create_bigint_addition_constraint(const BigIntOperation& input, DSLBigInts<Builder>& dsl_bigints);
template <typename Builder>
void create_bigint_sub_constraint(const BigIntOperation& input, DSLBigInts<Builder>& dsl_bigints);
template <typename Builder>
void create_bigint_mul_constraint(const BigIntOperation& input, DSLBigInts<Builder>& dsl_bigints);
template <typename Builder>
void create_bigint_div_constraint(const BigIntOperation& input, DSLBigInts<Builder>& dsl_bigints);
void create_bigint_div_constraint(const BigIntOperation& input, DSLBigInts<Builder>& dsl_bigints, bool);

} // namespace acir_format
Original file line number Diff line number Diff line change
Expand Up @@ -392,4 +392,78 @@ TEST_F(BigIntTests, TestBigIntConstraintReuse2)
EXPECT_EQ(verifier.verify_proof(proof), true);
}

TEST_F(BigIntTests, TestBigIntDIV)
{
// 6 / 3 = 2
// 6 = bigint(1) = from_bytes(w(1))
// 3 = bigint(2) = from_bytes(w(2))
// 2 = bigint(3) = to_bytes(w(3))
BigIntOperation div_constraint{
.lhs = 1,
.rhs = 2,
.result = 3,
.opcode = BigIntOperationType::Div,
};

BigIntFromLeBytes from_le_bytes_constraint_bigint1{
.inputs = { 1 },
.modulus = { 0x41, 0x41, 0x36, 0xD0, 0x8C, 0x5E, 0xD2, 0xBF, 0x3B, 0xA0, 0x48, 0xAF, 0xE6, 0xDC, 0xAE, 0xBA,
0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF },
.result = 1,
};
BigIntFromLeBytes from_le_bytes_constraint_bigint2{
.inputs = { 2 },
.modulus = { 0x41, 0x41, 0x36, 0xD0, 0x8C, 0x5E, 0xD2, 0xBF, 0x3B, 0xA0, 0x48, 0xAF, 0xE6, 0xDC, 0xAE, 0xBA,
0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF },
.result = 2,
};

BigIntToLeBytes result3_to_le_bytes{
.input = 3, .result = { 3 }, //
};

AcirFormat constraint_system{
.varnum = 5,
.recursive = false,
.public_inputs = {},
.logic_constraints = {},
.range_constraints = {},
.sha256_constraints = {},
.sha256_compression = {},
.schnorr_constraints = {},
.ecdsa_k1_constraints = {},
.ecdsa_r1_constraints = {},
.blake2s_constraints = {},
.blake3_constraints = {},
.keccak_constraints = {},
.keccak_var_constraints = {},
.keccak_permutations = {},
.pedersen_constraints = {},
.pedersen_hash_constraints = {},
.poseidon2_constraints = {},
.fixed_base_scalar_mul_constraints = {},
.ec_add_constraints = {},
.recursion_constraints = {},
.bigint_from_le_bytes_constraints = { from_le_bytes_constraint_bigint1, from_le_bytes_constraint_bigint2 },
.bigint_to_le_bytes_constraints = { result3_to_le_bytes },
.bigint_operations = { div_constraint },
.constraints = {},
.block_constraints = {},

};

WitnessVector witness{
0, 6, 3, 2, 0,
};
auto builder = create_circuit(constraint_system, /*size_hint*/ 0, witness);
auto composer = Composer();
auto prover = composer.create_ultra_with_keccak_prover(builder);
auto proof = prover.construct_proof();
EXPECT_TRUE(CircuitChecker::check(builder));

auto builder2 = create_circuit(constraint_system, /*size_hint*/ 0, WitnessVector{});
EXPECT_TRUE(CircuitChecker::check(builder));
auto verifier2 = composer.create_ultra_with_keccak_verifier(builder);
EXPECT_EQ(verifier2.verify_proof(proof), true);
}
} // namespace acir_format::tests
Loading