Skip to content

Commit

Permalink
feat(avm): cpp msm changes
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasRidhuan committed Jun 18, 2024
1 parent 6b3d04a commit 42a0383
Show file tree
Hide file tree
Showing 7 changed files with 366 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,13 @@ std::vector<Row> Execution::gen_trace(std::vector<Instruction> const& instructio
std::get<uint32_t>(inst.operands.at(6)),
std::get<uint32_t>(inst.operands.at(7)));
break;
case OpCode::MSM:
trace_builder.op_variable_msm(std::get<uint8_t>(inst.operands.at(0)),
std::get<uint32_t>(inst.operands.at(1)),
std::get<uint32_t>(inst.operands.at(2)),
std::get<uint32_t>(inst.operands.at(3)),
std::get<uint32_t>(inst.operands.at(4)));
break;
case OpCode::REVERT:
trace_builder.op_revert(std::get<uint8_t>(inst.operands.at(0)),
std::get<uint32_t>(inst.operands.at(1)),
Expand Down
4 changes: 4 additions & 0 deletions barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_opcode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ std::string to_string(OpCode opcode)
return "SHA256";
case OpCode::PEDERSEN:
return "PEDERSEN";
case OpCode::ECADD:
return "ECADD";
case OpCode::MSM:
return "MSM";
case OpCode::TORADIXLE:
return "TORADIXLE";
case OpCode::SHA256COMPRESSION:
Expand Down
251 changes: 251 additions & 0 deletions barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
#include <vector>

#include "barretenberg/common/throw_or_abort.hpp"
#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp"
#include "barretenberg/numeric/uint256/uint256.hpp"
#include "barretenberg/polynomials/univariate.hpp"
#include "barretenberg/vm/avm_trace/avm_common.hpp"
#include "barretenberg/vm/avm_trace/avm_helper.hpp"
#include "barretenberg/vm/avm_trace/avm_opcode.hpp"
Expand Down Expand Up @@ -3632,6 +3634,255 @@ void AvmTraceBuilder::op_ec_add(uint8_t indirect,
FF(internal_return_ptr),
{ result.is_point_at_infinity() });
}

// This function is a bit overloaded with logic around reconstructing points and scalars that could probably be moved to
// the gadget at some stage (although this is another temporary gadget..)
void AvmTraceBuilder::op_variable_msm(uint8_t indirect,
uint32_t points_offset,
uint32_t scalars_offset,
uint32_t output_offset,
uint32_t point_length_offset)
{
auto clk = static_cast<uint32_t>(main_trace.size()) + 1;
// This will all get refactored as part of the indirection refactor
bool tag_match = true;
uint32_t direct_points_offset = points_offset;
uint32_t direct_scalars_offset = scalars_offset;
uint32_t direct_output_offset = output_offset;
// Resolve the indirects
bool indirect_points_flag = is_operand_indirect(indirect, 0);
bool indirect_scalars_flag = is_operand_indirect(indirect, 1);
bool indirect_output_flag = is_operand_indirect(indirect, 2);

// Read in the points first
if (indirect_points_flag) {
auto read_ind_a =
mem_trace_builder.indirect_read_and_load_from_memory(call_ptr, clk, IndirectRegister::IND_A, points_offset);
direct_points_offset = uint32_t(read_ind_a.val);
tag_match = tag_match && read_ind_a.tag_match;
}

auto read_points = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk, IntermRegister::IA, direct_points_offset, AvmMemoryTag::FF, AvmMemoryTag::U0);

// Read in the scalars
if (indirect_scalars_flag) {
auto read_ind_b = mem_trace_builder.indirect_read_and_load_from_memory(
call_ptr, clk, IndirectRegister::IND_B, scalars_offset);
direct_scalars_offset = uint32_t(read_ind_b.val);
tag_match = tag_match && read_ind_b.tag_match;
}
auto read_scalars = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk, IntermRegister::IB, direct_scalars_offset, AvmMemoryTag::FF, AvmMemoryTag::U0);

// In the refactor we will have the read_slice function handle indirects as well
main_trace.push_back(Row{
.avm_main_clk = clk,
.avm_main_ia = read_points.val,
.avm_main_ib = read_scalars.val,
.avm_main_ind_a = indirect_points_flag ? FF(points_offset) : FF(0),
.avm_main_ind_b = indirect_scalars_flag ? FF(scalars_offset) : FF(0),
.avm_main_ind_op_a = FF(static_cast<uint32_t>(indirect_points_flag)),
.avm_main_ind_op_b = FF(static_cast<uint32_t>(indirect_scalars_flag)),
.avm_main_internal_return_ptr = FF(internal_return_ptr),
.avm_main_mem_idx_a = FF(direct_points_offset),
.avm_main_mem_idx_b = FF(direct_scalars_offset),
.avm_main_mem_op_a = FF(1),
.avm_main_mem_op_b = FF(1),
.avm_main_pc = FF(pc++),
.avm_main_r_in_tag = FF(static_cast<uint32_t>(AvmMemoryTag::FF)),
.avm_main_tag_err = FF(static_cast<uint32_t>(!tag_match)),
});
clk++;

// Read the points length (different row since it has a different memory tag)
auto points_length_read = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk, IntermRegister::IA, point_length_offset, AvmMemoryTag::U32, AvmMemoryTag::U0);
main_trace.push_back(Row{
.avm_main_clk = clk,
.avm_main_ia = points_length_read.val,
.avm_main_internal_return_ptr = FF(internal_return_ptr),
.avm_main_mem_idx_a = FF(point_length_offset),
.avm_main_mem_op_a = FF(1),
.avm_main_pc = FF(pc),
.avm_main_r_in_tag = FF(static_cast<uint32_t>(AvmMemoryTag::U32)),
.avm_main_tag_err = FF(static_cast<uint32_t>(!points_length_read.tag_match)),
});
clk++;

// Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] with the types [FF, FF, U8, FF, FF, U8, ...]
uint32_t num_points = uint32_t(points_length_read.val) / 3; // 3 elements per point
// We need to split up the reads due to the memory tags,
std::vector<FF> points_coords_vec;
std::vector<FF> points_inf_vec;
std::vector<FF> scalars_vec;
// Read the coordinates first, +2 since we read 2 points per row
for (uint32_t i = 0; i < num_points; i += 2) {
// We can read up to 4 coordinates per row (x1,y1,x2,y2)
// Each pair of coordinates are separated by 3 memory addressess
auto point_x1_read = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk, IntermRegister::IA, direct_points_offset + i * 3, AvmMemoryTag::FF, AvmMemoryTag::U0);
auto point_y1_read = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk, IntermRegister::IB, direct_points_offset + i * 3 + 1, AvmMemoryTag::FF, AvmMemoryTag::U0);
auto point_x2_read = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk, IntermRegister::IC, direct_points_offset + (i + 1) * 3, AvmMemoryTag::FF, AvmMemoryTag::U0);
auto point_y2_read = mem_trace_builder.read_and_load_from_memory(call_ptr,
clk,
IntermRegister::ID,
direct_points_offset + (i + 1) * 3 + 1,
AvmMemoryTag::FF,
AvmMemoryTag::U0);
bool tag_match =
point_x1_read.tag_match && point_y1_read.tag_match && point_x2_read.tag_match && point_y2_read.tag_match;
points_coords_vec.insert(points_coords_vec.end(),
{ point_x1_read.val, point_y1_read.val, point_x2_read.val, point_y2_read.val });
main_trace.push_back(Row{
.avm_main_clk = clk,
.avm_main_ia = point_x1_read.val,
.avm_main_ib = point_y1_read.val,
.avm_main_ic = point_x2_read.val,
.avm_main_id = point_y2_read.val,
.avm_main_internal_return_ptr = FF(internal_return_ptr),
.avm_main_mem_idx_a = FF(direct_points_offset + i * 3),
.avm_main_mem_idx_b = FF(direct_points_offset + i * 3 + 1),
.avm_main_mem_idx_c = FF(direct_points_offset + (i + 1) * 3),
.avm_main_mem_idx_d = FF(direct_points_offset + (i + 1) * 3 + 1),
.avm_main_mem_op_a = FF(1),
.avm_main_mem_op_b = FF(1),
.avm_main_mem_op_c = FF(1),
.avm_main_mem_op_d = FF(1),
.avm_main_pc = FF(pc),
.avm_main_r_in_tag = FF(static_cast<uint32_t>(AvmMemoryTag::FF)),
.avm_main_tag_err = FF(static_cast<uint32_t>(!tag_match)),
});
clk++;
}
// Read the Infinities flags, +4 since we read 4 points row
for (uint32_t i = 0; i < num_points; i += 4) {
// We can read up to 4 infinities per row
// Each infinity flag is separated by 3 memory addressess
uint32_t offset = direct_points_offset + i * 3 + 2;
auto point_inf1_read = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk, IntermRegister::IA, offset, AvmMemoryTag::U8, AvmMemoryTag::U0);
offset += 3;

auto point_inf2_read = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk, IntermRegister::IB, offset, AvmMemoryTag::U8, AvmMemoryTag::U0);
offset += 3;

auto point_inf3_read = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk, IntermRegister::IC, offset, AvmMemoryTag::U8, AvmMemoryTag::U0);
offset += 3;

auto point_inf4_read = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk, IntermRegister::ID, offset, AvmMemoryTag::U8, AvmMemoryTag::U0);

points_inf_vec.insert(points_inf_vec.end(),
{ point_inf1_read.val, point_inf2_read.val, point_inf3_read.val, point_inf4_read.val });
bool tag_match = point_inf1_read.tag_match && point_inf2_read.tag_match && point_inf3_read.tag_match &&
point_inf4_read.tag_match;
main_trace.push_back(Row{
.avm_main_clk = clk,
.avm_main_ia = point_inf1_read.val,
.avm_main_ib = point_inf2_read.val,
.avm_main_ic = point_inf3_read.val,
.avm_main_id = point_inf4_read.val,
.avm_main_internal_return_ptr = FF(internal_return_ptr),
.avm_main_mem_idx_a = FF(direct_points_offset + i * 3 + 2),
.avm_main_mem_idx_b = FF(direct_points_offset + (i + 1) * 3 + 2),
.avm_main_mem_idx_c = FF(direct_points_offset + (i + 2) * 3 + 2),
.avm_main_mem_idx_d = FF(direct_points_offset + (i + 3) * 3 + 2),
.avm_main_mem_op_a = FF(1),
.avm_main_mem_op_b = FF(1),
.avm_main_mem_op_c = FF(1),
.avm_main_mem_op_d = FF(1),
.avm_main_pc = FF(pc),
.avm_main_r_in_tag = FF(static_cast<uint32_t>(AvmMemoryTag::U8)),
.avm_main_tag_err = FF(static_cast<uint32_t>(!tag_match)),
});
clk++;
}
// Scalar read length is num_points* 2 since scalars are stored as lo and hi limbs
uint32_t scalar_read_length = num_points * 2;
auto num_scalar_rows = read_slice_to_memory(call_ptr,
clk,
direct_scalars_offset,
AvmMemoryTag::FF,
AvmMemoryTag::U0,
FF(internal_return_ptr),
scalar_read_length,
scalars_vec);
clk += num_scalar_rows;
// Reconstruct Grumpkin points
std::vector<grumpkin::g1::affine_element> points;
for (size_t i = 0; i < num_points; i++) {
grumpkin::g1::Fq x = points_coords_vec[i * 2];
grumpkin::g1::Fq y = points_coords_vec[i * 2 + 1];
bool is_inf = points_inf_vec[i] == 1;
if (is_inf) {
points.emplace_back(grumpkin::g1::affine_element::infinity());
} else {
points.emplace_back(x, y);
}
}
// Reconstruct Grumpkin scalars
// Scalars are stored as [lo1, hi1, lo2, hi2, ...] with the types [FF, FF, FF, FF, ...]
std::vector<grumpkin::fr> scalars;
for (size_t i = 0; i < num_points; i++) {
FF lo = scalars_vec[i * 2];
FF hi = scalars_vec[i * 2 + 1];
// hi is shifted 128 bits
uint256_t scalar = (uint256_t(hi) << 128) + uint256_t(lo);
scalars.emplace_back(scalar);
}
// Perform the variable MSM - could just put the logic in here since there are no constraints.
auto result = ecc_trace_builder.variable_msm(points, scalars, clk);
// Write the result back to memory [x, y, inf] with tags [FF, FF, U8]
if (indirect_output_flag) {
auto read_ind_a =
mem_trace_builder.indirect_read_and_load_from_memory(call_ptr, clk, IndirectRegister::IND_A, output_offset);
direct_output_offset = uint32_t(read_ind_a.val);
}
mem_trace_builder.write_into_memory(
call_ptr, clk, IntermRegister::IA, direct_output_offset, result.x, AvmMemoryTag::U0, AvmMemoryTag::FF);
mem_trace_builder.write_into_memory(
call_ptr, clk, IntermRegister::IB, direct_output_offset + 1, result.y, AvmMemoryTag::U0, AvmMemoryTag::FF);
main_trace.push_back(Row{
.avm_main_clk = clk,
.avm_main_ia = result.x,
.avm_main_ib = result.y,
.avm_main_ind_a = indirect_output_flag ? FF(output_offset) : FF(0),
.avm_main_ind_op_a = FF(static_cast<uint32_t>(indirect_output_flag)),
.avm_main_internal_return_ptr = FF(internal_return_ptr),
.avm_main_mem_idx_a = FF(direct_output_offset),
.avm_main_mem_idx_b = FF(direct_output_offset + 1),
.avm_main_mem_op_a = FF(1),
.avm_main_mem_op_b = FF(1),
.avm_main_pc = FF(pc),
.avm_main_rwa = FF(1),
.avm_main_rwb = FF(1),
.avm_main_w_in_tag = FF(static_cast<uint32_t>(AvmMemoryTag::FF)),
});
clk++;
// Write the infinity
mem_trace_builder.write_into_memory(call_ptr,
clk,
IntermRegister::IA,
direct_output_offset + 2,
result.is_point_at_infinity(),
AvmMemoryTag::U0,
AvmMemoryTag::U8);
main_trace.push_back(Row{
.avm_main_clk = clk,
.avm_main_ia = static_cast<uint8_t>(result.is_point_at_infinity()),
.avm_main_internal_return_ptr = FF(internal_return_ptr),
.avm_main_mem_idx_a = FF(direct_output_offset + 2),
.avm_main_mem_op_a = FF(1),
.avm_main_pc = FF(pc),
.avm_main_rwa = FF(1),
.avm_main_w_in_tag = FF(static_cast<uint32_t>(AvmMemoryTag::U8)),
});
}
// Finalise Lookup Counts
//
// For log derivative lookups, we require a column that contains the number of times each lookup is consumed
Expand Down
5 changes: 5 additions & 0 deletions barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,11 @@ class AvmTraceBuilder {
uint32_t rhs_y_offset,
uint32_t rhs_is_inf_offset,
uint32_t output_offset);
void op_variable_msm(uint8_t indirect,
uint32_t points_offset,
uint32_t scalars_offset,
uint32_t output_offset,
uint32_t point_length_offset);

private:
// Used for the standard indirect address resolution of three operands opcode.
Expand Down
17 changes: 17 additions & 0 deletions barretenberg/cpp/src/barretenberg/vm/avm_trace/gadgets/avm_ecc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,21 @@ element AvmEccTraceBuilder::embedded_curve_add(element lhs, element rhs, uint32_
return result;
}

element AvmEccTraceBuilder::variable_msm(const std::vector<element>& points,
const std::vector<grumpkin::fr>& scalars,
uint32_t clk)
{
// Replace this with pippenger if/when we have the time
auto result = grumpkin::g1::affine_point_at_infinity;
for (size_t i = 0; i < points.size(); ++i) {
result = result + points[i] * scalars[i];
}

std::tuple<FF, FF, bool> result_tuple = { result.x, result.y, result.is_point_at_infinity() };

ecc_trace.push_back({ .clk = clk, .result = result_tuple });

return result;
}

} // namespace bb::avm_trace
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ class AvmEccTraceBuilder {
public:
struct EccTraceEntry {
uint32_t clk = 0;
std::tuple<FF, FF, bool> p1; // x, y, is_infinity
std::tuple<FF, FF, bool> p2;
std::tuple<FF, FF, bool> result;
std::tuple<FF, FF, bool> p1 = { FF(0), FF(0), true }; // x, y, is_infinity
std::tuple<FF, FF, bool> p2 = { FF(0), FF(0), true };
std::tuple<FF, FF, bool> result = { FF(0), FF(0), true };
};

AvmEccTraceBuilder();
Expand All @@ -21,6 +21,9 @@ class AvmEccTraceBuilder {
grumpkin::g1::affine_element embedded_curve_add(grumpkin::g1::affine_element lhs,
grumpkin::g1::affine_element rhs,
uint32_t clk);
grumpkin::g1::affine_element variable_msm(const std::vector<grumpkin::g1::affine_element>& points,
const std::vector<grumpkin::fr>& scalars,
uint32_t clk);

private:
std::vector<EccTraceEntry> ecc_trace;
Expand Down
Loading

0 comments on commit 42a0383

Please sign in to comment.