Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix!: Make for loops a statement #2975

Merged
merged 2 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions compiler/noirc_frontend/src/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ pub enum ExpressionKind {
MemberAccess(Box<MemberAccessExpression>),
Cast(Box<CastExpression>),
Infix(Box<InfixExpression>),
For(Box<ForExpression>),
If(Box<IfExpression>),
Variable(Path),
Tuple(Vec<Expression>),
Expand Down Expand Up @@ -181,14 +180,6 @@ impl Expression {
}
}

#[derive(Debug, PartialEq, Eq, Clone)]
pub struct ForExpression {
pub identifier: Ident,
pub start_range: Expression,
pub end_range: Expression,
pub block: Expression,
}

pub type BinaryOp = Spanned<BinaryOpKind>;

#[derive(PartialEq, PartialOrd, Eq, Ord, Hash, Debug, Copy, Clone)]
Expand Down Expand Up @@ -469,7 +460,6 @@ impl Display for ExpressionKind {
MethodCall(call) => call.fmt(f),
Cast(cast) => cast.fmt(f),
Infix(infix) => infix.fmt(f),
For(for_loop) => for_loop.fmt(f),
If(if_expr) => if_expr.fmt(f),
Variable(path) => path.fmt(f),
Constructor(constructor) => constructor.fmt(f),
Expand Down Expand Up @@ -603,16 +593,6 @@ impl Display for BinaryOpKind {
}
}

impl Display for ForExpression {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"for {} in {} .. {} {}",
self.identifier, self.start_range, self.end_range, self.block
)
}
}

impl Display for IfExpression {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "if {} {}", self.condition, self.consequence)?;
Expand Down
26 changes: 23 additions & 3 deletions compiler/noirc_frontend/src/ast/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub enum Statement {
Constrain(ConstrainStatement),
Expression(Expression),
Assign(AssignStatement),
For(ForLoopStatement),
// This is an expression with a trailing semi-colon
Semi(Expression),
// This statement is the result of a recovered parse error.
Expand Down Expand Up @@ -65,13 +66,13 @@ impl Statement {
}
self
}
// A semicolon on a for loop is optional and does nothing
Statement::For(_) => self,

Statement::Expression(expr) => {
match (&expr.kind, semi, last_statement_in_block) {
// Semicolons are optional for these expressions
(ExpressionKind::Block(_), semi, _)
| (ExpressionKind::For(_), semi, _)
| (ExpressionKind::If(_), semi, _) => {
(ExpressionKind::Block(_), semi, _) | (ExpressionKind::If(_), semi, _) => {
if semi.is_some() {
Statement::Semi(expr)
} else {
Expand Down Expand Up @@ -459,13 +460,22 @@ impl LValue {
}
}

#[derive(Debug, PartialEq, Eq, Clone)]
pub struct ForLoopStatement {
pub identifier: Ident,
pub start_range: Expression,
pub end_range: Expression,
pub block: Expression,
}

impl Display for Statement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Statement::Let(let_statement) => let_statement.fmt(f),
Statement::Constrain(constrain) => constrain.fmt(f),
Statement::Expression(expression) => expression.fmt(f),
Statement::Assign(assign) => assign.fmt(f),
Statement::For(for_loop) => for_loop.fmt(f),
Statement::Semi(semi) => write!(f, "{semi};"),
Statement::Error => write!(f, "Error"),
}
Expand Down Expand Up @@ -544,3 +554,13 @@ impl Display for Pattern {
}
}
}

impl Display for ForLoopStatement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"for {} in {} .. {} {}",
self.identifier, self.start_range, self.end_range, self.block
)
}
}
85 changes: 35 additions & 50 deletions compiler/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
// XXX: Resolver does not check for unused functions
use crate::hir_def::expr::{
HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCapturedVar,
HirCastExpression, HirConstructorExpression, HirExpression, HirForExpression, HirIdent,
HirIfExpression, HirIndexExpression, HirInfixExpression, HirLambda, HirLiteral,
HirMemberAccess, HirMethodCallExpression, HirPrefixExpression,
HirCastExpression, HirConstructorExpression, HirExpression, HirIdent, HirIfExpression,
HirIndexExpression, HirInfixExpression, HirLambda, HirLiteral, HirMemberAccess,
HirMethodCallExpression, HirPrefixExpression,
};

use crate::hir_def::traits::{Trait, TraitConstraint};
Expand All @@ -26,7 +26,7 @@ use std::rc::Rc;

use crate::graph::CrateId;
use crate::hir::def_map::{LocalModuleId, ModuleDefId, TryFromModuleDefId, MAIN_FUNCTION};
use crate::hir_def::stmt::{HirAssignStatement, HirLValue, HirPattern};
use crate::hir_def::stmt::{HirAssignStatement, HirForStatement, HirLValue, HirPattern};
use crate::node_interner::{
DefinitionId, DefinitionKind, ExprId, FuncId, NodeInterner, StmtId, StructId, TraitId,
};
Expand Down Expand Up @@ -957,6 +957,25 @@ impl<'a> Resolver<'a> {
let stmt = HirAssignStatement { lvalue: identifier, expression };
HirStatement::Assign(stmt)
}
Statement::For(for_loop) => {
let start_range = self.resolve_expression(for_loop.start_range);
let end_range = self.resolve_expression(for_loop.end_range);
let (identifier, block) = (for_loop.identifier, for_loop.block);

// TODO: For loop variables are currently mutable by default since we haven't
jfecher marked this conversation as resolved.
Show resolved Hide resolved
// yet implemented syntax for them to be optionally mutable.
let (identifier, block) = self.in_new_scope(|this| {
let decl = this.add_variable_decl(
identifier,
false,
true,
DefinitionKind::Local(None),
);
(decl, this.resolve_expression(block))
});

HirStatement::For(HirForStatement { start_range, end_range, block, identifier })
}
Statement::Error => HirStatement::Error,
}
}
Expand Down Expand Up @@ -1169,30 +1188,6 @@ impl<'a> Resolver<'a> {
lhs: self.resolve_expression(cast_expr.lhs),
r#type: self.resolve_type(cast_expr.r#type),
}),
ExpressionKind::For(for_expr) => {
let start_range = self.resolve_expression(for_expr.start_range);
let end_range = self.resolve_expression(for_expr.end_range);
let (identifier, block) = (for_expr.identifier, for_expr.block);

// TODO: For loop variables are currently mutable by default since we haven't
// yet implemented syntax for them to be optionally mutable.
let (identifier, block_id) = self.in_new_scope(|this| {
let decl = this.add_variable_decl(
identifier,
false,
true,
DefinitionKind::Local(None),
);
(decl, this.resolve_expression(block))
});

HirExpression::For(HirForExpression {
start_range,
end_range,
block: block_id,
identifier,
})
}
ExpressionKind::If(if_expr) => HirExpression::If(HirIfExpression {
condition: self.resolve_expression(if_expr.condition),
consequence: self.resolve_expression(if_expr.consequence),
Expand Down Expand Up @@ -1738,7 +1733,7 @@ mod test {
let (hir_func, _, _) = resolver.resolve_function(func, id);

// Iterate over function statements and apply filtering function
parse_statement_blocks(
find_lambda_captures(
hir_func.block(&interner).statements(),
&interner,
&mut all_captures,
Expand All @@ -1747,33 +1742,23 @@ mod test {
all_captures
}

fn parse_statement_blocks(
fn find_lambda_captures(
stmts: &[StmtId],
interner: &NodeInterner,
result: &mut Vec<Vec<String>>,
) {
let mut expr: HirExpression;

for stmt_id in stmts.iter() {
let hir_stmt = interner.statement(stmt_id);
match hir_stmt {
HirStatement::Expression(expr_id) => {
expr = interner.expression(&expr_id);
}
HirStatement::Let(let_stmt) => {
expr = interner.expression(&let_stmt.expression);
}
HirStatement::Assign(assign_stmt) => {
expr = interner.expression(&assign_stmt.expression);
}
HirStatement::Constrain(constr_stmt) => {
expr = interner.expression(&constr_stmt.0);
}
HirStatement::Semi(semi_expr) => {
expr = interner.expression(&semi_expr);
}
let expr_id = match hir_stmt {
HirStatement::Expression(expr_id) => expr_id,
HirStatement::Let(let_stmt) => let_stmt.expression,
HirStatement::Assign(assign_stmt) => assign_stmt.expression,
HirStatement::Constrain(constr_stmt) => constr_stmt.0,
HirStatement::Semi(semi_expr) => semi_expr,
HirStatement::For(for_loop) => for_loop.block,
HirStatement::Error => panic!("Invalid HirStatement!"),
}
};
let expr = interner.expression(&expr_id);
get_lambda_captures(expr, interner, result); // TODO: dyn filter function as parameter
}
}
Expand All @@ -1794,7 +1779,7 @@ mod test {
// Check for other captures recursively within the lambda body
let hir_body_expr = interner.expression(&lambda_expr.body);
if let HirExpression::Block(block_expr) = hir_body_expr {
parse_statement_blocks(block_expr.statements(), interner, result);
find_lambda_captures(block_expr.statements(), interner, result);
}
}
}
Expand Down
36 changes: 1 addition & 35 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
types::Type,
},
node_interner::{DefinitionKind, ExprId, FuncId, TraitMethodId},
Shared, Signedness, TypeBinding, TypeVariableKind, UnaryOp,
Signedness, TypeBinding, TypeVariableKind, UnaryOp,
};

use super::{errors::TypeCheckError, TypeChecker};
Expand Down Expand Up @@ -194,40 +194,6 @@ impl<'interner> TypeChecker<'interner> {
let span = self.interner.expr_span(expr_id);
self.check_cast(lhs_type, cast_expr.r#type, span)
}
HirExpression::For(for_expr) => {
let start_range_type = self.check_expression(&for_expr.start_range);
let end_range_type = self.check_expression(&for_expr.end_range);

let start_span = self.interner.expr_span(&for_expr.start_range);
let end_span = self.interner.expr_span(&for_expr.end_range);

// Check that start range and end range have the same types
let range_span = start_span.merge(end_span);
self.unify(&start_range_type, &end_range_type, || TypeCheckError::TypeMismatch {
expected_typ: start_range_type.to_string(),
expr_typ: end_range_type.to_string(),
expr_span: range_span,
});

let fresh_id = self.interner.next_type_variable_id();
let type_variable = Shared::new(TypeBinding::Unbound(fresh_id));
let expected_type =
Type::TypeVariable(type_variable, TypeVariableKind::IntegerOrField);

self.unify(&start_range_type, &expected_type, || {
TypeCheckError::TypeCannotBeUsed {
typ: start_range_type.clone(),
place: "for loop",
span: range_span,
}
.add_context("The range of a loop must be known at compile-time")
});

self.interner.push_definition_type(for_expr.identifier.id, start_range_type);

self.check_expression(&for_expr.block);
Type::Unit
}
HirExpression::Block(block_expr) => {
let mut block_type = Type::Unit;

Expand Down
4 changes: 2 additions & 2 deletions compiler/noirc_frontend/src/hir/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,11 @@ mod test {
fn basic_for_expr() {
let src = r#"
fn main(_x : Field) {
let _j = for _i in 0..10 {
for _i in 0..10 {
for _k in 0..100 {

}
};
}
}

"#;
Expand Down
38 changes: 37 additions & 1 deletion compiler/noirc_frontend/src/hir/type_check/stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ use noirc_errors::{Location, Span};

use crate::hir_def::expr::{HirExpression, HirIdent, HirLiteral};
use crate::hir_def::stmt::{
HirAssignStatement, HirConstrainStatement, HirLValue, HirLetStatement, HirPattern, HirStatement,
HirAssignStatement, HirConstrainStatement, HirForStatement, HirLValue, HirLetStatement,
HirPattern, HirStatement,
};
use crate::hir_def::types::Type;
use crate::node_interner::{DefinitionId, ExprId, StmtId};
use crate::{Shared, TypeBinding, TypeVariableKind};

use super::errors::{Source, TypeCheckError};
use super::TypeChecker;
Expand Down Expand Up @@ -48,11 +50,45 @@ impl<'interner> TypeChecker<'interner> {
HirStatement::Let(let_stmt) => self.check_let_stmt(let_stmt),
HirStatement::Constrain(constrain_stmt) => self.check_constrain_stmt(constrain_stmt),
HirStatement::Assign(assign_stmt) => self.check_assign_stmt(assign_stmt, stmt_id),
HirStatement::For(for_loop) => self.check_for_loop(for_loop),
HirStatement::Error => (),
}
Type::Unit
}

fn check_for_loop(&mut self, for_loop: HirForStatement) {
let start_range_type = self.check_expression(&for_loop.start_range);
let end_range_type = self.check_expression(&for_loop.end_range);

let start_span = self.interner.expr_span(&for_loop.start_range);
let end_span = self.interner.expr_span(&for_loop.end_range);

// Check that start range and end range have the same types
let range_span = start_span.merge(end_span);
self.unify(&start_range_type, &end_range_type, || TypeCheckError::TypeMismatch {
expected_typ: start_range_type.to_string(),
expr_typ: end_range_type.to_string(),
expr_span: range_span,
});

let fresh_id = self.interner.next_type_variable_id();
let type_variable = Shared::new(TypeBinding::Unbound(fresh_id));
let expected_type = Type::TypeVariable(type_variable, TypeVariableKind::IntegerOrField);

self.unify(&start_range_type, &expected_type, || {
TypeCheckError::TypeCannotBeUsed {
typ: start_range_type.clone(),
place: "for loop",
span: range_span,
}
.add_context("The range of a loop must be known at compile-time")
});

self.interner.push_definition_type(for_loop.identifier.id, start_range_type);

self.check_expression(&for_loop.block);
}

/// Associate a given HirPattern with the given Type, and remember
/// this association in the NodeInterner.
pub(crate) fn bind_pattern(&mut self, pattern: &HirPattern, typ: Type) {
Expand Down
9 changes: 0 additions & 9 deletions compiler/noirc_frontend/src/hir_def/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ pub enum HirExpression {
Call(HirCallExpression),
MethodCall(HirMethodCallExpression),
Cast(HirCastExpression),
For(HirForExpression),
If(HirIfExpression),
Tuple(Vec<ExprId>),
Lambda(HirLambda),
Expand All @@ -48,14 +47,6 @@ pub struct HirIdent {
pub id: DefinitionId,
}

#[derive(Debug, Clone)]
pub struct HirForExpression {
pub identifier: HirIdent,
pub start_range: ExprId,
pub end_range: ExprId,
pub block: ExprId,
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct HirBinaryOp {
pub kind: BinaryOpKind,
Expand Down
Loading
Loading