From ce5e1a5fa69155420aa404e2bea4b483bac5b9a6 Mon Sep 17 00:00:00 2001 From: jfecher Date: Tue, 23 Apr 2024 13:08:48 -0500 Subject: [PATCH] chore(experimental): Add scan pass and `into_expression` for comptime interpreter (#4884) # Description ## Problem\* Resolves #4590 ## Summary\* This PR links up the comptime interpreter with the rest of the codebase. It does so by adding a scanning step where the Hir is scanned for `comptime` expressions to execute. When one of these is found, the interpreter switches to evaluation mode and evaluates the expression. Afterward, the result of the expression is inlined into the Hir via `Value::into_expression`. For `Code` values, this means the entire code block is spliced in (a macro expansion). You can now run simple programs at compile-time now as long as they don't have expansion of `quote`d values (macros) since those would require the full loop back to name resolution again. Anyway, here's an example that works now: ```rs fn main() { let x = comptime { 2 * 4 }; println(x); } ``` By monomorphization the compiler sees ```rs fn main() { let x = 8; println(x); } ``` More complex expressions within the `comptime` block should also work. Just note that this scanning + evaluation is currently only within functions. So comptime globals won't be evaluated. ## Additional Context I may try splitting out this PR into several smaller ones but need to look into where to make the split more. Leaving this up for now in case people find it useful or it's not as big as I expected. Future changes: Architecture-wise we're only missing the ability to make the full loop and go back to name resolution after a macro expansion now. Since we already have the Hir -> Ast pass (although it can be considerably improved since it is rather broken still), we only need to call it after the comptime interpreter and run name resolution again. ## Documentation\* Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [ ] I have tested the changes locally. - [ ] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- compiler/noirc_frontend/src/ast/expression.rs | 2 + compiler/noirc_frontend/src/ast/statement.rs | 8 +- .../noirc_frontend/src/hir/comptime/errors.rs | 53 +++++ .../src/hir/comptime/hir_to_ast.rs | 26 ++- .../src/hir/comptime/interpreter.rs | 152 +++---------- .../noirc_frontend/src/hir/comptime/mod.rs | 5 + .../noirc_frontend/src/hir/comptime/scan.rs | 207 ++++++++++++++++++ .../noirc_frontend/src/hir/comptime/tests.rs | 4 +- .../noirc_frontend/src/hir/comptime/value.rs | 170 ++++++++++++++ .../src/hir/def_collector/dc_crate.rs | 162 ++++++++------ .../src/hir/resolution/globals.rs | 8 +- .../src/hir/resolution/resolver.rs | 19 +- .../noirc_frontend/src/hir/type_check/expr.rs | 65 +++--- .../noirc_frontend/src/hir/type_check/stmt.rs | 2 +- compiler/noirc_frontend/src/hir_def/expr.rs | 4 +- compiler/noirc_frontend/src/hir_def/stmt.rs | 2 +- .../src/monomorphization/mod.rs | 8 +- compiler/noirc_frontend/src/node_interner.rs | 10 +- compiler/noirc_frontend/src/parser/parser.rs | 4 +- compiler/noirc_frontend/src/tests.rs | 2 +- tooling/nargo_fmt/src/rewrite/expr.rs | 3 + tooling/nargo_fmt/src/visitor/stmt.rs | 2 +- 22 files changed, 667 insertions(+), 251 deletions(-) create mode 100644 compiler/noirc_frontend/src/hir/comptime/errors.rs create mode 100644 compiler/noirc_frontend/src/hir/comptime/scan.rs create mode 100644 compiler/noirc_frontend/src/hir/comptime/value.rs diff --git a/compiler/noirc_frontend/src/ast/expression.rs b/compiler/noirc_frontend/src/ast/expression.rs index 92c1add80a6..5659de46588 100644 --- a/compiler/noirc_frontend/src/ast/expression.rs +++ b/compiler/noirc_frontend/src/ast/expression.rs @@ -28,6 +28,7 @@ pub enum ExpressionKind { Lambda(Box), Parenthesized(Box), Quote(BlockExpression), + Comptime(BlockExpression), Error, } @@ -504,6 +505,7 @@ impl Display for ExpressionKind { Lambda(lambda) => lambda.fmt(f), Parenthesized(sub_expr) => write!(f, "({sub_expr})"), Quote(block) => write!(f, "quote {block}"), + Comptime(block) => write!(f, "comptime {block}"), Error => write!(f, "Error"), } } diff --git a/compiler/noirc_frontend/src/ast/statement.rs b/compiler/noirc_frontend/src/ast/statement.rs index 1831a046f5b..f37c7adc983 100644 --- a/compiler/noirc_frontend/src/ast/statement.rs +++ b/compiler/noirc_frontend/src/ast/statement.rs @@ -40,7 +40,7 @@ pub enum StatementKind { Break, Continue, /// This statement should be executed at compile-time - Comptime(Box), + CompTime(Box), // This is an expression with a trailing semi-colon Semi(Expression), // This statement is the result of a recovered parse error. @@ -87,10 +87,10 @@ impl StatementKind { } self } - StatementKind::Comptime(mut statement) => { + StatementKind::CompTime(mut statement) => { *statement = statement.add_semicolon(semi, span, last_statement_in_block, emit_error); - StatementKind::Comptime(statement) + StatementKind::CompTime(statement) } // A semicolon on a for loop is optional and does nothing StatementKind::For(_) => self, @@ -685,7 +685,7 @@ impl Display for StatementKind { StatementKind::For(for_loop) => for_loop.fmt(f), StatementKind::Break => write!(f, "break"), StatementKind::Continue => write!(f, "continue"), - StatementKind::Comptime(statement) => write!(f, "comptime {statement}"), + StatementKind::CompTime(statement) => write!(f, "comptime {statement}"), StatementKind::Semi(semi) => write!(f, "{semi};"), StatementKind::Error => write!(f, "Error"), } diff --git a/compiler/noirc_frontend/src/hir/comptime/errors.rs b/compiler/noirc_frontend/src/hir/comptime/errors.rs new file mode 100644 index 00000000000..d3db7fcaee9 --- /dev/null +++ b/compiler/noirc_frontend/src/hir/comptime/errors.rs @@ -0,0 +1,53 @@ +use crate::{node_interner::DefinitionId, Type}; +use acvm::FieldElement; +use noirc_errors::Location; + +use super::value::Value; + +/// The possible errors that can halt the interpreter. +#[derive(Debug)] +pub enum InterpreterError { + ArgumentCountMismatch { expected: usize, actual: usize, call_location: Location }, + TypeMismatch { expected: Type, value: Value, location: Location }, + NoValueForId { id: DefinitionId, location: Location }, + IntegerOutOfRangeForType { value: FieldElement, typ: Type, location: Location }, + ErrorNodeEncountered { location: Location }, + NonFunctionCalled { value: Value, location: Location }, + NonBoolUsedInIf { value: Value, location: Location }, + NonBoolUsedInConstrain { value: Value, location: Location }, + FailingConstraint { message: Option, location: Location }, + NoMethodFound { object: Value, typ: Type, location: Location }, + NonIntegerUsedInLoop { value: Value, location: Location }, + NonPointerDereferenced { value: Value, location: Location }, + NonTupleOrStructInMemberAccess { value: Value, location: Location }, + NonArrayIndexed { value: Value, location: Location }, + NonIntegerUsedAsIndex { value: Value, location: Location }, + NonIntegerIntegerLiteral { typ: Type, location: Location }, + NonIntegerArrayLength { typ: Type, location: Location }, + NonNumericCasted { value: Value, location: Location }, + IndexOutOfBounds { index: usize, length: usize, location: Location }, + ExpectedStructToHaveField { value: Value, field_name: String, location: Location }, + TypeUnsupported { typ: Type, location: Location }, + InvalidValueForUnary { value: Value, operator: &'static str, location: Location }, + InvalidValuesForBinary { lhs: Value, rhs: Value, operator: &'static str, location: Location }, + CastToNonNumericType { typ: Type, location: Location }, + QuoteInRuntimeCode { location: Location }, + NonStructInConstructor { typ: Type, location: Location }, + CannotInlineMacro { value: Value, location: Location }, + UnquoteFoundDuringEvaluation { location: Location }, + + Unimplemented { item: &'static str, location: Location }, + + // Perhaps this should be unreachable! due to type checking also preventing this error? + // Currently it and the Continue variant are the only interpreter errors without a Location field + BreakNotInLoop { location: Location }, + ContinueNotInLoop { location: Location }, + + // These cases are not errors, they are just used to prevent us from running more code + // until the loop can be resumed properly. These cases will never be displayed to users. + Break, + Continue, +} + +#[allow(unused)] +pub(super) type IResult = std::result::Result; diff --git a/compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs b/compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs index 47ca7083ff0..dd23edf0004 100644 --- a/compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs +++ b/compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs @@ -9,7 +9,7 @@ use crate::ast::{ UnresolvedTypeData, UnresolvedTypeExpression, }; use crate::ast::{ConstrainStatement, Expression, Statement, StatementKind}; -use crate::hir_def::expr::{HirArrayLiteral, HirExpression, HirIdent}; +use crate::hir_def::expr::{HirArrayLiteral, HirBlockExpression, HirExpression, HirIdent}; use crate::hir_def::stmt::{HirLValue, HirPattern, HirStatement}; use crate::hir_def::types::Type; use crate::macros_api::HirLiteral; @@ -26,7 +26,7 @@ impl StmtId { #[allow(unused)] fn to_ast(self, interner: &NodeInterner) -> Statement { let statement = interner.statement(&self); - let span = interner.statement_span(&self); + let span = interner.statement_span(self); let kind = match statement { HirStatement::Let(let_stmt) => { @@ -66,8 +66,8 @@ impl StmtId { HirStatement::Expression(expr) => StatementKind::Expression(expr.to_ast(interner)), HirStatement::Semi(expr) => StatementKind::Semi(expr.to_ast(interner)), HirStatement::Error => StatementKind::Error, - HirStatement::Comptime(statement) => { - StatementKind::Comptime(Box::new(statement.to_ast(interner).kind)) + HirStatement::CompTime(statement) => { + StatementKind::CompTime(Box::new(statement.to_ast(interner).kind)) } }; @@ -108,10 +108,7 @@ impl ExprId { ExpressionKind::Literal(Literal::FmtStr(string)) } HirExpression::Literal(HirLiteral::Unit) => ExpressionKind::Literal(Literal::Unit), - HirExpression::Block(expr) => { - let statements = vecmap(expr.statements, |statement| statement.to_ast(interner)); - ExpressionKind::Block(BlockExpression { statements }) - } + HirExpression::Block(expr) => ExpressionKind::Block(expr.into_ast(interner)), HirExpression::Prefix(prefix) => ExpressionKind::Prefix(Box::new(PrefixExpression { operator: prefix.operator, rhs: prefix.rhs.to_ast(interner), @@ -172,8 +169,12 @@ impl ExprId { let body = lambda.body.to_ast(interner); ExpressionKind::Lambda(Box::new(Lambda { parameters, return_type, body })) } - HirExpression::Quote(block) => ExpressionKind::Quote(block), HirExpression::Error => ExpressionKind::Error, + HirExpression::Comptime(block) => ExpressionKind::Comptime(block.into_ast(interner)), + HirExpression::Quote(block) => ExpressionKind::Quote(block), + + // A macro was evaluated here! + HirExpression::Unquote(block) => ExpressionKind::Block(block), }; Expression::new(kind, span) @@ -353,3 +354,10 @@ impl HirArrayLiteral { } } } + +impl HirBlockExpression { + fn into_ast(self, interner: &NodeInterner) -> BlockExpression { + let statements = vecmap(self.statements, |statement| statement.to_ast(interner)); + BlockExpression { statements } + } +} diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs index 03839c2f0cd..c6d508a581e 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs @@ -1,12 +1,12 @@ -use std::{borrow::Cow, collections::hash_map::Entry, rc::Rc}; +use std::{collections::hash_map::Entry, rc::Rc}; use acvm::FieldElement; use im::Vector; -use iter_extended::{try_vecmap, vecmap}; +use iter_extended::try_vecmap; use noirc_errors::Location; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; -use crate::ast::{BinaryOpKind, BlockExpression, FunctionKind, IntegerBitSize, Signedness}; +use crate::ast::{BinaryOpKind, FunctionKind, IntegerBitSize, Signedness}; use crate::{ hir_def::{ expr::{ @@ -22,12 +22,16 @@ use crate::{ }, macros_api::{HirExpression, HirLiteral, HirStatement, NodeInterner}, node_interner::{DefinitionId, DefinitionKind, ExprId, FuncId, StmtId}, + Shared, Type, TypeBinding, TypeBindings, TypeVariableKind, }; -use crate::{Shared, Type, TypeBinding, TypeBindings, TypeVariableKind}; + +use super::errors::{IResult, InterpreterError}; +use super::value::Value; + #[allow(unused)] -pub(crate) struct Interpreter<'interner> { +pub struct Interpreter<'interner> { /// To expand macros the Interpreter may mutate hir nodes within the NodeInterner - interner: &'interner mut NodeInterner, + pub(super) interner: &'interner mut NodeInterner, /// Each value currently in scope in the interpreter. /// Each element of the Vec represents a scope with every scope together making @@ -49,72 +53,6 @@ pub(crate) struct Interpreter<'interner> { in_comptime_context: bool, } -#[allow(unused)] -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum Value { - Unit, - Bool(bool), - Field(FieldElement), - I8(i8), - I32(i32), - I64(i64), - U8(u8), - U32(u32), - U64(u64), - String(Rc), - Function(FuncId, Type), - Closure(HirLambda, Vec, Type), - Tuple(Vec), - Struct(HashMap, Value>, Type), - Pointer(Shared), - Array(Vector, Type), - Slice(Vector, Type), - Code(Rc), -} - -/// The possible errors that can halt the interpreter. -#[allow(unused)] -#[derive(Debug)] -pub(crate) enum InterpreterError { - ArgumentCountMismatch { expected: usize, actual: usize, call_location: Location }, - TypeMismatch { expected: Type, value: Value, location: Location }, - NoValueForId { id: DefinitionId, location: Location }, - IntegerOutOfRangeForType { value: FieldElement, typ: Type, location: Location }, - ErrorNodeEncountered { location: Location }, - NonFunctionCalled { value: Value, location: Location }, - NonBoolUsedInIf { value: Value, location: Location }, - NonBoolUsedInConstrain { value: Value, location: Location }, - FailingConstraint { message: Option, location: Location }, - NoMethodFound { object: Value, typ: Type, location: Location }, - NonIntegerUsedInLoop { value: Value, location: Location }, - NonPointerDereferenced { value: Value, location: Location }, - NonTupleOrStructInMemberAccess { value: Value, location: Location }, - NonArrayIndexed { value: Value, location: Location }, - NonIntegerUsedAsIndex { value: Value, location: Location }, - NonIntegerIntegerLiteral { typ: Type, location: Location }, - NonIntegerArrayLength { typ: Type, location: Location }, - NonNumericCasted { value: Value, location: Location }, - IndexOutOfBounds { index: usize, length: usize, location: Location }, - ExpectedStructToHaveField { value: Value, field_name: String, location: Location }, - TypeUnsupported { typ: Type, location: Location }, - InvalidValueForUnary { value: Value, operator: &'static str, location: Location }, - InvalidValuesForBinary { lhs: Value, rhs: Value, operator: &'static str, location: Location }, - CastToNonNumericType { typ: Type, location: Location }, - - // Perhaps this should be unreachable! due to type checking also preventing this error? - // Currently it and the Continue variant are the only interpreter errors without a Location field - BreakNotInLoop, - ContinueNotInLoop, - - // These cases are not errors but prevent us from running more code - // until the loop can be resumed properly. - Break, - Continue, -} - -#[allow(unused)] -type IResult = std::result::Result; - #[allow(unused)] impl<'a> Interpreter<'a> { pub(crate) fn new(interner: &'a mut NodeInterner) -> Self { @@ -193,14 +131,14 @@ impl<'a> Interpreter<'a> { /// Enters a function, pushing a new scope and resetting any required state. /// Returns the previous values of the internal state, to be reset when /// `exit_function` is called. - fn enter_function(&mut self) -> (bool, Vec>) { + pub(super) fn enter_function(&mut self) -> (bool, Vec>) { // Drain every scope except the global scope let scope = self.scopes.drain(1..).collect(); self.push_scope(); (std::mem::take(&mut self.in_loop), scope) } - fn exit_function(&mut self, mut state: (bool, Vec>)) { + pub(super) fn exit_function(&mut self, mut state: (bool, Vec>)) { self.in_loop = state.0; // Keep only the global scope @@ -208,11 +146,11 @@ impl<'a> Interpreter<'a> { self.scopes.append(&mut state.1); } - fn push_scope(&mut self) { + pub(super) fn push_scope(&mut self) { self.scopes.push(HashMap::default()); } - fn pop_scope(&mut self) { + pub(super) fn pop_scope(&mut self) { self.scopes.pop(); } @@ -375,6 +313,13 @@ impl<'a> Interpreter<'a> { HirExpression::Tuple(tuple) => self.evaluate_tuple(tuple), HirExpression::Lambda(lambda) => self.evaluate_lambda(lambda, id), HirExpression::Quote(block) => Ok(Value::Code(Rc::new(block))), + HirExpression::Comptime(block) => self.evaluate_block(block), + HirExpression::Unquote(block) => { + // An Unquote expression being found is indicative of a macro being + // expanded within another comptime fn which we don't currently support. + let location = self.interner.expr_location(&id); + Err(InterpreterError::UnquoteFoundDuringEvaluation { location }) + } HirExpression::Error => { let location = self.interner.expr_location(&id); Err(InterpreterError::ErrorNodeEncountered { location }) @@ -390,7 +335,7 @@ impl<'a> Interpreter<'a> { let typ = self.interner.id_type(id); Ok(Value::Function(*function_id, typ)) } - DefinitionKind::Local(_) => dbg!(self.lookup(&ident)), + DefinitionKind::Local(_) => self.lookup(&ident), DefinitionKind::Global(global_id) => { let let_ = self.interner.get_global_let_statement(*global_id).unwrap(); self.evaluate_let(let_)?; @@ -503,7 +448,7 @@ impl<'a> Interpreter<'a> { } } - fn evaluate_block(&mut self, mut block: HirBlockExpression) -> IResult { + pub(super) fn evaluate_block(&mut self, mut block: HirBlockExpression) -> IResult { let last_statement = block.statements.pop(); self.push_scope(); @@ -1075,10 +1020,10 @@ impl<'a> Interpreter<'a> { HirStatement::Constrain(constrain) => self.evaluate_constrain(constrain), HirStatement::Assign(assign) => self.evaluate_assign(assign), HirStatement::For(for_) => self.evaluate_for(for_), - HirStatement::Break => self.evaluate_break(), - HirStatement::Continue => self.evaluate_continue(), + HirStatement::Break => self.evaluate_break(statement), + HirStatement::Continue => self.evaluate_continue(statement), HirStatement::Expression(expression) => self.evaluate(expression), - HirStatement::Comptime(statement) => self.evaluate_comptime(statement), + HirStatement::CompTime(statement) => self.evaluate_comptime(statement), HirStatement::Semi(expression) => { self.evaluate(expression)?; Ok(Value::Unit) @@ -1236,59 +1181,28 @@ impl<'a> Interpreter<'a> { Ok(Value::Unit) } - fn evaluate_break(&mut self) -> IResult { + fn evaluate_break(&mut self, id: StmtId) -> IResult { if self.in_loop { Err(InterpreterError::Break) } else { - Err(InterpreterError::BreakNotInLoop) + let location = self.interner.statement_location(id); + Err(InterpreterError::BreakNotInLoop { location }) } } - fn evaluate_continue(&mut self) -> IResult { + fn evaluate_continue(&mut self, id: StmtId) -> IResult { if self.in_loop { Err(InterpreterError::Continue) } else { - Err(InterpreterError::ContinueNotInLoop) + let location = self.interner.statement_location(id); + Err(InterpreterError::ContinueNotInLoop { location }) } } - fn evaluate_comptime(&mut self, statement: StmtId) -> IResult { + pub(super) fn evaluate_comptime(&mut self, statement: StmtId) -> IResult { let was_in_comptime = std::mem::replace(&mut self.in_comptime_context, true); let result = self.evaluate_statement(statement); self.in_comptime_context = was_in_comptime; result } } - -impl Value { - fn get_type(&self) -> Cow { - Cow::Owned(match self { - Value::Unit => Type::Unit, - Value::Bool(_) => Type::Bool, - Value::Field(_) => Type::FieldElement, - Value::I8(_) => Type::Integer(Signedness::Signed, IntegerBitSize::Eight), - Value::I32(_) => Type::Integer(Signedness::Signed, IntegerBitSize::ThirtyTwo), - Value::I64(_) => Type::Integer(Signedness::Signed, IntegerBitSize::SixtyFour), - Value::U8(_) => Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight), - Value::U32(_) => Type::Integer(Signedness::Unsigned, IntegerBitSize::ThirtyTwo), - Value::U64(_) => Type::Integer(Signedness::Unsigned, IntegerBitSize::SixtyFour), - Value::String(value) => { - let length = Type::Constant(value.len() as u64); - Type::String(Box::new(length)) - } - Value::Function(_, typ) => return Cow::Borrowed(typ), - Value::Closure(_, _, typ) => return Cow::Borrowed(typ), - Value::Tuple(fields) => { - Type::Tuple(vecmap(fields, |field| field.get_type().into_owned())) - } - Value::Struct(_, typ) => return Cow::Borrowed(typ), - Value::Array(_, typ) => return Cow::Borrowed(typ), - Value::Slice(_, typ) => return Cow::Borrowed(typ), - Value::Code(_) => Type::Code, - Value::Pointer(element) => { - let element = element.borrow().get_type().into_owned(); - Type::MutableReference(Box::new(element)) - } - }) - } -} diff --git a/compiler/noirc_frontend/src/hir/comptime/mod.rs b/compiler/noirc_frontend/src/hir/comptime/mod.rs index 83aaddaa405..26e05d675b3 100644 --- a/compiler/noirc_frontend/src/hir/comptime/mod.rs +++ b/compiler/noirc_frontend/src/hir/comptime/mod.rs @@ -1,3 +1,8 @@ +mod errors; mod hir_to_ast; mod interpreter; +mod scan; mod tests; +mod value; + +pub use interpreter::Interpreter; diff --git a/compiler/noirc_frontend/src/hir/comptime/scan.rs b/compiler/noirc_frontend/src/hir/comptime/scan.rs new file mode 100644 index 00000000000..d4fa355627f --- /dev/null +++ b/compiler/noirc_frontend/src/hir/comptime/scan.rs @@ -0,0 +1,207 @@ +//! This module is for the scanning of the Hir by the interpreter. +//! In this initial step, the Hir is scanned for `CompTime` nodes +//! without actually executing anything until such a node is found. +//! Once such a node is found, the interpreter will call the relevant +//! evaluate method on that node type, insert the result into the Ast, +//! and continue scanning the rest of the program. +//! +//! Since it mostly just needs to recur on the Hir looking for CompTime +//! nodes, this pass is fairly simple. The only thing it really needs to +//! ensure to do is to push and pop scopes on the interpreter as needed +//! so that any variables defined within e.g. an `if` statement containing +//! a `CompTime` block aren't accessible outside of the `if`. +use crate::{ + hir_def::{ + expr::{ + HirArrayLiteral, HirBlockExpression, HirCallExpression, HirConstructorExpression, + HirIfExpression, HirIndexExpression, HirInfixExpression, HirLambda, + HirMethodCallExpression, + }, + stmt::HirForStatement, + }, + macros_api::{HirExpression, HirLiteral, HirStatement}, + node_interner::{ExprId, FuncId, StmtId}, +}; + +use super::{ + errors::{IResult, InterpreterError}, + interpreter::Interpreter, +}; + +#[allow(dead_code)] +impl<'interner> Interpreter<'interner> { + /// Scan through a function, evaluating any CompTime nodes found. + /// These nodes will be modified in place, replaced with the + /// result of their evaluation. + pub fn scan_function(&mut self, function: FuncId) -> IResult<()> { + let function = self.interner.function(&function); + + let state = self.enter_function(); + self.scan_expression(function.as_expr())?; + self.exit_function(state); + Ok(()) + } + + fn scan_expression(&mut self, expr: ExprId) -> IResult<()> { + match self.interner.expression(&expr) { + HirExpression::Ident(_) => Ok(()), + HirExpression::Literal(literal) => self.scan_literal(literal), + HirExpression::Block(block) => self.scan_block(block), + HirExpression::Prefix(prefix) => self.scan_expression(prefix.rhs), + HirExpression::Infix(infix) => self.scan_infix(infix), + HirExpression::Index(index) => self.scan_index(index), + HirExpression::Constructor(constructor) => self.scan_constructor(constructor), + HirExpression::MemberAccess(member_access) => self.scan_expression(member_access.lhs), + HirExpression::Call(call) => self.scan_call(call), + HirExpression::MethodCall(method_call) => self.scan_method_call(method_call), + HirExpression::Cast(cast) => self.scan_expression(cast.lhs), + HirExpression::If(if_) => self.scan_if(if_), + HirExpression::Tuple(tuple) => self.scan_tuple(tuple), + HirExpression::Lambda(lambda) => self.scan_lambda(lambda), + HirExpression::Comptime(block) => { + let location = self.interner.expr_location(&expr); + let new_expr = + self.evaluate_block(block)?.into_expression(self.interner, location)?; + let new_expr = self.interner.expression(&new_expr); + self.interner.replace_expr(&expr, new_expr); + Ok(()) + } + HirExpression::Quote(_) => { + // This error could be detected much earlier in the compiler pipeline but + // it just makes sense for the comptime code to handle comptime things. + let location = self.interner.expr_location(&expr); + Err(InterpreterError::QuoteInRuntimeCode { location }) + } + HirExpression::Error => Ok(()), + + // Unquote should only be inserted by the comptime interpreter while expanding macros + // and is removed by the Hir -> Ast conversion pass which converts it into a normal block. + // If we find one now during scanning it most likely means the Hir -> Ast conversion + // missed it somehow. In the future we may allow users to manually write unquote + // expressions in their code but for now this is unreachable. + HirExpression::Unquote(block) => { + unreachable!("Found unquote block while scanning: {block}") + } + } + } + + fn scan_literal(&mut self, literal: HirLiteral) -> IResult<()> { + match literal { + HirLiteral::Array(elements) | HirLiteral::Slice(elements) => match elements { + HirArrayLiteral::Standard(elements) => { + for element in elements { + self.scan_expression(element)?; + } + Ok(()) + } + HirArrayLiteral::Repeated { repeated_element, length: _ } => { + self.scan_expression(repeated_element) + } + }, + HirLiteral::Bool(_) + | HirLiteral::Integer(_, _) + | HirLiteral::Str(_) + | HirLiteral::FmtStr(_, _) + | HirLiteral::Unit => Ok(()), + } + } + + fn scan_block(&mut self, block: HirBlockExpression) -> IResult<()> { + self.push_scope(); + for statement in &block.statements { + self.scan_statement(*statement)?; + } + self.pop_scope(); + Ok(()) + } + + fn scan_infix(&mut self, infix: HirInfixExpression) -> IResult<()> { + self.scan_expression(infix.lhs)?; + self.scan_expression(infix.rhs) + } + + fn scan_index(&mut self, index: HirIndexExpression) -> IResult<()> { + self.scan_expression(index.collection)?; + self.scan_expression(index.index) + } + + fn scan_constructor(&mut self, constructor: HirConstructorExpression) -> IResult<()> { + for (_, field) in constructor.fields { + self.scan_expression(field)?; + } + Ok(()) + } + + fn scan_call(&mut self, call: HirCallExpression) -> IResult<()> { + self.scan_expression(call.func)?; + for arg in call.arguments { + self.scan_expression(arg)?; + } + Ok(()) + } + + fn scan_method_call(&mut self, method_call: HirMethodCallExpression) -> IResult<()> { + self.scan_expression(method_call.object)?; + for arg in method_call.arguments { + self.scan_expression(arg)?; + } + Ok(()) + } + + fn scan_if(&mut self, if_: HirIfExpression) -> IResult<()> { + self.scan_expression(if_.condition)?; + + self.push_scope(); + self.scan_expression(if_.consequence)?; + self.pop_scope(); + + if let Some(alternative) = if_.alternative { + self.push_scope(); + self.scan_expression(alternative)?; + self.pop_scope(); + } + Ok(()) + } + + fn scan_tuple(&mut self, tuple: Vec) -> IResult<()> { + for field in tuple { + self.scan_expression(field)?; + } + Ok(()) + } + + fn scan_lambda(&mut self, lambda: HirLambda) -> IResult<()> { + self.scan_expression(lambda.body) + } + + fn scan_statement(&mut self, statement: StmtId) -> IResult<()> { + match self.interner.statement(&statement) { + HirStatement::Let(let_) => self.scan_expression(let_.expression), + HirStatement::Constrain(constrain) => self.scan_expression(constrain.0), + HirStatement::Assign(assign) => self.scan_expression(assign.expression), + HirStatement::For(for_) => self.scan_for(for_), + HirStatement::Break => Ok(()), + HirStatement::Continue => Ok(()), + HirStatement::Expression(expression) => self.scan_expression(expression), + HirStatement::Semi(semi) => self.scan_expression(semi), + HirStatement::Error => Ok(()), + HirStatement::CompTime(comptime) => { + let location = self.interner.statement_location(comptime); + let new_expr = + self.evaluate_comptime(comptime)?.into_expression(self.interner, location)?; + self.interner.replace_statement(statement, HirStatement::Expression(new_expr)); + Ok(()) + } + } + } + + fn scan_for(&mut self, for_: HirForStatement) -> IResult<()> { + // We don't need to set self.in_loop since we're not actually evaluating this loop. + // We just need to push a scope so that if there's a `comptime { .. }` expr inside this + // loop, any variables it defines aren't accessible outside of it. + self.push_scope(); + self.scan_expression(for_.block)?; + self.pop_scope(); + Ok(()) + } +} diff --git a/compiler/noirc_frontend/src/hir/comptime/tests.rs b/compiler/noirc_frontend/src/hir/comptime/tests.rs index 016e7079886..1a84dae4a87 100644 --- a/compiler/noirc_frontend/src/hir/comptime/tests.rs +++ b/compiler/noirc_frontend/src/hir/comptime/tests.rs @@ -2,7 +2,9 @@ use noirc_errors::Location; -use super::interpreter::{Interpreter, InterpreterError, Value}; +use super::errors::InterpreterError; +use super::interpreter::Interpreter; +use super::value::Value; use crate::hir::type_check::test::type_check_src_code; fn interpret_helper(src: &str, func_namespace: Vec) -> Result { diff --git a/compiler/noirc_frontend/src/hir/comptime/value.rs b/compiler/noirc_frontend/src/hir/comptime/value.rs new file mode 100644 index 00000000000..89102344b09 --- /dev/null +++ b/compiler/noirc_frontend/src/hir/comptime/value.rs @@ -0,0 +1,170 @@ +use std::{borrow::Cow, rc::Rc}; + +use acvm::FieldElement; +use im::Vector; +use iter_extended::{try_vecmap, vecmap}; +use noirc_errors::Location; + +use crate::{ + ast::{BlockExpression, Ident, IntegerBitSize, Signedness}, + hir_def::expr::{HirArrayLiteral, HirConstructorExpression, HirIdent, HirLambda, ImplKind}, + macros_api::{HirExpression, HirLiteral, NodeInterner}, + node_interner::{ExprId, FuncId}, + Shared, Type, +}; +use rustc_hash::FxHashMap as HashMap; + +use super::errors::{IResult, InterpreterError}; + +#[allow(unused)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Value { + Unit, + Bool(bool), + Field(FieldElement), + I8(i8), + I32(i32), + I64(i64), + U8(u8), + U32(u32), + U64(u64), + String(Rc), + Function(FuncId, Type), + Closure(HirLambda, Vec, Type), + Tuple(Vec), + Struct(HashMap, Value>, Type), + Pointer(Shared), + Array(Vector, Type), + Slice(Vector, Type), + Code(Rc), +} + +impl Value { + pub(crate) fn get_type(&self) -> Cow { + Cow::Owned(match self { + Value::Unit => Type::Unit, + Value::Bool(_) => Type::Bool, + Value::Field(_) => Type::FieldElement, + Value::I8(_) => Type::Integer(Signedness::Signed, IntegerBitSize::Eight), + Value::I32(_) => Type::Integer(Signedness::Signed, IntegerBitSize::ThirtyTwo), + Value::I64(_) => Type::Integer(Signedness::Signed, IntegerBitSize::SixtyFour), + Value::U8(_) => Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight), + Value::U32(_) => Type::Integer(Signedness::Unsigned, IntegerBitSize::ThirtyTwo), + Value::U64(_) => Type::Integer(Signedness::Unsigned, IntegerBitSize::SixtyFour), + Value::String(value) => { + let length = Type::Constant(value.len() as u64); + Type::String(Box::new(length)) + } + Value::Function(_, typ) => return Cow::Borrowed(typ), + Value::Closure(_, _, typ) => return Cow::Borrowed(typ), + Value::Tuple(fields) => { + Type::Tuple(vecmap(fields, |field| field.get_type().into_owned())) + } + Value::Struct(_, typ) => return Cow::Borrowed(typ), + Value::Array(_, typ) => return Cow::Borrowed(typ), + Value::Slice(_, typ) => return Cow::Borrowed(typ), + Value::Code(_) => Type::Code, + Value::Pointer(element) => { + let element = element.borrow().get_type().into_owned(); + Type::MutableReference(Box::new(element)) + } + }) + } + + pub(crate) fn into_expression( + self, + interner: &mut NodeInterner, + location: Location, + ) -> IResult { + let typ = self.get_type().into_owned(); + + let expression = match self { + Value::Unit => HirExpression::Literal(HirLiteral::Unit), + Value::Bool(value) => HirExpression::Literal(HirLiteral::Bool(value)), + Value::Field(value) => HirExpression::Literal(HirLiteral::Integer(value, false)), + Value::I8(value) => { + let negative = value < 0; + let value = value.abs(); + let value = (value as u128).into(); + HirExpression::Literal(HirLiteral::Integer(value, negative)) + } + Value::I32(value) => { + let negative = value < 0; + let value = value.abs(); + let value = (value as u128).into(); + HirExpression::Literal(HirLiteral::Integer(value, negative)) + } + Value::I64(value) => { + let negative = value < 0; + let value = value.abs(); + let value = (value as u128).into(); + HirExpression::Literal(HirLiteral::Integer(value, negative)) + } + Value::U8(value) => { + HirExpression::Literal(HirLiteral::Integer((value as u128).into(), false)) + } + Value::U32(value) => { + HirExpression::Literal(HirLiteral::Integer((value as u128).into(), false)) + } + Value::U64(value) => { + HirExpression::Literal(HirLiteral::Integer((value as u128).into(), false)) + } + Value::String(value) => HirExpression::Literal(HirLiteral::Str(unwrap_rc(value))), + Value::Function(id, _typ) => { + let id = interner.function_definition_id(id); + let impl_kind = ImplKind::NotATraitMethod; + HirExpression::Ident(HirIdent { location, id, impl_kind }) + } + Value::Closure(_lambda, _env, _typ) => { + // TODO: How should a closure's environment be inlined? + let item = "returning closures from a comptime fn"; + return Err(InterpreterError::Unimplemented { item, location }); + } + Value::Tuple(fields) => { + let fields = try_vecmap(fields, |field| field.into_expression(interner, location))?; + HirExpression::Tuple(fields) + } + Value::Struct(fields, typ) => { + let fields = try_vecmap(fields, |(name, field)| { + let field = field.into_expression(interner, location)?; + Ok((Ident::new(unwrap_rc(name), location.span), field)) + })?; + + let (r#type, struct_generics) = match typ.follow_bindings() { + Type::Struct(def, generics) => (def, generics), + _ => return Err(InterpreterError::NonStructInConstructor { typ, location }), + }; + + HirExpression::Constructor(HirConstructorExpression { + r#type, + struct_generics, + fields, + }) + } + Value::Array(elements, _) => { + let elements = + try_vecmap(elements, |elements| elements.into_expression(interner, location))?; + HirExpression::Literal(HirLiteral::Array(HirArrayLiteral::Standard(elements))) + } + Value::Slice(elements, _) => { + let elements = + try_vecmap(elements, |elements| elements.into_expression(interner, location))?; + HirExpression::Literal(HirLiteral::Slice(HirArrayLiteral::Standard(elements))) + } + Value::Code(block) => HirExpression::Unquote(unwrap_rc(block)), + Value::Pointer(_) => { + return Err(InterpreterError::CannotInlineMacro { value: self, location }) + } + }; + + let id = interner.push_expr(expression); + interner.push_expr_location(id, location.span, location.file); + interner.push_expr_type(id, typ); + Ok(id) + } +} + +/// Unwraps an Rc value without cloning the inner value if the reference count is 1. Clones otherwise. +fn unwrap_rc(rc: Rc) -> T { + Rc::try_unwrap(rc).unwrap_or_else(|rc| (*rc).clone()) +} diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index 4c6b5ab5885..7805f36cdb2 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -1,6 +1,7 @@ use super::dc_mod::collect_defs; use super::errors::{DefCollectorErrorKind, DuplicateType}; use crate::graph::CrateId; +use crate::hir::comptime::Interpreter; use crate::hir::def_map::{CrateDefMap, LocalModuleId, ModuleId}; use crate::hir::resolution::errors::ResolverError; @@ -30,6 +31,15 @@ use std::collections::{BTreeMap, HashMap}; use std::vec; +#[derive(Default)] +pub struct ResolvedModule { + pub globals: Vec<(FileId, GlobalId)>, + pub functions: Vec<(FileId, FuncId)>, + pub trait_impl_functions: Vec<(FileId, FuncId)>, + + pub errors: Vec<(CompilationError, FileId)>, +} + /// Stores all of the unresolved functions in a particular file/mod #[derive(Clone)] pub struct UnresolvedFunctions { @@ -304,6 +314,8 @@ impl DefCollector { } } + let mut resolved_module = ResolvedModule { errors, ..Default::default() }; + // We must first resolve and intern the globals before we can resolve any stmts inside each function. // Each function uses its own resolver with a newly created ScopeForest, and must be resolved again to be within a function's scope // @@ -312,21 +324,29 @@ impl DefCollector { let (literal_globals, other_globals) = filter_literal_globals(def_collector.collected_globals); - let mut resolved_globals = resolve_globals(context, literal_globals, crate_id); + resolved_module.resolve_globals(context, literal_globals, crate_id); - errors.extend(resolve_type_aliases( + resolved_module.errors.extend(resolve_type_aliases( context, def_collector.collected_type_aliases, crate_id, )); - errors.extend(resolve_traits(context, def_collector.collected_traits, crate_id)); + resolved_module.errors.extend(resolve_traits( + context, + def_collector.collected_traits, + crate_id, + )); // Must resolve structs before we resolve globals. - errors.extend(resolve_structs(context, def_collector.collected_types, crate_id)); + resolved_module.errors.extend(resolve_structs( + context, + def_collector.collected_types, + crate_id, + )); // Bind trait impls to their trait. Collect trait functions, that have a // default implementation, which hasn't been overridden. - errors.extend(collect_trait_impls( + resolved_module.errors.extend(collect_trait_impls( context, crate_id, &mut def_collector.collected_traits_impls, @@ -339,54 +359,55 @@ impl DefCollector { // // These are resolved after trait impls so that struct methods are chosen // over trait methods if there are name conflicts. - errors.extend(collect_impls(context, crate_id, &def_collector.collected_impls)); + resolved_module.errors.extend(collect_impls( + context, + crate_id, + &def_collector.collected_impls, + )); // We must wait to resolve non-integer globals until after we resolve structs since struct // globals will need to reference the struct type they're initialized to to ensure they are valid. - resolved_globals.extend(resolve_globals(context, other_globals, crate_id)); - errors.extend(resolved_globals.errors); + resolved_module.resolve_globals(context, other_globals, crate_id); // Resolve each function in the crate. This is now possible since imports have been resolved - let mut functions = Vec::new(); - functions.extend(resolve_free_functions( + resolved_module.functions = resolve_free_functions( &mut context.def_interner, crate_id, &context.def_maps, def_collector.collected_functions, None, - &mut errors, - )); + &mut resolved_module.errors, + ); - functions.extend(resolve_impls( + resolved_module.functions.extend(resolve_impls( &mut context.def_interner, crate_id, &context.def_maps, def_collector.collected_impls, - &mut errors, + &mut resolved_module.errors, )); - let impl_functions = resolve_trait_impls( + resolved_module.trait_impl_functions = resolve_trait_impls( context, def_collector.collected_traits_impls, crate_id, - &mut errors, + &mut resolved_module.errors, ); for macro_processor in macro_processors { macro_processor.process_typed_ast(&crate_id, context).unwrap_or_else( |(macro_err, file_id)| { - errors.push((macro_err.into(), file_id)); + resolved_module.errors.push((macro_err.into(), file_id)); }, ); } - errors.extend(context.def_interner.check_for_dependency_cycles()); + resolved_module.errors.extend(context.def_interner.check_for_dependency_cycles()); - errors.extend(type_check_globals(&mut context.def_interner, resolved_globals.globals)); - errors.extend(type_check_functions(&mut context.def_interner, functions)); - errors.extend(type_check_trait_impl_signatures(&mut context.def_interner, &impl_functions)); - errors.extend(type_check_functions(&mut context.def_interner, impl_functions)); - errors + resolved_module.type_check(context); + resolved_module.evaluate_comptime(&mut context.def_interner); + + resolved_module.errors } } @@ -444,48 +465,59 @@ fn filter_literal_globals( }) } -fn type_check_globals( - interner: &mut NodeInterner, - global_ids: Vec<(FileId, GlobalId)>, -) -> Vec<(CompilationError, fm::FileId)> { - global_ids - .into_iter() - .flat_map(|(file_id, global_id)| { - TypeChecker::check_global(global_id, interner) - .iter() - .cloned() - .map(|e| (e.into(), file_id)) - .collect::>() - }) - .collect() -} +impl ResolvedModule { + fn type_check(&mut self, context: &mut Context) { + self.type_check_globals(&mut context.def_interner); + self.type_check_functions(&mut context.def_interner); + self.type_check_trait_impl_function(&mut context.def_interner); + } -fn type_check_functions( - interner: &mut NodeInterner, - file_func_ids: Vec<(FileId, FuncId)>, -) -> Vec<(CompilationError, fm::FileId)> { - file_func_ids - .into_iter() - .flat_map(|(file, func)| { - type_check_func(interner, func) - .into_iter() - .map(|e| (e.into(), file)) - .collect::>() - }) - .collect() -} + fn type_check_globals(&mut self, interner: &mut NodeInterner) { + for (file_id, global_id) in self.globals.iter() { + for error in TypeChecker::check_global(*global_id, interner) { + self.errors.push((error.into(), *file_id)); + } + } + } + + fn type_check_functions(&mut self, interner: &mut NodeInterner) { + for (file, func) in self.functions.iter() { + for error in type_check_func(interner, *func) { + self.errors.push((error.into(), *file)); + } + } + } + + fn type_check_trait_impl_function(&mut self, interner: &mut NodeInterner) { + for (file, func) in self.trait_impl_functions.iter() { + for error in check_trait_impl_method_matches_declaration(interner, *func) { + self.errors.push((error.into(), *file)); + } + for error in type_check_func(interner, *func) { + self.errors.push((error.into(), *file)); + } + } + } + + /// Evaluate all `comptime` expressions in this module + fn evaluate_comptime(&self, interner: &mut NodeInterner) { + let mut interpreter = Interpreter::new(interner); + + for (_file, function) in &self.functions { + // .unwrap() is temporary here until we can convert + // from InterpreterError to (CompilationError, FileId) + interpreter.scan_function(*function).unwrap(); + } + } -fn type_check_trait_impl_signatures( - interner: &mut NodeInterner, - file_func_ids: &[(FileId, FuncId)], -) -> Vec<(CompilationError, fm::FileId)> { - file_func_ids - .iter() - .flat_map(|(file, func)| { - check_trait_impl_method_matches_declaration(interner, *func) - .into_iter() - .map(|e| (e.into(), *file)) - .collect::>() - }) - .collect() + fn resolve_globals( + &mut self, + context: &mut Context, + literal_globals: Vec, + crate_id: CrateId, + ) { + let globals = resolve_globals(context, literal_globals, crate_id); + self.globals.extend(globals.globals); + self.errors.extend(globals.errors); + } } diff --git a/compiler/noirc_frontend/src/hir/resolution/globals.rs b/compiler/noirc_frontend/src/hir/resolution/globals.rs index 9fb31271727..bcda4e75d3d 100644 --- a/compiler/noirc_frontend/src/hir/resolution/globals.rs +++ b/compiler/noirc_frontend/src/hir/resolution/globals.rs @@ -11,18 +11,12 @@ use crate::{ use fm::FileId; use iter_extended::vecmap; +#[derive(Default)] pub(crate) struct ResolvedGlobals { pub(crate) globals: Vec<(FileId, GlobalId)>, pub(crate) errors: Vec<(CompilationError, FileId)>, } -impl ResolvedGlobals { - pub(crate) fn extend(&mut self, oth: Self) { - self.globals.extend(oth.globals); - self.errors.extend(oth.errors); - } -} - pub(crate) fn resolve_globals( context: &mut Context, globals: Vec, diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index 0e69b3bdeba..8d2cb17189b 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -1272,9 +1272,9 @@ impl<'a> Resolver<'a> { HirStatement::Continue } StatementKind::Error => HirStatement::Error, - StatementKind::Comptime(statement) => { + StatementKind::CompTime(statement) => { let statement = self.resolve_stmt(*statement, span); - HirStatement::Comptime(self.interner.push_stmt(statement)) + HirStatement::CompTime(self.interner.push_stmt(statement)) } } } @@ -1323,7 +1323,9 @@ impl<'a> Resolver<'a> { pub fn intern_stmt(&mut self, stmt: Statement) -> StmtId { let hir_stmt = self.resolve_stmt(stmt.kind, stmt.span); - self.interner.push_stmt(hir_stmt) + let id = self.interner.push_stmt(hir_stmt); + self.interner.push_statement_location(id, stmt.span, self.file); + id } fn resolve_lvalue(&mut self, lvalue: LValue) -> HirLValue { @@ -1531,7 +1533,9 @@ impl<'a> Resolver<'a> { collection: self.resolve_expression(indexed_expr.collection), index: self.resolve_expression(indexed_expr.index), }), - ExpressionKind::Block(block_expr) => self.resolve_block(block_expr), + ExpressionKind::Block(block_expr) => { + HirExpression::Block(self.resolve_block(block_expr)) + } ExpressionKind::Constructor(constructor) => { let span = constructor.type_name.span(); @@ -1598,6 +1602,7 @@ impl<'a> Resolver<'a> { // The quoted expression isn't resolved since we don't want errors if variables aren't defined ExpressionKind::Quote(block) => HirExpression::Quote(block), + ExpressionKind::Comptime(block) => HirExpression::Comptime(self.resolve_block(block)), }; // If these lines are ever changed, make sure to change the early return @@ -1930,14 +1935,14 @@ impl<'a> Resolver<'a> { Ok(path_resolution.module_def_id) } - fn resolve_block(&mut self, block_expr: BlockExpression) -> HirExpression { + fn resolve_block(&mut self, block_expr: BlockExpression) -> HirBlockExpression { let statements = self.in_new_scope(|this| vecmap(block_expr.statements, |stmt| this.intern_stmt(stmt))); - HirExpression::Block(HirBlockExpression { statements }) + HirBlockExpression { statements } } pub fn intern_block(&mut self, block: BlockExpression) -> ExprId { - let hir_block = self.resolve_block(block); + let hir_block = HirExpression::Block(self.resolve_block(block)); self.interner.push_expr(hir_block) } diff --git a/compiler/noirc_frontend/src/hir/type_check/expr.rs b/compiler/noirc_frontend/src/hir/type_check/expr.rs index 62330732be4..0bc7673e105 100644 --- a/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -6,8 +6,8 @@ use crate::{ hir::{resolution::resolver::verify_mutable_reference, type_check::errors::Source}, hir_def::{ expr::{ - self, HirArrayLiteral, HirBinaryOp, HirExpression, HirIdent, HirLiteral, - HirMethodCallExpression, HirMethodReference, HirPrefixExpression, ImplKind, + self, HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirExpression, HirIdent, + HirLiteral, HirMethodCallExpression, HirMethodReference, HirPrefixExpression, ImplKind, }, types::Type, }, @@ -271,34 +271,7 @@ impl<'interner> TypeChecker<'interner> { let span = self.interner.expr_span(expr_id); self.check_cast(lhs_type, cast_expr.r#type, span) } - HirExpression::Block(block_expr) => { - let mut block_type = Type::Unit; - - let statements = block_expr.statements(); - for (i, stmt) in statements.iter().enumerate() { - let expr_type = self.check_statement(stmt); - - if let crate::hir_def::stmt::HirStatement::Semi(expr) = - self.interner.statement(stmt) - { - let inner_expr_type = self.interner.id_type(expr); - let span = self.interner.expr_span(&expr); - - self.unify(&inner_expr_type, &Type::Unit, || { - TypeCheckError::UnusedResultError { - expr_type: inner_expr_type.clone(), - expr_span: span, - } - }); - } - - if i + 1 == statements.len() { - block_type = expr_type; - } - } - - block_type - } + HirExpression::Block(block_expr) => self.check_block(block_expr), HirExpression::Prefix(prefix_expr) => { let rhs_type = self.check_expression(&prefix_expr.rhs); let span = self.interner.expr_span(&prefix_expr.rhs); @@ -336,12 +309,44 @@ impl<'interner> TypeChecker<'interner> { Type::Function(params, Box::new(lambda.return_type), Box::new(env_type)) } HirExpression::Quote(_) => Type::Code, + HirExpression::Comptime(block) => self.check_block(block), + + // Unquote should be inserted & removed by the comptime interpreter. + // Even if we allowed it here, we wouldn't know what type to give to the result. + HirExpression::Unquote(block) => { + unreachable!("Unquote remaining during type checking {block}") + } }; self.interner.push_expr_type(*expr_id, typ.clone()); typ } + fn check_block(&mut self, block: HirBlockExpression) -> Type { + let mut block_type = Type::Unit; + + let statements = block.statements(); + for (i, stmt) in statements.iter().enumerate() { + let expr_type = self.check_statement(stmt); + + if let crate::hir_def::stmt::HirStatement::Semi(expr) = self.interner.statement(stmt) { + let inner_expr_type = self.interner.id_type(expr); + let span = self.interner.expr_span(&expr); + + self.unify(&inner_expr_type, &Type::Unit, || TypeCheckError::UnusedResultError { + expr_type: inner_expr_type.clone(), + expr_span: span, + }); + } + + if i + 1 == statements.len() { + block_type = expr_type; + } + } + + block_type + } + /// Returns the type of the given identifier fn check_ident(&mut self, ident: HirIdent, expr_id: &ExprId) -> Type { let mut bindings = TypeBindings::new(); diff --git a/compiler/noirc_frontend/src/hir/type_check/stmt.rs b/compiler/noirc_frontend/src/hir/type_check/stmt.rs index f5f6e1e8180..064fefc8ae9 100644 --- a/compiler/noirc_frontend/src/hir/type_check/stmt.rs +++ b/compiler/noirc_frontend/src/hir/type_check/stmt.rs @@ -51,7 +51,7 @@ impl<'interner> TypeChecker<'interner> { HirStatement::Constrain(constrain_stmt) => self.check_constrain_stmt(constrain_stmt), HirStatement::Assign(assign_stmt) => self.check_assign_stmt(assign_stmt, stmt_id), HirStatement::For(for_loop) => self.check_for_loop(for_loop), - HirStatement::Comptime(statement) => return self.check_statement(&statement), + HirStatement::CompTime(statement) => return self.check_statement(&statement), HirStatement::Break | HirStatement::Continue | HirStatement::Error => (), } Type::Unit diff --git a/compiler/noirc_frontend/src/hir_def/expr.rs b/compiler/noirc_frontend/src/hir_def/expr.rs index d88b65d1fce..bf7d9b7b4ba 100644 --- a/compiler/noirc_frontend/src/hir_def/expr.rs +++ b/compiler/noirc_frontend/src/hir_def/expr.rs @@ -31,8 +31,10 @@ pub enum HirExpression { If(HirIfExpression), Tuple(Vec), Lambda(HirLambda), - Error, Quote(crate::ast::BlockExpression), + Unquote(crate::ast::BlockExpression), + Comptime(HirBlockExpression), + Error, } impl HirExpression { diff --git a/compiler/noirc_frontend/src/hir_def/stmt.rs b/compiler/noirc_frontend/src/hir_def/stmt.rs index 7e22e5ee9c0..48e7d7344e3 100644 --- a/compiler/noirc_frontend/src/hir_def/stmt.rs +++ b/compiler/noirc_frontend/src/hir_def/stmt.rs @@ -20,7 +20,7 @@ pub enum HirStatement { Continue, Expression(ExprId), Semi(ExprId), - Comptime(StmtId), + CompTime(StmtId), Error, } diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index 74a0dd855c0..d92b6c65d7a 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -521,6 +521,12 @@ impl<'interner> Monomorphizer<'interner> { } HirExpression::Error => unreachable!("Encountered Error node during monomorphization"), HirExpression::Quote(_) => unreachable!("quote expression remaining in runtime code"), + HirExpression::Unquote(_) => { + unreachable!("unquote expression remaining in runtime code") + } + HirExpression::Comptime(_) => { + unreachable!("comptime expression remaining in runtime code") + } }; Ok(expr) @@ -626,7 +632,7 @@ impl<'interner> Monomorphizer<'interner> { HirStatement::Error => unreachable!(), // All `comptime` statements & expressions should be removed before runtime. - HirStatement::Comptime(_) => unreachable!("comptime statement in runtime code"), + HirStatement::CompTime(_) => unreachable!("comptime statement in runtime code"), } } diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index b0e68be4868..9d3a79820dc 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -921,10 +921,18 @@ impl NodeInterner { self.id_location(expr_id) } - pub fn statement_span(&self, stmt_id: &StmtId) -> Span { + pub fn statement_span(&self, stmt_id: StmtId) -> Span { self.id_location(stmt_id).span } + pub fn statement_location(&self, stmt_id: StmtId) -> Location { + self.id_location(stmt_id) + } + + pub fn push_statement_location(&mut self, id: StmtId, span: Span, file: FileId) { + self.id_to_location.insert(id.into(), Location::new(span, file)); + } + pub fn get_struct(&self, id: StructId) -> Shared { self.structs[&id].clone() } diff --git a/compiler/noirc_frontend/src/parser/parser.rs b/compiler/noirc_frontend/src/parser/parser.rs index 603193d1593..858e5c4838c 100644 --- a/compiler/noirc_frontend/src/parser/parser.rs +++ b/compiler/noirc_frontend/src/parser/parser.rs @@ -540,7 +540,7 @@ where StatementKind::Expression(Expression::new(ExpressionKind::Block(block), span)) }), ))) - .map(|statement| StatementKind::Comptime(Box::new(statement))) + .map(|statement| StatementKind::CompTime(Box::new(statement))) } /// Comptime in an expression position only accepts entire blocks @@ -548,7 +548,7 @@ fn comptime_expr<'a, S>(statement: S) -> impl NoirParser + 'a where S: NoirParser + 'a, { - keyword(Keyword::CompTime).ignore_then(block(statement)).map(ExpressionKind::Block) + keyword(Keyword::CompTime).ignore_then(block(statement)).map(ExpressionKind::Comptime) } fn declaration<'a, P>(expr_parser: P) -> impl NoirParser + 'a diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index 31bf2245b1f..ac3d7bbc4cc 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -780,7 +780,7 @@ mod test { HirStatement::Error => panic!("Invalid HirStatement!"), HirStatement::Break => panic!("Unexpected break"), HirStatement::Continue => panic!("Unexpected continue"), - HirStatement::Comptime(_) => panic!("Unexpected comptime"), + HirStatement::CompTime(_) => panic!("Unexpected comptime"), }; let expr = interner.expression(&expr_id); diff --git a/tooling/nargo_fmt/src/rewrite/expr.rs b/tooling/nargo_fmt/src/rewrite/expr.rs index e4616c99aaa..6b7dca6c5c7 100644 --- a/tooling/nargo_fmt/src/rewrite/expr.rs +++ b/tooling/nargo_fmt/src/rewrite/expr.rs @@ -159,6 +159,9 @@ pub(crate) fn rewrite( } ExpressionKind::Lambda(_) | ExpressionKind::Variable(_) => visitor.slice(span).to_string(), ExpressionKind::Quote(block) => format!("quote {}", rewrite_block(visitor, block, span)), + ExpressionKind::Comptime(block) => { + format!("comptime {}", rewrite_block(visitor, block, span)) + } ExpressionKind::Error => unreachable!(), } } diff --git a/tooling/nargo_fmt/src/visitor/stmt.rs b/tooling/nargo_fmt/src/visitor/stmt.rs index e41827c94a1..869977d5f3c 100644 --- a/tooling/nargo_fmt/src/visitor/stmt.rs +++ b/tooling/nargo_fmt/src/visitor/stmt.rs @@ -103,7 +103,7 @@ impl super::FmtVisitor<'_> { StatementKind::Error => unreachable!(), StatementKind::Break => self.push_rewrite("break;".into(), span), StatementKind::Continue => self.push_rewrite("continue;".into(), span), - StatementKind::Comptime(statement) => self.visit_stmt(*statement, span, is_last), + StatementKind::CompTime(statement) => self.visit_stmt(*statement, span, is_last), } } }