diff --git a/compiler/noirc_frontend/src/ast/expression.rs b/compiler/noirc_frontend/src/ast/expression.rs index 92c1add80a6..5659de46588 100644 --- a/compiler/noirc_frontend/src/ast/expression.rs +++ b/compiler/noirc_frontend/src/ast/expression.rs @@ -28,6 +28,7 @@ pub enum ExpressionKind { Lambda(Box), Parenthesized(Box), Quote(BlockExpression), + Comptime(BlockExpression), Error, } @@ -504,6 +505,7 @@ impl Display for ExpressionKind { Lambda(lambda) => lambda.fmt(f), Parenthesized(sub_expr) => write!(f, "({sub_expr})"), Quote(block) => write!(f, "quote {block}"), + Comptime(block) => write!(f, "comptime {block}"), Error => write!(f, "Error"), } } diff --git a/compiler/noirc_frontend/src/ast/statement.rs b/compiler/noirc_frontend/src/ast/statement.rs index 1831a046f5b..f37c7adc983 100644 --- a/compiler/noirc_frontend/src/ast/statement.rs +++ b/compiler/noirc_frontend/src/ast/statement.rs @@ -40,7 +40,7 @@ pub enum StatementKind { Break, Continue, /// This statement should be executed at compile-time - Comptime(Box), + CompTime(Box), // This is an expression with a trailing semi-colon Semi(Expression), // This statement is the result of a recovered parse error. @@ -87,10 +87,10 @@ impl StatementKind { } self } - StatementKind::Comptime(mut statement) => { + StatementKind::CompTime(mut statement) => { *statement = statement.add_semicolon(semi, span, last_statement_in_block, emit_error); - StatementKind::Comptime(statement) + StatementKind::CompTime(statement) } // A semicolon on a for loop is optional and does nothing StatementKind::For(_) => self, @@ -685,7 +685,7 @@ impl Display for StatementKind { StatementKind::For(for_loop) => for_loop.fmt(f), StatementKind::Break => write!(f, "break"), StatementKind::Continue => write!(f, "continue"), - StatementKind::Comptime(statement) => write!(f, "comptime {statement}"), + StatementKind::CompTime(statement) => write!(f, "comptime {statement}"), StatementKind::Semi(semi) => write!(f, "{semi};"), StatementKind::Error => write!(f, "Error"), } diff --git a/compiler/noirc_frontend/src/hir/comptime/errors.rs b/compiler/noirc_frontend/src/hir/comptime/errors.rs new file mode 100644 index 00000000000..d3db7fcaee9 --- /dev/null +++ b/compiler/noirc_frontend/src/hir/comptime/errors.rs @@ -0,0 +1,53 @@ +use crate::{node_interner::DefinitionId, Type}; +use acvm::FieldElement; +use noirc_errors::Location; + +use super::value::Value; + +/// The possible errors that can halt the interpreter. +#[derive(Debug)] +pub enum InterpreterError { + ArgumentCountMismatch { expected: usize, actual: usize, call_location: Location }, + TypeMismatch { expected: Type, value: Value, location: Location }, + NoValueForId { id: DefinitionId, location: Location }, + IntegerOutOfRangeForType { value: FieldElement, typ: Type, location: Location }, + ErrorNodeEncountered { location: Location }, + NonFunctionCalled { value: Value, location: Location }, + NonBoolUsedInIf { value: Value, location: Location }, + NonBoolUsedInConstrain { value: Value, location: Location }, + FailingConstraint { message: Option, location: Location }, + NoMethodFound { object: Value, typ: Type, location: Location }, + NonIntegerUsedInLoop { value: Value, location: Location }, + NonPointerDereferenced { value: Value, location: Location }, + NonTupleOrStructInMemberAccess { value: Value, location: Location }, + NonArrayIndexed { value: Value, location: Location }, + NonIntegerUsedAsIndex { value: Value, location: Location }, + NonIntegerIntegerLiteral { typ: Type, location: Location }, + NonIntegerArrayLength { typ: Type, location: Location }, + NonNumericCasted { value: Value, location: Location }, + IndexOutOfBounds { index: usize, length: usize, location: Location }, + ExpectedStructToHaveField { value: Value, field_name: String, location: Location }, + TypeUnsupported { typ: Type, location: Location }, + InvalidValueForUnary { value: Value, operator: &'static str, location: Location }, + InvalidValuesForBinary { lhs: Value, rhs: Value, operator: &'static str, location: Location }, + CastToNonNumericType { typ: Type, location: Location }, + QuoteInRuntimeCode { location: Location }, + NonStructInConstructor { typ: Type, location: Location }, + CannotInlineMacro { value: Value, location: Location }, + UnquoteFoundDuringEvaluation { location: Location }, + + Unimplemented { item: &'static str, location: Location }, + + // Perhaps this should be unreachable! due to type checking also preventing this error? + // Currently it and the Continue variant are the only interpreter errors without a Location field + BreakNotInLoop { location: Location }, + ContinueNotInLoop { location: Location }, + + // These cases are not errors, they are just used to prevent us from running more code + // until the loop can be resumed properly. These cases will never be displayed to users. + Break, + Continue, +} + +#[allow(unused)] +pub(super) type IResult = std::result::Result; diff --git a/compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs b/compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs index 47ca7083ff0..dd23edf0004 100644 --- a/compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs +++ b/compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs @@ -9,7 +9,7 @@ use crate::ast::{ UnresolvedTypeData, UnresolvedTypeExpression, }; use crate::ast::{ConstrainStatement, Expression, Statement, StatementKind}; -use crate::hir_def::expr::{HirArrayLiteral, HirExpression, HirIdent}; +use crate::hir_def::expr::{HirArrayLiteral, HirBlockExpression, HirExpression, HirIdent}; use crate::hir_def::stmt::{HirLValue, HirPattern, HirStatement}; use crate::hir_def::types::Type; use crate::macros_api::HirLiteral; @@ -26,7 +26,7 @@ impl StmtId { #[allow(unused)] fn to_ast(self, interner: &NodeInterner) -> Statement { let statement = interner.statement(&self); - let span = interner.statement_span(&self); + let span = interner.statement_span(self); let kind = match statement { HirStatement::Let(let_stmt) => { @@ -66,8 +66,8 @@ impl StmtId { HirStatement::Expression(expr) => StatementKind::Expression(expr.to_ast(interner)), HirStatement::Semi(expr) => StatementKind::Semi(expr.to_ast(interner)), HirStatement::Error => StatementKind::Error, - HirStatement::Comptime(statement) => { - StatementKind::Comptime(Box::new(statement.to_ast(interner).kind)) + HirStatement::CompTime(statement) => { + StatementKind::CompTime(Box::new(statement.to_ast(interner).kind)) } }; @@ -108,10 +108,7 @@ impl ExprId { ExpressionKind::Literal(Literal::FmtStr(string)) } HirExpression::Literal(HirLiteral::Unit) => ExpressionKind::Literal(Literal::Unit), - HirExpression::Block(expr) => { - let statements = vecmap(expr.statements, |statement| statement.to_ast(interner)); - ExpressionKind::Block(BlockExpression { statements }) - } + HirExpression::Block(expr) => ExpressionKind::Block(expr.into_ast(interner)), HirExpression::Prefix(prefix) => ExpressionKind::Prefix(Box::new(PrefixExpression { operator: prefix.operator, rhs: prefix.rhs.to_ast(interner), @@ -172,8 +169,12 @@ impl ExprId { let body = lambda.body.to_ast(interner); ExpressionKind::Lambda(Box::new(Lambda { parameters, return_type, body })) } - HirExpression::Quote(block) => ExpressionKind::Quote(block), HirExpression::Error => ExpressionKind::Error, + HirExpression::Comptime(block) => ExpressionKind::Comptime(block.into_ast(interner)), + HirExpression::Quote(block) => ExpressionKind::Quote(block), + + // A macro was evaluated here! + HirExpression::Unquote(block) => ExpressionKind::Block(block), }; Expression::new(kind, span) @@ -353,3 +354,10 @@ impl HirArrayLiteral { } } } + +impl HirBlockExpression { + fn into_ast(self, interner: &NodeInterner) -> BlockExpression { + let statements = vecmap(self.statements, |statement| statement.to_ast(interner)); + BlockExpression { statements } + } +} diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs index 03839c2f0cd..c6d508a581e 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs @@ -1,12 +1,12 @@ -use std::{borrow::Cow, collections::hash_map::Entry, rc::Rc}; +use std::{collections::hash_map::Entry, rc::Rc}; use acvm::FieldElement; use im::Vector; -use iter_extended::{try_vecmap, vecmap}; +use iter_extended::try_vecmap; use noirc_errors::Location; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; -use crate::ast::{BinaryOpKind, BlockExpression, FunctionKind, IntegerBitSize, Signedness}; +use crate::ast::{BinaryOpKind, FunctionKind, IntegerBitSize, Signedness}; use crate::{ hir_def::{ expr::{ @@ -22,12 +22,16 @@ use crate::{ }, macros_api::{HirExpression, HirLiteral, HirStatement, NodeInterner}, node_interner::{DefinitionId, DefinitionKind, ExprId, FuncId, StmtId}, + Shared, Type, TypeBinding, TypeBindings, TypeVariableKind, }; -use crate::{Shared, Type, TypeBinding, TypeBindings, TypeVariableKind}; + +use super::errors::{IResult, InterpreterError}; +use super::value::Value; + #[allow(unused)] -pub(crate) struct Interpreter<'interner> { +pub struct Interpreter<'interner> { /// To expand macros the Interpreter may mutate hir nodes within the NodeInterner - interner: &'interner mut NodeInterner, + pub(super) interner: &'interner mut NodeInterner, /// Each value currently in scope in the interpreter. /// Each element of the Vec represents a scope with every scope together making @@ -49,72 +53,6 @@ pub(crate) struct Interpreter<'interner> { in_comptime_context: bool, } -#[allow(unused)] -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum Value { - Unit, - Bool(bool), - Field(FieldElement), - I8(i8), - I32(i32), - I64(i64), - U8(u8), - U32(u32), - U64(u64), - String(Rc), - Function(FuncId, Type), - Closure(HirLambda, Vec, Type), - Tuple(Vec), - Struct(HashMap, Value>, Type), - Pointer(Shared), - Array(Vector, Type), - Slice(Vector, Type), - Code(Rc), -} - -/// The possible errors that can halt the interpreter. -#[allow(unused)] -#[derive(Debug)] -pub(crate) enum InterpreterError { - ArgumentCountMismatch { expected: usize, actual: usize, call_location: Location }, - TypeMismatch { expected: Type, value: Value, location: Location }, - NoValueForId { id: DefinitionId, location: Location }, - IntegerOutOfRangeForType { value: FieldElement, typ: Type, location: Location }, - ErrorNodeEncountered { location: Location }, - NonFunctionCalled { value: Value, location: Location }, - NonBoolUsedInIf { value: Value, location: Location }, - NonBoolUsedInConstrain { value: Value, location: Location }, - FailingConstraint { message: Option, location: Location }, - NoMethodFound { object: Value, typ: Type, location: Location }, - NonIntegerUsedInLoop { value: Value, location: Location }, - NonPointerDereferenced { value: Value, location: Location }, - NonTupleOrStructInMemberAccess { value: Value, location: Location }, - NonArrayIndexed { value: Value, location: Location }, - NonIntegerUsedAsIndex { value: Value, location: Location }, - NonIntegerIntegerLiteral { typ: Type, location: Location }, - NonIntegerArrayLength { typ: Type, location: Location }, - NonNumericCasted { value: Value, location: Location }, - IndexOutOfBounds { index: usize, length: usize, location: Location }, - ExpectedStructToHaveField { value: Value, field_name: String, location: Location }, - TypeUnsupported { typ: Type, location: Location }, - InvalidValueForUnary { value: Value, operator: &'static str, location: Location }, - InvalidValuesForBinary { lhs: Value, rhs: Value, operator: &'static str, location: Location }, - CastToNonNumericType { typ: Type, location: Location }, - - // Perhaps this should be unreachable! due to type checking also preventing this error? - // Currently it and the Continue variant are the only interpreter errors without a Location field - BreakNotInLoop, - ContinueNotInLoop, - - // These cases are not errors but prevent us from running more code - // until the loop can be resumed properly. - Break, - Continue, -} - -#[allow(unused)] -type IResult = std::result::Result; - #[allow(unused)] impl<'a> Interpreter<'a> { pub(crate) fn new(interner: &'a mut NodeInterner) -> Self { @@ -193,14 +131,14 @@ impl<'a> Interpreter<'a> { /// Enters a function, pushing a new scope and resetting any required state. /// Returns the previous values of the internal state, to be reset when /// `exit_function` is called. - fn enter_function(&mut self) -> (bool, Vec>) { + pub(super) fn enter_function(&mut self) -> (bool, Vec>) { // Drain every scope except the global scope let scope = self.scopes.drain(1..).collect(); self.push_scope(); (std::mem::take(&mut self.in_loop), scope) } - fn exit_function(&mut self, mut state: (bool, Vec>)) { + pub(super) fn exit_function(&mut self, mut state: (bool, Vec>)) { self.in_loop = state.0; // Keep only the global scope @@ -208,11 +146,11 @@ impl<'a> Interpreter<'a> { self.scopes.append(&mut state.1); } - fn push_scope(&mut self) { + pub(super) fn push_scope(&mut self) { self.scopes.push(HashMap::default()); } - fn pop_scope(&mut self) { + pub(super) fn pop_scope(&mut self) { self.scopes.pop(); } @@ -375,6 +313,13 @@ impl<'a> Interpreter<'a> { HirExpression::Tuple(tuple) => self.evaluate_tuple(tuple), HirExpression::Lambda(lambda) => self.evaluate_lambda(lambda, id), HirExpression::Quote(block) => Ok(Value::Code(Rc::new(block))), + HirExpression::Comptime(block) => self.evaluate_block(block), + HirExpression::Unquote(block) => { + // An Unquote expression being found is indicative of a macro being + // expanded within another comptime fn which we don't currently support. + let location = self.interner.expr_location(&id); + Err(InterpreterError::UnquoteFoundDuringEvaluation { location }) + } HirExpression::Error => { let location = self.interner.expr_location(&id); Err(InterpreterError::ErrorNodeEncountered { location }) @@ -390,7 +335,7 @@ impl<'a> Interpreter<'a> { let typ = self.interner.id_type(id); Ok(Value::Function(*function_id, typ)) } - DefinitionKind::Local(_) => dbg!(self.lookup(&ident)), + DefinitionKind::Local(_) => self.lookup(&ident), DefinitionKind::Global(global_id) => { let let_ = self.interner.get_global_let_statement(*global_id).unwrap(); self.evaluate_let(let_)?; @@ -503,7 +448,7 @@ impl<'a> Interpreter<'a> { } } - fn evaluate_block(&mut self, mut block: HirBlockExpression) -> IResult { + pub(super) fn evaluate_block(&mut self, mut block: HirBlockExpression) -> IResult { let last_statement = block.statements.pop(); self.push_scope(); @@ -1075,10 +1020,10 @@ impl<'a> Interpreter<'a> { HirStatement::Constrain(constrain) => self.evaluate_constrain(constrain), HirStatement::Assign(assign) => self.evaluate_assign(assign), HirStatement::For(for_) => self.evaluate_for(for_), - HirStatement::Break => self.evaluate_break(), - HirStatement::Continue => self.evaluate_continue(), + HirStatement::Break => self.evaluate_break(statement), + HirStatement::Continue => self.evaluate_continue(statement), HirStatement::Expression(expression) => self.evaluate(expression), - HirStatement::Comptime(statement) => self.evaluate_comptime(statement), + HirStatement::CompTime(statement) => self.evaluate_comptime(statement), HirStatement::Semi(expression) => { self.evaluate(expression)?; Ok(Value::Unit) @@ -1236,59 +1181,28 @@ impl<'a> Interpreter<'a> { Ok(Value::Unit) } - fn evaluate_break(&mut self) -> IResult { + fn evaluate_break(&mut self, id: StmtId) -> IResult { if self.in_loop { Err(InterpreterError::Break) } else { - Err(InterpreterError::BreakNotInLoop) + let location = self.interner.statement_location(id); + Err(InterpreterError::BreakNotInLoop { location }) } } - fn evaluate_continue(&mut self) -> IResult { + fn evaluate_continue(&mut self, id: StmtId) -> IResult { if self.in_loop { Err(InterpreterError::Continue) } else { - Err(InterpreterError::ContinueNotInLoop) + let location = self.interner.statement_location(id); + Err(InterpreterError::ContinueNotInLoop { location }) } } - fn evaluate_comptime(&mut self, statement: StmtId) -> IResult { + pub(super) fn evaluate_comptime(&mut self, statement: StmtId) -> IResult { let was_in_comptime = std::mem::replace(&mut self.in_comptime_context, true); let result = self.evaluate_statement(statement); self.in_comptime_context = was_in_comptime; result } } - -impl Value { - fn get_type(&self) -> Cow { - Cow::Owned(match self { - Value::Unit => Type::Unit, - Value::Bool(_) => Type::Bool, - Value::Field(_) => Type::FieldElement, - Value::I8(_) => Type::Integer(Signedness::Signed, IntegerBitSize::Eight), - Value::I32(_) => Type::Integer(Signedness::Signed, IntegerBitSize::ThirtyTwo), - Value::I64(_) => Type::Integer(Signedness::Signed, IntegerBitSize::SixtyFour), - Value::U8(_) => Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight), - Value::U32(_) => Type::Integer(Signedness::Unsigned, IntegerBitSize::ThirtyTwo), - Value::U64(_) => Type::Integer(Signedness::Unsigned, IntegerBitSize::SixtyFour), - Value::String(value) => { - let length = Type::Constant(value.len() as u64); - Type::String(Box::new(length)) - } - Value::Function(_, typ) => return Cow::Borrowed(typ), - Value::Closure(_, _, typ) => return Cow::Borrowed(typ), - Value::Tuple(fields) => { - Type::Tuple(vecmap(fields, |field| field.get_type().into_owned())) - } - Value::Struct(_, typ) => return Cow::Borrowed(typ), - Value::Array(_, typ) => return Cow::Borrowed(typ), - Value::Slice(_, typ) => return Cow::Borrowed(typ), - Value::Code(_) => Type::Code, - Value::Pointer(element) => { - let element = element.borrow().get_type().into_owned(); - Type::MutableReference(Box::new(element)) - } - }) - } -} diff --git a/compiler/noirc_frontend/src/hir/comptime/mod.rs b/compiler/noirc_frontend/src/hir/comptime/mod.rs index 83aaddaa405..26e05d675b3 100644 --- a/compiler/noirc_frontend/src/hir/comptime/mod.rs +++ b/compiler/noirc_frontend/src/hir/comptime/mod.rs @@ -1,3 +1,8 @@ +mod errors; mod hir_to_ast; mod interpreter; +mod scan; mod tests; +mod value; + +pub use interpreter::Interpreter; diff --git a/compiler/noirc_frontend/src/hir/comptime/scan.rs b/compiler/noirc_frontend/src/hir/comptime/scan.rs new file mode 100644 index 00000000000..d4fa355627f --- /dev/null +++ b/compiler/noirc_frontend/src/hir/comptime/scan.rs @@ -0,0 +1,207 @@ +//! This module is for the scanning of the Hir by the interpreter. +//! In this initial step, the Hir is scanned for `CompTime` nodes +//! without actually executing anything until such a node is found. +//! Once such a node is found, the interpreter will call the relevant +//! evaluate method on that node type, insert the result into the Ast, +//! and continue scanning the rest of the program. +//! +//! Since it mostly just needs to recur on the Hir looking for CompTime +//! nodes, this pass is fairly simple. The only thing it really needs to +//! ensure to do is to push and pop scopes on the interpreter as needed +//! so that any variables defined within e.g. an `if` statement containing +//! a `CompTime` block aren't accessible outside of the `if`. +use crate::{ + hir_def::{ + expr::{ + HirArrayLiteral, HirBlockExpression, HirCallExpression, HirConstructorExpression, + HirIfExpression, HirIndexExpression, HirInfixExpression, HirLambda, + HirMethodCallExpression, + }, + stmt::HirForStatement, + }, + macros_api::{HirExpression, HirLiteral, HirStatement}, + node_interner::{ExprId, FuncId, StmtId}, +}; + +use super::{ + errors::{IResult, InterpreterError}, + interpreter::Interpreter, +}; + +#[allow(dead_code)] +impl<'interner> Interpreter<'interner> { + /// Scan through a function, evaluating any CompTime nodes found. + /// These nodes will be modified in place, replaced with the + /// result of their evaluation. + pub fn scan_function(&mut self, function: FuncId) -> IResult<()> { + let function = self.interner.function(&function); + + let state = self.enter_function(); + self.scan_expression(function.as_expr())?; + self.exit_function(state); + Ok(()) + } + + fn scan_expression(&mut self, expr: ExprId) -> IResult<()> { + match self.interner.expression(&expr) { + HirExpression::Ident(_) => Ok(()), + HirExpression::Literal(literal) => self.scan_literal(literal), + HirExpression::Block(block) => self.scan_block(block), + HirExpression::Prefix(prefix) => self.scan_expression(prefix.rhs), + HirExpression::Infix(infix) => self.scan_infix(infix), + HirExpression::Index(index) => self.scan_index(index), + HirExpression::Constructor(constructor) => self.scan_constructor(constructor), + HirExpression::MemberAccess(member_access) => self.scan_expression(member_access.lhs), + HirExpression::Call(call) => self.scan_call(call), + HirExpression::MethodCall(method_call) => self.scan_method_call(method_call), + HirExpression::Cast(cast) => self.scan_expression(cast.lhs), + HirExpression::If(if_) => self.scan_if(if_), + HirExpression::Tuple(tuple) => self.scan_tuple(tuple), + HirExpression::Lambda(lambda) => self.scan_lambda(lambda), + HirExpression::Comptime(block) => { + let location = self.interner.expr_location(&expr); + let new_expr = + self.evaluate_block(block)?.into_expression(self.interner, location)?; + let new_expr = self.interner.expression(&new_expr); + self.interner.replace_expr(&expr, new_expr); + Ok(()) + } + HirExpression::Quote(_) => { + // This error could be detected much earlier in the compiler pipeline but + // it just makes sense for the comptime code to handle comptime things. + let location = self.interner.expr_location(&expr); + Err(InterpreterError::QuoteInRuntimeCode { location }) + } + HirExpression::Error => Ok(()), + + // Unquote should only be inserted by the comptime interpreter while expanding macros + // and is removed by the Hir -> Ast conversion pass which converts it into a normal block. + // If we find one now during scanning it most likely means the Hir -> Ast conversion + // missed it somehow. In the future we may allow users to manually write unquote + // expressions in their code but for now this is unreachable. + HirExpression::Unquote(block) => { + unreachable!("Found unquote block while scanning: {block}") + } + } + } + + fn scan_literal(&mut self, literal: HirLiteral) -> IResult<()> { + match literal { + HirLiteral::Array(elements) | HirLiteral::Slice(elements) => match elements { + HirArrayLiteral::Standard(elements) => { + for element in elements { + self.scan_expression(element)?; + } + Ok(()) + } + HirArrayLiteral::Repeated { repeated_element, length: _ } => { + self.scan_expression(repeated_element) + } + }, + HirLiteral::Bool(_) + | HirLiteral::Integer(_, _) + | HirLiteral::Str(_) + | HirLiteral::FmtStr(_, _) + | HirLiteral::Unit => Ok(()), + } + } + + fn scan_block(&mut self, block: HirBlockExpression) -> IResult<()> { + self.push_scope(); + for statement in &block.statements { + self.scan_statement(*statement)?; + } + self.pop_scope(); + Ok(()) + } + + fn scan_infix(&mut self, infix: HirInfixExpression) -> IResult<()> { + self.scan_expression(infix.lhs)?; + self.scan_expression(infix.rhs) + } + + fn scan_index(&mut self, index: HirIndexExpression) -> IResult<()> { + self.scan_expression(index.collection)?; + self.scan_expression(index.index) + } + + fn scan_constructor(&mut self, constructor: HirConstructorExpression) -> IResult<()> { + for (_, field) in constructor.fields { + self.scan_expression(field)?; + } + Ok(()) + } + + fn scan_call(&mut self, call: HirCallExpression) -> IResult<()> { + self.scan_expression(call.func)?; + for arg in call.arguments { + self.scan_expression(arg)?; + } + Ok(()) + } + + fn scan_method_call(&mut self, method_call: HirMethodCallExpression) -> IResult<()> { + self.scan_expression(method_call.object)?; + for arg in method_call.arguments { + self.scan_expression(arg)?; + } + Ok(()) + } + + fn scan_if(&mut self, if_: HirIfExpression) -> IResult<()> { + self.scan_expression(if_.condition)?; + + self.push_scope(); + self.scan_expression(if_.consequence)?; + self.pop_scope(); + + if let Some(alternative) = if_.alternative { + self.push_scope(); + self.scan_expression(alternative)?; + self.pop_scope(); + } + Ok(()) + } + + fn scan_tuple(&mut self, tuple: Vec) -> IResult<()> { + for field in tuple { + self.scan_expression(field)?; + } + Ok(()) + } + + fn scan_lambda(&mut self, lambda: HirLambda) -> IResult<()> { + self.scan_expression(lambda.body) + } + + fn scan_statement(&mut self, statement: StmtId) -> IResult<()> { + match self.interner.statement(&statement) { + HirStatement::Let(let_) => self.scan_expression(let_.expression), + HirStatement::Constrain(constrain) => self.scan_expression(constrain.0), + HirStatement::Assign(assign) => self.scan_expression(assign.expression), + HirStatement::For(for_) => self.scan_for(for_), + HirStatement::Break => Ok(()), + HirStatement::Continue => Ok(()), + HirStatement::Expression(expression) => self.scan_expression(expression), + HirStatement::Semi(semi) => self.scan_expression(semi), + HirStatement::Error => Ok(()), + HirStatement::CompTime(comptime) => { + let location = self.interner.statement_location(comptime); + let new_expr = + self.evaluate_comptime(comptime)?.into_expression(self.interner, location)?; + self.interner.replace_statement(statement, HirStatement::Expression(new_expr)); + Ok(()) + } + } + } + + fn scan_for(&mut self, for_: HirForStatement) -> IResult<()> { + // We don't need to set self.in_loop since we're not actually evaluating this loop. + // We just need to push a scope so that if there's a `comptime { .. }` expr inside this + // loop, any variables it defines aren't accessible outside of it. + self.push_scope(); + self.scan_expression(for_.block)?; + self.pop_scope(); + Ok(()) + } +} diff --git a/compiler/noirc_frontend/src/hir/comptime/tests.rs b/compiler/noirc_frontend/src/hir/comptime/tests.rs index 016e7079886..1a84dae4a87 100644 --- a/compiler/noirc_frontend/src/hir/comptime/tests.rs +++ b/compiler/noirc_frontend/src/hir/comptime/tests.rs @@ -2,7 +2,9 @@ use noirc_errors::Location; -use super::interpreter::{Interpreter, InterpreterError, Value}; +use super::errors::InterpreterError; +use super::interpreter::Interpreter; +use super::value::Value; use crate::hir::type_check::test::type_check_src_code; fn interpret_helper(src: &str, func_namespace: Vec) -> Result { diff --git a/compiler/noirc_frontend/src/hir/comptime/value.rs b/compiler/noirc_frontend/src/hir/comptime/value.rs new file mode 100644 index 00000000000..89102344b09 --- /dev/null +++ b/compiler/noirc_frontend/src/hir/comptime/value.rs @@ -0,0 +1,170 @@ +use std::{borrow::Cow, rc::Rc}; + +use acvm::FieldElement; +use im::Vector; +use iter_extended::{try_vecmap, vecmap}; +use noirc_errors::Location; + +use crate::{ + ast::{BlockExpression, Ident, IntegerBitSize, Signedness}, + hir_def::expr::{HirArrayLiteral, HirConstructorExpression, HirIdent, HirLambda, ImplKind}, + macros_api::{HirExpression, HirLiteral, NodeInterner}, + node_interner::{ExprId, FuncId}, + Shared, Type, +}; +use rustc_hash::FxHashMap as HashMap; + +use super::errors::{IResult, InterpreterError}; + +#[allow(unused)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Value { + Unit, + Bool(bool), + Field(FieldElement), + I8(i8), + I32(i32), + I64(i64), + U8(u8), + U32(u32), + U64(u64), + String(Rc), + Function(FuncId, Type), + Closure(HirLambda, Vec, Type), + Tuple(Vec), + Struct(HashMap, Value>, Type), + Pointer(Shared), + Array(Vector, Type), + Slice(Vector, Type), + Code(Rc), +} + +impl Value { + pub(crate) fn get_type(&self) -> Cow { + Cow::Owned(match self { + Value::Unit => Type::Unit, + Value::Bool(_) => Type::Bool, + Value::Field(_) => Type::FieldElement, + Value::I8(_) => Type::Integer(Signedness::Signed, IntegerBitSize::Eight), + Value::I32(_) => Type::Integer(Signedness::Signed, IntegerBitSize::ThirtyTwo), + Value::I64(_) => Type::Integer(Signedness::Signed, IntegerBitSize::SixtyFour), + Value::U8(_) => Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight), + Value::U32(_) => Type::Integer(Signedness::Unsigned, IntegerBitSize::ThirtyTwo), + Value::U64(_) => Type::Integer(Signedness::Unsigned, IntegerBitSize::SixtyFour), + Value::String(value) => { + let length = Type::Constant(value.len() as u64); + Type::String(Box::new(length)) + } + Value::Function(_, typ) => return Cow::Borrowed(typ), + Value::Closure(_, _, typ) => return Cow::Borrowed(typ), + Value::Tuple(fields) => { + Type::Tuple(vecmap(fields, |field| field.get_type().into_owned())) + } + Value::Struct(_, typ) => return Cow::Borrowed(typ), + Value::Array(_, typ) => return Cow::Borrowed(typ), + Value::Slice(_, typ) => return Cow::Borrowed(typ), + Value::Code(_) => Type::Code, + Value::Pointer(element) => { + let element = element.borrow().get_type().into_owned(); + Type::MutableReference(Box::new(element)) + } + }) + } + + pub(crate) fn into_expression( + self, + interner: &mut NodeInterner, + location: Location, + ) -> IResult { + let typ = self.get_type().into_owned(); + + let expression = match self { + Value::Unit => HirExpression::Literal(HirLiteral::Unit), + Value::Bool(value) => HirExpression::Literal(HirLiteral::Bool(value)), + Value::Field(value) => HirExpression::Literal(HirLiteral::Integer(value, false)), + Value::I8(value) => { + let negative = value < 0; + let value = value.abs(); + let value = (value as u128).into(); + HirExpression::Literal(HirLiteral::Integer(value, negative)) + } + Value::I32(value) => { + let negative = value < 0; + let value = value.abs(); + let value = (value as u128).into(); + HirExpression::Literal(HirLiteral::Integer(value, negative)) + } + Value::I64(value) => { + let negative = value < 0; + let value = value.abs(); + let value = (value as u128).into(); + HirExpression::Literal(HirLiteral::Integer(value, negative)) + } + Value::U8(value) => { + HirExpression::Literal(HirLiteral::Integer((value as u128).into(), false)) + } + Value::U32(value) => { + HirExpression::Literal(HirLiteral::Integer((value as u128).into(), false)) + } + Value::U64(value) => { + HirExpression::Literal(HirLiteral::Integer((value as u128).into(), false)) + } + Value::String(value) => HirExpression::Literal(HirLiteral::Str(unwrap_rc(value))), + Value::Function(id, _typ) => { + let id = interner.function_definition_id(id); + let impl_kind = ImplKind::NotATraitMethod; + HirExpression::Ident(HirIdent { location, id, impl_kind }) + } + Value::Closure(_lambda, _env, _typ) => { + // TODO: How should a closure's environment be inlined? + let item = "returning closures from a comptime fn"; + return Err(InterpreterError::Unimplemented { item, location }); + } + Value::Tuple(fields) => { + let fields = try_vecmap(fields, |field| field.into_expression(interner, location))?; + HirExpression::Tuple(fields) + } + Value::Struct(fields, typ) => { + let fields = try_vecmap(fields, |(name, field)| { + let field = field.into_expression(interner, location)?; + Ok((Ident::new(unwrap_rc(name), location.span), field)) + })?; + + let (r#type, struct_generics) = match typ.follow_bindings() { + Type::Struct(def, generics) => (def, generics), + _ => return Err(InterpreterError::NonStructInConstructor { typ, location }), + }; + + HirExpression::Constructor(HirConstructorExpression { + r#type, + struct_generics, + fields, + }) + } + Value::Array(elements, _) => { + let elements = + try_vecmap(elements, |elements| elements.into_expression(interner, location))?; + HirExpression::Literal(HirLiteral::Array(HirArrayLiteral::Standard(elements))) + } + Value::Slice(elements, _) => { + let elements = + try_vecmap(elements, |elements| elements.into_expression(interner, location))?; + HirExpression::Literal(HirLiteral::Slice(HirArrayLiteral::Standard(elements))) + } + Value::Code(block) => HirExpression::Unquote(unwrap_rc(block)), + Value::Pointer(_) => { + return Err(InterpreterError::CannotInlineMacro { value: self, location }) + } + }; + + let id = interner.push_expr(expression); + interner.push_expr_location(id, location.span, location.file); + interner.push_expr_type(id, typ); + Ok(id) + } +} + +/// Unwraps an Rc value without cloning the inner value if the reference count is 1. Clones otherwise. +fn unwrap_rc(rc: Rc) -> T { + Rc::try_unwrap(rc).unwrap_or_else(|rc| (*rc).clone()) +} diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index 4c6b5ab5885..7805f36cdb2 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -1,6 +1,7 @@ use super::dc_mod::collect_defs; use super::errors::{DefCollectorErrorKind, DuplicateType}; use crate::graph::CrateId; +use crate::hir::comptime::Interpreter; use crate::hir::def_map::{CrateDefMap, LocalModuleId, ModuleId}; use crate::hir::resolution::errors::ResolverError; @@ -30,6 +31,15 @@ use std::collections::{BTreeMap, HashMap}; use std::vec; +#[derive(Default)] +pub struct ResolvedModule { + pub globals: Vec<(FileId, GlobalId)>, + pub functions: Vec<(FileId, FuncId)>, + pub trait_impl_functions: Vec<(FileId, FuncId)>, + + pub errors: Vec<(CompilationError, FileId)>, +} + /// Stores all of the unresolved functions in a particular file/mod #[derive(Clone)] pub struct UnresolvedFunctions { @@ -304,6 +314,8 @@ impl DefCollector { } } + let mut resolved_module = ResolvedModule { errors, ..Default::default() }; + // We must first resolve and intern the globals before we can resolve any stmts inside each function. // Each function uses its own resolver with a newly created ScopeForest, and must be resolved again to be within a function's scope // @@ -312,21 +324,29 @@ impl DefCollector { let (literal_globals, other_globals) = filter_literal_globals(def_collector.collected_globals); - let mut resolved_globals = resolve_globals(context, literal_globals, crate_id); + resolved_module.resolve_globals(context, literal_globals, crate_id); - errors.extend(resolve_type_aliases( + resolved_module.errors.extend(resolve_type_aliases( context, def_collector.collected_type_aliases, crate_id, )); - errors.extend(resolve_traits(context, def_collector.collected_traits, crate_id)); + resolved_module.errors.extend(resolve_traits( + context, + def_collector.collected_traits, + crate_id, + )); // Must resolve structs before we resolve globals. - errors.extend(resolve_structs(context, def_collector.collected_types, crate_id)); + resolved_module.errors.extend(resolve_structs( + context, + def_collector.collected_types, + crate_id, + )); // Bind trait impls to their trait. Collect trait functions, that have a // default implementation, which hasn't been overridden. - errors.extend(collect_trait_impls( + resolved_module.errors.extend(collect_trait_impls( context, crate_id, &mut def_collector.collected_traits_impls, @@ -339,54 +359,55 @@ impl DefCollector { // // These are resolved after trait impls so that struct methods are chosen // over trait methods if there are name conflicts. - errors.extend(collect_impls(context, crate_id, &def_collector.collected_impls)); + resolved_module.errors.extend(collect_impls( + context, + crate_id, + &def_collector.collected_impls, + )); // We must wait to resolve non-integer globals until after we resolve structs since struct // globals will need to reference the struct type they're initialized to to ensure they are valid. - resolved_globals.extend(resolve_globals(context, other_globals, crate_id)); - errors.extend(resolved_globals.errors); + resolved_module.resolve_globals(context, other_globals, crate_id); // Resolve each function in the crate. This is now possible since imports have been resolved - let mut functions = Vec::new(); - functions.extend(resolve_free_functions( + resolved_module.functions = resolve_free_functions( &mut context.def_interner, crate_id, &context.def_maps, def_collector.collected_functions, None, - &mut errors, - )); + &mut resolved_module.errors, + ); - functions.extend(resolve_impls( + resolved_module.functions.extend(resolve_impls( &mut context.def_interner, crate_id, &context.def_maps, def_collector.collected_impls, - &mut errors, + &mut resolved_module.errors, )); - let impl_functions = resolve_trait_impls( + resolved_module.trait_impl_functions = resolve_trait_impls( context, def_collector.collected_traits_impls, crate_id, - &mut errors, + &mut resolved_module.errors, ); for macro_processor in macro_processors { macro_processor.process_typed_ast(&crate_id, context).unwrap_or_else( |(macro_err, file_id)| { - errors.push((macro_err.into(), file_id)); + resolved_module.errors.push((macro_err.into(), file_id)); }, ); } - errors.extend(context.def_interner.check_for_dependency_cycles()); + resolved_module.errors.extend(context.def_interner.check_for_dependency_cycles()); - errors.extend(type_check_globals(&mut context.def_interner, resolved_globals.globals)); - errors.extend(type_check_functions(&mut context.def_interner, functions)); - errors.extend(type_check_trait_impl_signatures(&mut context.def_interner, &impl_functions)); - errors.extend(type_check_functions(&mut context.def_interner, impl_functions)); - errors + resolved_module.type_check(context); + resolved_module.evaluate_comptime(&mut context.def_interner); + + resolved_module.errors } } @@ -444,48 +465,59 @@ fn filter_literal_globals( }) } -fn type_check_globals( - interner: &mut NodeInterner, - global_ids: Vec<(FileId, GlobalId)>, -) -> Vec<(CompilationError, fm::FileId)> { - global_ids - .into_iter() - .flat_map(|(file_id, global_id)| { - TypeChecker::check_global(global_id, interner) - .iter() - .cloned() - .map(|e| (e.into(), file_id)) - .collect::>() - }) - .collect() -} +impl ResolvedModule { + fn type_check(&mut self, context: &mut Context) { + self.type_check_globals(&mut context.def_interner); + self.type_check_functions(&mut context.def_interner); + self.type_check_trait_impl_function(&mut context.def_interner); + } -fn type_check_functions( - interner: &mut NodeInterner, - file_func_ids: Vec<(FileId, FuncId)>, -) -> Vec<(CompilationError, fm::FileId)> { - file_func_ids - .into_iter() - .flat_map(|(file, func)| { - type_check_func(interner, func) - .into_iter() - .map(|e| (e.into(), file)) - .collect::>() - }) - .collect() -} + fn type_check_globals(&mut self, interner: &mut NodeInterner) { + for (file_id, global_id) in self.globals.iter() { + for error in TypeChecker::check_global(*global_id, interner) { + self.errors.push((error.into(), *file_id)); + } + } + } + + fn type_check_functions(&mut self, interner: &mut NodeInterner) { + for (file, func) in self.functions.iter() { + for error in type_check_func(interner, *func) { + self.errors.push((error.into(), *file)); + } + } + } + + fn type_check_trait_impl_function(&mut self, interner: &mut NodeInterner) { + for (file, func) in self.trait_impl_functions.iter() { + for error in check_trait_impl_method_matches_declaration(interner, *func) { + self.errors.push((error.into(), *file)); + } + for error in type_check_func(interner, *func) { + self.errors.push((error.into(), *file)); + } + } + } + + /// Evaluate all `comptime` expressions in this module + fn evaluate_comptime(&self, interner: &mut NodeInterner) { + let mut interpreter = Interpreter::new(interner); + + for (_file, function) in &self.functions { + // .unwrap() is temporary here until we can convert + // from InterpreterError to (CompilationError, FileId) + interpreter.scan_function(*function).unwrap(); + } + } -fn type_check_trait_impl_signatures( - interner: &mut NodeInterner, - file_func_ids: &[(FileId, FuncId)], -) -> Vec<(CompilationError, fm::FileId)> { - file_func_ids - .iter() - .flat_map(|(file, func)| { - check_trait_impl_method_matches_declaration(interner, *func) - .into_iter() - .map(|e| (e.into(), *file)) - .collect::>() - }) - .collect() + fn resolve_globals( + &mut self, + context: &mut Context, + literal_globals: Vec, + crate_id: CrateId, + ) { + let globals = resolve_globals(context, literal_globals, crate_id); + self.globals.extend(globals.globals); + self.errors.extend(globals.errors); + } } diff --git a/compiler/noirc_frontend/src/hir/resolution/globals.rs b/compiler/noirc_frontend/src/hir/resolution/globals.rs index 9fb31271727..bcda4e75d3d 100644 --- a/compiler/noirc_frontend/src/hir/resolution/globals.rs +++ b/compiler/noirc_frontend/src/hir/resolution/globals.rs @@ -11,18 +11,12 @@ use crate::{ use fm::FileId; use iter_extended::vecmap; +#[derive(Default)] pub(crate) struct ResolvedGlobals { pub(crate) globals: Vec<(FileId, GlobalId)>, pub(crate) errors: Vec<(CompilationError, FileId)>, } -impl ResolvedGlobals { - pub(crate) fn extend(&mut self, oth: Self) { - self.globals.extend(oth.globals); - self.errors.extend(oth.errors); - } -} - pub(crate) fn resolve_globals( context: &mut Context, globals: Vec, diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index 0e69b3bdeba..8d2cb17189b 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -1272,9 +1272,9 @@ impl<'a> Resolver<'a> { HirStatement::Continue } StatementKind::Error => HirStatement::Error, - StatementKind::Comptime(statement) => { + StatementKind::CompTime(statement) => { let statement = self.resolve_stmt(*statement, span); - HirStatement::Comptime(self.interner.push_stmt(statement)) + HirStatement::CompTime(self.interner.push_stmt(statement)) } } } @@ -1323,7 +1323,9 @@ impl<'a> Resolver<'a> { pub fn intern_stmt(&mut self, stmt: Statement) -> StmtId { let hir_stmt = self.resolve_stmt(stmt.kind, stmt.span); - self.interner.push_stmt(hir_stmt) + let id = self.interner.push_stmt(hir_stmt); + self.interner.push_statement_location(id, stmt.span, self.file); + id } fn resolve_lvalue(&mut self, lvalue: LValue) -> HirLValue { @@ -1531,7 +1533,9 @@ impl<'a> Resolver<'a> { collection: self.resolve_expression(indexed_expr.collection), index: self.resolve_expression(indexed_expr.index), }), - ExpressionKind::Block(block_expr) => self.resolve_block(block_expr), + ExpressionKind::Block(block_expr) => { + HirExpression::Block(self.resolve_block(block_expr)) + } ExpressionKind::Constructor(constructor) => { let span = constructor.type_name.span(); @@ -1598,6 +1602,7 @@ impl<'a> Resolver<'a> { // The quoted expression isn't resolved since we don't want errors if variables aren't defined ExpressionKind::Quote(block) => HirExpression::Quote(block), + ExpressionKind::Comptime(block) => HirExpression::Comptime(self.resolve_block(block)), }; // If these lines are ever changed, make sure to change the early return @@ -1930,14 +1935,14 @@ impl<'a> Resolver<'a> { Ok(path_resolution.module_def_id) } - fn resolve_block(&mut self, block_expr: BlockExpression) -> HirExpression { + fn resolve_block(&mut self, block_expr: BlockExpression) -> HirBlockExpression { let statements = self.in_new_scope(|this| vecmap(block_expr.statements, |stmt| this.intern_stmt(stmt))); - HirExpression::Block(HirBlockExpression { statements }) + HirBlockExpression { statements } } pub fn intern_block(&mut self, block: BlockExpression) -> ExprId { - let hir_block = self.resolve_block(block); + let hir_block = HirExpression::Block(self.resolve_block(block)); self.interner.push_expr(hir_block) } diff --git a/compiler/noirc_frontend/src/hir/type_check/expr.rs b/compiler/noirc_frontend/src/hir/type_check/expr.rs index 62330732be4..0bc7673e105 100644 --- a/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -6,8 +6,8 @@ use crate::{ hir::{resolution::resolver::verify_mutable_reference, type_check::errors::Source}, hir_def::{ expr::{ - self, HirArrayLiteral, HirBinaryOp, HirExpression, HirIdent, HirLiteral, - HirMethodCallExpression, HirMethodReference, HirPrefixExpression, ImplKind, + self, HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirExpression, HirIdent, + HirLiteral, HirMethodCallExpression, HirMethodReference, HirPrefixExpression, ImplKind, }, types::Type, }, @@ -271,34 +271,7 @@ impl<'interner> TypeChecker<'interner> { let span = self.interner.expr_span(expr_id); self.check_cast(lhs_type, cast_expr.r#type, span) } - HirExpression::Block(block_expr) => { - let mut block_type = Type::Unit; - - let statements = block_expr.statements(); - for (i, stmt) in statements.iter().enumerate() { - let expr_type = self.check_statement(stmt); - - if let crate::hir_def::stmt::HirStatement::Semi(expr) = - self.interner.statement(stmt) - { - let inner_expr_type = self.interner.id_type(expr); - let span = self.interner.expr_span(&expr); - - self.unify(&inner_expr_type, &Type::Unit, || { - TypeCheckError::UnusedResultError { - expr_type: inner_expr_type.clone(), - expr_span: span, - } - }); - } - - if i + 1 == statements.len() { - block_type = expr_type; - } - } - - block_type - } + HirExpression::Block(block_expr) => self.check_block(block_expr), HirExpression::Prefix(prefix_expr) => { let rhs_type = self.check_expression(&prefix_expr.rhs); let span = self.interner.expr_span(&prefix_expr.rhs); @@ -336,12 +309,44 @@ impl<'interner> TypeChecker<'interner> { Type::Function(params, Box::new(lambda.return_type), Box::new(env_type)) } HirExpression::Quote(_) => Type::Code, + HirExpression::Comptime(block) => self.check_block(block), + + // Unquote should be inserted & removed by the comptime interpreter. + // Even if we allowed it here, we wouldn't know what type to give to the result. + HirExpression::Unquote(block) => { + unreachable!("Unquote remaining during type checking {block}") + } }; self.interner.push_expr_type(*expr_id, typ.clone()); typ } + fn check_block(&mut self, block: HirBlockExpression) -> Type { + let mut block_type = Type::Unit; + + let statements = block.statements(); + for (i, stmt) in statements.iter().enumerate() { + let expr_type = self.check_statement(stmt); + + if let crate::hir_def::stmt::HirStatement::Semi(expr) = self.interner.statement(stmt) { + let inner_expr_type = self.interner.id_type(expr); + let span = self.interner.expr_span(&expr); + + self.unify(&inner_expr_type, &Type::Unit, || TypeCheckError::UnusedResultError { + expr_type: inner_expr_type.clone(), + expr_span: span, + }); + } + + if i + 1 == statements.len() { + block_type = expr_type; + } + } + + block_type + } + /// Returns the type of the given identifier fn check_ident(&mut self, ident: HirIdent, expr_id: &ExprId) -> Type { let mut bindings = TypeBindings::new(); diff --git a/compiler/noirc_frontend/src/hir/type_check/stmt.rs b/compiler/noirc_frontend/src/hir/type_check/stmt.rs index f5f6e1e8180..064fefc8ae9 100644 --- a/compiler/noirc_frontend/src/hir/type_check/stmt.rs +++ b/compiler/noirc_frontend/src/hir/type_check/stmt.rs @@ -51,7 +51,7 @@ impl<'interner> TypeChecker<'interner> { HirStatement::Constrain(constrain_stmt) => self.check_constrain_stmt(constrain_stmt), HirStatement::Assign(assign_stmt) => self.check_assign_stmt(assign_stmt, stmt_id), HirStatement::For(for_loop) => self.check_for_loop(for_loop), - HirStatement::Comptime(statement) => return self.check_statement(&statement), + HirStatement::CompTime(statement) => return self.check_statement(&statement), HirStatement::Break | HirStatement::Continue | HirStatement::Error => (), } Type::Unit diff --git a/compiler/noirc_frontend/src/hir_def/expr.rs b/compiler/noirc_frontend/src/hir_def/expr.rs index d88b65d1fce..bf7d9b7b4ba 100644 --- a/compiler/noirc_frontend/src/hir_def/expr.rs +++ b/compiler/noirc_frontend/src/hir_def/expr.rs @@ -31,8 +31,10 @@ pub enum HirExpression { If(HirIfExpression), Tuple(Vec), Lambda(HirLambda), - Error, Quote(crate::ast::BlockExpression), + Unquote(crate::ast::BlockExpression), + Comptime(HirBlockExpression), + Error, } impl HirExpression { diff --git a/compiler/noirc_frontend/src/hir_def/stmt.rs b/compiler/noirc_frontend/src/hir_def/stmt.rs index 7e22e5ee9c0..48e7d7344e3 100644 --- a/compiler/noirc_frontend/src/hir_def/stmt.rs +++ b/compiler/noirc_frontend/src/hir_def/stmt.rs @@ -20,7 +20,7 @@ pub enum HirStatement { Continue, Expression(ExprId), Semi(ExprId), - Comptime(StmtId), + CompTime(StmtId), Error, } diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index 74a0dd855c0..d92b6c65d7a 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -521,6 +521,12 @@ impl<'interner> Monomorphizer<'interner> { } HirExpression::Error => unreachable!("Encountered Error node during monomorphization"), HirExpression::Quote(_) => unreachable!("quote expression remaining in runtime code"), + HirExpression::Unquote(_) => { + unreachable!("unquote expression remaining in runtime code") + } + HirExpression::Comptime(_) => { + unreachable!("comptime expression remaining in runtime code") + } }; Ok(expr) @@ -626,7 +632,7 @@ impl<'interner> Monomorphizer<'interner> { HirStatement::Error => unreachable!(), // All `comptime` statements & expressions should be removed before runtime. - HirStatement::Comptime(_) => unreachable!("comptime statement in runtime code"), + HirStatement::CompTime(_) => unreachable!("comptime statement in runtime code"), } } diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index b0e68be4868..9d3a79820dc 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -921,10 +921,18 @@ impl NodeInterner { self.id_location(expr_id) } - pub fn statement_span(&self, stmt_id: &StmtId) -> Span { + pub fn statement_span(&self, stmt_id: StmtId) -> Span { self.id_location(stmt_id).span } + pub fn statement_location(&self, stmt_id: StmtId) -> Location { + self.id_location(stmt_id) + } + + pub fn push_statement_location(&mut self, id: StmtId, span: Span, file: FileId) { + self.id_to_location.insert(id.into(), Location::new(span, file)); + } + pub fn get_struct(&self, id: StructId) -> Shared { self.structs[&id].clone() } diff --git a/compiler/noirc_frontend/src/parser/parser.rs b/compiler/noirc_frontend/src/parser/parser.rs index 603193d1593..858e5c4838c 100644 --- a/compiler/noirc_frontend/src/parser/parser.rs +++ b/compiler/noirc_frontend/src/parser/parser.rs @@ -540,7 +540,7 @@ where StatementKind::Expression(Expression::new(ExpressionKind::Block(block), span)) }), ))) - .map(|statement| StatementKind::Comptime(Box::new(statement))) + .map(|statement| StatementKind::CompTime(Box::new(statement))) } /// Comptime in an expression position only accepts entire blocks @@ -548,7 +548,7 @@ fn comptime_expr<'a, S>(statement: S) -> impl NoirParser + 'a where S: NoirParser + 'a, { - keyword(Keyword::CompTime).ignore_then(block(statement)).map(ExpressionKind::Block) + keyword(Keyword::CompTime).ignore_then(block(statement)).map(ExpressionKind::Comptime) } fn declaration<'a, P>(expr_parser: P) -> impl NoirParser + 'a diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index 31bf2245b1f..ac3d7bbc4cc 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -780,7 +780,7 @@ mod test { HirStatement::Error => panic!("Invalid HirStatement!"), HirStatement::Break => panic!("Unexpected break"), HirStatement::Continue => panic!("Unexpected continue"), - HirStatement::Comptime(_) => panic!("Unexpected comptime"), + HirStatement::CompTime(_) => panic!("Unexpected comptime"), }; let expr = interner.expression(&expr_id); diff --git a/tooling/nargo_fmt/src/rewrite/expr.rs b/tooling/nargo_fmt/src/rewrite/expr.rs index e4616c99aaa..6b7dca6c5c7 100644 --- a/tooling/nargo_fmt/src/rewrite/expr.rs +++ b/tooling/nargo_fmt/src/rewrite/expr.rs @@ -159,6 +159,9 @@ pub(crate) fn rewrite( } ExpressionKind::Lambda(_) | ExpressionKind::Variable(_) => visitor.slice(span).to_string(), ExpressionKind::Quote(block) => format!("quote {}", rewrite_block(visitor, block, span)), + ExpressionKind::Comptime(block) => { + format!("comptime {}", rewrite_block(visitor, block, span)) + } ExpressionKind::Error => unreachable!(), } } diff --git a/tooling/nargo_fmt/src/visitor/stmt.rs b/tooling/nargo_fmt/src/visitor/stmt.rs index e41827c94a1..869977d5f3c 100644 --- a/tooling/nargo_fmt/src/visitor/stmt.rs +++ b/tooling/nargo_fmt/src/visitor/stmt.rs @@ -103,7 +103,7 @@ impl super::FmtVisitor<'_> { StatementKind::Error => unreachable!(), StatementKind::Break => self.push_rewrite("break;".into(), span), StatementKind::Continue => self.push_rewrite("continue;".into(), span), - StatementKind::Comptime(statement) => self.visit_stmt(*statement, span, is_last), + StatementKind::CompTime(statement) => self.visit_stmt(*statement, span, is_last), } } }