diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.cpp index fd2f23fe0f8..f37d1f0bdfc 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.cpp @@ -100,11 +100,12 @@ void build_constraints(Builder& builder, AcirFormat const& constraint_system, bo // Add big_int constraints DSLBigInts dsl_bigints; + dsl_bigints.set_builder(&builder); for (const auto& constraint : constraint_system.bigint_from_le_bytes_constraints) { create_bigint_from_le_bytes_constraint(builder, constraint, dsl_bigints); } for (const auto& constraint : constraint_system.bigint_operations) { - create_bigint_operations_constraint(constraint, dsl_bigints); + create_bigint_operations_constraint(constraint, dsl_bigints, has_valid_witness_assignments); } for (const auto& constraint : constraint_system.bigint_to_le_bytes_constraints) { create_bigint_to_le_bytes_constraint(builder, constraint, dsl_bigints); diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/bigint_constraint.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/bigint_constraint.cpp index a3b92e8626b..b97ff7db265 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/bigint_constraint.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/bigint_constraint.cpp @@ -2,6 +2,7 @@ #include "barretenberg/common/assert.hpp" #include "barretenberg/dsl/types.hpp" #include "barretenberg/numeric/uint256/uint256.hpp" +#include "barretenberg/numeric/uintx/uintx.hpp" #include "barretenberg/stdlib/primitives/bigfield/bigfield.hpp" #include #include @@ -34,14 +35,16 @@ ModulusId modulus_param_to_id(ModulusParam param) secp256r1::FrParams::modulus_2 == param.modulus_2 && secp256r1::FrParams::modulus_3 == param.modulus_3) { return ModulusId::SECP256R1_FR; } - return ModulusId::UNKNOWN; } template void create_bigint_operations_constraint(const BigIntOperation& input, - DSLBigInts& dsl_bigint); + DSLBigInts& dsl_bigint, + bool has_valid_witness_assignments); template void create_bigint_operations_constraint( - const BigIntOperation& input, DSLBigInts& dsl_bigint); + const BigIntOperation& input, + DSLBigInts& dsl_bigint, + bool has_valid_witness_assignments); template void create_bigint_addition_constraint(const BigIntOperation& input, DSLBigInts& dsl_bigint); template void create_bigint_addition_constraint( @@ -55,9 +58,11 @@ template void create_bigint_mul_constraint(const BigIntOper template void create_bigint_mul_constraint( const BigIntOperation& input, DSLBigInts& dsl_bigint); template void create_bigint_div_constraint(const BigIntOperation& input, - DSLBigInts& dsl_bigint); -template void create_bigint_div_constraint( - const BigIntOperation& input, DSLBigInts& dsl_bigint); + DSLBigInts& dsl_bigint, + bool has_valid_witness_assignments); +template void create_bigint_div_constraint(const BigIntOperation& input, + DSLBigInts& dsl_bigint, + bool has_valid_witness_assignments); template void create_bigint_addition_constraint(const BigIntOperation& input, DSLBigInts& dsl_bigint) @@ -198,8 +203,18 @@ void create_bigint_mul_constraint(const BigIntOperation& input, DSLBigInts -void create_bigint_div_constraint(const BigIntOperation& input, DSLBigInts& dsl_bigint) +void create_bigint_div_constraint(const BigIntOperation& input, + DSLBigInts& dsl_bigint, + bool has_valid_witness_assignments) { + if (!has_valid_witness_assignments) { + // Asserts catch the case where the divisor is zero, so we need to provide a different value (1) to avoid the + // assert + std::array limbs_idx; + dsl_bigint.get_witness_idx_of_limbs(input.rhs, limbs_idx); + dsl_bigint.set_value(1, limbs_idx); + } + switch (dsl_bigint.get_modulus_id(input.lhs)) { case ModulusId::BN254_FR: { auto lhs = dsl_bigint.bn254_fr(input.lhs); @@ -244,7 +259,9 @@ void create_bigint_div_constraint(const BigIntOperation& input, DSLBigInts -void create_bigint_operations_constraint(const BigIntOperation& input, DSLBigInts& dsl_bigint) +void create_bigint_operations_constraint(const BigIntOperation& input, + DSLBigInts& dsl_bigint, + bool has_valid_witness_assignments) { switch (input.opcode) { case BigIntOperationType::Add: { @@ -260,7 +277,7 @@ void create_bigint_operations_constraint(const BigIntOperation& input, DSLBigInt break; } case BigIntOperationType::Div: { - create_bigint_div_constraint(input, dsl_bigint); + create_bigint_div_constraint(input, dsl_bigint, has_valid_witness_assignments); break; } default: { diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/bigint_constraint.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/bigint_constraint.hpp index 27e9353efb4..9e6e4f087fd 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/bigint_constraint.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/bigint_constraint.hpp @@ -1,6 +1,8 @@ #pragma once #include "barretenberg/dsl/types.hpp" #include "barretenberg/serialize/msgpack.hpp" + +#include #include #include @@ -77,9 +79,13 @@ template class DSLBigInts { std::map m_secp256r1_fq; std::map m_secp256r1_fr; + Builder* builder; + public: DSLBigInts() = default; + void set_builder(Builder* ctx) { builder = ctx; } + ModulusId get_modulus_id(uint32_t bigint_id) { if (this->m_bn254_fq.contains(bigint_id)) { @@ -104,6 +110,62 @@ template class DSLBigInts { return ModulusId::UNKNOWN; } + /// Set value of the witnesses representing the bigfield element + /// so that the bigfield value is the input value. + /// The input value is decomposed into the binary basis for the binary limbs + /// The input array must be: + /// the 4 witness index of the binary limbs, and the index of the prime limb + void set_value(uint256_t value, const std::array limbs_idx) + { + uint256_t limb_modulus = uint256_t(1) << big_bn254_fq::NUM_LIMB_BITS; + builder->variables[limbs_idx[4]] = value; + for (uint32_t i = 0; i < 4; i++) { + uint256_t limb = value % limb_modulus; + value = (value - limb) / limb_modulus; + builder->variables[limbs_idx[i]] = limb; + } + } + + /// Utility function that retrieve the witness indexes of a bigfield element + /// for use in set_value() + void get_witness_idx_of_limbs(uint32_t bigint_id, std::array& limbs_idx) + { + if (m_bn254_fr.contains(bigint_id)) { + for (uint32_t i = 0; i < 4; i++) { + limbs_idx[i] = m_bn254_fr[bigint_id].binary_basis_limbs[i].element.witness_index; + } + limbs_idx[4] = m_bn254_fr[bigint_id].prime_basis_limb.witness_index; + } else if (m_bn254_fq.contains(bigint_id)) { + for (uint32_t i = 0; i < 4; i++) { + limbs_idx[i] = m_bn254_fq[bigint_id].binary_basis_limbs[i].element.witness_index; + } + limbs_idx[4] = m_bn254_fq[bigint_id].prime_basis_limb.witness_index; + } else if (m_secp256k1_fq.contains(bigint_id)) { + auto big_field = m_secp256k1_fq[bigint_id]; + for (uint32_t i = 0; i < 4; i++) { + limbs_idx[i] = big_field.binary_basis_limbs[i].element.witness_index; + } + limbs_idx[4] = big_field.prime_basis_limb.witness_index; + } else if (m_secp256k1_fr.contains(bigint_id)) { + auto big_field = m_secp256k1_fr[bigint_id]; + for (uint32_t i = 0; i < 4; i++) { + limbs_idx[i] = big_field.binary_basis_limbs[i].element.witness_index; + } + limbs_idx[4] = big_field.prime_basis_limb.witness_index; + } else if (m_secp256r1_fr.contains(bigint_id)) { + auto big_field = m_secp256r1_fr[bigint_id]; + for (uint32_t i = 0; i < 4; i++) { + limbs_idx[i] = big_field.binary_basis_limbs[i].element.witness_index; + } + limbs_idx[4] = big_field.prime_basis_limb.witness_index; + } else if (m_secp256r1_fq.contains(bigint_id)) { + auto big_field = m_secp256r1_fq[bigint_id]; + for (uint32_t i = 0; i < 4; i++) { + limbs_idx[i] = big_field.binary_basis_limbs[i].element.witness_index; + } + limbs_idx[4] = big_field.prime_basis_limb.witness_index; + } + } big_bn254_fr bn254_fr(uint32_t bigint_id) { if (this->m_bn254_fr.contains(bigint_id)) { @@ -192,7 +254,7 @@ void create_bigint_to_le_bytes_constraint(Builder& builder, DSLBigInts& dsl_bigints); template -void create_bigint_operations_constraint(const BigIntOperation& input, DSLBigInts& dsl_bigints); +void create_bigint_operations_constraint(const BigIntOperation& input, DSLBigInts& dsl_bigints, bool); template void create_bigint_addition_constraint(const BigIntOperation& input, DSLBigInts& dsl_bigints); template @@ -200,6 +262,6 @@ void create_bigint_sub_constraint(const BigIntOperation& input, DSLBigInts void create_bigint_mul_constraint(const BigIntOperation& input, DSLBigInts& dsl_bigints); template -void create_bigint_div_constraint(const BigIntOperation& input, DSLBigInts& dsl_bigints); +void create_bigint_div_constraint(const BigIntOperation& input, DSLBigInts& dsl_bigints, bool); } // namespace acir_format \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/bigint_constraint.test.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/bigint_constraint.test.cpp index 02cd777e5bc..550ee4a7b40 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/bigint_constraint.test.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/bigint_constraint.test.cpp @@ -392,4 +392,78 @@ TEST_F(BigIntTests, TestBigIntConstraintReuse2) EXPECT_EQ(verifier.verify_proof(proof), true); } +TEST_F(BigIntTests, TestBigIntDIV) +{ + // 6 / 3 = 2 + // 6 = bigint(1) = from_bytes(w(1)) + // 3 = bigint(2) = from_bytes(w(2)) + // 2 = bigint(3) = to_bytes(w(3)) + BigIntOperation div_constraint{ + .lhs = 1, + .rhs = 2, + .result = 3, + .opcode = BigIntOperationType::Div, + }; + + BigIntFromLeBytes from_le_bytes_constraint_bigint1{ + .inputs = { 1 }, + .modulus = { 0x41, 0x41, 0x36, 0xD0, 0x8C, 0x5E, 0xD2, 0xBF, 0x3B, 0xA0, 0x48, 0xAF, 0xE6, 0xDC, 0xAE, 0xBA, + 0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF }, + .result = 1, + }; + BigIntFromLeBytes from_le_bytes_constraint_bigint2{ + .inputs = { 2 }, + .modulus = { 0x41, 0x41, 0x36, 0xD0, 0x8C, 0x5E, 0xD2, 0xBF, 0x3B, 0xA0, 0x48, 0xAF, 0xE6, 0xDC, 0xAE, 0xBA, + 0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF }, + .result = 2, + }; + + BigIntToLeBytes result3_to_le_bytes{ + .input = 3, .result = { 3 }, // + }; + + AcirFormat constraint_system{ + .varnum = 5, + .recursive = false, + .public_inputs = {}, + .logic_constraints = {}, + .range_constraints = {}, + .sha256_constraints = {}, + .sha256_compression = {}, + .schnorr_constraints = {}, + .ecdsa_k1_constraints = {}, + .ecdsa_r1_constraints = {}, + .blake2s_constraints = {}, + .blake3_constraints = {}, + .keccak_constraints = {}, + .keccak_var_constraints = {}, + .keccak_permutations = {}, + .pedersen_constraints = {}, + .pedersen_hash_constraints = {}, + .poseidon2_constraints = {}, + .fixed_base_scalar_mul_constraints = {}, + .ec_add_constraints = {}, + .recursion_constraints = {}, + .bigint_from_le_bytes_constraints = { from_le_bytes_constraint_bigint1, from_le_bytes_constraint_bigint2 }, + .bigint_to_le_bytes_constraints = { result3_to_le_bytes }, + .bigint_operations = { div_constraint }, + .constraints = {}, + .block_constraints = {}, + + }; + + WitnessVector witness{ + 0, 6, 3, 2, 0, + }; + auto builder = create_circuit(constraint_system, /*size_hint*/ 0, witness); + auto composer = Composer(); + auto prover = composer.create_ultra_with_keccak_prover(builder); + auto proof = prover.construct_proof(); + EXPECT_TRUE(CircuitChecker::check(builder)); + + auto builder2 = create_circuit(constraint_system, /*size_hint*/ 0, WitnessVector{}); + EXPECT_TRUE(CircuitChecker::check(builder)); + auto verifier2 = composer.create_ultra_with_keccak_verifier(builder); + EXPECT_EQ(verifier2.verify_proof(proof), true); +} } // namespace acir_format::tests \ No newline at end of file