Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix!: Make for loops a statement #2975

Merged
merged 2 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions compiler/noirc_frontend/src/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ pub enum ExpressionKind {
MemberAccess(Box<MemberAccessExpression>),
Cast(Box<CastExpression>),
Infix(Box<InfixExpression>),
For(Box<ForExpression>),
If(Box<IfExpression>),
Variable(Path),
Tuple(Vec<Expression>),
Expand Down Expand Up @@ -181,14 +180,6 @@ impl Expression {
}
}

#[derive(Debug, PartialEq, Eq, Clone)]
pub struct ForExpression {
pub identifier: Ident,
pub start_range: Expression,
pub end_range: Expression,
pub block: Expression,
}

pub type BinaryOp = Spanned<BinaryOpKind>;

#[derive(PartialEq, PartialOrd, Eq, Ord, Hash, Debug, Copy, Clone)]
Expand Down Expand Up @@ -469,7 +460,6 @@ impl Display for ExpressionKind {
MethodCall(call) => call.fmt(f),
Cast(cast) => cast.fmt(f),
Infix(infix) => infix.fmt(f),
For(for_loop) => for_loop.fmt(f),
If(if_expr) => if_expr.fmt(f),
Variable(path) => path.fmt(f),
Constructor(constructor) => constructor.fmt(f),
Expand Down Expand Up @@ -603,16 +593,6 @@ impl Display for BinaryOpKind {
}
}

impl Display for ForExpression {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"for {} in {} .. {} {}",
self.identifier, self.start_range, self.end_range, self.block
)
}
}

impl Display for IfExpression {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "if {} {}", self.condition, self.consequence)?;
Expand Down
26 changes: 23 additions & 3 deletions compiler/noirc_frontend/src/ast/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub enum Statement {
Constrain(ConstrainStatement),
Expression(Expression),
Assign(AssignStatement),
For(ForLoopStatement),
// This is an expression with a trailing semi-colon
Semi(Expression),
// This statement is the result of a recovered parse error.
Expand Down Expand Up @@ -65,13 +66,13 @@ impl Statement {
}
self
}
// A semicolon on a for loop is optional and does nothing
Statement::For(_) => self,

Statement::Expression(expr) => {
match (&expr.kind, semi, last_statement_in_block) {
// Semicolons are optional for these expressions
(ExpressionKind::Block(_), semi, _)
| (ExpressionKind::For(_), semi, _)
| (ExpressionKind::If(_), semi, _) => {
(ExpressionKind::Block(_), semi, _) | (ExpressionKind::If(_), semi, _) => {
if semi.is_some() {
Statement::Semi(expr)
} else {
Expand Down Expand Up @@ -459,13 +460,22 @@ impl LValue {
}
}

#[derive(Debug, PartialEq, Eq, Clone)]
pub struct ForLoopStatement {
pub identifier: Ident,
pub start_range: Expression,
pub end_range: Expression,
pub block: Expression,
}

impl Display for Statement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Statement::Let(let_statement) => let_statement.fmt(f),
Statement::Constrain(constrain) => constrain.fmt(f),
Statement::Expression(expression) => expression.fmt(f),
Statement::Assign(assign) => assign.fmt(f),
Statement::For(for_loop) => for_loop.fmt(f),
Statement::Semi(semi) => write!(f, "{semi};"),
Statement::Error => write!(f, "Error"),
}
Expand Down Expand Up @@ -544,3 +554,13 @@ impl Display for Pattern {
}
}
}

impl Display for ForLoopStatement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"for {} in {} .. {} {}",
self.identifier, self.start_range, self.end_range, self.block
)
}
}
85 changes: 35 additions & 50 deletions compiler/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
// XXX: Resolver does not check for unused functions
use crate::hir_def::expr::{
HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCapturedVar,
HirCastExpression, HirConstructorExpression, HirExpression, HirForExpression, HirIdent,
HirIfExpression, HirIndexExpression, HirInfixExpression, HirLambda, HirLiteral,
HirMemberAccess, HirMethodCallExpression, HirPrefixExpression,
HirCastExpression, HirConstructorExpression, HirExpression, HirIdent, HirIfExpression,
HirIndexExpression, HirInfixExpression, HirLambda, HirLiteral, HirMemberAccess,
HirMethodCallExpression, HirPrefixExpression,
};

use crate::hir_def::traits::{Trait, TraitConstraint};
Expand All @@ -26,7 +26,7 @@

use crate::graph::CrateId;
use crate::hir::def_map::{LocalModuleId, ModuleDefId, TryFromModuleDefId, MAIN_FUNCTION};
use crate::hir_def::stmt::{HirAssignStatement, HirLValue, HirPattern};
use crate::hir_def::stmt::{HirAssignStatement, HirForStatement, HirLValue, HirPattern};
use crate::node_interner::{
DefinitionId, DefinitionKind, ExprId, FuncId, NodeInterner, StmtId, StructId, TraitId,
};
Expand Down Expand Up @@ -957,6 +957,25 @@
let stmt = HirAssignStatement { lvalue: identifier, expression };
HirStatement::Assign(stmt)
}
Statement::For(for_loop) => {
let start_range = self.resolve_expression(for_loop.start_range);
let end_range = self.resolve_expression(for_loop.end_range);
let (identifier, block) = (for_loop.identifier, for_loop.block);

// TODO: For loop variables are currently mutable by default since we haven't
jfecher marked this conversation as resolved.
Show resolved Hide resolved
// yet implemented syntax for them to be optionally mutable.
let (identifier, block) = self.in_new_scope(|this| {
let decl = this.add_variable_decl(
identifier,
false,
true,
DefinitionKind::Local(None),
);
(decl, this.resolve_expression(block))
});

HirStatement::For(HirForStatement { start_range, end_range, block, identifier })
}
Statement::Error => HirStatement::Error,
}
}
Expand Down Expand Up @@ -1169,30 +1188,6 @@
lhs: self.resolve_expression(cast_expr.lhs),
r#type: self.resolve_type(cast_expr.r#type),
}),
ExpressionKind::For(for_expr) => {
let start_range = self.resolve_expression(for_expr.start_range);
let end_range = self.resolve_expression(for_expr.end_range);
let (identifier, block) = (for_expr.identifier, for_expr.block);

// TODO: For loop variables are currently mutable by default since we haven't
// yet implemented syntax for them to be optionally mutable.
let (identifier, block_id) = self.in_new_scope(|this| {
let decl = this.add_variable_decl(
identifier,
false,
true,
DefinitionKind::Local(None),
);
(decl, this.resolve_expression(block))
});

HirExpression::For(HirForExpression {
start_range,
end_range,
block: block_id,
identifier,
})
}
ExpressionKind::If(if_expr) => HirExpression::If(HirIfExpression {
condition: self.resolve_expression(if_expr.condition),
consequence: self.resolve_expression(if_expr.consequence),
Expand Down Expand Up @@ -1738,7 +1733,7 @@
let (hir_func, _, _) = resolver.resolve_function(func, id);

// Iterate over function statements and apply filtering function
parse_statement_blocks(
find_lambda_captures(
hir_func.block(&interner).statements(),
&interner,
&mut all_captures,
Expand All @@ -1747,33 +1742,23 @@
all_captures
}

fn parse_statement_blocks(
fn find_lambda_captures(
stmts: &[StmtId],
interner: &NodeInterner,
result: &mut Vec<Vec<String>>,
) {
let mut expr: HirExpression;

for stmt_id in stmts.iter() {
let hir_stmt = interner.statement(stmt_id);
match hir_stmt {
HirStatement::Expression(expr_id) => {
expr = interner.expression(&expr_id);
}
HirStatement::Let(let_stmt) => {
expr = interner.expression(&let_stmt.expression);
}
HirStatement::Assign(assign_stmt) => {
expr = interner.expression(&assign_stmt.expression);
}
HirStatement::Constrain(constr_stmt) => {
expr = interner.expression(&constr_stmt.0);
}
HirStatement::Semi(semi_expr) => {
expr = interner.expression(&semi_expr);
}
let expr_id = match hir_stmt {
HirStatement::Expression(expr_id) => expr_id,
HirStatement::Let(let_stmt) => let_stmt.expression,
HirStatement::Assign(assign_stmt) => assign_stmt.expression,
HirStatement::Constrain(constr_stmt) => constr_stmt.0,
HirStatement::Semi(semi_expr) => semi_expr,
HirStatement::For(for_loop) => for_loop.block,
HirStatement::Error => panic!("Invalid HirStatement!"),
}
};
let expr = interner.expression(&expr_id);
get_lambda_captures(expr, interner, result); // TODO: dyn filter function as parameter
}
}
Expand All @@ -1794,7 +1779,7 @@
// Check for other captures recursively within the lambda body
let hir_body_expr = interner.expression(&lambda_expr.body);
if let HirExpression::Block(block_expr) = hir_body_expr {
parse_statement_blocks(block_expr.statements(), interner, result);
find_lambda_captures(block_expr.statements(), interner, result);
}
}
}
Expand Down Expand Up @@ -2086,7 +2071,7 @@
println(f"I want to print {0}");

let new_val = 10;
println(f"randomstring{new_val}{new_val}");

Check warning on line 2074 in compiler/noirc_frontend/src/hir/resolution/resolver.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (randomstring)
}
fn println<T>(x : T) -> T {
x
Expand Down
36 changes: 1 addition & 35 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
types::Type,
},
node_interner::{DefinitionKind, ExprId, FuncId, TraitMethodId},
Shared, Signedness, TypeBinding, TypeVariableKind, UnaryOp,
Signedness, TypeBinding, TypeVariableKind, UnaryOp,
};

use super::{errors::TypeCheckError, TypeChecker};
Expand Down Expand Up @@ -194,40 +194,6 @@
let span = self.interner.expr_span(expr_id);
self.check_cast(lhs_type, cast_expr.r#type, span)
}
HirExpression::For(for_expr) => {
let start_range_type = self.check_expression(&for_expr.start_range);
let end_range_type = self.check_expression(&for_expr.end_range);

let start_span = self.interner.expr_span(&for_expr.start_range);
let end_span = self.interner.expr_span(&for_expr.end_range);

// Check that start range and end range have the same types
let range_span = start_span.merge(end_span);
self.unify(&start_range_type, &end_range_type, || TypeCheckError::TypeMismatch {
expected_typ: start_range_type.to_string(),
expr_typ: end_range_type.to_string(),
expr_span: range_span,
});

let fresh_id = self.interner.next_type_variable_id();
let type_variable = Shared::new(TypeBinding::Unbound(fresh_id));
let expected_type =
Type::TypeVariable(type_variable, TypeVariableKind::IntegerOrField);

self.unify(&start_range_type, &expected_type, || {
TypeCheckError::TypeCannotBeUsed {
typ: start_range_type.clone(),
place: "for loop",
span: range_span,
}
.add_context("The range of a loop must be known at compile-time")
});

self.interner.push_definition_type(for_expr.identifier.id, start_range_type);

self.check_expression(&for_expr.block);
Type::Unit
}
HirExpression::Block(block_expr) => {
let mut block_type = Type::Unit;

Expand Down Expand Up @@ -499,7 +465,7 @@
arguments: Vec<(Type, ExprId, Span)>,
span: Span,
) -> Type {
let (fntyp, param_len) = match method_ref {

Check warning on line 468 in compiler/noirc_frontend/src/hir/type_check/expr.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (fntyp)
HirMethodReference::FuncId(func_id) => {
if func_id == FuncId::dummy_id() {
return Type::Error;
Expand Down Expand Up @@ -528,7 +494,7 @@
});
}

let (function_type, instantiation_bindings) = fntyp.instantiate(self.interner);

Check warning on line 497 in compiler/noirc_frontend/src/hir/type_check/expr.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (fntyp)

self.interner.store_instantiation_bindings(*function_ident_id, instantiation_bindings);
self.interner.push_expr_type(function_ident_id, function_type.clone());
Expand Down Expand Up @@ -918,19 +884,19 @@
&mut self,
fn_params: &Vec<Type>,
fn_ret: &Type,
callsite_args: &Vec<(Type, ExprId, Span)>,

Check warning on line 887 in compiler/noirc_frontend/src/hir/type_check/expr.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (callsite)
span: Span,
) -> Type {
if fn_params.len() != callsite_args.len() {

Check warning on line 890 in compiler/noirc_frontend/src/hir/type_check/expr.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (callsite)
self.errors.push(TypeCheckError::ParameterCountMismatch {
expected: fn_params.len(),
found: callsite_args.len(),

Check warning on line 893 in compiler/noirc_frontend/src/hir/type_check/expr.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (callsite)
span,
});
return Type::Error;
}

for (param, (arg, _, arg_span)) in fn_params.iter().zip(callsite_args) {

Check warning on line 899 in compiler/noirc_frontend/src/hir/type_check/expr.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (callsite)
self.unify(arg, param, || TypeCheckError::TypeMismatch {
expected_typ: param.to_string(),
expr_typ: arg.to_string(),
Expand Down
4 changes: 2 additions & 2 deletions compiler/noirc_frontend/src/hir/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,11 @@ mod test {
fn basic_for_expr() {
let src = r#"
fn main(_x : Field) {
let _j = for _i in 0..10 {
for _i in 0..10 {
for _k in 0..100 {

}
};
}
}

"#;
Expand Down
38 changes: 37 additions & 1 deletion compiler/noirc_frontend/src/hir/type_check/stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ use noirc_errors::{Location, Span};

use crate::hir_def::expr::{HirExpression, HirIdent, HirLiteral};
use crate::hir_def::stmt::{
HirAssignStatement, HirConstrainStatement, HirLValue, HirLetStatement, HirPattern, HirStatement,
HirAssignStatement, HirConstrainStatement, HirForStatement, HirLValue, HirLetStatement,
HirPattern, HirStatement,
};
use crate::hir_def::types::Type;
use crate::node_interner::{DefinitionId, ExprId, StmtId};
use crate::{Shared, TypeBinding, TypeVariableKind};

use super::errors::{Source, TypeCheckError};
use super::TypeChecker;
Expand Down Expand Up @@ -48,11 +50,45 @@ impl<'interner> TypeChecker<'interner> {
HirStatement::Let(let_stmt) => self.check_let_stmt(let_stmt),
HirStatement::Constrain(constrain_stmt) => self.check_constrain_stmt(constrain_stmt),
HirStatement::Assign(assign_stmt) => self.check_assign_stmt(assign_stmt, stmt_id),
HirStatement::For(for_loop) => self.check_for_loop(for_loop),
HirStatement::Error => (),
}
Type::Unit
}

fn check_for_loop(&mut self, for_loop: HirForStatement) {
let start_range_type = self.check_expression(&for_loop.start_range);
let end_range_type = self.check_expression(&for_loop.end_range);

let start_span = self.interner.expr_span(&for_loop.start_range);
let end_span = self.interner.expr_span(&for_loop.end_range);

// Check that start range and end range have the same types
let range_span = start_span.merge(end_span);
self.unify(&start_range_type, &end_range_type, || TypeCheckError::TypeMismatch {
expected_typ: start_range_type.to_string(),
expr_typ: end_range_type.to_string(),
expr_span: range_span,
});

let fresh_id = self.interner.next_type_variable_id();
let type_variable = Shared::new(TypeBinding::Unbound(fresh_id));
let expected_type = Type::TypeVariable(type_variable, TypeVariableKind::IntegerOrField);

self.unify(&start_range_type, &expected_type, || {
TypeCheckError::TypeCannotBeUsed {
typ: start_range_type.clone(),
place: "for loop",
span: range_span,
}
.add_context("The range of a loop must be known at compile-time")
});

self.interner.push_definition_type(for_loop.identifier.id, start_range_type);

self.check_expression(&for_loop.block);
}

/// Associate a given HirPattern with the given Type, and remember
/// this association in the NodeInterner.
pub(crate) fn bind_pattern(&mut self, pattern: &HirPattern, typ: Type) {
Expand Down
9 changes: 0 additions & 9 deletions compiler/noirc_frontend/src/hir_def/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
Call(HirCallExpression),
MethodCall(HirMethodCallExpression),
Cast(HirCastExpression),
For(HirForExpression),
If(HirIfExpression),
Tuple(Vec<ExprId>),
Lambda(HirLambda),
Expand All @@ -48,14 +47,6 @@
pub id: DefinitionId,
}

#[derive(Debug, Clone)]
pub struct HirForExpression {
pub identifier: HirIdent,
pub start_range: ExprId,
pub end_range: ExprId,
pub block: ExprId,
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct HirBinaryOp {
pub kind: BinaryOpKind,
Expand Down Expand Up @@ -108,7 +99,7 @@
pub rhs: ExprId,
}

/// This is always a struct field access `mystruct.field`

Check warning on line 102 in compiler/noirc_frontend/src/hir_def/expr.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (mystruct)
/// and never a method call. The later is represented by HirMethodCallExpression.
#[derive(Debug, Clone)]
pub struct HirMemberAccess {
Expand Down Expand Up @@ -160,7 +151,7 @@

/// Or a method can come from a Trait impl block, in which case
/// the actual function called will depend on the instantiated type,
/// which can be only known during monomorphizaiton.

Check warning on line 154 in compiler/noirc_frontend/src/hir_def/expr.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (monomorphizaiton)
TraitMethodId(Type, TraitMethodId),
}

Expand Down
Loading