diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs index 3b35b49ead3..51837debb24 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs @@ -19,9 +19,9 @@ use rustc_hash::FxHashMap as HashMap; use crate::{ ast::{ - ArrayLiteral, BlockExpression, ConstrainKind, Expression, ExpressionKind, FunctionKind, - FunctionReturnType, IntegerBitSize, LValue, Literal, Pattern, Statement, StatementKind, - UnaryOp, UnresolvedType, UnresolvedTypeData, Visibility, + ArrayLiteral, BlockExpression, ConstrainKind, Expression, ExpressionKind, ForRange, + FunctionKind, FunctionReturnType, IntegerBitSize, LValue, Literal, Pattern, Statement, + StatementKind, UnaryOp, UnresolvedType, UnresolvedTypeData, Visibility, }, hir::{ comptime::{ @@ -75,6 +75,8 @@ impl<'local, 'context> Interpreter<'local, 'context> { "expr_as_constructor" => { expr_as_constructor(interner, arguments, return_type, location) } + "expr_as_for" => expr_as_for(interner, arguments, return_type, location), + "expr_as_for_range" => expr_as_for_range(interner, arguments, return_type, location), "expr_as_function_call" => { expr_as_function_call(interner, arguments, return_type, location) } @@ -1418,7 +1420,7 @@ fn expr_as_comptime( // fn as_constructor(self) -> Option<(Quoted, [(Quoted, Expr)])> fn expr_as_constructor( - interner: &mut NodeInterner, + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, @@ -1448,6 +1450,55 @@ fn expr_as_constructor( option(return_type, option_value) } +// fn as_for(self) -> Option<(Quoted, Expr, Expr)> +fn expr_as_for( + interner: &NodeInterner, + arguments: Vec<(Value, Location)>, + return_type: Type, + location: Location, +) -> IResult { + expr_as(interner, arguments, return_type, location, |expr| { + if let ExprValue::Statement(StatementKind::For(for_statement)) = expr { + if let ForRange::Array(array) = for_statement.range { + let identifier = + Value::Quoted(Rc::new(vec![Token::Ident(for_statement.identifier.0.contents)])); + let array = Value::expression(array.kind); + let body = Value::expression(for_statement.block.kind); + Some(Value::Tuple(vec![identifier, array, body])) + } else { + None + } + } else { + None + } + }) +} + +// fn as_for_range(self) -> Option<(Quoted, Expr, Expr, Expr)> +fn expr_as_for_range( + interner: &NodeInterner, + arguments: Vec<(Value, Location)>, + return_type: Type, + location: Location, +) -> IResult { + expr_as(interner, arguments, return_type, location, |expr| { + if let ExprValue::Statement(StatementKind::For(for_statement)) = expr { + if let ForRange::Range(from, to) = for_statement.range { + let identifier = + Value::Quoted(Rc::new(vec![Token::Ident(for_statement.identifier.0.contents)])); + let from = Value::expression(from.kind); + let to = Value::expression(to.kind); + let body = Value::expression(for_statement.block.kind); + Some(Value::Tuple(vec![identifier, from, to, body])) + } else { + None + } + } else { + None + } + }) +} + // fn as_function_call(self) -> Option<(Expr, [Expr])> fn expr_as_function_call( interner: &NodeInterner, diff --git a/docs/docs/noir/standard_library/meta/expr.md b/docs/docs/noir/standard_library/meta/expr.md index c56568f5379..ddbbcd7cdde 100644 --- a/docs/docs/noir/standard_library/meta/expr.md +++ b/docs/docs/noir/standard_library/meta/expr.md @@ -73,6 +73,20 @@ return each statement in the block. If this expression is a constructor `Type { field1: expr1, ..., fieldN: exprN }`, return the type and the fields. +### as_for + +#include_code as_for noir_stdlib/src/meta/expr.nr rust + +If this expression is a for statement over a single expression, return the identifier, +the expression and the for loop body. + +### as_for_range + +#include_code as_for noir_stdlib/src/meta/expr.nr rust + +If this expression is a for statement over a range, return the identifier, +the range start, the range end and the for loop body. + ### as_function_call #include_code as_function_call noir_stdlib/src/meta/expr.nr rust diff --git a/noir_stdlib/src/meta/expr.nr b/noir_stdlib/src/meta/expr.nr index 5c6a6f2236e..caf3fa172c4 100644 --- a/noir_stdlib/src/meta/expr.nr +++ b/noir_stdlib/src/meta/expr.nr @@ -72,6 +72,20 @@ impl Expr { comptime fn as_constructor(self) -> Option<(UnresolvedType, [(Quoted, Expr)])> {} // docs:end:as_constructor + /// If this expression is a for statement over a single expression, return the identifier, + /// the expression and the for loop body. + #[builtin(expr_as_for)] + // docs:start:as_for + comptime fn as_for(self) -> Option<(Quoted, Expr, Expr)> {} + // docs:end:as_for + + /// If this expression is a for statement over a range, return the identifier, + /// the range start, the range end and the for loop body. + #[builtin(expr_as_for_range)] + // docs:start:as_for_range + comptime fn as_for_range(self) -> Option<(Quoted, Expr, Expr, Expr)> {} + // docs:end:as_for_range + /// If this expression is a function call `foo(arg1, ..., argN)`, return /// the function and a slice of each argument. #[builtin(expr_as_function_call)] @@ -218,6 +232,8 @@ impl Expr { let result = result.or_else(|| modify_constructor(self, f)); let result = result.or_else(|| modify_if(self, f)); let result = result.or_else(|| modify_index(self, f)); + let result = result.or_else(|| modify_for(self, f)); + let result = result.or_else(|| modify_for_range(self, f)); let result = result.or_else(|| modify_let(self, f)); let result = result.or_else(|| modify_function_call(self, f)); let result = result.or_else(|| modify_member_access(self, f)); @@ -388,6 +404,29 @@ comptime fn modify_index(expr: Expr, f: fn[Env](Expr) -> Option) -> O ) } +comptime fn modify_for(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_for().map( + |expr: (Quoted, Expr, Expr)| { + let (identifier, array, body) = expr; + let array = array.modify(f); + let body = body.modify(f); + new_for(identifier, array, body) + } + ) +} + +comptime fn modify_for_range(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_for_range().map( + |expr: (Quoted, Expr, Expr, Expr)| { + let (identifier, from, to, body) = expr; + let from = from.modify(f); + let to = to.modify(f); + let body = body.modify(f); + new_for_range(identifier, from, to, body) + } + ) +} + comptime fn modify_let(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { expr.as_let().map( |expr: (Expr, Option, Expr)| { @@ -548,6 +587,14 @@ comptime fn new_if(condition: Expr, consequence: Expr, alternative: Option } } +comptime fn new_for(identifier: Quoted, array: Expr, body: Expr) -> Expr { + quote { for $identifier in $array { $body } }.as_expr().unwrap() +} + +comptime fn new_for_range(identifier: Quoted, from: Expr, to: Expr, body: Expr) -> Expr { + quote { for $identifier in $from .. $to { $body } }.as_expr().unwrap() +} + comptime fn new_index(object: Expr, index: Expr) -> Expr { quote { $object[$index] }.as_expr().unwrap() } diff --git a/test_programs/noir_test_success/comptime_expr/src/main.nr b/test_programs/noir_test_success/comptime_expr/src/main.nr index 9eb4dc8a694..50b10c45e59 100644 --- a/test_programs/noir_test_success/comptime_expr/src/main.nr +++ b/test_programs/noir_test_success/comptime_expr/src/main.nr @@ -670,6 +670,58 @@ mod tests { } } + #[test] + fn test_expr_as_for_statement() { + comptime + { + let expr = quote { for x in 2 { 3 } }.as_expr().unwrap(); + let (index, array, body) = expr.as_for().unwrap(); + assert_eq(index, quote { x }); + assert_eq(array.as_integer().unwrap(), (2, false)); + assert_eq(body.as_block().unwrap()[0].as_integer().unwrap(), (3, false)); + } + } + + #[test] + fn test_expr_modify_for_statement() { + comptime + { + let expr = quote { for x in 2 { 3 } }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let (index, array, body) = expr.as_for().unwrap(); + assert_eq(index, quote { x }); + assert_eq(array.as_integer().unwrap(), (4, false)); + assert_eq(body.as_block().unwrap()[0].as_block().unwrap()[0].as_integer().unwrap(), (6, false)); + } + } + + #[test] + fn test_expr_as_for_range_statement() { + comptime + { + let expr = quote { for x in 2..3 { 4 } }.as_expr().unwrap(); + let (index, from, to, body) = expr.as_for_range().unwrap(); + assert_eq(index, quote { x }); + assert_eq(from.as_integer().unwrap(), (2, false)); + assert_eq(to.as_integer().unwrap(), (3, false)); + assert_eq(body.as_block().unwrap()[0].as_integer().unwrap(), (4, false)); + } + } + + #[test] + fn test_expr_modify_for_range_statement() { + comptime + { + let expr = quote { for x in 2..3 { 4 } }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let (index, from, to, body) = expr.as_for_range().unwrap(); + assert_eq(index, quote { x }); + assert_eq(from.as_integer().unwrap(), (4, false)); + assert_eq(to.as_integer().unwrap(), (6, false)); + assert_eq(body.as_block().unwrap()[0].as_block().unwrap()[0].as_integer().unwrap(), (8, false)); + } + } + #[test] fn test_automatically_unwraps_parenthesized_expression() { comptime