Skip to content

Commit

Permalink
Allow revert with custom error
Browse files Browse the repository at this point in the history
  • Loading branch information
cburgdorf committed Jul 1, 2021
1 parent e2972ba commit f9aa721
Show file tree
Hide file tree
Showing 22 changed files with 714 additions and 111 deletions.
5 changes: 5 additions & 0 deletions analyzer/src/namespace/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,11 @@ impl Struct {
.map(|(_, typ)| typ)
}

/// Return the types of all fields
pub fn get_field_types(&self) -> Vec<FixedSize> {
self.fields.iter().cloned().map(|(_, typ)| typ).collect()
}

/// Return the index of the given field name
pub fn get_field_index(&self, name: &str) -> Option<usize> {
self.fields.iter().position(|(field, _)| field == name)
Expand Down
28 changes: 27 additions & 1 deletion analyzer/src/traversal/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ fn func_stmt(
Assert { .. } => assert(scope, context, stmt),
Expr { .. } => expr(scope, context, stmt),
Pass => Ok(()),
Revert { .. } => Ok(()),
Revert { .. } => revert(scope, context, stmt),
Break | Continue => {
loop_flow_statement(scope, context, stmt);
Ok(())
Expand Down Expand Up @@ -439,6 +439,32 @@ fn assert(
unreachable!()
}

fn revert(
scope: Shared<BlockScope>,
context: &mut Context,
stmt: &Node<fe::FuncStmt>,
) -> Result<(), FatalError> {
if let fe::FuncStmt::Revert { error } = &stmt.kind {
if let Some(error_expr) = error {
let error_attributes = expressions::expr(Rc::clone(&scope), context, error_expr, None)?;
if !matches!(error_attributes.typ, Type::Struct(_)) {
context.error(
"`revert` error must be a struct",
error_expr.span,
format!(
"this has type `{}`; expected a struct",
error_attributes.typ
),
);
}
}

return Ok(());
}

unreachable!()
}

fn func_return(
scope: Shared<BlockScope>,
context: &mut Context,
Expand Down
6 changes: 3 additions & 3 deletions common/src/utils/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@ use crate::utils::keccak;

/// Formats the name and fields and calculates the 32 byte keccak256 value of
/// the signature.
pub fn event_topic(name: &str, fields: Vec<String>) -> String {
pub fn event_topic(name: &str, fields: &[String]) -> String {
sign_event_or_func(name, fields, 32)
}
/// Formats the name and params and calculates the 4 byte keccak256 value of the
/// signature.
pub fn func_selector(name: &str, params: Vec<String>) -> String {
pub fn func_selector(name: &str, params: &[String]) -> String {
sign_event_or_func(name, params, 4)
}

fn sign_event_or_func(name: &str, params: Vec<String>, size: usize) -> String {
fn sign_event_or_func(name: &str, params: &[String], size: usize) -> String {
let signature = format!("{}({})", name, params.join(","));
keccak::partial(signature.as_bytes(), size)
}
9 changes: 9 additions & 0 deletions compiler/src/yul/context.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::yul::AnalyzerContext;
use fe_analyzer::namespace::types::{FeString, Struct};
use indexmap::IndexSet;

// This is contract context, but it's used all over so it has a short name.
Expand All @@ -10,6 +11,12 @@ pub struct Context<'a> {

/// Names of contracts that have been created inside of this contract.
pub created_contracts: IndexSet<String>,

/// Strings that can be used as revert error in assertions
pub assert_strings: IndexSet<FeString>,

// Structs that can be used as errors in revert statements
pub revert_errors: IndexSet<Struct>,
}

impl<'a> Context<'a> {
Expand All @@ -18,6 +25,8 @@ impl<'a> Context<'a> {
analysis,
string_literals: IndexSet::new(),
created_contracts: IndexSet::new(),
assert_strings: IndexSet::new(),
revert_errors: IndexSet::new(),
}
}
}
48 changes: 44 additions & 4 deletions compiler/src/yul/mappers/functions.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use fe_analyzer::namespace::types::FixedSize;

use crate::yul::mappers::{assignments, declarations, expressions};
use crate::yul::names;
use crate::yul::operations::abi as abi_operations;
use crate::yul::operations::data as data_operations;
use crate::yul::Context;
use fe_analyzer::context::ExpressionAttributes;
Expand Down Expand Up @@ -56,7 +59,7 @@ fn func_stmt(context: &mut Context, stmt: &Node<fe::FuncStmt>) -> yul::Statement
fe::FuncStmt::Pass => statement! { pop(0) },
fe::FuncStmt::Break => break_statement(context, stmt),
fe::FuncStmt::Continue => continue_statement(context, stmt),
fe::FuncStmt::Revert { .. } => revert(stmt),
fe::FuncStmt::Revert { .. } => revert(context, stmt),
}
}

Expand Down Expand Up @@ -118,8 +121,28 @@ fn expr(context: &mut Context, stmt: &Node<fe::FuncStmt>) -> yul::Statement {
}
}

fn revert(stmt: &Node<fe::FuncStmt>) -> yul::Statement {
if let fe::FuncStmt::Revert { .. } = &stmt.kind {
fn revert(context: &mut Context, stmt: &Node<fe::FuncStmt>) -> yul::Statement {
if let fe::FuncStmt::Revert { error } = &stmt.kind {
if let Some(error_expr) = error {
let error_attributes = context
.analysis
.get_expression(error_expr)
.expect("missing expression");

if let Type::Struct(val) = &error_attributes.typ {
context.revert_errors.insert(val.clone());

let revert_data = expressions::expr(context, error_expr);
let size =
abi_operations::encode_size(vec![val.clone()], vec![revert_data.clone()]);
let revert_fn = names::revert_name(&val.name, &val.get_field_types());

return statement! {
([revert_fn]([revert_data], [size]))
};
}
}

return statement! { revert(0, 0) };
}

Expand Down Expand Up @@ -150,7 +173,24 @@ fn assert(context: &mut Context, stmt: &Node<fe::FuncStmt>) -> yul::Statement {
return match msg {
Some(val) => {
let msg = expressions::expr(context, val);
statement! { if (iszero([test])) { (revert_with_reason_string([msg])) } }
let msg_expr = context
.analysis
.get_expression(val)
.expect("missing expression");

if let Type::String(str) = &msg_expr.typ {
let size = abi_operations::encode_size(vec![str.clone()], vec![msg.clone()]);
let fixed_size = FixedSize::String(str.clone());
context.assert_strings.insert(str.clone());
let revert_fn = names::error_revert_name(&[fixed_size]);

return statement! {
if (iszero([test])) {
([revert_fn]([msg], [size]))
}
};
}
unreachable!()
}
None => statement! { if (iszero([test])) { (revert(0, 0)) } },
};
Expand Down
25 changes: 25 additions & 0 deletions compiler/src/yul/names.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use fe_analyzer::namespace::types::{AbiDecodeLocation, AbiEncoding, Integer};
use fe_common::utils::abi as abi_utils;
use yultsur::*;

/// Generate a function name to perform checked addition
Expand Down Expand Up @@ -59,6 +60,30 @@ pub fn var_name(name: &str) -> yul::Identifier {
identifier! { (format!("${}", name)) }
}

/// Generate a revert function name for the name `Error` and a given set of types
pub fn error_revert_name<T: AbiEncoding>(types: &[T]) -> yul::Identifier {
revert_name("Error", types)
}

/// Generates a revert function name for a given name and types
pub fn revert_name<T: AbiEncoding>(name: &str, types: &[T]) -> yul::Identifier {
let type_names = types
.iter()
.map(|param| param.lower_snake())
.collect::<Vec<String>>();

let abi_names = types
.iter()
.map(|param| param.abi_selector_name())
.collect::<Vec<String>>();

let selector = abi_utils::func_selector(name, &abi_names);

let name = format!("revert_with_{}_{}", selector, &type_names.join("_"));

identifier! { (name) }
}

/// Generates an ABI encoding function name for a given set of types.
pub fn encode_name<T: AbiEncoding>(types: &[T]) -> yul::Identifier {
let type_names = types
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/yul/runtime/abi_dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ fn selector(name: &str, params: &[FixedSize]) -> yul::Literal {
.map(|param| param.abi_selector_name())
.collect::<Vec<String>>();

literal! {(abi_utils::func_selector(name, params))}
literal! {(abi_utils::func_selector(name, &params))}
}

fn selection(name: &str, params: &[FixedSize]) -> yul::Expression {
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/yul/runtime/functions/contracts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub fn calls(contract: Contract) -> Vec<yul::Statement> {
.unzip();
// the function selector must be added to the first 4 bytes of the calldata
let selector = {
let selector = abi_utils::func_selector(&function.name, param_names);
let selector = abi_utils::func_selector(&function.name, &param_names);
literal_expression! { (selector) }
};
// the operations used to encode the parameters
Expand Down
26 changes: 0 additions & 26 deletions compiler/src/yul/runtime/functions/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ pub fn all() -> Vec<yul::Statement> {
mcopys(),
mloadn(),
mstoren(),
revert_with_reason_string(),
scopym(),
scopys(),
set_zero(),
Expand Down Expand Up @@ -381,28 +380,3 @@ pub fn load_data_string() -> yul::Statement {
}
}
}

/// Revert with encoded reason string
pub fn revert_with_reason_string() -> yul::Statement {
function_definition! {
function revert_with_reason_string(reason) {
// Function selector for Error(string)
(let ptr := alloc_mstoren(0x08C379A0, 4))

// Write the (fixed) data offset into the next 32 bytes of memory
(pop((alloc_mstoren(0x0000000000000000000000000000000000000000000000000000000000000020, 32))))

// Read the size of the string
(let reason_size := mloadn(reason, 32))

//Copy the whole reason string (length + data) to the current segment of memory
(pop((mcopym(reason , (add(reason_size, 32))))))

// Right pad the reason bytes to a multiple of 32 bytes
(let padding := sub((ceil32(reason_size)), reason_size))
(pop((alloc(padding))))

(revert(ptr, (add(68, (add(reason_size, padding))))))
}
}
}
1 change: 1 addition & 0 deletions compiler/src/yul/runtime/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod abi;
pub mod contracts;
pub mod data;
pub mod math;
pub mod revert;
pub mod structs;

/// Returns all functions that should be available during runtime.
Expand Down
49 changes: 49 additions & 0 deletions compiler/src/yul/runtime/functions/revert.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use crate::yul::names;
use fe_analyzer::namespace::types::Struct;
use fe_analyzer::namespace::types::{AbiEncoding, FixedSize};
use fe_common::utils::abi as abi_utils;
use yultsur::*;

fn selector(name: &str, params: &[FixedSize]) -> yul::Expression {
let params = params
.iter()
.map(|param| param.abi_selector_name())
.collect::<Vec<String>>();

literal_expression! {(abi_utils::func_selector(name, &params))}
}

/// Generate a YUL function to revert with the `Error` signature and the
/// given set of params.
/// NOTE: This is currently used for `assert False, "message"` statements which are
/// encoded as `Error(msg="message")`. This will be removed in the future.
pub fn generate_revert_fn_for_assert(params: &[FixedSize]) -> yul::Statement {
generate_revert_fn("Error", params, params)
}

/// Generate a YUL function to revert with a specific struct used as error data
pub fn generate_struct_revert(val: &Struct) -> yul::Statement {
let struct_fields = val.get_field_types();
generate_revert_fn(&val.name, &[FixedSize::Struct(val.clone())], &struct_fields)
}

/// Generate a YUL function that can be used to revert with data
pub fn generate_revert_fn(
name: &str,
encoding_params: &[FixedSize],
selector_params: &[FixedSize],
) -> yul::Statement {
let abi_encode_fn = names::encode_name(encoding_params);

let function_name = names::revert_name(name, selector_params);

let selector = selector(name, &selector_params);

return function_definition! {
function [function_name](data_ptr, size) {
(let ptr := alloc_mstoren([selector], 4))
(pop(([abi_encode_fn](data_ptr))))
(revert(ptr, (add(4, size))))
}
};
}
41 changes: 40 additions & 1 deletion compiler/src/yul/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@ pub fn build(context: &Context, contract: &Node<fe::Contract>) -> Vec<yul::State
.map(|function| function.param_types())
.collect();

let assert_strings_batch = context
.assert_strings
.clone()
.into_iter()
.map(|val| vec![val.into()])
.collect::<Vec<_>>();

let revert_errors_batch = context
.revert_errors
.clone()
.into_iter()
.map(|val| val.get_field_types())
.collect::<Vec<_>>();

let structs_batch = attributes
.structs
.clone()
Expand All @@ -46,6 +60,8 @@ pub fn build(context: &Context, contract: &Node<fe::Contract>) -> Vec<yul::State
public_functions_batch,
events_batch,
contracts_batch,
assert_strings_batch,
revert_errors_batch,
structs_batch,
]
.concat();
Expand Down Expand Up @@ -99,7 +115,30 @@ pub fn build(context: &Context, contract: &Node<fe::Contract>) -> Vec<yul::State
.collect::<Vec<_>>()
.concat();

return [std, encoding, decoding, contract_calls, struct_apis].concat();
let revert_calls_from_assert = context
.assert_strings
.clone()
.into_iter()
.map(|val| functions::revert::generate_revert_fn_for_assert(&[val.into()]))
.collect::<Vec<_>>();

let revert_calls = context
.revert_errors
.clone()
.into_iter()
.map(|val| functions::revert::generate_struct_revert(&val))
.collect::<Vec<_>>();

return [
std,
encoding,
decoding,
contract_calls,
revert_calls_from_assert,
revert_calls,
struct_apis,
]
.concat();
}

panic!("missing contract attributes")
Expand Down
Loading

0 comments on commit f9aa721

Please sign in to comment.