From 79beef1fdcd40bf0906275bd615441b05d7b9bae Mon Sep 17 00:00:00 2001 From: Ilyas Ridhuan Date: Thu, 23 May 2024 16:14:19 +0100 Subject: [PATCH] test(avm): AVM Minimial lookup table for testing (#6641) Please read [contributing guidelines](CONTRIBUTING.md) and remove this line. --- .../barretenberg/vm/avm_trace/avm_trace.cpp | 248 +++++++++++++----- .../barretenberg/vm/avm_trace/avm_trace.hpp | 2 +- .../vm/tests/avm_inter_table.test.cpp | 9 +- .../barretenberg/vm/tests/helpers.test.cpp | 1 - 4 files changed, 185 insertions(+), 75 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.cpp b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.cpp index 0c5ee151acc..b05968b77ca 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.cpp @@ -1887,6 +1887,102 @@ void AvmTraceBuilder::finalise_mem_trace_lookup_counts() } } +// WARNING: FOR TESTING ONLY +// Generates the minimal lookup table for the binary trace +uint32_t finalize_bin_trace_lookup_for_testing(std::vector& main_trace, AvmBinaryTraceBuilder& bin_trace_builder) +{ + // Generate ByteLength Lookup table of instruction tags to the number of bytes + // {U8: 1, U16: 2, U32: 4, U64: 8, U128: 16} + for (auto const& [clk, count] : bin_trace_builder.byte_operation_counter) { + // from the clk we can derive the a and b inputs + auto b = static_cast(clk); + auto a = static_cast(clk >> 8); + auto op_id = static_cast(clk >> 16); + uint8_t bit_op = 0; + if (op_id == 0) { + bit_op = a & b; + } else if (op_id == 1) { + bit_op = a | b; + } else { + bit_op = a ^ b; + } + if (clk > (main_trace.size() - 1)) { + main_trace.push_back(Row{ + .avm_main_clk = FF(clk), + .avm_byte_lookup_bin_sel = FF(1), + .avm_byte_lookup_table_input_a = a, + .avm_byte_lookup_table_input_b = b, + .avm_byte_lookup_table_op_id = op_id, + .avm_byte_lookup_table_output = bit_op, + .lookup_byte_operations_counts = count, + }); + } else { + main_trace.at(clk).lookup_byte_operations_counts = count; + main_trace.at(clk).avm_byte_lookup_bin_sel = FF(1); + main_trace.at(clk).avm_byte_lookup_table_op_id = op_id; + main_trace.at(clk).avm_byte_lookup_table_input_a = a; + main_trace.at(clk).avm_byte_lookup_table_input_b = b; + main_trace.at(clk).avm_byte_lookup_table_output = bit_op; + } + // Add the counter value stored throughout the execution + } + return static_cast(main_trace.size()); +} + +// WARNING: FOR TESTING ONLY +// Generates the lookup table for the range checks without doing a full 2**16 rows +uint32_t finalize_rng_chks_for_testing(std::vector& main_trace, + AvmAluTraceBuilder& alu_trace_builder, + AvmMemTraceBuilder& mem_trace_builder, + const std::unordered_map& mem_rng_check_lo_counts, + const std::unordered_map& mem_rng_check_mid_counts, + std::unordered_map mem_rng_check_hi_counts) +{ + // Build the main_trace, and add any new rows with specific clks that line up with lookup reads + + // Is there a "spread-like" operator in cpp or can I make it generric of the first param of the unordered map + std::vector> u8_rng_chks = { alu_trace_builder.u8_range_chk_counters[0], + alu_trace_builder.u8_range_chk_counters[1], + alu_trace_builder.u8_pow_2_counters[0], + alu_trace_builder.u8_pow_2_counters[1], + std::move(mem_rng_check_hi_counts) }; + + auto custom_clk = std::set{}; + for (auto const& row : u8_rng_chks) { + for (auto const& [key, value] : row) { + custom_clk.insert(key); + } + } + for (auto const& row : alu_trace_builder.u16_range_chk_counters) { + for (auto const& [key, value] : row) { + custom_clk.insert(key); + } + } + for (auto const& row : alu_trace_builder.div_u64_range_chk_counters) { + for (auto const& [key, value] : row) { + custom_clk.insert(key); + } + } + for (auto const& [key, value] : mem_rng_check_lo_counts) { + custom_clk.insert(key); + } + for (auto const& [key, value] : mem_rng_check_mid_counts) { + custom_clk.insert(key); + } + + for (auto const& [clk, count] : mem_trace_builder.m_tag_err_lookup_counts) { + custom_clk.insert(clk); + } + + auto old_size = main_trace.size() - 1; + for (auto const& clk : custom_clk) { + if (clk > old_size) { + main_trace.push_back(Row{ .avm_main_clk = FF(clk) }); + } + } + return static_cast(main_trace.size()); +} + /** * @brief Finalisation of the memory trace and incorporating it to the main trace. * In particular, sorting the memory trace, setting .m_lastAccess and @@ -1895,10 +1991,8 @@ void AvmTraceBuilder::finalise_mem_trace_lookup_counts() * * @return The main trace */ -std::vector AvmTraceBuilder::finalize() +std::vector AvmTraceBuilder::finalize(uint32_t min_trace_size, bool range_check_required) { - // bool const range_check_required = alu_trace_builder.is_range_check_required(); - bool const range_check_required = true; auto mem_trace = mem_trace_builder.finalize(); auto alu_trace = alu_trace_builder.finalize(); auto conv_trace = conversion_trace_builder.finalize(); @@ -1910,7 +2004,9 @@ std::vector AvmTraceBuilder::finalize() size_t bin_trace_size = bin_trace.size(); // Get tag_err counts from the mem_trace_builder - finalise_mem_trace_lookup_counts(); + if (range_check_required) { + finalise_mem_trace_lookup_counts(); + } // Data structure to collect all lookup counts pertaining to 16-bit/32-bit range checks in memory trace std::unordered_map mem_rng_check_lo_counts; @@ -1920,10 +2016,10 @@ std::vector AvmTraceBuilder::finalize() // Main Trace needs to be at least as big as the biggest subtrace. // If the bin_trace_size has entries, we need the main_trace to be as big as our byte lookup table (3 * // 2**16 long) - size_t const lookup_table_size = bin_trace_size > 0 ? 3 * (1 << 16) : 0; + size_t const lookup_table_size = (bin_trace_size > 0 && range_check_required) ? 3 * (1 << 16) : 0; size_t const range_check_size = range_check_required ? UINT16_MAX + 1 : 0; - std::vector trace_sizes = { mem_trace_size, main_trace_size, alu_trace_size, lookup_table_size, - range_check_size, conv_trace_size, KERNEL_INPUTS_LENGTH }; + std::vector trace_sizes = { mem_trace_size, main_trace_size, alu_trace_size, lookup_table_size, + range_check_size, conv_trace_size, KERNEL_INPUTS_LENGTH, min_trace_size }; auto trace_size = std::max_element(trace_sizes.begin(), trace_sizes.end()); // We only need to pad with zeroes to the size to the largest trace here, pow_2 padding is handled in the @@ -2185,7 +2281,16 @@ std::vector AvmTraceBuilder::finalize() } } - for (size_t i = 0; i < main_trace_size; i++) { + auto new_trace_size = range_check_required ? main_trace_size + : finalize_rng_chks_for_testing(main_trace, + alu_trace_builder, + mem_trace_builder, + mem_rng_check_lo_counts, + mem_rng_check_mid_counts, + mem_rng_check_hi_counts); + auto old_trace_size = main_trace_size - 1; + + for (size_t i = 0; i < new_trace_size; i++) { auto& r = main_trace.at(i); if ((r.avm_main_sel_op_add == FF(1) || r.avm_main_sel_op_sub == FF(1) || r.avm_main_sel_op_mul == FF(1) || @@ -2202,49 +2307,49 @@ std::vector AvmTraceBuilder::finalize() r.avm_main_space_id = r.avm_main_call_ptr; }; - if (i <= UINT8_MAX) { - r.lookup_u8_0_counts = alu_trace_builder.u8_range_chk_counters[0][static_cast(i)]; - r.lookup_u8_1_counts = alu_trace_builder.u8_range_chk_counters[1][static_cast(i)]; - r.lookup_pow_2_0_counts = alu_trace_builder.u8_pow_2_counters[0][static_cast(i)]; - r.lookup_pow_2_1_counts = alu_trace_builder.u8_pow_2_counters[1][static_cast(i)]; - r.lookup_mem_rng_chk_hi_counts = mem_rng_check_hi_counts[static_cast(i)]; + r.avm_main_clk = i > old_trace_size ? r.avm_main_clk : FF(i); + auto counter = i > old_trace_size ? static_cast(r.avm_main_clk) : static_cast(i); + r.incl_main_tag_err_counts = mem_trace_builder.m_tag_err_lookup_counts[static_cast(counter)]; + if (counter <= UINT8_MAX) { + r.lookup_u8_0_counts = alu_trace_builder.u8_range_chk_counters[0][static_cast(counter)]; + r.lookup_u8_1_counts = alu_trace_builder.u8_range_chk_counters[1][static_cast(counter)]; + r.lookup_pow_2_0_counts = alu_trace_builder.u8_pow_2_counters[0][static_cast(counter)]; + r.lookup_pow_2_1_counts = alu_trace_builder.u8_pow_2_counters[1][static_cast(counter)]; + r.lookup_mem_rng_chk_hi_counts = mem_rng_check_hi_counts[static_cast(counter)]; r.avm_main_sel_rng_8 = FF(1); - r.avm_main_table_pow_2 = uint256_t(1) << uint256_t(i); + r.avm_main_table_pow_2 = uint256_t(1) << uint256_t(counter); } - - if (i <= UINT16_MAX) { + if (counter <= UINT16_MAX) { // We add to the clk here in case our trace is smaller than our range checks // There might be a cleaner way to do this in the future as this only applies // when our trace (excluding range checks) is < 2**16 - r.lookup_u16_0_counts = alu_trace_builder.u16_range_chk_counters[0][static_cast(i)]; - r.lookup_u16_1_counts = alu_trace_builder.u16_range_chk_counters[1][static_cast(i)]; - r.lookup_u16_2_counts = alu_trace_builder.u16_range_chk_counters[2][static_cast(i)]; - r.lookup_u16_3_counts = alu_trace_builder.u16_range_chk_counters[3][static_cast(i)]; - r.lookup_u16_4_counts = alu_trace_builder.u16_range_chk_counters[4][static_cast(i)]; - r.lookup_u16_5_counts = alu_trace_builder.u16_range_chk_counters[5][static_cast(i)]; - r.lookup_u16_6_counts = alu_trace_builder.u16_range_chk_counters[6][static_cast(i)]; - r.lookup_u16_7_counts = alu_trace_builder.u16_range_chk_counters[7][static_cast(i)]; - r.lookup_u16_8_counts = alu_trace_builder.u16_range_chk_counters[8][static_cast(i)]; - r.lookup_u16_9_counts = alu_trace_builder.u16_range_chk_counters[9][static_cast(i)]; - r.lookup_u16_10_counts = alu_trace_builder.u16_range_chk_counters[10][static_cast(i)]; - r.lookup_u16_11_counts = alu_trace_builder.u16_range_chk_counters[11][static_cast(i)]; - r.lookup_u16_12_counts = alu_trace_builder.u16_range_chk_counters[12][static_cast(i)]; - r.lookup_u16_13_counts = alu_trace_builder.u16_range_chk_counters[13][static_cast(i)]; - r.lookup_u16_14_counts = alu_trace_builder.u16_range_chk_counters[14][static_cast(i)]; - - r.lookup_mem_rng_chk_mid_counts = mem_rng_check_mid_counts[static_cast(i)]; - r.lookup_mem_rng_chk_lo_counts = mem_rng_check_lo_counts[static_cast(i)]; - - r.lookup_div_u16_0_counts = alu_trace_builder.div_u64_range_chk_counters[0][static_cast(i)]; - r.lookup_div_u16_1_counts = alu_trace_builder.div_u64_range_chk_counters[1][static_cast(i)]; - r.lookup_div_u16_2_counts = alu_trace_builder.div_u64_range_chk_counters[2][static_cast(i)]; - r.lookup_div_u16_3_counts = alu_trace_builder.div_u64_range_chk_counters[3][static_cast(i)]; - r.lookup_div_u16_4_counts = alu_trace_builder.div_u64_range_chk_counters[4][static_cast(i)]; - r.lookup_div_u16_5_counts = alu_trace_builder.div_u64_range_chk_counters[5][static_cast(i)]; - r.lookup_div_u16_6_counts = alu_trace_builder.div_u64_range_chk_counters[6][static_cast(i)]; - r.lookup_div_u16_7_counts = alu_trace_builder.div_u64_range_chk_counters[7][static_cast(i)]; - - r.avm_main_clk = FF(static_cast(i)); + r.lookup_u16_0_counts = alu_trace_builder.u16_range_chk_counters[0][static_cast(counter)]; + r.lookup_u16_1_counts = alu_trace_builder.u16_range_chk_counters[1][static_cast(counter)]; + r.lookup_u16_2_counts = alu_trace_builder.u16_range_chk_counters[2][static_cast(counter)]; + r.lookup_u16_3_counts = alu_trace_builder.u16_range_chk_counters[3][static_cast(counter)]; + r.lookup_u16_4_counts = alu_trace_builder.u16_range_chk_counters[4][static_cast(counter)]; + r.lookup_u16_5_counts = alu_trace_builder.u16_range_chk_counters[5][static_cast(counter)]; + r.lookup_u16_6_counts = alu_trace_builder.u16_range_chk_counters[6][static_cast(counter)]; + r.lookup_u16_7_counts = alu_trace_builder.u16_range_chk_counters[7][static_cast(counter)]; + r.lookup_u16_8_counts = alu_trace_builder.u16_range_chk_counters[8][static_cast(counter)]; + r.lookup_u16_9_counts = alu_trace_builder.u16_range_chk_counters[9][static_cast(counter)]; + r.lookup_u16_10_counts = alu_trace_builder.u16_range_chk_counters[10][static_cast(counter)]; + r.lookup_u16_11_counts = alu_trace_builder.u16_range_chk_counters[11][static_cast(counter)]; + r.lookup_u16_12_counts = alu_trace_builder.u16_range_chk_counters[12][static_cast(counter)]; + r.lookup_u16_13_counts = alu_trace_builder.u16_range_chk_counters[13][static_cast(counter)]; + r.lookup_u16_14_counts = alu_trace_builder.u16_range_chk_counters[14][static_cast(counter)]; + + r.lookup_mem_rng_chk_mid_counts = mem_rng_check_mid_counts[static_cast(counter)]; + r.lookup_mem_rng_chk_lo_counts = mem_rng_check_lo_counts[static_cast(counter)]; + + r.lookup_div_u16_0_counts = alu_trace_builder.div_u64_range_chk_counters[0][static_cast(counter)]; + r.lookup_div_u16_1_counts = alu_trace_builder.div_u64_range_chk_counters[1][static_cast(counter)]; + r.lookup_div_u16_2_counts = alu_trace_builder.div_u64_range_chk_counters[2][static_cast(counter)]; + r.lookup_div_u16_3_counts = alu_trace_builder.div_u64_range_chk_counters[3][static_cast(counter)]; + r.lookup_div_u16_4_counts = alu_trace_builder.div_u64_range_chk_counters[4][static_cast(counter)]; + r.lookup_div_u16_5_counts = alu_trace_builder.div_u64_range_chk_counters[5][static_cast(counter)]; + r.lookup_div_u16_6_counts = alu_trace_builder.div_u64_range_chk_counters[6][static_cast(counter)]; + r.lookup_div_u16_7_counts = alu_trace_builder.div_u64_range_chk_counters[7][static_cast(counter)]; r.avm_main_sel_rng_16 = FF(1); } } @@ -2281,29 +2386,33 @@ std::vector AvmTraceBuilder::finalize() // Only generate precomputed byte tables if we are actually going to use them in this main trace. if (bin_trace_size > 0) { - // Generate Lookup Table of all combinations of 2, 8-bit numbers and op_id. - for (size_t op_id = 0; op_id < 3; op_id++) { - for (size_t input_a = 0; input_a <= UINT8_MAX; input_a++) { - for (size_t input_b = 0; input_b <= UINT8_MAX; input_b++) { - auto a = static_cast(input_a); - auto b = static_cast(input_b); - - // Derive a unique row index given op_id, a, and b. - auto main_trace_index = static_cast((op_id << 16) + (input_a << 8) + b); - - main_trace.at(main_trace_index).avm_byte_lookup_bin_sel = FF(1); - main_trace.at(main_trace_index).avm_byte_lookup_table_op_id = op_id; - main_trace.at(main_trace_index).avm_byte_lookup_table_input_a = a; - main_trace.at(main_trace_index).avm_byte_lookup_table_input_b = b; - // Add the counter value stored throughout the execution - main_trace.at(main_trace_index).lookup_byte_operations_counts = - bin_trace_builder.byte_operation_counter[main_trace_index]; - if (op_id == 0) { - main_trace.at(main_trace_index).avm_byte_lookup_table_output = a & b; - } else if (op_id == 1) { - main_trace.at(main_trace_index).avm_byte_lookup_table_output = a | b; - } else { - main_trace.at(main_trace_index).avm_byte_lookup_table_output = a ^ b; + if (!range_check_required) { + finalize_bin_trace_lookup_for_testing(main_trace, bin_trace_builder); + } else { + // Generate Lookup Table of all combinations of 2, 8-bit numbers and op_id. + for (size_t op_id = 0; op_id < 3; op_id++) { + for (size_t input_a = 0; input_a <= UINT8_MAX; input_a++) { + for (size_t input_b = 0; input_b <= UINT8_MAX; input_b++) { + auto a = static_cast(input_a); + auto b = static_cast(input_b); + + // Derive a unique row index given op_id, a, and b. + auto main_trace_index = static_cast((op_id << 16) + (input_a << 8) + b); + + main_trace.at(main_trace_index).avm_byte_lookup_bin_sel = FF(1); + main_trace.at(main_trace_index).avm_byte_lookup_table_op_id = op_id; + main_trace.at(main_trace_index).avm_byte_lookup_table_input_a = a; + main_trace.at(main_trace_index).avm_byte_lookup_table_input_b = b; + // Add the counter value stored throughout the execution + main_trace.at(main_trace_index).lookup_byte_operations_counts = + bin_trace_builder.byte_operation_counter[main_trace_index]; + if (op_id == 0) { + main_trace.at(main_trace_index).avm_byte_lookup_table_output = a & b; + } else if (op_id == 1) { + main_trace.at(main_trace_index).avm_byte_lookup_table_output = a | b; + } else { + main_trace.at(main_trace_index).avm_byte_lookup_table_output = a ^ b; + } } } } @@ -2313,6 +2422,7 @@ std::vector AvmTraceBuilder::finalize() for (uint8_t avm_in_tag = 0; avm_in_tag < 5; avm_in_tag++) { // The +1 here is because the instruction tags we care about (i.e excl U0 and FF) has the range // [1,5] + main_trace.at(avm_in_tag).avm_byte_lookup_bin_sel = FF(1); main_trace.at(avm_in_tag).avm_byte_lookup_table_in_tags = avm_in_tag + 1; main_trace.at(avm_in_tag).avm_byte_lookup_table_byte_lengths = static_cast(pow(2, avm_in_tag)); main_trace.at(avm_in_tag).lookup_byte_lengths_counts = diff --git a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.hpp b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.hpp index 40655884d7d..36baa3b76fd 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.hpp @@ -25,7 +25,7 @@ class AvmTraceBuilder { public: AvmTraceBuilder(std::array kernel_inputs = {}); - std::vector finalize(); + std::vector finalize(uint32_t min_trace_size = 0, bool range_check_required = false); void reset(); uint32_t getPc() const { return pc; } diff --git a/barretenberg/cpp/src/barretenberg/vm/tests/avm_inter_table.test.cpp b/barretenberg/cpp/src/barretenberg/vm/tests/avm_inter_table.test.cpp index 8f8eb327cf9..a13e69df23d 100644 --- a/barretenberg/cpp/src/barretenberg/vm/tests/avm_inter_table.test.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/tests/avm_inter_table.test.cpp @@ -142,13 +142,14 @@ class AvmRangeCheckNegativeTests : public AvmInterTableTests { size_t mem_idx; size_t alu_idx; - void genTraceAdd(uint128_t const& a, uint128_t const& b, uint128_t const& c, AvmMemoryTag tag) + void genTraceAdd( + uint128_t const& a, uint128_t const& b, uint128_t const& c, AvmMemoryTag tag, uint32_t min_trace_size = 0) { trace_builder.op_set(0, a, 0, tag); trace_builder.op_set(0, b, 1, tag); trace_builder.op_add(0, 0, 1, 2, tag); // 7 + 8 = 15 trace_builder.return_op(0, 0, 0); - trace = trace_builder.finalize(); + trace = trace_builder.finalize(min_trace_size); // Find the row with addition operation and retrieve clk. auto row = @@ -254,7 +255,7 @@ TEST_F(AvmRangeCheckNegativeTests, additionU8Reg1) // Out-of-range value in register u16_r0 TEST_F(AvmRangeCheckNegativeTests, additionU16Reg0) { - genTraceAdd(1200, 2000, 3200, AvmMemoryTag::U16); + genTraceAdd(1200, 2000, 3200, AvmMemoryTag::U16, 130); auto& row = trace.at(main_idx); auto& mem_row = trace.at(mem_idx); auto& alu_row = trace.at(alu_idx); @@ -649,4 +650,4 @@ TEST_F(AvmPermMainMemNegativeTests, wrongClkIcInMem) EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "PERM_MAIN_MEM_C"); } -} // namespace tests_avm \ No newline at end of file +} // namespace tests_avm diff --git a/barretenberg/cpp/src/barretenberg/vm/tests/helpers.test.cpp b/barretenberg/cpp/src/barretenberg/vm/tests/helpers.test.cpp index 5f5cac0afeb..48871af5d5b 100644 --- a/barretenberg/cpp/src/barretenberg/vm/tests/helpers.test.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/tests/helpers.test.cpp @@ -37,7 +37,6 @@ void validate_trace(std::vector&& trace, std::array