Skip to content

Commit

Permalink
test(avm): AVM Minimial lookup table for testing (#6641)
Browse files Browse the repository at this point in the history
Please read [contributing guidelines](CONTRIBUTING.md) and remove this
line.
  • Loading branch information
IlyasRidhuan authored May 23, 2024
1 parent 7dfc369 commit 79beef1
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 75 deletions.
248 changes: 179 additions & 69 deletions barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1887,6 +1887,102 @@ void AvmTraceBuilder::finalise_mem_trace_lookup_counts()
}
}

// WARNING: FOR TESTING ONLY
// Generates the minimal lookup table for the binary trace
uint32_t finalize_bin_trace_lookup_for_testing(std::vector<Row>& main_trace, AvmBinaryTraceBuilder& bin_trace_builder)
{
// Generate ByteLength Lookup table of instruction tags to the number of bytes
// {U8: 1, U16: 2, U32: 4, U64: 8, U128: 16}
for (auto const& [clk, count] : bin_trace_builder.byte_operation_counter) {
// from the clk we can derive the a and b inputs
auto b = static_cast<uint8_t>(clk);
auto a = static_cast<uint8_t>(clk >> 8);
auto op_id = static_cast<uint8_t>(clk >> 16);
uint8_t bit_op = 0;
if (op_id == 0) {
bit_op = a & b;
} else if (op_id == 1) {
bit_op = a | b;
} else {
bit_op = a ^ b;
}
if (clk > (main_trace.size() - 1)) {
main_trace.push_back(Row{
.avm_main_clk = FF(clk),
.avm_byte_lookup_bin_sel = FF(1),
.avm_byte_lookup_table_input_a = a,
.avm_byte_lookup_table_input_b = b,
.avm_byte_lookup_table_op_id = op_id,
.avm_byte_lookup_table_output = bit_op,
.lookup_byte_operations_counts = count,
});
} else {
main_trace.at(clk).lookup_byte_operations_counts = count;
main_trace.at(clk).avm_byte_lookup_bin_sel = FF(1);
main_trace.at(clk).avm_byte_lookup_table_op_id = op_id;
main_trace.at(clk).avm_byte_lookup_table_input_a = a;
main_trace.at(clk).avm_byte_lookup_table_input_b = b;
main_trace.at(clk).avm_byte_lookup_table_output = bit_op;
}
// Add the counter value stored throughout the execution
}
return static_cast<uint32_t>(main_trace.size());
}

// WARNING: FOR TESTING ONLY
// Generates the lookup table for the range checks without doing a full 2**16 rows
uint32_t finalize_rng_chks_for_testing(std::vector<Row>& main_trace,
AvmAluTraceBuilder& alu_trace_builder,
AvmMemTraceBuilder& mem_trace_builder,
const std::unordered_map<uint16_t, uint32_t>& mem_rng_check_lo_counts,
const std::unordered_map<uint16_t, uint32_t>& mem_rng_check_mid_counts,
std::unordered_map<uint8_t, uint32_t> mem_rng_check_hi_counts)
{
// Build the main_trace, and add any new rows with specific clks that line up with lookup reads

// Is there a "spread-like" operator in cpp or can I make it generric of the first param of the unordered map
std::vector<std::unordered_map<uint8_t, uint32_t>> u8_rng_chks = { alu_trace_builder.u8_range_chk_counters[0],
alu_trace_builder.u8_range_chk_counters[1],
alu_trace_builder.u8_pow_2_counters[0],
alu_trace_builder.u8_pow_2_counters[1],
std::move(mem_rng_check_hi_counts) };

auto custom_clk = std::set<uint32_t>{};
for (auto const& row : u8_rng_chks) {
for (auto const& [key, value] : row) {
custom_clk.insert(key);
}
}
for (auto const& row : alu_trace_builder.u16_range_chk_counters) {
for (auto const& [key, value] : row) {
custom_clk.insert(key);
}
}
for (auto const& row : alu_trace_builder.div_u64_range_chk_counters) {
for (auto const& [key, value] : row) {
custom_clk.insert(key);
}
}
for (auto const& [key, value] : mem_rng_check_lo_counts) {
custom_clk.insert(key);
}
for (auto const& [key, value] : mem_rng_check_mid_counts) {
custom_clk.insert(key);
}

for (auto const& [clk, count] : mem_trace_builder.m_tag_err_lookup_counts) {
custom_clk.insert(clk);
}

auto old_size = main_trace.size() - 1;
for (auto const& clk : custom_clk) {
if (clk > old_size) {
main_trace.push_back(Row{ .avm_main_clk = FF(clk) });
}
}
return static_cast<uint32_t>(main_trace.size());
}

/**
* @brief Finalisation of the memory trace and incorporating it to the main trace.
* In particular, sorting the memory trace, setting .m_lastAccess and
Expand All @@ -1895,10 +1991,8 @@ void AvmTraceBuilder::finalise_mem_trace_lookup_counts()
*
* @return The main trace
*/
std::vector<Row> AvmTraceBuilder::finalize()
std::vector<Row> AvmTraceBuilder::finalize(uint32_t min_trace_size, bool range_check_required)
{
// bool const range_check_required = alu_trace_builder.is_range_check_required();
bool const range_check_required = true;
auto mem_trace = mem_trace_builder.finalize();
auto alu_trace = alu_trace_builder.finalize();
auto conv_trace = conversion_trace_builder.finalize();
Expand All @@ -1910,7 +2004,9 @@ std::vector<Row> AvmTraceBuilder::finalize()
size_t bin_trace_size = bin_trace.size();

// Get tag_err counts from the mem_trace_builder
finalise_mem_trace_lookup_counts();
if (range_check_required) {
finalise_mem_trace_lookup_counts();
}

// Data structure to collect all lookup counts pertaining to 16-bit/32-bit range checks in memory trace
std::unordered_map<uint16_t, uint32_t> mem_rng_check_lo_counts;
Expand All @@ -1920,10 +2016,10 @@ std::vector<Row> AvmTraceBuilder::finalize()
// Main Trace needs to be at least as big as the biggest subtrace.
// If the bin_trace_size has entries, we need the main_trace to be as big as our byte lookup table (3 *
// 2**16 long)
size_t const lookup_table_size = bin_trace_size > 0 ? 3 * (1 << 16) : 0;
size_t const lookup_table_size = (bin_trace_size > 0 && range_check_required) ? 3 * (1 << 16) : 0;
size_t const range_check_size = range_check_required ? UINT16_MAX + 1 : 0;
std::vector<size_t> trace_sizes = { mem_trace_size, main_trace_size, alu_trace_size, lookup_table_size,
range_check_size, conv_trace_size, KERNEL_INPUTS_LENGTH };
std::vector<size_t> trace_sizes = { mem_trace_size, main_trace_size, alu_trace_size, lookup_table_size,
range_check_size, conv_trace_size, KERNEL_INPUTS_LENGTH, min_trace_size };
auto trace_size = std::max_element(trace_sizes.begin(), trace_sizes.end());

// We only need to pad with zeroes to the size to the largest trace here, pow_2 padding is handled in the
Expand Down Expand Up @@ -2185,7 +2281,16 @@ std::vector<Row> AvmTraceBuilder::finalize()
}
}

for (size_t i = 0; i < main_trace_size; i++) {
auto new_trace_size = range_check_required ? main_trace_size
: finalize_rng_chks_for_testing(main_trace,
alu_trace_builder,
mem_trace_builder,
mem_rng_check_lo_counts,
mem_rng_check_mid_counts,
mem_rng_check_hi_counts);
auto old_trace_size = main_trace_size - 1;

for (size_t i = 0; i < new_trace_size; i++) {
auto& r = main_trace.at(i);

if ((r.avm_main_sel_op_add == FF(1) || r.avm_main_sel_op_sub == FF(1) || r.avm_main_sel_op_mul == FF(1) ||
Expand All @@ -2202,49 +2307,49 @@ std::vector<Row> AvmTraceBuilder::finalize()
r.avm_main_space_id = r.avm_main_call_ptr;
};

if (i <= UINT8_MAX) {
r.lookup_u8_0_counts = alu_trace_builder.u8_range_chk_counters[0][static_cast<uint8_t>(i)];
r.lookup_u8_1_counts = alu_trace_builder.u8_range_chk_counters[1][static_cast<uint8_t>(i)];
r.lookup_pow_2_0_counts = alu_trace_builder.u8_pow_2_counters[0][static_cast<uint8_t>(i)];
r.lookup_pow_2_1_counts = alu_trace_builder.u8_pow_2_counters[1][static_cast<uint8_t>(i)];
r.lookup_mem_rng_chk_hi_counts = mem_rng_check_hi_counts[static_cast<uint8_t>(i)];
r.avm_main_clk = i > old_trace_size ? r.avm_main_clk : FF(i);
auto counter = i > old_trace_size ? static_cast<uint32_t>(r.avm_main_clk) : static_cast<uint32_t>(i);
r.incl_main_tag_err_counts = mem_trace_builder.m_tag_err_lookup_counts[static_cast<uint32_t>(counter)];
if (counter <= UINT8_MAX) {
r.lookup_u8_0_counts = alu_trace_builder.u8_range_chk_counters[0][static_cast<uint8_t>(counter)];
r.lookup_u8_1_counts = alu_trace_builder.u8_range_chk_counters[1][static_cast<uint8_t>(counter)];
r.lookup_pow_2_0_counts = alu_trace_builder.u8_pow_2_counters[0][static_cast<uint8_t>(counter)];
r.lookup_pow_2_1_counts = alu_trace_builder.u8_pow_2_counters[1][static_cast<uint8_t>(counter)];
r.lookup_mem_rng_chk_hi_counts = mem_rng_check_hi_counts[static_cast<uint8_t>(counter)];
r.avm_main_sel_rng_8 = FF(1);
r.avm_main_table_pow_2 = uint256_t(1) << uint256_t(i);
r.avm_main_table_pow_2 = uint256_t(1) << uint256_t(counter);
}

if (i <= UINT16_MAX) {
if (counter <= UINT16_MAX) {
// We add to the clk here in case our trace is smaller than our range checks
// There might be a cleaner way to do this in the future as this only applies
// when our trace (excluding range checks) is < 2**16
r.lookup_u16_0_counts = alu_trace_builder.u16_range_chk_counters[0][static_cast<uint16_t>(i)];
r.lookup_u16_1_counts = alu_trace_builder.u16_range_chk_counters[1][static_cast<uint16_t>(i)];
r.lookup_u16_2_counts = alu_trace_builder.u16_range_chk_counters[2][static_cast<uint16_t>(i)];
r.lookup_u16_3_counts = alu_trace_builder.u16_range_chk_counters[3][static_cast<uint16_t>(i)];
r.lookup_u16_4_counts = alu_trace_builder.u16_range_chk_counters[4][static_cast<uint16_t>(i)];
r.lookup_u16_5_counts = alu_trace_builder.u16_range_chk_counters[5][static_cast<uint16_t>(i)];
r.lookup_u16_6_counts = alu_trace_builder.u16_range_chk_counters[6][static_cast<uint16_t>(i)];
r.lookup_u16_7_counts = alu_trace_builder.u16_range_chk_counters[7][static_cast<uint16_t>(i)];
r.lookup_u16_8_counts = alu_trace_builder.u16_range_chk_counters[8][static_cast<uint16_t>(i)];
r.lookup_u16_9_counts = alu_trace_builder.u16_range_chk_counters[9][static_cast<uint16_t>(i)];
r.lookup_u16_10_counts = alu_trace_builder.u16_range_chk_counters[10][static_cast<uint16_t>(i)];
r.lookup_u16_11_counts = alu_trace_builder.u16_range_chk_counters[11][static_cast<uint16_t>(i)];
r.lookup_u16_12_counts = alu_trace_builder.u16_range_chk_counters[12][static_cast<uint16_t>(i)];
r.lookup_u16_13_counts = alu_trace_builder.u16_range_chk_counters[13][static_cast<uint16_t>(i)];
r.lookup_u16_14_counts = alu_trace_builder.u16_range_chk_counters[14][static_cast<uint16_t>(i)];

r.lookup_mem_rng_chk_mid_counts = mem_rng_check_mid_counts[static_cast<uint16_t>(i)];
r.lookup_mem_rng_chk_lo_counts = mem_rng_check_lo_counts[static_cast<uint16_t>(i)];

r.lookup_div_u16_0_counts = alu_trace_builder.div_u64_range_chk_counters[0][static_cast<uint16_t>(i)];
r.lookup_div_u16_1_counts = alu_trace_builder.div_u64_range_chk_counters[1][static_cast<uint16_t>(i)];
r.lookup_div_u16_2_counts = alu_trace_builder.div_u64_range_chk_counters[2][static_cast<uint16_t>(i)];
r.lookup_div_u16_3_counts = alu_trace_builder.div_u64_range_chk_counters[3][static_cast<uint16_t>(i)];
r.lookup_div_u16_4_counts = alu_trace_builder.div_u64_range_chk_counters[4][static_cast<uint16_t>(i)];
r.lookup_div_u16_5_counts = alu_trace_builder.div_u64_range_chk_counters[5][static_cast<uint16_t>(i)];
r.lookup_div_u16_6_counts = alu_trace_builder.div_u64_range_chk_counters[6][static_cast<uint16_t>(i)];
r.lookup_div_u16_7_counts = alu_trace_builder.div_u64_range_chk_counters[7][static_cast<uint16_t>(i)];

r.avm_main_clk = FF(static_cast<uint32_t>(i));
r.lookup_u16_0_counts = alu_trace_builder.u16_range_chk_counters[0][static_cast<uint16_t>(counter)];
r.lookup_u16_1_counts = alu_trace_builder.u16_range_chk_counters[1][static_cast<uint16_t>(counter)];
r.lookup_u16_2_counts = alu_trace_builder.u16_range_chk_counters[2][static_cast<uint16_t>(counter)];
r.lookup_u16_3_counts = alu_trace_builder.u16_range_chk_counters[3][static_cast<uint16_t>(counter)];
r.lookup_u16_4_counts = alu_trace_builder.u16_range_chk_counters[4][static_cast<uint16_t>(counter)];
r.lookup_u16_5_counts = alu_trace_builder.u16_range_chk_counters[5][static_cast<uint16_t>(counter)];
r.lookup_u16_6_counts = alu_trace_builder.u16_range_chk_counters[6][static_cast<uint16_t>(counter)];
r.lookup_u16_7_counts = alu_trace_builder.u16_range_chk_counters[7][static_cast<uint16_t>(counter)];
r.lookup_u16_8_counts = alu_trace_builder.u16_range_chk_counters[8][static_cast<uint16_t>(counter)];
r.lookup_u16_9_counts = alu_trace_builder.u16_range_chk_counters[9][static_cast<uint16_t>(counter)];
r.lookup_u16_10_counts = alu_trace_builder.u16_range_chk_counters[10][static_cast<uint16_t>(counter)];
r.lookup_u16_11_counts = alu_trace_builder.u16_range_chk_counters[11][static_cast<uint16_t>(counter)];
r.lookup_u16_12_counts = alu_trace_builder.u16_range_chk_counters[12][static_cast<uint16_t>(counter)];
r.lookup_u16_13_counts = alu_trace_builder.u16_range_chk_counters[13][static_cast<uint16_t>(counter)];
r.lookup_u16_14_counts = alu_trace_builder.u16_range_chk_counters[14][static_cast<uint16_t>(counter)];

r.lookup_mem_rng_chk_mid_counts = mem_rng_check_mid_counts[static_cast<uint16_t>(counter)];
r.lookup_mem_rng_chk_lo_counts = mem_rng_check_lo_counts[static_cast<uint16_t>(counter)];

r.lookup_div_u16_0_counts = alu_trace_builder.div_u64_range_chk_counters[0][static_cast<uint16_t>(counter)];
r.lookup_div_u16_1_counts = alu_trace_builder.div_u64_range_chk_counters[1][static_cast<uint16_t>(counter)];
r.lookup_div_u16_2_counts = alu_trace_builder.div_u64_range_chk_counters[2][static_cast<uint16_t>(counter)];
r.lookup_div_u16_3_counts = alu_trace_builder.div_u64_range_chk_counters[3][static_cast<uint16_t>(counter)];
r.lookup_div_u16_4_counts = alu_trace_builder.div_u64_range_chk_counters[4][static_cast<uint16_t>(counter)];
r.lookup_div_u16_5_counts = alu_trace_builder.div_u64_range_chk_counters[5][static_cast<uint16_t>(counter)];
r.lookup_div_u16_6_counts = alu_trace_builder.div_u64_range_chk_counters[6][static_cast<uint16_t>(counter)];
r.lookup_div_u16_7_counts = alu_trace_builder.div_u64_range_chk_counters[7][static_cast<uint16_t>(counter)];
r.avm_main_sel_rng_16 = FF(1);
}
}
Expand Down Expand Up @@ -2281,29 +2386,33 @@ std::vector<Row> AvmTraceBuilder::finalize()

// Only generate precomputed byte tables if we are actually going to use them in this main trace.
if (bin_trace_size > 0) {
// Generate Lookup Table of all combinations of 2, 8-bit numbers and op_id.
for (size_t op_id = 0; op_id < 3; op_id++) {
for (size_t input_a = 0; input_a <= UINT8_MAX; input_a++) {
for (size_t input_b = 0; input_b <= UINT8_MAX; input_b++) {
auto a = static_cast<uint8_t>(input_a);
auto b = static_cast<uint8_t>(input_b);

// Derive a unique row index given op_id, a, and b.
auto main_trace_index = static_cast<uint32_t>((op_id << 16) + (input_a << 8) + b);

main_trace.at(main_trace_index).avm_byte_lookup_bin_sel = FF(1);
main_trace.at(main_trace_index).avm_byte_lookup_table_op_id = op_id;
main_trace.at(main_trace_index).avm_byte_lookup_table_input_a = a;
main_trace.at(main_trace_index).avm_byte_lookup_table_input_b = b;
// Add the counter value stored throughout the execution
main_trace.at(main_trace_index).lookup_byte_operations_counts =
bin_trace_builder.byte_operation_counter[main_trace_index];
if (op_id == 0) {
main_trace.at(main_trace_index).avm_byte_lookup_table_output = a & b;
} else if (op_id == 1) {
main_trace.at(main_trace_index).avm_byte_lookup_table_output = a | b;
} else {
main_trace.at(main_trace_index).avm_byte_lookup_table_output = a ^ b;
if (!range_check_required) {
finalize_bin_trace_lookup_for_testing(main_trace, bin_trace_builder);
} else {
// Generate Lookup Table of all combinations of 2, 8-bit numbers and op_id.
for (size_t op_id = 0; op_id < 3; op_id++) {
for (size_t input_a = 0; input_a <= UINT8_MAX; input_a++) {
for (size_t input_b = 0; input_b <= UINT8_MAX; input_b++) {
auto a = static_cast<uint8_t>(input_a);
auto b = static_cast<uint8_t>(input_b);

// Derive a unique row index given op_id, a, and b.
auto main_trace_index = static_cast<uint32_t>((op_id << 16) + (input_a << 8) + b);

main_trace.at(main_trace_index).avm_byte_lookup_bin_sel = FF(1);
main_trace.at(main_trace_index).avm_byte_lookup_table_op_id = op_id;
main_trace.at(main_trace_index).avm_byte_lookup_table_input_a = a;
main_trace.at(main_trace_index).avm_byte_lookup_table_input_b = b;
// Add the counter value stored throughout the execution
main_trace.at(main_trace_index).lookup_byte_operations_counts =
bin_trace_builder.byte_operation_counter[main_trace_index];
if (op_id == 0) {
main_trace.at(main_trace_index).avm_byte_lookup_table_output = a & b;
} else if (op_id == 1) {
main_trace.at(main_trace_index).avm_byte_lookup_table_output = a | b;
} else {
main_trace.at(main_trace_index).avm_byte_lookup_table_output = a ^ b;
}
}
}
}
Expand All @@ -2313,6 +2422,7 @@ std::vector<Row> AvmTraceBuilder::finalize()
for (uint8_t avm_in_tag = 0; avm_in_tag < 5; avm_in_tag++) {
// The +1 here is because the instruction tags we care about (i.e excl U0 and FF) has the range
// [1,5]
main_trace.at(avm_in_tag).avm_byte_lookup_bin_sel = FF(1);
main_trace.at(avm_in_tag).avm_byte_lookup_table_in_tags = avm_in_tag + 1;
main_trace.at(avm_in_tag).avm_byte_lookup_table_byte_lengths = static_cast<uint8_t>(pow(2, avm_in_tag));
main_trace.at(avm_in_tag).lookup_byte_lengths_counts =
Expand Down
Loading

0 comments on commit 79beef1

Please sign in to comment.