From db139794585f961572a986b4a2793f8d3429c1e7 Mon Sep 17 00:00:00 2001 From: jeanmon Date: Mon, 9 Sep 2024 06:57:48 +0000 Subject: [PATCH] 8285: unit test for dsl avm recursive verifier --- .../dsl/acir_format/acir_format_mocks.cpp | 3 + .../avm_recursion_constraint.test.cpp | 140 ++++++++++++++++++ .../recursion/avm_recursive_verifier.test.cpp | 8 +- 3 files changed, 147 insertions(+), 4 deletions(-) create mode 100644 barretenberg/cpp/src/barretenberg/dsl/acir_format/avm_recursion_constraint.test.cpp diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format_mocks.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format_mocks.cpp index 2d392f0a88c..e1240d85e82 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format_mocks.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format_mocks.cpp @@ -94,6 +94,9 @@ void mock_opcode_indices(acir_format::AcirFormat& constraint_system) for (size_t i = 0; i < constraint_system.honk_recursion_constraints.size(); i++) { constraint_system.original_opcode_indices.honk_recursion_constraints.push_back(current_opcode++); } + for (size_t i = 0; i < constraint_system.avm_recursion_constraints.size(); i++) { + constraint_system.original_opcode_indices.avm_recursion_constraints.push_back(current_opcode++); + } for (size_t i = 0; i < constraint_system.ivc_recursion_constraints.size(); i++) { constraint_system.original_opcode_indices.ivc_recursion_constraints.push_back(current_opcode++); } diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/avm_recursion_constraint.test.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/avm_recursion_constraint.test.cpp new file mode 100644 index 00000000000..30e65fbb4b2 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/avm_recursion_constraint.test.cpp @@ -0,0 +1,140 @@ +#ifndef DISABLE_AZTEC_VM + +#include "avm_recursion_constraint.hpp" +#include "acir_format.hpp" +#include "acir_format_mocks.hpp" +#include "barretenberg/stdlib/primitives/circuit_builders/circuit_builders_fwd.hpp" +#include "barretenberg/sumcheck/instance/prover_instance.hpp" +#include "barretenberg/ultra_honk/ultra_prover.hpp" +#include "barretenberg/ultra_honk/ultra_verifier.hpp" +#include "barretenberg/vm/avm/generated/circuit_builder.hpp" +#include "barretenberg/vm/avm/generated/composer.hpp" +#include "barretenberg/vm/avm/generated/flavor.hpp" +#include "barretenberg/vm/avm/generated/prover.hpp" +#include "barretenberg/vm/avm/generated/verifier.hpp" +#include "barretenberg/vm/avm/tests/helpers.test.hpp" +#include "barretenberg/vm/avm/trace/trace.hpp" +#include "proof_surgeon.hpp" +#include +#include +#include + +using namespace acir_format; +using namespace bb; +using namespace bb::avm_trace; + +class AcirAvmRecursionConstraint : public ::testing::Test { + public: + using InnerBuilder = AvmCircuitBuilder; + using InnerProver = AvmProver; + using InnerVerifier = AvmVerifier; + + using OuterProver = UltraProver; + using OuterVerifier = UltraVerifier; + using OuterDeciderProvingKey = DeciderProvingKey_; + + using DeciderProvingKey = DeciderProvingKey_; + using OuterVerificationKey = UltraFlavor::VerificationKey; + using OuterBuilder = UltraCircuitBuilder; + + static void SetUpTestSuite() { bb::srs::init_crs_factory("../srs_db/ignition"); } + + static InnerBuilder create_inner_circuit() + { + VmPublicInputs public_inputs; + std::array kernel_inputs{}; + kernel_inputs.at(DA_GAS_LEFT_CONTEXT_INPUTS_OFFSET) = 1000000; + kernel_inputs.at(L2_GAS_LEFT_CONTEXT_INPUTS_OFFSET) = 1000000; + std::get<0>(public_inputs) = kernel_inputs; + + AvmTraceBuilder trace_builder(public_inputs); + InnerBuilder builder; + + trace_builder.op_set(0, 15, 1, AvmMemoryTag::U8); + trace_builder.op_set(0, 12, 2, AvmMemoryTag::U8); + trace_builder.op_add(0, 1, 2, 3, AvmMemoryTag::U8); + trace_builder.op_sub(0, 3, 2, 3, AvmMemoryTag::U8); + trace_builder.op_mul(0, 1, 1, 3, AvmMemoryTag::U8); + trace_builder.op_return(0, 0, 0); + auto trace = trace_builder.finalize(); // Passing true enables a longer trace with lookups + + builder.set_trace(std::move(trace)); + builder.check_circuit(); + return builder; + } + + /** + * @brief Create a circuit that recursively verifies one or more inner avm circuits + */ + static OuterBuilder create_outer_circuit(std::vector& inner_avm_circuits) + { + std::vector avm_recursion_constraints; + + SlabVector witness; + + for (auto& avm_circuit : inner_avm_circuits) { + AvmComposer composer = AvmComposer(); + InnerProver prover = composer.create_prover(avm_circuit); + InnerVerifier verifier = composer.create_verifier(avm_circuit); + + std::vector key_witnesses = verifier.key->to_field_elements(); + std::vector proof_witnesses = prover.construct_proof(); + // const size_t num_public_inputs = verifier.key->num_public_inputs; + + // Helper to append some values to the witness vector and return their corresponding indices + auto add_to_witness_and_track_indices = + [&witness](const std::vector& input) -> std::vector { + std::vector indices; + indices.reserve(input.size()); + auto witness_idx = static_cast(witness.size()); + for (const auto& value : input) { + witness.push_back(value); + indices.push_back(witness_idx++); + } + return indices; + }; + + RecursionConstraint avm_recursion_constraint{ + .key = add_to_witness_and_track_indices(key_witnesses), + .proof = add_to_witness_and_track_indices(proof_witnesses), + .public_inputs = {}, + .key_hash = 0, // not used + .proof_type = AVM, + }; + avm_recursion_constraints.push_back(avm_recursion_constraint); + } + + std::vector avm_recursion_opcode_indices(avm_recursion_constraints.size()); + std::iota(avm_recursion_opcode_indices.begin(), avm_recursion_opcode_indices.end(), 0); + + AcirFormat constraint_system; + constraint_system.varnum = static_cast(witness.size()); + constraint_system.recursive = false; + constraint_system.num_acir_opcodes = static_cast(avm_recursion_constraints.size()); + constraint_system.avm_recursion_constraints = avm_recursion_constraints; + constraint_system.original_opcode_indices = create_empty_original_opcode_indices(); + + mock_opcode_indices(constraint_system); + auto outer_circuit = create_circuit(constraint_system, /*size_hint*/ 0, witness); + return outer_circuit; + } +}; + +TEST_F(AcirAvmRecursionConstraint, TestBasicSingleAvmRecursionConstraint) +{ + std::vector layer_1_circuits; + layer_1_circuits.push_back(create_inner_circuit()); + auto layer_2_circuit = create_outer_circuit(layer_1_circuits); + + info("circuit gates = ", layer_2_circuit.get_num_gates()); + + auto proving_key = std::make_shared(layer_2_circuit); + OuterProver prover(proving_key); + info("prover gates = ", proving_key->proving_key.circuit_size); + auto proof = prover.construct_proof(); + auto verification_key = std::make_shared(proving_key->proving_key); + OuterVerifier verifier(verification_key); + EXPECT_EQ(verifier.verify_proof(proof), true); +} + +#endif // DISABLE_AZTEC_VM \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/recursion/avm_recursive_verifier.test.cpp b/barretenberg/cpp/src/barretenberg/vm/avm/recursion/avm_recursive_verifier.test.cpp index a0a3a421b0d..946cd2e15fb 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/recursion/avm_recursive_verifier.test.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/recursion/avm_recursive_verifier.test.cpp @@ -35,8 +35,8 @@ class AvmRecursiveTests : public ::testing::Test { using RecursiveVerifier = AvmRecursiveVerifier_; using OuterBuilder = typename RecursiveFlavor::CircuitBuilder; - using OuterProver = UltraProver_; - using OuterVerifier = UltraVerifier_; + using OuterProver = UltraProver; + using OuterVerifier = UltraVerifier; using OuterDeciderProvingKey = DeciderProvingKey_; static void SetUpTestSuite() { bb::srs::init_crs_factory("../srs_db/ignition"); } @@ -65,8 +65,8 @@ TEST_F(AvmRecursiveTests, recursion) { AvmCircuitBuilder circuit_builder = generate_avm_circuit(); AvmComposer composer = AvmComposer(); - AvmProver prover = composer.create_prover(circuit_builder); - AvmVerifier verifier = composer.create_verifier(circuit_builder); + InnerProver prover = composer.create_prover(circuit_builder); + InnerVerifier verifier = composer.create_verifier(circuit_builder); HonkProof proof = prover.construct_proof();