From 42a0383f45f038dbb0d0e5581689382ed42fc89f Mon Sep 17 00:00:00 2001 From: IlyasRidhuan Date: Thu, 13 Jun 2024 17:08:12 +0000 Subject: [PATCH] feat(avm): cpp msm changes --- .../vm/avm_trace/avm_execution.cpp | 7 + .../barretenberg/vm/avm_trace/avm_opcode.cpp | 4 + .../barretenberg/vm/avm_trace/avm_trace.cpp | 251 ++++++++++++++++++ .../barretenberg/vm/avm_trace/avm_trace.hpp | 5 + .../vm/avm_trace/gadgets/avm_ecc.cpp | 17 ++ .../vm/avm_trace/gadgets/avm_ecc.hpp | 9 +- .../vm/tests/avm_execution.test.cpp | 76 ++++++ 7 files changed, 366 insertions(+), 3 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_execution.cpp b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_execution.cpp index a3447e24c6e..f3d2dac5f2d 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_execution.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_execution.cpp @@ -677,6 +677,13 @@ std::vector Execution::gen_trace(std::vector const& instructio std::get(inst.operands.at(6)), std::get(inst.operands.at(7))); break; + case OpCode::MSM: + trace_builder.op_variable_msm(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(4))); + break; case OpCode::REVERT: trace_builder.op_revert(std::get(inst.operands.at(0)), std::get(inst.operands.at(1)), diff --git a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_opcode.cpp b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_opcode.cpp index 8ffe817e23a..2439fd4e0a2 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_opcode.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_opcode.cpp @@ -144,6 +144,10 @@ std::string to_string(OpCode opcode) return "SHA256"; case OpCode::PEDERSEN: return "PEDERSEN"; + case OpCode::ECADD: + return "ECADD"; + case OpCode::MSM: + return "MSM"; case OpCode::TORADIXLE: return "TORADIXLE"; case OpCode::SHA256COMPRESSION: diff --git a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.cpp b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.cpp index 5c1ac86c7c3..4ee754de507 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.cpp @@ -14,7 +14,9 @@ #include #include "barretenberg/common/throw_or_abort.hpp" +#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" #include "barretenberg/numeric/uint256/uint256.hpp" +#include "barretenberg/polynomials/univariate.hpp" #include "barretenberg/vm/avm_trace/avm_common.hpp" #include "barretenberg/vm/avm_trace/avm_helper.hpp" #include "barretenberg/vm/avm_trace/avm_opcode.hpp" @@ -3632,6 +3634,255 @@ void AvmTraceBuilder::op_ec_add(uint8_t indirect, FF(internal_return_ptr), { result.is_point_at_infinity() }); } + +// This function is a bit overloaded with logic around reconstructing points and scalars that could probably be moved to +// the gadget at some stage (although this is another temporary gadget..) +void AvmTraceBuilder::op_variable_msm(uint8_t indirect, + uint32_t points_offset, + uint32_t scalars_offset, + uint32_t output_offset, + uint32_t point_length_offset) +{ + auto clk = static_cast(main_trace.size()) + 1; + // This will all get refactored as part of the indirection refactor + bool tag_match = true; + uint32_t direct_points_offset = points_offset; + uint32_t direct_scalars_offset = scalars_offset; + uint32_t direct_output_offset = output_offset; + // Resolve the indirects + bool indirect_points_flag = is_operand_indirect(indirect, 0); + bool indirect_scalars_flag = is_operand_indirect(indirect, 1); + bool indirect_output_flag = is_operand_indirect(indirect, 2); + + // Read in the points first + if (indirect_points_flag) { + auto read_ind_a = + mem_trace_builder.indirect_read_and_load_from_memory(call_ptr, clk, IndirectRegister::IND_A, points_offset); + direct_points_offset = uint32_t(read_ind_a.val); + tag_match = tag_match && read_ind_a.tag_match; + } + + auto read_points = mem_trace_builder.read_and_load_from_memory( + call_ptr, clk, IntermRegister::IA, direct_points_offset, AvmMemoryTag::FF, AvmMemoryTag::U0); + + // Read in the scalars + if (indirect_scalars_flag) { + auto read_ind_b = mem_trace_builder.indirect_read_and_load_from_memory( + call_ptr, clk, IndirectRegister::IND_B, scalars_offset); + direct_scalars_offset = uint32_t(read_ind_b.val); + tag_match = tag_match && read_ind_b.tag_match; + } + auto read_scalars = mem_trace_builder.read_and_load_from_memory( + call_ptr, clk, IntermRegister::IB, direct_scalars_offset, AvmMemoryTag::FF, AvmMemoryTag::U0); + + // In the refactor we will have the read_slice function handle indirects as well + main_trace.push_back(Row{ + .avm_main_clk = clk, + .avm_main_ia = read_points.val, + .avm_main_ib = read_scalars.val, + .avm_main_ind_a = indirect_points_flag ? FF(points_offset) : FF(0), + .avm_main_ind_b = indirect_scalars_flag ? FF(scalars_offset) : FF(0), + .avm_main_ind_op_a = FF(static_cast(indirect_points_flag)), + .avm_main_ind_op_b = FF(static_cast(indirect_scalars_flag)), + .avm_main_internal_return_ptr = FF(internal_return_ptr), + .avm_main_mem_idx_a = FF(direct_points_offset), + .avm_main_mem_idx_b = FF(direct_scalars_offset), + .avm_main_mem_op_a = FF(1), + .avm_main_mem_op_b = FF(1), + .avm_main_pc = FF(pc++), + .avm_main_r_in_tag = FF(static_cast(AvmMemoryTag::FF)), + .avm_main_tag_err = FF(static_cast(!tag_match)), + }); + clk++; + + // Read the points length (different row since it has a different memory tag) + auto points_length_read = mem_trace_builder.read_and_load_from_memory( + call_ptr, clk, IntermRegister::IA, point_length_offset, AvmMemoryTag::U32, AvmMemoryTag::U0); + main_trace.push_back(Row{ + .avm_main_clk = clk, + .avm_main_ia = points_length_read.val, + .avm_main_internal_return_ptr = FF(internal_return_ptr), + .avm_main_mem_idx_a = FF(point_length_offset), + .avm_main_mem_op_a = FF(1), + .avm_main_pc = FF(pc), + .avm_main_r_in_tag = FF(static_cast(AvmMemoryTag::U32)), + .avm_main_tag_err = FF(static_cast(!points_length_read.tag_match)), + }); + clk++; + + // Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] with the types [FF, FF, U8, FF, FF, U8, ...] + uint32_t num_points = uint32_t(points_length_read.val) / 3; // 3 elements per point + // We need to split up the reads due to the memory tags, + std::vector points_coords_vec; + std::vector points_inf_vec; + std::vector scalars_vec; + // Read the coordinates first, +2 since we read 2 points per row + for (uint32_t i = 0; i < num_points; i += 2) { + // We can read up to 4 coordinates per row (x1,y1,x2,y2) + // Each pair of coordinates are separated by 3 memory addressess + auto point_x1_read = mem_trace_builder.read_and_load_from_memory( + call_ptr, clk, IntermRegister::IA, direct_points_offset + i * 3, AvmMemoryTag::FF, AvmMemoryTag::U0); + auto point_y1_read = mem_trace_builder.read_and_load_from_memory( + call_ptr, clk, IntermRegister::IB, direct_points_offset + i * 3 + 1, AvmMemoryTag::FF, AvmMemoryTag::U0); + auto point_x2_read = mem_trace_builder.read_and_load_from_memory( + call_ptr, clk, IntermRegister::IC, direct_points_offset + (i + 1) * 3, AvmMemoryTag::FF, AvmMemoryTag::U0); + auto point_y2_read = mem_trace_builder.read_and_load_from_memory(call_ptr, + clk, + IntermRegister::ID, + direct_points_offset + (i + 1) * 3 + 1, + AvmMemoryTag::FF, + AvmMemoryTag::U0); + bool tag_match = + point_x1_read.tag_match && point_y1_read.tag_match && point_x2_read.tag_match && point_y2_read.tag_match; + points_coords_vec.insert(points_coords_vec.end(), + { point_x1_read.val, point_y1_read.val, point_x2_read.val, point_y2_read.val }); + main_trace.push_back(Row{ + .avm_main_clk = clk, + .avm_main_ia = point_x1_read.val, + .avm_main_ib = point_y1_read.val, + .avm_main_ic = point_x2_read.val, + .avm_main_id = point_y2_read.val, + .avm_main_internal_return_ptr = FF(internal_return_ptr), + .avm_main_mem_idx_a = FF(direct_points_offset + i * 3), + .avm_main_mem_idx_b = FF(direct_points_offset + i * 3 + 1), + .avm_main_mem_idx_c = FF(direct_points_offset + (i + 1) * 3), + .avm_main_mem_idx_d = FF(direct_points_offset + (i + 1) * 3 + 1), + .avm_main_mem_op_a = FF(1), + .avm_main_mem_op_b = FF(1), + .avm_main_mem_op_c = FF(1), + .avm_main_mem_op_d = FF(1), + .avm_main_pc = FF(pc), + .avm_main_r_in_tag = FF(static_cast(AvmMemoryTag::FF)), + .avm_main_tag_err = FF(static_cast(!tag_match)), + }); + clk++; + } + // Read the Infinities flags, +4 since we read 4 points row + for (uint32_t i = 0; i < num_points; i += 4) { + // We can read up to 4 infinities per row + // Each infinity flag is separated by 3 memory addressess + uint32_t offset = direct_points_offset + i * 3 + 2; + auto point_inf1_read = mem_trace_builder.read_and_load_from_memory( + call_ptr, clk, IntermRegister::IA, offset, AvmMemoryTag::U8, AvmMemoryTag::U0); + offset += 3; + + auto point_inf2_read = mem_trace_builder.read_and_load_from_memory( + call_ptr, clk, IntermRegister::IB, offset, AvmMemoryTag::U8, AvmMemoryTag::U0); + offset += 3; + + auto point_inf3_read = mem_trace_builder.read_and_load_from_memory( + call_ptr, clk, IntermRegister::IC, offset, AvmMemoryTag::U8, AvmMemoryTag::U0); + offset += 3; + + auto point_inf4_read = mem_trace_builder.read_and_load_from_memory( + call_ptr, clk, IntermRegister::ID, offset, AvmMemoryTag::U8, AvmMemoryTag::U0); + + points_inf_vec.insert(points_inf_vec.end(), + { point_inf1_read.val, point_inf2_read.val, point_inf3_read.val, point_inf4_read.val }); + bool tag_match = point_inf1_read.tag_match && point_inf2_read.tag_match && point_inf3_read.tag_match && + point_inf4_read.tag_match; + main_trace.push_back(Row{ + .avm_main_clk = clk, + .avm_main_ia = point_inf1_read.val, + .avm_main_ib = point_inf2_read.val, + .avm_main_ic = point_inf3_read.val, + .avm_main_id = point_inf4_read.val, + .avm_main_internal_return_ptr = FF(internal_return_ptr), + .avm_main_mem_idx_a = FF(direct_points_offset + i * 3 + 2), + .avm_main_mem_idx_b = FF(direct_points_offset + (i + 1) * 3 + 2), + .avm_main_mem_idx_c = FF(direct_points_offset + (i + 2) * 3 + 2), + .avm_main_mem_idx_d = FF(direct_points_offset + (i + 3) * 3 + 2), + .avm_main_mem_op_a = FF(1), + .avm_main_mem_op_b = FF(1), + .avm_main_mem_op_c = FF(1), + .avm_main_mem_op_d = FF(1), + .avm_main_pc = FF(pc), + .avm_main_r_in_tag = FF(static_cast(AvmMemoryTag::U8)), + .avm_main_tag_err = FF(static_cast(!tag_match)), + }); + clk++; + } + // Scalar read length is num_points* 2 since scalars are stored as lo and hi limbs + uint32_t scalar_read_length = num_points * 2; + auto num_scalar_rows = read_slice_to_memory(call_ptr, + clk, + direct_scalars_offset, + AvmMemoryTag::FF, + AvmMemoryTag::U0, + FF(internal_return_ptr), + scalar_read_length, + scalars_vec); + clk += num_scalar_rows; + // Reconstruct Grumpkin points + std::vector points; + for (size_t i = 0; i < num_points; i++) { + grumpkin::g1::Fq x = points_coords_vec[i * 2]; + grumpkin::g1::Fq y = points_coords_vec[i * 2 + 1]; + bool is_inf = points_inf_vec[i] == 1; + if (is_inf) { + points.emplace_back(grumpkin::g1::affine_element::infinity()); + } else { + points.emplace_back(x, y); + } + } + // Reconstruct Grumpkin scalars + // Scalars are stored as [lo1, hi1, lo2, hi2, ...] with the types [FF, FF, FF, FF, ...] + std::vector scalars; + for (size_t i = 0; i < num_points; i++) { + FF lo = scalars_vec[i * 2]; + FF hi = scalars_vec[i * 2 + 1]; + // hi is shifted 128 bits + uint256_t scalar = (uint256_t(hi) << 128) + uint256_t(lo); + scalars.emplace_back(scalar); + } + // Perform the variable MSM - could just put the logic in here since there are no constraints. + auto result = ecc_trace_builder.variable_msm(points, scalars, clk); + // Write the result back to memory [x, y, inf] with tags [FF, FF, U8] + if (indirect_output_flag) { + auto read_ind_a = + mem_trace_builder.indirect_read_and_load_from_memory(call_ptr, clk, IndirectRegister::IND_A, output_offset); + direct_output_offset = uint32_t(read_ind_a.val); + } + mem_trace_builder.write_into_memory( + call_ptr, clk, IntermRegister::IA, direct_output_offset, result.x, AvmMemoryTag::U0, AvmMemoryTag::FF); + mem_trace_builder.write_into_memory( + call_ptr, clk, IntermRegister::IB, direct_output_offset + 1, result.y, AvmMemoryTag::U0, AvmMemoryTag::FF); + main_trace.push_back(Row{ + .avm_main_clk = clk, + .avm_main_ia = result.x, + .avm_main_ib = result.y, + .avm_main_ind_a = indirect_output_flag ? FF(output_offset) : FF(0), + .avm_main_ind_op_a = FF(static_cast(indirect_output_flag)), + .avm_main_internal_return_ptr = FF(internal_return_ptr), + .avm_main_mem_idx_a = FF(direct_output_offset), + .avm_main_mem_idx_b = FF(direct_output_offset + 1), + .avm_main_mem_op_a = FF(1), + .avm_main_mem_op_b = FF(1), + .avm_main_pc = FF(pc), + .avm_main_rwa = FF(1), + .avm_main_rwb = FF(1), + .avm_main_w_in_tag = FF(static_cast(AvmMemoryTag::FF)), + }); + clk++; + // Write the infinity + mem_trace_builder.write_into_memory(call_ptr, + clk, + IntermRegister::IA, + direct_output_offset + 2, + result.is_point_at_infinity(), + AvmMemoryTag::U0, + AvmMemoryTag::U8); + main_trace.push_back(Row{ + .avm_main_clk = clk, + .avm_main_ia = static_cast(result.is_point_at_infinity()), + .avm_main_internal_return_ptr = FF(internal_return_ptr), + .avm_main_mem_idx_a = FF(direct_output_offset + 2), + .avm_main_mem_op_a = FF(1), + .avm_main_pc = FF(pc), + .avm_main_rwa = FF(1), + .avm_main_w_in_tag = FF(static_cast(AvmMemoryTag::U8)), + }); +} // Finalise Lookup Counts // // For log derivative lookups, we require a column that contains the number of times each lookup is consumed diff --git a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.hpp b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.hpp index aa2ff3c9900..9c9b157b48f 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.hpp @@ -204,6 +204,11 @@ class AvmTraceBuilder { uint32_t rhs_y_offset, uint32_t rhs_is_inf_offset, uint32_t output_offset); + void op_variable_msm(uint8_t indirect, + uint32_t points_offset, + uint32_t scalars_offset, + uint32_t output_offset, + uint32_t point_length_offset); private: // Used for the standard indirect address resolution of three operands opcode. diff --git a/barretenberg/cpp/src/barretenberg/vm/avm_trace/gadgets/avm_ecc.cpp b/barretenberg/cpp/src/barretenberg/vm/avm_trace/gadgets/avm_ecc.cpp index fd3fc8955d1..494da13d192 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm_trace/gadgets/avm_ecc.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm_trace/gadgets/avm_ecc.cpp @@ -30,4 +30,21 @@ element AvmEccTraceBuilder::embedded_curve_add(element lhs, element rhs, uint32_ return result; } +element AvmEccTraceBuilder::variable_msm(const std::vector& points, + const std::vector& scalars, + uint32_t clk) +{ + // Replace this with pippenger if/when we have the time + auto result = grumpkin::g1::affine_point_at_infinity; + for (size_t i = 0; i < points.size(); ++i) { + result = result + points[i] * scalars[i]; + } + + std::tuple result_tuple = { result.x, result.y, result.is_point_at_infinity() }; + + ecc_trace.push_back({ .clk = clk, .result = result_tuple }); + + return result; +} + } // namespace bb::avm_trace diff --git a/barretenberg/cpp/src/barretenberg/vm/avm_trace/gadgets/avm_ecc.hpp b/barretenberg/cpp/src/barretenberg/vm/avm_trace/gadgets/avm_ecc.hpp index 6450a33db39..71f070c85a3 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm_trace/gadgets/avm_ecc.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm_trace/gadgets/avm_ecc.hpp @@ -9,9 +9,9 @@ class AvmEccTraceBuilder { public: struct EccTraceEntry { uint32_t clk = 0; - std::tuple p1; // x, y, is_infinity - std::tuple p2; - std::tuple result; + std::tuple p1 = { FF(0), FF(0), true }; // x, y, is_infinity + std::tuple p2 = { FF(0), FF(0), true }; + std::tuple result = { FF(0), FF(0), true }; }; AvmEccTraceBuilder(); @@ -21,6 +21,9 @@ class AvmEccTraceBuilder { grumpkin::g1::affine_element embedded_curve_add(grumpkin::g1::affine_element lhs, grumpkin::g1::affine_element rhs, uint32_t clk); + grumpkin::g1::affine_element variable_msm(const std::vector& points, + const std::vector& scalars, + uint32_t clk); private: std::vector ecc_trace; diff --git a/barretenberg/cpp/src/barretenberg/vm/tests/avm_execution.test.cpp b/barretenberg/cpp/src/barretenberg/vm/tests/avm_execution.test.cpp index bfa5339d362..9f3d453452b 100644 --- a/barretenberg/cpp/src/barretenberg/vm/tests/avm_execution.test.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/tests/avm_execution.test.cpp @@ -1400,6 +1400,82 @@ TEST_F(AvmExecutionTests, embeddedCurveAddOpCode) validate_trace(std::move(trace), public_inputs); } +// Positive test with MSM +TEST_F(AvmExecutionTests, msmOpCode) +{ + grumpkin::g1::affine_element a = grumpkin::g1::affine_element::random_element(); + FF a_is_inf = a.is_point_at_infinity(); + grumpkin::g1::affine_element b = grumpkin::g1::affine_element::random_element(); + FF b_is_inf = b.is_point_at_infinity(); + + grumpkin::g1::Fr scalar_a = grumpkin::g1::Fr::random_element(); + FF scalar_a_lo = uint256_t::from_uint128(uint128_t(scalar_a)); + FF scalar_a_hi = uint256_t(scalar_a) >> 128; + grumpkin::g1::Fr scalar_b = grumpkin::g1::Fr::random_element(); + FF scalar_b_lo = uint256_t::from_uint128(uint128_t(scalar_b)); + FF scalar_b_hi = uint256_t(scalar_b) >> 128; + auto expected_result = a * scalar_a + b * scalar_b; + std::vector expected_output = { expected_result.x, expected_result.y, expected_result.is_point_at_infinity() }; + // Send all the input as Fields and cast them to U8 later + std::vector calldata = { FF(a.x), FF(a.y), a_is_inf, FF(b.x), FF(b.y), + b_is_inf, scalar_a_lo, scalar_a_hi, scalar_b_lo, scalar_b_hi }; + std::string bytecode_hex = to_hex(OpCode::CALLDATACOPY) + // Calldatacopy...should fix the limit on calldatacopy + "00" // Indirect flag + "00000000" // cd_offset 0 + "0000000a" // copy_size (10 elements) + "00000000" // dst_offset 0 + + to_hex(OpCode::CAST) + // opcode CAST inf to U8 + "00" // Indirect flag + "01" // U8 tag field + "00000002" // a_is_inf + "00000002" // + + to_hex(OpCode::CAST) + // opcode CAST inf to U8 + "00" // Indirect flag + "01" // U8 tag field + "00000005" // b_is_inf + "00000005" // + + to_hex(OpCode::SET) + // opcode SET for length + "00" // Indirect flag + "03" // U32 + "00000006" // Length of point elements (6) + "0000000b" // dst offset (11) + + to_hex(OpCode::SET) + // SET Indirects + "00" // Indirect flag + "03" // U32 + "00000000" // points offset + "0000000d" // dst offset + + + to_hex(OpCode::SET) + // SET Indirects + "00" // Indirect flag + "03" // U32 + "00000006" // scalars offset + "0000000e" + // dst offset + to_hex(OpCode::SET) + // SET Indirects + "00" // Indirect flag + "03" // U32 + "0000000c" // output offset + "0000000f" + // dst offset + to_hex(OpCode::MSM) + // opcode MSM + "07" // Indirect flag (first 3 indirect) + "0000000d" // points offset + "0000000e" // scalars offset + "0000000f" // output offset + "0000000b" // length offset + + to_hex(OpCode::RETURN) + // opcode RETURN + "00" // Indirect flag + "0000000c" // ret offset 12 (this overwrites) + "00000003"; // ret size 3 + + auto bytecode = hex_to_bytes(bytecode_hex); + auto instructions = Deserialization::parse(bytecode); + + // Assign a vector that we will mutate internally in gen_trace to store the return values; + std::vector returndata; + auto trace = Execution::gen_trace(instructions, returndata, calldata, public_inputs_vec); + + EXPECT_EQ(returndata, expected_output); + + validate_trace(std::move(trace)); +} // Positive test for Kernel Input opcodes TEST_F(AvmExecutionTests, kernelInputOpcodes) {