Skip to content

Commit

Permalink
feat!: add opcode for poseidon2 permutation (AztecProtocol#4214)
Browse files Browse the repository at this point in the history
Related to issue: noir-lang/noir#4037

The PR adds the opcode to ACIR and updates BB and Noir accordingly.
Furthermore you can use it via a foreign function in the stdlib. This
will generate the proper ACIR opcode but the solver will not be able to
solve it and BB will skip it.
  • Loading branch information
guipublic authored Jan 25, 2024
1 parent 39d697f commit 1f91b23
Show file tree
Hide file tree
Showing 14 changed files with 345 additions and 9 deletions.
142 changes: 140 additions & 2 deletions barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,16 @@ struct BlackBoxFuncCall {
static BigIntToLeBytes bincodeDeserialize(std::vector<uint8_t>);
};

struct Poseidon2Permutation {
std::vector<Circuit::FunctionInput> inputs;
std::vector<Circuit::Witness> outputs;
uint32_t len;

friend bool operator==(const Poseidon2Permutation&, const Poseidon2Permutation&);
std::vector<uint8_t> bincodeSerialize() const;
static Poseidon2Permutation bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<AND,
XOR,
RANGE,
Expand All @@ -277,7 +287,8 @@ struct BlackBoxFuncCall {
BigIntMul,
BigIntDiv,
BigIntFromLeBytes,
BigIntToLeBytes>
BigIntToLeBytes,
Poseidon2Permutation>
value;

friend bool operator==(const BlackBoxFuncCall&, const BlackBoxFuncCall&);
Expand Down Expand Up @@ -664,6 +675,16 @@ struct BlackBoxOp {
static BigIntToLeBytes bincodeDeserialize(std::vector<uint8_t>);
};

struct Poseidon2Permutation {
Circuit::HeapVector message;
Circuit::HeapArray output;
Circuit::RegisterIndex len;

friend bool operator==(const Poseidon2Permutation&, const Poseidon2Permutation&);
std::vector<uint8_t> bincodeSerialize() const;
static Poseidon2Permutation bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<Sha256,
Blake2s,
Blake3,
Expand All @@ -681,7 +702,8 @@ struct BlackBoxOp {
BigIntMul,
BigIntDiv,
BigIntFromLeBytes,
BigIntToLeBytes>
BigIntToLeBytes,
Poseidon2Permutation>
value;

friend bool operator==(const BlackBoxOp&, const BlackBoxOp&);
Expand Down Expand Up @@ -3269,6 +3291,65 @@ Circuit::BlackBoxFuncCall::BigIntToLeBytes serde::Deserializable<

namespace Circuit {

inline bool operator==(const BlackBoxFuncCall::Poseidon2Permutation& lhs,
const BlackBoxFuncCall::Poseidon2Permutation& rhs)
{
if (!(lhs.inputs == rhs.inputs)) {
return false;
}
if (!(lhs.outputs == rhs.outputs)) {
return false;
}
if (!(lhs.len == rhs.len)) {
return false;
}
return true;
}

inline std::vector<uint8_t> BlackBoxFuncCall::Poseidon2Permutation::bincodeSerialize() const
{
auto serializer = serde::BincodeSerializer();
serde::Serializable<BlackBoxFuncCall::Poseidon2Permutation>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BlackBoxFuncCall::Poseidon2Permutation BlackBoxFuncCall::Poseidon2Permutation::bincodeDeserialize(
std::vector<uint8_t> input)
{
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BlackBoxFuncCall::Poseidon2Permutation>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw_or_abort("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BlackBoxFuncCall::Poseidon2Permutation>::serialize(
const Circuit::BlackBoxFuncCall::Poseidon2Permutation& obj, Serializer& serializer)
{
serde::Serializable<decltype(obj.inputs)>::serialize(obj.inputs, serializer);
serde::Serializable<decltype(obj.outputs)>::serialize(obj.outputs, serializer);
serde::Serializable<decltype(obj.len)>::serialize(obj.len, serializer);
}

template <>
template <typename Deserializer>
Circuit::BlackBoxFuncCall::Poseidon2Permutation serde::Deserializable<
Circuit::BlackBoxFuncCall::Poseidon2Permutation>::deserialize(Deserializer& deserializer)
{
Circuit::BlackBoxFuncCall::Poseidon2Permutation obj;
obj.inputs = serde::Deserializable<decltype(obj.inputs)>::deserialize(deserializer);
obj.outputs = serde::Deserializable<decltype(obj.outputs)>::deserialize(deserializer);
obj.len = serde::Deserializable<decltype(obj.len)>::deserialize(deserializer);
return obj;
}

namespace Circuit {

inline bool operator==(const BlackBoxOp& lhs, const BlackBoxOp& rhs)
{
if (!(lhs.value == rhs.value)) {
Expand Down Expand Up @@ -4352,6 +4433,63 @@ Circuit::BlackBoxOp::BigIntToLeBytes serde::Deserializable<Circuit::BlackBoxOp::

namespace Circuit {

inline bool operator==(const BlackBoxOp::Poseidon2Permutation& lhs, const BlackBoxOp::Poseidon2Permutation& rhs)
{
if (!(lhs.message == rhs.message)) {
return false;
}
if (!(lhs.output == rhs.output)) {
return false;
}
if (!(lhs.len == rhs.len)) {
return false;
}
return true;
}

inline std::vector<uint8_t> BlackBoxOp::Poseidon2Permutation::bincodeSerialize() const
{
auto serializer = serde::BincodeSerializer();
serde::Serializable<BlackBoxOp::Poseidon2Permutation>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BlackBoxOp::Poseidon2Permutation BlackBoxOp::Poseidon2Permutation::bincodeDeserialize(std::vector<uint8_t> input)
{
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BlackBoxOp::Poseidon2Permutation>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw_or_abort("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BlackBoxOp::Poseidon2Permutation>::serialize(
const Circuit::BlackBoxOp::Poseidon2Permutation& obj, Serializer& serializer)
{
serde::Serializable<decltype(obj.message)>::serialize(obj.message, serializer);
serde::Serializable<decltype(obj.output)>::serialize(obj.output, serializer);
serde::Serializable<decltype(obj.len)>::serialize(obj.len, serializer);
}

template <>
template <typename Deserializer>
Circuit::BlackBoxOp::Poseidon2Permutation serde::Deserializable<Circuit::BlackBoxOp::Poseidon2Permutation>::deserialize(
Deserializer& deserializer)
{
Circuit::BlackBoxOp::Poseidon2Permutation obj;
obj.message = serde::Deserializable<decltype(obj.message)>::deserialize(deserializer);
obj.output = serde::Deserializable<decltype(obj.output)>::deserialize(deserializer);
obj.len = serde::Deserializable<decltype(obj.len)>::deserialize(deserializer);
return obj;
}

namespace Circuit {

inline bool operator==(const BlockId& lhs, const BlockId& rhs)
{
if (!(lhs.value == rhs.value)) {
Expand Down
112 changes: 110 additions & 2 deletions noir/acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,17 @@ namespace Circuit {
static BigIntToLeBytes bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<AND, XOR, RANGE, SHA256, Blake2s, Blake3, SchnorrVerify, PedersenCommitment, PedersenHash, EcdsaSecp256k1, EcdsaSecp256r1, FixedBaseScalarMul, EmbeddedCurveAdd, Keccak256, Keccak256VariableLength, Keccakf1600, RecursiveAggregation, BigIntAdd, BigIntNeg, BigIntMul, BigIntDiv, BigIntFromLeBytes, BigIntToLeBytes> value;
struct Poseidon2Permutation {
std::vector<Circuit::FunctionInput> inputs;
std::vector<Circuit::Witness> outputs;
uint32_t len;

friend bool operator==(const Poseidon2Permutation&, const Poseidon2Permutation&);
std::vector<uint8_t> bincodeSerialize() const;
static Poseidon2Permutation bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<AND, XOR, RANGE, SHA256, Blake2s, Blake3, SchnorrVerify, PedersenCommitment, PedersenHash, EcdsaSecp256k1, EcdsaSecp256r1, FixedBaseScalarMul, EmbeddedCurveAdd, Keccak256, Keccak256VariableLength, Keccakf1600, RecursiveAggregation, BigIntAdd, BigIntNeg, BigIntMul, BigIntDiv, BigIntFromLeBytes, BigIntToLeBytes, Poseidon2Permutation> value;

friend bool operator==(const BlackBoxFuncCall&, const BlackBoxFuncCall&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -641,7 +651,17 @@ namespace Circuit {
static BigIntToLeBytes bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<Sha256, Blake2s, Blake3, Keccak256, Keccakf1600, EcdsaSecp256k1, EcdsaSecp256r1, SchnorrVerify, PedersenCommitment, PedersenHash, FixedBaseScalarMul, EmbeddedCurveAdd, BigIntAdd, BigIntNeg, BigIntMul, BigIntDiv, BigIntFromLeBytes, BigIntToLeBytes> value;
struct Poseidon2Permutation {
Circuit::HeapVector message;
Circuit::HeapArray output;
Circuit::RegisterIndex len;

friend bool operator==(const Poseidon2Permutation&, const Poseidon2Permutation&);
std::vector<uint8_t> bincodeSerialize() const;
static Poseidon2Permutation bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<Sha256, Blake2s, Blake3, Keccak256, Keccakf1600, EcdsaSecp256k1, EcdsaSecp256r1, SchnorrVerify, PedersenCommitment, PedersenHash, FixedBaseScalarMul, EmbeddedCurveAdd, BigIntAdd, BigIntNeg, BigIntMul, BigIntDiv, BigIntFromLeBytes, BigIntToLeBytes, Poseidon2Permutation> value;

friend bool operator==(const BlackBoxOp&, const BlackBoxOp&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -2784,6 +2804,50 @@ Circuit::BlackBoxFuncCall::BigIntToLeBytes serde::Deserializable<Circuit::BlackB
return obj;
}

namespace Circuit {

inline bool operator==(const BlackBoxFuncCall::Poseidon2Permutation &lhs, const BlackBoxFuncCall::Poseidon2Permutation &rhs) {
if (!(lhs.inputs == rhs.inputs)) { return false; }
if (!(lhs.outputs == rhs.outputs)) { return false; }
if (!(lhs.len == rhs.len)) { return false; }
return true;
}

inline std::vector<uint8_t> BlackBoxFuncCall::Poseidon2Permutation::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<BlackBoxFuncCall::Poseidon2Permutation>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BlackBoxFuncCall::Poseidon2Permutation BlackBoxFuncCall::Poseidon2Permutation::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BlackBoxFuncCall::Poseidon2Permutation>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BlackBoxFuncCall::Poseidon2Permutation>::serialize(const Circuit::BlackBoxFuncCall::Poseidon2Permutation &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.inputs)>::serialize(obj.inputs, serializer);
serde::Serializable<decltype(obj.outputs)>::serialize(obj.outputs, serializer);
serde::Serializable<decltype(obj.len)>::serialize(obj.len, serializer);
}

template <>
template <typename Deserializer>
Circuit::BlackBoxFuncCall::Poseidon2Permutation serde::Deserializable<Circuit::BlackBoxFuncCall::Poseidon2Permutation>::deserialize(Deserializer &deserializer) {
Circuit::BlackBoxFuncCall::Poseidon2Permutation obj;
obj.inputs = serde::Deserializable<decltype(obj.inputs)>::deserialize(deserializer);
obj.outputs = serde::Deserializable<decltype(obj.outputs)>::deserialize(deserializer);
obj.len = serde::Deserializable<decltype(obj.len)>::deserialize(deserializer);
return obj;
}

namespace Circuit {

inline bool operator==(const BlackBoxOp &lhs, const BlackBoxOp &rhs) {
Expand Down Expand Up @@ -3624,6 +3688,50 @@ Circuit::BlackBoxOp::BigIntToLeBytes serde::Deserializable<Circuit::BlackBoxOp::
return obj;
}

namespace Circuit {

inline bool operator==(const BlackBoxOp::Poseidon2Permutation &lhs, const BlackBoxOp::Poseidon2Permutation &rhs) {
if (!(lhs.message == rhs.message)) { return false; }
if (!(lhs.output == rhs.output)) { return false; }
if (!(lhs.len == rhs.len)) { return false; }
return true;
}

inline std::vector<uint8_t> BlackBoxOp::Poseidon2Permutation::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<BlackBoxOp::Poseidon2Permutation>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BlackBoxOp::Poseidon2Permutation BlackBoxOp::Poseidon2Permutation::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BlackBoxOp::Poseidon2Permutation>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BlackBoxOp::Poseidon2Permutation>::serialize(const Circuit::BlackBoxOp::Poseidon2Permutation &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.message)>::serialize(obj.message, serializer);
serde::Serializable<decltype(obj.output)>::serialize(obj.output, serializer);
serde::Serializable<decltype(obj.len)>::serialize(obj.len, serializer);
}

template <>
template <typename Deserializer>
Circuit::BlackBoxOp::Poseidon2Permutation serde::Deserializable<Circuit::BlackBoxOp::Poseidon2Permutation>::deserialize(Deserializer &deserializer) {
Circuit::BlackBoxOp::Poseidon2Permutation obj;
obj.message = serde::Deserializable<decltype(obj.message)>::deserialize(deserializer);
obj.output = serde::Deserializable<decltype(obj.output)>::deserialize(deserializer);
obj.len = serde::Deserializable<decltype(obj.len)>::deserialize(deserializer);
return obj;
}

namespace Circuit {

inline bool operator==(const BlockId &lhs, const BlockId &rhs) {
Expand Down
4 changes: 4 additions & 0 deletions noir/acvm-repo/acir/src/circuit/black_box_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ pub enum BlackBoxFunc {
BigIntFromLeBytes,
/// BigInt to le bytes
BigIntToLeBytes,
/// Permutation function of Poseidon2
Poseidon2Permutation,
}

impl std::fmt::Display for BlackBoxFunc {
Expand Down Expand Up @@ -92,6 +94,7 @@ impl BlackBoxFunc {
BlackBoxFunc::BigIntDiv => "bigint_div",
BlackBoxFunc::BigIntFromLeBytes => "bigint_from_le_bytes",
BlackBoxFunc::BigIntToLeBytes => "bigint_to_le_bytes",
BlackBoxFunc::Poseidon2Permutation => "poseidon2_permutation",
}
}

Expand Down Expand Up @@ -119,6 +122,7 @@ impl BlackBoxFunc {
"bigint_div" => Some(BlackBoxFunc::BigIntDiv),
"bigint_from_le_bytes" => Some(BlackBoxFunc::BigIntFromLeBytes),
"bigint_to_le_bytes" => Some(BlackBoxFunc::BigIntToLeBytes),
"poseidon2_permutation" => Some(BlackBoxFunc::Poseidon2Permutation),
_ => None,
}
}
Expand Down
20 changes: 17 additions & 3 deletions noir/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,17 @@ pub enum BlackBoxFuncCall {
input: u32,
outputs: Vec<Witness>,
},
/// Applies the Poseidon2 permutation function to the given state,
/// outputting the permuted state.
Poseidon2Permutation {
/// Input state for the permutation of Poseidon2
inputs: Vec<FunctionInput>,
/// Permuted state
outputs: Vec<Witness>,
/// State length (in number of field elements)
/// It is the length of inputs and outputs vectors
len: u32,
},
}

impl BlackBoxFuncCall {
Expand Down Expand Up @@ -171,7 +182,8 @@ impl BlackBoxFuncCall {
BlackBoxFuncCall::BigIntMul { .. } => BlackBoxFunc::BigIntMul,
BlackBoxFuncCall::BigIntDiv { .. } => BlackBoxFunc::BigIntDiv,
BlackBoxFuncCall::BigIntFromLeBytes { .. } => BlackBoxFunc::BigIntFromLeBytes,
&BlackBoxFuncCall::BigIntToLeBytes { .. } => BlackBoxFunc::BigIntToLeBytes,
BlackBoxFuncCall::BigIntToLeBytes { .. } => BlackBoxFunc::BigIntToLeBytes,
BlackBoxFuncCall::Poseidon2Permutation { .. } => BlackBoxFunc::Poseidon2Permutation,
}
}

Expand All @@ -188,7 +200,8 @@ impl BlackBoxFuncCall {
| BlackBoxFuncCall::Keccakf1600 { inputs, .. }
| BlackBoxFuncCall::PedersenCommitment { inputs, .. }
| BlackBoxFuncCall::PedersenHash { inputs, .. }
| BlackBoxFuncCall::BigIntFromLeBytes { inputs, .. } => inputs.to_vec(),
| BlackBoxFuncCall::BigIntFromLeBytes { inputs, .. }
| BlackBoxFuncCall::Poseidon2Permutation { inputs, .. } => inputs.to_vec(),
BlackBoxFuncCall::AND { lhs, rhs, .. } | BlackBoxFuncCall::XOR { lhs, rhs, .. } => {
vec![*lhs, *rhs]
}
Expand Down Expand Up @@ -282,7 +295,8 @@ impl BlackBoxFuncCall {
| BlackBoxFuncCall::Blake3 { outputs, .. }
| BlackBoxFuncCall::Keccak256 { outputs, .. }
| BlackBoxFuncCall::Keccakf1600 { outputs, .. }
| BlackBoxFuncCall::Keccak256VariableLength { outputs, .. } => outputs.to_vec(),
| BlackBoxFuncCall::Keccak256VariableLength { outputs, .. }
| BlackBoxFuncCall::Poseidon2Permutation { outputs, .. } => outputs.to_vec(),
BlackBoxFuncCall::AND { output, .. }
| BlackBoxFuncCall::XOR { output, .. }
| BlackBoxFuncCall::SchnorrVerify { output, .. }
Expand Down
Loading

0 comments on commit 1f91b23

Please sign in to comment.