Skip to content

Commit

Permalink
feat: Encode static error strings in the ABI (#9552)
Browse files Browse the repository at this point in the history
Avoids embedding revert string in circuits. Instead, static string
errors get a specific selector, and they are encoded in the ABI. We use
noir_abi to resolve those messages in the error case.

---------

Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com>
  • Loading branch information
sirasistant and TomAFrench authored Nov 4, 2024
1 parent f47cc17 commit 1a41d42
Show file tree
Hide file tree
Showing 59 changed files with 454 additions and 590 deletions.
14 changes: 2 additions & 12 deletions avm-transpiler/src/transpile_contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@ use serde::{Deserialize, Serialize};
use acvm::acir::circuit::Program;
use noirc_errors::debug_info::ProgramDebugInfo;

use crate::transpile::{
brillig_to_avm, map_brillig_pcs_to_avm_pcs, patch_assert_message_pcs, patch_debug_info_pcs,
};
use crate::utils::{extract_brillig_from_acir_program, extract_static_assert_messages};
use fxhash::FxHashMap as HashMap;
use crate::transpile::{brillig_to_avm, map_brillig_pcs_to_avm_pcs, patch_debug_info_pcs};
use crate::utils::extract_brillig_from_acir_program;

/// Representation of a contract with some transpiled functions
#[derive(Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -51,7 +48,6 @@ pub struct AvmContractFunctionArtifact {
)]
pub debug_symbols: ProgramDebugInfo,
pub brillig_names: Vec<String>,
pub assert_messages: HashMap<usize, String>,
}

/// Representation of an ACIR contract function but with
Expand Down Expand Up @@ -96,16 +92,11 @@ impl From<CompiledAcirContractArtifact> for TranspiledContractArtifact {
// Extract Brillig Opcodes from acir
let acir_program = function.bytecode;
let brillig_bytecode = extract_brillig_from_acir_program(&acir_program);
let assert_messages = extract_static_assert_messages(&acir_program);
info!("Extracted Brillig program has {} instructions", brillig_bytecode.len());

// Map Brillig pcs to AVM pcs (index is Brillig PC, value is AVM PC)
let brillig_pcs_to_avm_pcs = map_brillig_pcs_to_avm_pcs(brillig_bytecode);

// Patch the assert messages with updated PCs
let assert_messages =
patch_assert_message_pcs(assert_messages, &brillig_pcs_to_avm_pcs);

// Transpile to AVM
let avm_bytecode = brillig_to_avm(brillig_bytecode, &brillig_pcs_to_avm_pcs);

Expand All @@ -132,7 +123,6 @@ impl From<CompiledAcirContractArtifact> for TranspiledContractArtifact {
bytecode: base64::prelude::BASE64_STANDARD.encode(avm_bytecode),
debug_symbols: ProgramDebugInfo { debug_infos },
brillig_names: function.brillig_names,
assert_messages,
},
));
} else {
Expand Down
31 changes: 1 addition & 30 deletions avm-transpiler/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use fxhash::FxHashMap as HashMap;

use acvm::acir::circuit::brillig::BrilligFunctionId;
use acvm::{AcirField, FieldElement};
use log::{debug, info, trace};

use acvm::acir::brillig::Opcode as BrilligOpcode;
use acvm::acir::circuit::{AssertionPayload, Opcode, Program};
use acvm::acir::circuit::{Opcode, Program};

use crate::instructions::{AvmInstruction, AvmOperand};
use crate::opcodes::AvmOpcode;
Expand Down Expand Up @@ -39,33 +37,6 @@ pub fn extract_brillig_from_acir_program(
&program.unconstrained_functions[0].bytecode
}

/// Assertion messages that are static strings are stored in the assert_messages map of the ACIR program.
pub fn extract_static_assert_messages(program: &Program<FieldElement>) -> HashMap<usize, String> {
assert_eq!(
program.functions.len(),
1,
"An AVM program should have only a single ACIR function with a 'BrilligCall'"
);
let main_function = &program.functions[0];
main_function
.assert_messages
.iter()
.filter_map(|(location, payload)| {
if let AssertionPayload::StaticString(static_string) = payload {
Some((
location
.to_brillig_location()
.expect("Assert message is not for the brillig function")
.0,
static_string.clone(),
))
} else {
None
}
})
.collect()
}

/// Print inputs, outputs, and instructions in a Brillig program
pub fn dbg_print_brillig_program(brillig_bytecode: &[BrilligOpcode<FieldElement>]) {
trace!("Printing Brillig program...");
Expand Down
126 changes: 10 additions & 116 deletions barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#pragma once

#include "barretenberg/common/throw_or_abort.hpp"
#include "bincode.hpp"
#include "serde.hpp"

Expand Down Expand Up @@ -1276,24 +1275,8 @@ struct ExpressionOrMemory {
};

struct AssertionPayload {

struct StaticString {
std::string value;

friend bool operator==(const StaticString&, const StaticString&);
std::vector<uint8_t> bincodeSerialize() const;
static StaticString bincodeDeserialize(std::vector<uint8_t>);
};

struct Dynamic {
std::tuple<uint64_t, std::vector<Program::ExpressionOrMemory>> value;

friend bool operator==(const Dynamic&, const Dynamic&);
std::vector<uint8_t> bincodeSerialize() const;
static Dynamic bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<StaticString, Dynamic> value;
uint64_t error_selector;
std::vector<Program::ExpressionOrMemory> payload;

friend bool operator==(const AssertionPayload&, const AssertionPayload&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -1394,7 +1377,10 @@ namespace Program {

inline bool operator==(const AssertionPayload& lhs, const AssertionPayload& rhs)
{
if (!(lhs.value == rhs.value)) {
if (!(lhs.error_selector == rhs.error_selector)) {
return false;
}
if (!(lhs.payload == rhs.payload)) {
return false;
}
return true;
Expand Down Expand Up @@ -1425,7 +1411,8 @@ void serde::Serializable<Program::AssertionPayload>::serialize(const Program::As
Serializer& serializer)
{
serializer.increase_container_depth();
serde::Serializable<decltype(obj.value)>::serialize(obj.value, serializer);
serde::Serializable<decltype(obj.error_selector)>::serialize(obj.error_selector, serializer);
serde::Serializable<decltype(obj.payload)>::serialize(obj.payload, serializer);
serializer.decrease_container_depth();
}

Expand All @@ -1435,107 +1422,14 @@ Program::AssertionPayload serde::Deserializable<Program::AssertionPayload>::dese
{
deserializer.increase_container_depth();
Program::AssertionPayload obj;
obj.value = serde::Deserializable<decltype(obj.value)>::deserialize(deserializer);
obj.error_selector = serde::Deserializable<decltype(obj.error_selector)>::deserialize(deserializer);
obj.payload = serde::Deserializable<decltype(obj.payload)>::deserialize(deserializer);
deserializer.decrease_container_depth();
return obj;
}

namespace Program {

inline bool operator==(const AssertionPayload::StaticString& lhs, const AssertionPayload::StaticString& rhs)
{
if (!(lhs.value == rhs.value)) {
return false;
}
return true;
}

inline std::vector<uint8_t> AssertionPayload::StaticString::bincodeSerialize() const
{
auto serializer = serde::BincodeSerializer();
serde::Serializable<AssertionPayload::StaticString>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline AssertionPayload::StaticString AssertionPayload::StaticString::bincodeDeserialize(std::vector<uint8_t> input)
{
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<AssertionPayload::StaticString>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw_or_abort("Some input bytes were not read");
}
return value;
}

} // end of namespace Program

template <>
template <typename Serializer>
void serde::Serializable<Program::AssertionPayload::StaticString>::serialize(
const Program::AssertionPayload::StaticString& obj, Serializer& serializer)
{
serde::Serializable<decltype(obj.value)>::serialize(obj.value, serializer);
}

template <>
template <typename Deserializer>
Program::AssertionPayload::StaticString serde::Deserializable<Program::AssertionPayload::StaticString>::deserialize(
Deserializer& deserializer)
{
Program::AssertionPayload::StaticString obj;
obj.value = serde::Deserializable<decltype(obj.value)>::deserialize(deserializer);
return obj;
}

namespace Program {

inline bool operator==(const AssertionPayload::Dynamic& lhs, const AssertionPayload::Dynamic& rhs)
{
if (!(lhs.value == rhs.value)) {
return false;
}
return true;
}

inline std::vector<uint8_t> AssertionPayload::Dynamic::bincodeSerialize() const
{
auto serializer = serde::BincodeSerializer();
serde::Serializable<AssertionPayload::Dynamic>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline AssertionPayload::Dynamic AssertionPayload::Dynamic::bincodeDeserialize(std::vector<uint8_t> input)
{
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<AssertionPayload::Dynamic>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw_or_abort("Some input bytes were not read");
}
return value;
}

} // end of namespace Program

template <>
template <typename Serializer>
void serde::Serializable<Program::AssertionPayload::Dynamic>::serialize(const Program::AssertionPayload::Dynamic& obj,
Serializer& serializer)
{
serde::Serializable<decltype(obj.value)>::serialize(obj.value, serializer);
}

template <>
template <typename Deserializer>
Program::AssertionPayload::Dynamic serde::Deserializable<Program::AssertionPayload::Dynamic>::deserialize(
Deserializer& deserializer)
{
Program::AssertionPayload::Dynamic obj;
obj.value = serde::Deserializable<decltype(obj.value)>::deserialize(deserializer);
return obj;
}

namespace Program {

inline bool operator==(const BinaryFieldOp& lhs, const BinaryFieldOp& rhs)
{
if (!(lhs.value == rhs.value)) {
Expand Down
105 changes: 8 additions & 97 deletions noir/noir-repo/acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1216,24 +1216,8 @@ namespace Program {
};

struct AssertionPayload {

struct StaticString {
std::string value;

friend bool operator==(const StaticString&, const StaticString&);
std::vector<uint8_t> bincodeSerialize() const;
static StaticString bincodeDeserialize(std::vector<uint8_t>);
};

struct Dynamic {
std::tuple<uint64_t, std::vector<Program::ExpressionOrMemory>> value;

friend bool operator==(const Dynamic&, const Dynamic&);
std::vector<uint8_t> bincodeSerialize() const;
static Dynamic bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<StaticString, Dynamic> value;
uint64_t error_selector;
std::vector<Program::ExpressionOrMemory> payload;

friend bool operator==(const AssertionPayload&, const AssertionPayload&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -1334,7 +1318,8 @@ namespace Program {
namespace Program {

inline bool operator==(const AssertionPayload &lhs, const AssertionPayload &rhs) {
if (!(lhs.value == rhs.value)) { return false; }
if (!(lhs.error_selector == rhs.error_selector)) { return false; }
if (!(lhs.payload == rhs.payload)) { return false; }
return true;
}

Expand All @@ -1359,7 +1344,8 @@ template <>
template <typename Serializer>
void serde::Serializable<Program::AssertionPayload>::serialize(const Program::AssertionPayload &obj, Serializer &serializer) {
serializer.increase_container_depth();
serde::Serializable<decltype(obj.value)>::serialize(obj.value, serializer);
serde::Serializable<decltype(obj.error_selector)>::serialize(obj.error_selector, serializer);
serde::Serializable<decltype(obj.payload)>::serialize(obj.payload, serializer);
serializer.decrease_container_depth();
}

Expand All @@ -1368,87 +1354,12 @@ template <typename Deserializer>
Program::AssertionPayload serde::Deserializable<Program::AssertionPayload>::deserialize(Deserializer &deserializer) {
deserializer.increase_container_depth();
Program::AssertionPayload obj;
obj.value = serde::Deserializable<decltype(obj.value)>::deserialize(deserializer);
obj.error_selector = serde::Deserializable<decltype(obj.error_selector)>::deserialize(deserializer);
obj.payload = serde::Deserializable<decltype(obj.payload)>::deserialize(deserializer);
deserializer.decrease_container_depth();
return obj;
}

namespace Program {

inline bool operator==(const AssertionPayload::StaticString &lhs, const AssertionPayload::StaticString &rhs) {
if (!(lhs.value == rhs.value)) { return false; }
return true;
}

inline std::vector<uint8_t> AssertionPayload::StaticString::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<AssertionPayload::StaticString>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline AssertionPayload::StaticString AssertionPayload::StaticString::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<AssertionPayload::StaticString>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Program

template <>
template <typename Serializer>
void serde::Serializable<Program::AssertionPayload::StaticString>::serialize(const Program::AssertionPayload::StaticString &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.value)>::serialize(obj.value, serializer);
}

template <>
template <typename Deserializer>
Program::AssertionPayload::StaticString serde::Deserializable<Program::AssertionPayload::StaticString>::deserialize(Deserializer &deserializer) {
Program::AssertionPayload::StaticString obj;
obj.value = serde::Deserializable<decltype(obj.value)>::deserialize(deserializer);
return obj;
}

namespace Program {

inline bool operator==(const AssertionPayload::Dynamic &lhs, const AssertionPayload::Dynamic &rhs) {
if (!(lhs.value == rhs.value)) { return false; }
return true;
}

inline std::vector<uint8_t> AssertionPayload::Dynamic::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<AssertionPayload::Dynamic>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline AssertionPayload::Dynamic AssertionPayload::Dynamic::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<AssertionPayload::Dynamic>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Program

template <>
template <typename Serializer>
void serde::Serializable<Program::AssertionPayload::Dynamic>::serialize(const Program::AssertionPayload::Dynamic &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.value)>::serialize(obj.value, serializer);
}

template <>
template <typename Deserializer>
Program::AssertionPayload::Dynamic serde::Deserializable<Program::AssertionPayload::Dynamic>::deserialize(Deserializer &deserializer) {
Program::AssertionPayload::Dynamic obj;
obj.value = serde::Deserializable<decltype(obj.value)>::deserialize(deserializer);
return obj;
}

namespace Program {

inline bool operator==(const BinaryFieldOp &lhs, const BinaryFieldOp &rhs) {
Expand Down
Loading

0 comments on commit 1a41d42

Please sign in to comment.