Skip to content

Commit

Permalink
feat(avm): add temporary sha256 execution (#6604)
Browse files Browse the repository at this point in the history
Please read [contributing guidelines](CONTRIBUTING.md) and remove this
line.
  • Loading branch information
IlyasRidhuan authored May 30, 2024
1 parent 3a215ed commit 34088b4
Show file tree
Hide file tree
Showing 13 changed files with 262 additions and 41 deletions.
9 changes: 5 additions & 4 deletions barretenberg/cpp/pil/avm/avm_main.pil
Original file line number Diff line number Diff line change
Expand Up @@ -560,10 +560,11 @@ namespace avm_main(256);
is
avm_conversion.to_radix_le_sel {avm_conversion.clk, avm_conversion.input, avm_conversion.radix, avm_conversion.num_limbs};

#[PERM_MAIN_SHA256]
sel_op_sha256 {clk, ia, ib, ic}
is
avm_sha256.sha256_compression_sel {avm_sha256.clk, avm_sha256.state, avm_sha256.input, avm_sha256.output};
// This will be enabled when we migrate just to sha256Compression, as getting sha256 to work with it is tricky.
// #[PERM_MAIN_SHA256]
// sel_op_sha256 {clk, ia, ib, ic}
// is
// avm_sha256.sha256_compression_sel {avm_sha256.clk, avm_sha256.state, avm_sha256.input, avm_sha256.output};

#[PERM_MAIN_POS2_PERM]
sel_op_poseidon2 {clk, ia, ib}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@
[[maybe_unused]] auto perm_main_alu = View(new_term.perm_main_alu); \
[[maybe_unused]] auto perm_main_bin = View(new_term.perm_main_bin); \
[[maybe_unused]] auto perm_main_conv = View(new_term.perm_main_conv); \
[[maybe_unused]] auto perm_main_sha256 = View(new_term.perm_main_sha256); \
[[maybe_unused]] auto perm_main_pos2_perm = View(new_term.perm_main_pos2_perm); \
[[maybe_unused]] auto perm_main_mem_a = View(new_term.perm_main_mem_a); \
[[maybe_unused]] auto perm_main_mem_b = View(new_term.perm_main_mem_b); \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ const std::unordered_map<OpCode, std::vector<OperandType>> OPCODE_WIRE_FORMAT =
{ OpCode::TORADIXLE,
{ OperandType::INDIRECT, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32 } },
// Gadget - Hashing
{ OpCode::SHA256, { OperandType::INDIRECT, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32 } },
{ OpCode::SHA256COMPRESSION,
{ OperandType::INDIRECT, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32 } },
{ OpCode::POSEIDON2, { OperandType::INDIRECT, OperandType::UINT32, OperandType::UINT32 } },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,12 @@ std::vector<Row> Execution::gen_trace(std::vector<Instruction> const& instructio
std::get<uint32_t>(inst.operands.at(1)),
std::get<uint32_t>(inst.operands.at(2)),
std::get<uint32_t>(inst.operands.at(3)));

break;
case OpCode::SHA256:
trace_builder.op_sha256(std::get<uint8_t>(inst.operands.at(0)),
std::get<uint32_t>(inst.operands.at(1)),
std::get<uint32_t>(inst.operands.at(2)),
std::get<uint32_t>(inst.operands.at(3)));
break;
case OpCode::POSEIDON2:
trace_builder.op_poseidon2_permutation(std::get<uint8_t>(inst.operands.at(0)),
Expand Down
142 changes: 142 additions & 0 deletions barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2203,6 +2203,148 @@ void AvmTraceBuilder::op_sha256_compression(uint8_t indirect,
call_ptr, clk, res.direct_c_offset, AvmMemoryTag::U32, AvmMemoryTag::U32, FF(internal_return_ptr), ff_result);
}

/**
* @brief SHA256 Hash with direct or indirect memory access.
* This function is temporary until we have transitioned to sha256Compression
* @param indirect byte encoding information about indirect/direct memory access.
* @param output_offset An index in memory pointing to where the first U32 value of the output array should be stored.
* @param input_offset An index in memory pointing to the first U8 value of the state array to be used in the next
* instance of sha256.
* @param input_size_offset An index in memory pointing to the U32 value of the input size.
*/
void AvmTraceBuilder::op_sha256(uint8_t indirect,
uint32_t output_offset,
uint32_t input_offset,
uint32_t input_size_offset)
{
auto clk = static_cast<uint32_t>(main_trace.size());
bool tag_match = true;
uint32_t direct_src_offset = input_offset;
uint32_t direct_dst_offset = output_offset;

bool indirect_src_flag = is_operand_indirect(indirect, 1);
bool indirect_dst_flag = is_operand_indirect(indirect, 0);

if (indirect_src_flag) {
auto read_ind_src =
mem_trace_builder.indirect_read_and_load_from_memory(call_ptr, clk, IndirectRegister::IND_A, input_offset);
direct_src_offset = uint32_t(read_ind_src.val);
tag_match = tag_match && read_ind_src.tag_match;
}

if (indirect_dst_flag) {
auto read_ind_dst =
mem_trace_builder.indirect_read_and_load_from_memory(call_ptr, clk, IndirectRegister::IND_C, output_offset);
direct_dst_offset = uint32_t(read_ind_dst.val);
tag_match = tag_match && read_ind_dst.tag_match;
}
// Note we load the input and output onto one line in the main trace and the length on the next line
// We do this so we can load two different AvmMemoryTags (u8 for the I/O and u32 for the length)
auto input_read = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk, IntermRegister::IA, direct_src_offset, AvmMemoryTag::U8, AvmMemoryTag::U8);
auto output_read = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk, IntermRegister::IC, direct_dst_offset, AvmMemoryTag::U8, AvmMemoryTag::U8);

// Store the clock time that we will use to line up the gadget later
auto sha256_op_clk = clk;
main_trace.push_back(Row{
.avm_main_clk = clk,
.avm_main_ia = input_read.val, // First element of input
.avm_main_ic = output_read.val, // First element of output
.avm_main_ind_a = indirect_src_flag ? FF(input_offset) : FF(0),
.avm_main_ind_c = indirect_dst_flag ? FF(output_offset) : FF(0),
.avm_main_ind_op_a = FF(static_cast<uint32_t>(indirect_src_flag)),
.avm_main_ind_op_c = FF(static_cast<uint32_t>(indirect_dst_flag)),
.avm_main_internal_return_ptr = FF(internal_return_ptr),
.avm_main_mem_idx_a = FF(direct_src_offset), // input
.avm_main_mem_idx_c = FF(direct_dst_offset), // output
.avm_main_mem_op_a = FF(1),
.avm_main_mem_op_c = FF(1),
.avm_main_pc = FF(pc++),
.avm_main_r_in_tag = FF(static_cast<uint32_t>(AvmMemoryTag::U8)),
.avm_main_sel_op_sha256 = FF(1),
.avm_main_w_in_tag = FF(static_cast<uint32_t>(AvmMemoryTag::U8)),
});
clk++;
auto input_length_read = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk, IntermRegister::IB, input_size_offset, AvmMemoryTag::U32, AvmMemoryTag::U32);
main_trace.push_back(Row{
.avm_main_clk = clk,
.avm_main_ib = input_length_read.val, // Message Length
.avm_main_internal_return_ptr = FF(internal_return_ptr),
.avm_main_mem_idx_b = FF(input_size_offset), // length
.avm_main_mem_op_b = FF(1),
.avm_main_pc = FF(pc),
.avm_main_r_in_tag = FF(static_cast<uint32_t>(AvmMemoryTag::U32)),
.avm_main_w_in_tag = FF(static_cast<uint32_t>(AvmMemoryTag::U32)),
});
clk++;

std::vector<uint8_t> input;
input.reserve(uint32_t(input_length_read.val));

// We unroll this loop because the function typically expects arrays and for this temporary sha256 function we have
// a dynamic amount of input so we will use a vector.
auto register_order = std::array{ IntermRegister::IA, IntermRegister::IB, IntermRegister::IC, IntermRegister::ID };
// If the slice size isnt a multiple of 4, we still need an extra row to write the remainder
uint32_t const num_main_rows = static_cast<uint32_t>(input_length_read.val) / 4 +
static_cast<uint32_t>(uint32_t(input_length_read.val) % 4 != 0);
for (uint32_t i = 0; i < num_main_rows; i++) {
Row main_row{
.avm_main_clk = clk + i,
.avm_main_internal_return_ptr = FF(internal_return_ptr),
.avm_main_pc = FF(pc),
.avm_main_r_in_tag = FF(static_cast<uint32_t>(AvmMemoryTag::U8)),
.avm_main_w_in_tag = FF(static_cast<uint32_t>(AvmMemoryTag::U8)),
};
// Write 4 values to memory in each_row
for (uint32_t j = 0; j < 4; j++) {
auto offset = i * 4 + j;
// If we exceed the slice size, we break
if (offset >= uint32_t(input_length_read.val)) {
break;
}
auto mem_read = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk + i, register_order[j], direct_src_offset + offset, AvmMemoryTag::U8, AvmMemoryTag::U8);
input.emplace_back(uint8_t(mem_read.val));
// This looks a bit gross, but it is fine for now.
if (j == 0) {
main_row.avm_main_ia = input.at(offset);
main_row.avm_main_mem_idx_a = FF(direct_src_offset + offset);
main_row.avm_main_mem_op_a = FF(1);
main_row.avm_main_tag_err = FF(static_cast<uint32_t>(!mem_read.tag_match));
} else if (j == 1) {
main_row.avm_main_ib = input.at(offset);
main_row.avm_main_mem_idx_b = FF(direct_src_offset + offset);
main_row.avm_main_mem_op_b = FF(1);
main_row.avm_main_tag_err = FF(static_cast<uint32_t>(!mem_read.tag_match));
} else if (j == 2) {
main_row.avm_main_ic = input.at(offset);
main_row.avm_main_mem_idx_c = FF(direct_src_offset + offset);
main_row.avm_main_mem_op_c = FF(1);
main_row.avm_main_tag_err = FF(static_cast<uint32_t>(!mem_read.tag_match));
} else {
main_row.avm_main_id = input.at(offset);
main_row.avm_main_mem_idx_d = FF(direct_src_offset + offset);
main_row.avm_main_mem_op_d = FF(1);
main_row.avm_main_tag_err = FF(static_cast<uint32_t>(!mem_read.tag_match));
}
}
main_trace.emplace_back(main_row);
}

clk += num_main_rows;

std::array<uint8_t, 32> result = sha256_trace_builder.sha256(input, sha256_op_clk);
// We convert the results to field elements here
std::vector<FF> ff_result;
for (uint32_t i = 0; i < 32; i++) {
ff_result.emplace_back(result[i]);
}
// Write the result to memory after
write_slice_to_memory(
call_ptr, clk, direct_dst_offset, AvmMemoryTag::U8, AvmMemoryTag::U8, FF(internal_return_ptr), ff_result);
}
/**
* @brief Poseidon2 Permutation with direct or indirect memory access.
*
Expand Down
2 changes: 2 additions & 0 deletions barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ class AvmTraceBuilder {
void op_keccakf1600(uint8_t indirect, uint32_t output_offset, uint32_t input_offset, uint32_t input_size_offset);
// Keccak operation - temporary while we transition to keccakf1600
void op_keccak(uint8_t indirect, uint32_t output_offset, uint32_t input_offset, uint32_t input_size_offset);
// SHA256 operation - temporary while we transition to sha256_compression
void op_sha256(uint8_t indirect, uint32_t output_offset, uint32_t input_offset, uint32_t input_size_offset);

private:
// Used for the standard indirect address resolution of three operands opcode.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,11 @@ std::array<uint32_t, 8> AvmSha256TraceBuilder::sha256_compression(const std::arr
return output;
}

std::array<uint8_t, 32> AvmSha256TraceBuilder::sha256(const std::vector<uint8_t>& input, uint32_t clk)
{
auto output = crypto::sha256(input);
// Cant push here since we are not using the same format as the sha256_compression
sha256_trace.push_back(Sha256TraceEntry{ clk, {}, {}, {} });
return output;
}
} // namespace bb::avm_trace
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ class AvmSha256TraceBuilder {
public:
struct Sha256TraceEntry {
uint32_t clk = 0;
std::array<uint32_t, 8> state;
std::array<uint32_t, 16> input;
std::array<uint32_t, 8> output;
std::array<uint32_t, 8> state{};
std::array<uint32_t, 16> input{};
std::array<uint32_t, 8> output{};
};

AvmSha256TraceBuilder();
Expand All @@ -21,6 +21,7 @@ class AvmSha256TraceBuilder {
std::array<uint32_t, 8> sha256_compression(const std::array<uint32_t, 8>& h_init,
const std::array<uint32_t, 16>& input,
uint32_t clk);
std::array<uint8_t, 32> sha256(const std::vector<uint8_t>& input, uint32_t clk);

private:
std::vector<Sha256TraceEntry> sha256_trace;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
#include "barretenberg/relations/generated/avm/perm_main_mem_ind_c.hpp"
#include "barretenberg/relations/generated/avm/perm_main_mem_ind_d.hpp"
#include "barretenberg/relations/generated/avm/perm_main_pos2_perm.hpp"
#include "barretenberg/relations/generated/avm/perm_main_sha256.hpp"
#include "barretenberg/vm/generated/avm_flavor.hpp"

namespace bb {
Expand Down Expand Up @@ -344,7 +343,6 @@ template <typename FF> struct AvmFullRow {
FF perm_main_alu{};
FF perm_main_bin{};
FF perm_main_conv{};
FF perm_main_sha256{};
FF perm_main_pos2_perm{};
FF perm_main_mem_a{};
FF perm_main_mem_b{};
Expand Down Expand Up @@ -500,8 +498,8 @@ class AvmCircuitBuilder {
using Polynomial = Flavor::Polynomial;
using ProverPolynomials = Flavor::ProverPolynomials;

static constexpr size_t num_fixed_columns = 409;
static constexpr size_t num_polys = 347;
static constexpr size_t num_fixed_columns = 408;
static constexpr size_t num_polys = 346;
std::vector<Row> rows;

void set_trace(std::vector<Row>&& trace) { rows = std::move(trace); }
Expand Down Expand Up @@ -1018,10 +1016,6 @@ class AvmCircuitBuilder {
return evaluate_logderivative.template operator()<perm_main_conv_relation<FF>>("PERM_MAIN_CONV");
};

auto perm_main_sha256 = [=]() {
return evaluate_logderivative.template operator()<perm_main_sha256_relation<FF>>("PERM_MAIN_SHA256");
};

auto perm_main_pos2_perm = [=]() {
return evaluate_logderivative.template operator()<perm_main_pos2_perm_relation<FF>>("PERM_MAIN_POS2_PERM");
};
Expand Down Expand Up @@ -1236,8 +1230,6 @@ class AvmCircuitBuilder {

relation_futures.emplace_back(std::async(std::launch::async, perm_main_conv));

relation_futures.emplace_back(std::async(std::launch::async, perm_main_sha256));

relation_futures.emplace_back(std::async(std::launch::async, perm_main_pos2_perm));

relation_futures.emplace_back(std::async(std::launch::async, perm_main_mem_a));
Expand Down Expand Up @@ -1361,8 +1353,6 @@ class AvmCircuitBuilder {

perm_main_conv();

perm_main_sha256();

perm_main_pos2_perm();

perm_main_mem_a();
Expand Down
Loading

0 comments on commit 34088b4

Please sign in to comment.