diff --git a/compiler/src/yul/mappers/functions.rs b/compiler/src/yul/mappers/functions.rs index f34ea63b80..56da43bb3e 100644 --- a/compiler/src/yul/mappers/functions.rs +++ b/compiler/src/yul/mappers/functions.rs @@ -188,10 +188,15 @@ fn emit(context: &Context, stmt: &Node) -> yul::Statement { } fn assert(context: &Context, stmt: &Node) -> yul::Statement { - if let fe::FuncStmt::Assert { test, msg: _ } = &stmt.kind { + if let fe::FuncStmt::Assert { test, msg } = &stmt.kind { let test = expressions::expr(context, test); - - return statement! { if (iszero([test])) { (revert(0, 0)) } }; + return match msg { + Some(val) => { + let msg = expressions::expr(context, val); + statement! { if (iszero([test])) { (revert_with_reason_string([msg])) } } + }, + None => statement! { if (iszero([test])) { (revert(0, 0)) } } + } } unreachable!() diff --git a/compiler/tests/features.rs b/compiler/tests/features.rs index 7bc129eb7f..0820135280 100644 --- a/compiler/tests/features.rs +++ b/compiler/tests/features.rs @@ -75,17 +75,25 @@ fn test_assert() { let exit1 = harness.capture_call(&mut executor, "bar", &[uint_token(4)]); - assert!(matches!( - exit1, - evm::Capture::Exit((evm::ExitReason::Revert(_), _)) - )); + match exit1 { + evm::Capture::Exit((evm::ExitReason::Revert(_), output)) => assert_eq!(output.len(), 0), + _ => panic!("Did not revert correctly") + } let exit2 = harness.capture_call(&mut executor, "bar", &[uint_token(42)]); assert!(matches!( exit2, evm::Capture::Exit((evm::ExitReason::Succeed(_), _)) - )) + )); + + let exit3 = harness.capture_call(&mut executor, "baz", &[uint_token(4)]); + + match exit3 { + evm::Capture::Exit((evm::ExitReason::Revert(_), output)) => assert_eq!(output, encode_error_reason("Must be greater than five")), + _ => panic!("Did not revert correctly") + } + }) } diff --git a/compiler/tests/fixtures/features/assert.fe b/compiler/tests/fixtures/features/assert.fe index fc0c768778..a83e13bbcc 100644 --- a/compiler/tests/fixtures/features/assert.fe +++ b/compiler/tests/fixtures/features/assert.fe @@ -1,3 +1,6 @@ contract Foo: pub def bar(baz: u256): - assert baz > 5 \ No newline at end of file + assert baz > 5 + + pub def baz(baz: u256): + assert baz > 5, "Must be greater than five" \ No newline at end of file