Skip to content

Commit

Permalink
feat: represent assertions more similarly to function calls (#6103)
Browse files Browse the repository at this point in the history
# Description

## Problem

Resolves #6102

## Summary

See the relevant issue, but in addition to LSP working better, and
getting a better error message when giving an incorrect number of
arguments to `assert` and `assert_eq`, I think the code ends up being
slightly simpler.

## Additional Context

## Documentation

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
asterite authored Sep 20, 2024
1 parent cd81f85 commit 3ecd0e2
Show file tree
Hide file tree
Showing 13 changed files with 224 additions and 176 deletions.
44 changes: 25 additions & 19 deletions aztec_macros/src/transforms/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,14 +316,17 @@ fn create_static_check(fname: &str, is_private: bool) -> Statement {
.iter()
.fold(variable("context"), |acc, member| member_access(acc, member))
};
make_statement(StatementKind::Constrain(ConstrainStatement(
make_eq(is_static_call_expr, expression(ExpressionKind::Literal(Literal::Bool(true)))),
Some(expression(ExpressionKind::Literal(Literal::Str(format!(
"Function {} can only be called statically",
fname
))))),
ConstrainKind::Assert,
)))
make_statement(StatementKind::Constrain(ConstrainStatement {
kind: ConstrainKind::Assert,
arguments: vec![
make_eq(is_static_call_expr, expression(ExpressionKind::Literal(Literal::Bool(true)))),
expression(ExpressionKind::Literal(Literal::Str(format!(
"Function {} can only be called statically",
fname
)))),
],
span: Default::default(),
}))
}

/// Creates a check for internal functions ensuring that the caller is self.
Expand All @@ -332,17 +335,20 @@ fn create_static_check(fname: &str, is_private: bool) -> Statement {
/// assert(context.msg_sender() == context.this_address(), "Function can only be called internally");
/// ```
fn create_internal_check(fname: &str) -> Statement {
make_statement(StatementKind::Constrain(ConstrainStatement(
make_eq(
method_call(variable("context"), "msg_sender", vec![]),
method_call(variable("context"), "this_address", vec![]),
),
Some(expression(ExpressionKind::Literal(Literal::Str(format!(
"Function {} can only be called internally",
fname
))))),
ConstrainKind::Assert,
)))
make_statement(StatementKind::Constrain(ConstrainStatement {
kind: ConstrainKind::Assert,
arguments: vec![
make_eq(
method_call(variable("context"), "msg_sender", vec![]),
method_call(variable("context"), "this_address", vec![]),
),
expression(ExpressionKind::Literal(Literal::Str(format!(
"Function {} can only be called internally",
fname
)))),
],
span: Default::default(),
}))
}

/// Creates a call to assert_initialization_matches_address_preimage to be inserted
Expand Down
5 changes: 1 addition & 4 deletions aztec_macros/src/utils/parse_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,7 @@ fn empty_statement(statement: &mut Statement) {
}

fn empty_constrain_statement(constrain_statement: &mut ConstrainStatement) {
empty_expression(&mut constrain_statement.0);
if let Some(expression) = &mut constrain_statement.1 {
empty_expression(expression);
}
empty_expressions(&mut constrain_statement.arguments);
}

fn empty_expressions(expressions: &mut [Expression]) {
Expand Down
47 changes: 40 additions & 7 deletions compiler/noirc_frontend/src/ast/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,27 @@ pub enum LValue {
}

#[derive(Debug, PartialEq, Eq, Clone)]
pub struct ConstrainStatement(pub Expression, pub Option<Expression>, pub ConstrainKind);
pub struct ConstrainStatement {
pub kind: ConstrainKind,
pub arguments: Vec<Expression>,
pub span: Span,
}

impl Display for ConstrainStatement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.kind {
ConstrainKind::Assert | ConstrainKind::AssertEq => write!(
f,
"{}({})",
self.kind,
vecmap(&self.arguments, |arg| arg.to_string()).join(", ")
),
ConstrainKind::Constrain => {
write!(f, "constrain {}", &self.arguments[0])
}
}
}
}

#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum ConstrainKind {
Expand All @@ -571,6 +591,25 @@ pub enum ConstrainKind {
Constrain,
}

impl ConstrainKind {
pub fn required_arguments_count(&self) -> usize {
match self {
ConstrainKind::Assert | ConstrainKind::Constrain => 1,
ConstrainKind::AssertEq => 2,
}
}
}

impl Display for ConstrainKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConstrainKind::Assert => write!(f, "assert"),
ConstrainKind::AssertEq => write!(f, "assert_eq"),
ConstrainKind::Constrain => write!(f, "constrain"),
}
}
}

#[derive(Debug, PartialEq, Eq, Clone)]
pub enum Pattern {
Identifier(Ident),
Expand Down Expand Up @@ -885,12 +924,6 @@ impl Display for LetStatement {
}
}

impl Display for ConstrainStatement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "constrain {}", self.0)
}
}

impl Display for AssignStatement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{} = {}", self.lvalue, self.expression)
Expand Down
6 changes: 1 addition & 5 deletions compiler/noirc_frontend/src/ast/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1117,11 +1117,7 @@ impl ConstrainStatement {
}

pub fn accept_children(&self, visitor: &mut impl Visitor) {
self.0.accept(visitor);

if let Some(exp) = &self.1 {
exp.accept(visitor);
}
visit_expressions(&self.arguments, visitor);
}
}

Expand Down
54 changes: 48 additions & 6 deletions compiler/noirc_frontend/src/elaborator/statements.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use noirc_errors::{Location, Span};
use noirc_errors::{Location, Span, Spanned};

use crate::{
ast::{AssignStatement, ConstrainStatement, LValue},
ast::{
AssignStatement, BinaryOpKind, ConstrainKind, ConstrainStatement, Expression,
ExpressionKind, InfixExpression, LValue,
},
hir::{
resolution::errors::ResolverError,
type_check::{Source, TypeCheckError},
Expand Down Expand Up @@ -110,12 +113,51 @@ impl<'context> Elaborator<'context> {
(HirStatement::Let(let_), Type::Unit)
}

pub(super) fn elaborate_constrain(&mut self, stmt: ConstrainStatement) -> (HirStatement, Type) {
let expr_span = stmt.0.span;
let (expr_id, expr_type) = self.elaborate_expression(stmt.0);
pub(super) fn elaborate_constrain(
&mut self,
mut stmt: ConstrainStatement,
) -> (HirStatement, Type) {
let span = stmt.span;
let min_args_count = stmt.kind.required_arguments_count();
let max_args_count = min_args_count + 1;
let actual_args_count = stmt.arguments.len();

let (message, expr) = if !(min_args_count..=max_args_count).contains(&actual_args_count) {
self.push_err(TypeCheckError::AssertionParameterCountMismatch {
kind: stmt.kind,
found: actual_args_count,
span,
});

// Given that we already produced an error, let's make this an `assert(true)` so
// we don't get further errors.
let message = None;
let kind = ExpressionKind::Literal(crate::ast::Literal::Bool(true));
let expr = Expression { kind, span };
(message, expr)
} else {
let message =
(actual_args_count != min_args_count).then(|| stmt.arguments.pop().unwrap());
let expr = match stmt.kind {
ConstrainKind::Assert | ConstrainKind::Constrain => stmt.arguments.pop().unwrap(),
ConstrainKind::AssertEq => {
let rhs = stmt.arguments.pop().unwrap();
let lhs = stmt.arguments.pop().unwrap();
let span = Span::from(lhs.span.start()..rhs.span.end());
let operator = Spanned::from(span, BinaryOpKind::Equal);
let kind =
ExpressionKind::Infix(Box::new(InfixExpression { lhs, operator, rhs }));
Expression { kind, span }
}
};
(message, expr)
};

let expr_span = expr.span;
let (expr_id, expr_type) = self.elaborate_expression(expr);

// Must type check the assertion message expression so that we instantiate bindings
let msg = stmt.1.map(|assert_msg_expr| self.elaborate_expression(assert_msg_expr).0);
let msg = message.map(|assert_msg_expr| self.elaborate_expression(assert_msg_expr).0);

self.unify(&expr_type, &Type::Bool, || TypeCheckError::TypeMismatch {
expr_typ: expr_type.to_string(),
Expand Down
11 changes: 6 additions & 5 deletions compiler/noirc_frontend/src/hir/comptime/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -683,11 +683,12 @@ fn remove_interned_in_statement_kind(
r#type: remove_interned_in_unresolved_type(interner, let_statement.r#type),
..let_statement
}),
StatementKind::Constrain(constrain) => StatementKind::Constrain(ConstrainStatement(
remove_interned_in_expression(interner, constrain.0),
constrain.1.map(|expr| remove_interned_in_expression(interner, expr)),
constrain.2,
)),
StatementKind::Constrain(constrain) => StatementKind::Constrain(ConstrainStatement {
arguments: vecmap(constrain.arguments, |expr| {
remove_interned_in_expression(interner, expr)
}),
..constrain
}),
StatementKind::Expression(expr) => {
StatementKind::Expression(remove_interned_in_expression(interner, expr))
}
Expand Down
11 changes: 9 additions & 2 deletions compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,17 @@ impl HirStatement {
}
HirStatement::Constrain(constrain) => {
let expr = constrain.0.to_display_ast(interner);
let message = constrain.2.map(|message| message.to_display_ast(interner));
let mut arguments = vec![expr];
if let Some(message) = constrain.2 {
arguments.push(message.to_display_ast(interner));
}

// TODO: Find difference in usage between Assert & AssertEq
StatementKind::Constrain(ConstrainStatement(expr, message, ConstrainKind::Assert))
StatementKind::Constrain(ConstrainStatement {
kind: ConstrainKind::Assert,
arguments,
span,
})
}
HirStatement::Assign(assign) => StatementKind::Assign(AssignStatement {
lvalue: assign.lvalue.to_display_ast(interner),
Expand Down
39 changes: 28 additions & 11 deletions compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1240,9 +1240,17 @@ fn expr_as_assert(
location: Location,
) -> IResult<Value> {
expr_as(interner, arguments, return_type.clone(), location, |expr| {
if let ExprValue::Statement(StatementKind::Constrain(constrain)) = expr {
if constrain.2 == ConstrainKind::Assert {
let predicate = Value::expression(constrain.0.kind);
if let ExprValue::Statement(StatementKind::Constrain(mut constrain)) = expr {
if constrain.kind == ConstrainKind::Assert
&& !constrain.arguments.is_empty()
&& constrain.arguments.len() <= 2
{
let (message, predicate) = if constrain.arguments.len() == 1 {
(None, constrain.arguments.pop().unwrap())
} else {
(Some(constrain.arguments.pop().unwrap()), constrain.arguments.pop().unwrap())
};
let predicate = Value::expression(predicate.kind);

let option_type = extract_option_generic_type(return_type);
let Type::Tuple(mut tuple_types) = option_type else {
Expand All @@ -1251,7 +1259,7 @@ fn expr_as_assert(
assert_eq!(tuple_types.len(), 2);

let option_type = tuple_types.pop().unwrap();
let message = constrain.1.map(|message| Value::expression(message.kind));
let message = message.map(|msg| Value::expression(msg.kind));
let message = option(option_type, message).ok()?;

Some(Value::Tuple(vec![predicate, message]))
Expand All @@ -1272,14 +1280,23 @@ fn expr_as_assert_eq(
location: Location,
) -> IResult<Value> {
expr_as(interner, arguments, return_type.clone(), location, |expr| {
if let ExprValue::Statement(StatementKind::Constrain(constrain)) = expr {
if constrain.2 == ConstrainKind::AssertEq {
let ExpressionKind::Infix(infix) = constrain.0.kind else {
panic!("Expected AssertEq constrain statement to have an infix expression");
if let ExprValue::Statement(StatementKind::Constrain(mut constrain)) = expr {
if constrain.kind == ConstrainKind::AssertEq
&& constrain.arguments.len() >= 2
&& constrain.arguments.len() <= 3
{
let (message, rhs, lhs) = if constrain.arguments.len() == 2 {
(None, constrain.arguments.pop().unwrap(), constrain.arguments.pop().unwrap())
} else {
(
Some(constrain.arguments.pop().unwrap()),
constrain.arguments.pop().unwrap(),
constrain.arguments.pop().unwrap(),
)
};

let lhs = Value::expression(infix.lhs.kind);
let rhs = Value::expression(infix.rhs.kind);
let lhs = Value::expression(lhs.kind);
let rhs = Value::expression(rhs.kind);

let option_type = extract_option_generic_type(return_type);
let Type::Tuple(mut tuple_types) = option_type else {
Expand All @@ -1288,7 +1305,7 @@ fn expr_as_assert_eq(
assert_eq!(tuple_types.len(), 3);

let option_type = tuple_types.pop().unwrap();
let message = constrain.1.map(|message| Value::expression(message.kind));
let message = message.map(|message| Value::expression(message.kind));
let message = option(option_type, message).ok()?;

Some(Value::Tuple(vec![lhs, rhs, message]))
Expand Down
10 changes: 10 additions & 0 deletions compiler/noirc_frontend/src/hir/type_check/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use noirc_errors::CustomDiagnostic as Diagnostic;
use noirc_errors::Span;
use thiserror::Error;

use crate::ast::ConstrainKind;
use crate::ast::{BinaryOpKind, FunctionReturnType, IntegerBitSize, Signedness};
use crate::hir::resolution::errors::ResolverError;
use crate::hir_def::expr::HirBinaryOp;
Expand Down Expand Up @@ -59,6 +60,8 @@ pub enum TypeCheckError {
AccessUnknownMember { lhs_type: Type, field_name: String, span: Span },
#[error("Function expects {expected} parameters but {found} were given")]
ParameterCountMismatch { expected: usize, found: usize, span: Span },
#[error("{} expects {} or {} parameters but {found} were given", kind, kind.required_arguments_count(), kind.required_arguments_count() + 1)]
AssertionParameterCountMismatch { kind: ConstrainKind, found: usize, span: Span },
#[error("{item} expects {expected} generics but {found} were given")]
GenericCountMismatch { item: String, expected: usize, found: usize, span: Span },
#[error("{item} has incompatible `unconstrained`")]
Expand Down Expand Up @@ -260,6 +263,13 @@ impl<'a> From<&'a TypeCheckError> for Diagnostic {
let msg = format!("Function expects {expected} parameter{empty_or_s} but {found} {was_or_were} given");
Diagnostic::simple_error(msg, String::new(), *span)
}
TypeCheckError::AssertionParameterCountMismatch { kind, found, span } => {
let was_or_were = if *found == 1 { "was" } else { "were" };
let min = kind.required_arguments_count();
let max = min + 1;
let msg = format!("{kind} expects {min} or {max} parameters but {found} {was_or_were} given");
Diagnostic::simple_error(msg, String::new(), *span)
}
TypeCheckError::GenericCountMismatch { item, expected, found, span } => {
let empty_or_s = if *expected == 1 { "" } else { "s" };
let was_or_were = if *found == 1 { "was" } else { "were" };
Expand Down
17 changes: 6 additions & 11 deletions compiler/noirc_frontend/src/parser/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,6 @@ where
choice((
assertion::constrain(expr_parser.clone()),
assertion::assertion(expr_parser.clone()),
assertion::assertion_eq(expr_parser.clone()),
declaration(expr_parser.clone()),
assignment(expr_parser.clone()),
if_statement(expr_no_constructors.clone(), statement.clone()),
Expand Down Expand Up @@ -1629,17 +1628,13 @@ mod test {
Case { source: "let", expect: "let $error = Error", errors: 3 },
Case { source: "foo = one two three", expect: "foo = one", errors: 1 },
Case { source: "constrain", expect: "constrain Error", errors: 2 },
Case { source: "assert", expect: "constrain Error", errors: 1 },
Case { source: "assert", expect: "assert()", errors: 1 },
Case { source: "constrain x ==", expect: "constrain (x == Error)", errors: 2 },
Case { source: "assert(x ==)", expect: "constrain (x == Error)", errors: 1 },
Case { source: "assert(x == x, x)", expect: "constrain (x == x)", errors: 0 },
Case { source: "assert_eq(x,)", expect: "constrain (Error == Error)", errors: 1 },
Case {
source: "assert_eq(x, x, x, x)",
expect: "constrain (Error == Error)",
errors: 1,
},
Case { source: "assert_eq(x, x, x)", expect: "constrain (x == x)", errors: 0 },
Case { source: "assert(x ==)", expect: "assert((x == Error))", errors: 1 },
Case { source: "assert(x == x, x)", expect: "assert((x == x), x)", errors: 0 },
Case { source: "assert_eq(x,)", expect: "assert_eq(x)", errors: 0 },
Case { source: "assert_eq(x, x, x, x)", expect: "assert_eq(x, x, x, x)", errors: 0 },
Case { source: "assert_eq(x, x, x)", expect: "assert_eq(x, x, x)", errors: 0 },
];

check_cases_with_errors(&cases[..], fresh_statement());
Expand Down
Loading

0 comments on commit 3ecd0e2

Please sign in to comment.