Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Run macros within comptime contexts #5576

Merged
merged 7 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions compiler/noirc_frontend/src/elaborator/comptime.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use std::mem::replace;

use crate::{
hir_def::expr::HirIdent,
macros_api::Expression,
node_interner::{DependencyId, ExprId, FuncId},
};

use super::{Elaborator, FunctionContext, ResolverMeta};

impl<'context> Elaborator<'context> {
/// Elaborate an expression from the middle of a comptime scope.
/// When this happens we require additional information to know
/// what variables should be in scope.
pub fn elaborate_expression_from_comptime(
&mut self,
expr: Expression,
function: Option<FuncId>,
) -> ExprId {
self.function_context.push(FunctionContext::default());
let old_scope = self.scopes.end_function();
self.scopes.start_function();
let function_id = function.map(DependencyId::Function);
let old_item = replace(&mut self.current_item, function_id);

// Note: recover_generics isn't good enough here because any existing generics
// should not be in scope of this new function
let old_generics = std::mem::take(&mut self.generics);

let old_crate_and_module = function.map(|function| {
let meta = self.interner.function_meta(&function);
let old_crate = replace(&mut self.crate_id, meta.source_crate);
let old_module = replace(&mut self.local_module, meta.source_module);
self.introduce_generics_into_scope(meta.all_generics.clone());
(old_crate, old_module)
});

self.populate_scope_from_comptime_scopes();
let expr = self.elaborate_expression(expr).0;

if let Some((old_crate, old_module)) = old_crate_and_module {
self.crate_id = old_crate;
self.local_module = old_module;
}

self.generics = old_generics;
self.current_item = old_item;
self.scopes.end_function();
self.scopes.0.push(old_scope);
self.check_and_pop_function_context();
expr
}

fn populate_scope_from_comptime_scopes(&mut self) {
// Take the comptime scope to be our runtime scope.
// Iterate from global scope to the most local scope so that the
// later definitions will naturally shadow the former.
for scope in &self.comptime_scopes {
for definition_id in scope.keys() {
let definition = self.interner.definition(*definition_id);
let name = definition.name.clone();
let location = definition.location;

let scope = self.scopes.get_mut_scope();
let ident = HirIdent::non_trait_method(*definition_id, location);
let meta = ResolverMeta { ident, num_times_used: 0, warn_if_unused: false };
scope.add_key_value(name.clone(), meta);
}
}
}
}
28 changes: 18 additions & 10 deletions compiler/noirc_frontend/src/elaborator/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,15 +300,21 @@ impl<'context> Elaborator<'context> {
}

let location = Location::new(span, self.file);
let hir_call = HirCallExpression { func, arguments, location };
let typ = self.type_check_call(&hir_call, func_type, args, span);
let is_macro_call = call.is_macro_call;
let hir_call = HirCallExpression { func, arguments, location, is_macro_call };
let mut typ = self.type_check_call(&hir_call, func_type, args, span);

if call.is_macro_call {
self.call_macro(func, comptime_args, location, typ)
.unwrap_or_else(|| (HirExpression::Error, Type::Error))
} else {
(HirExpression::Call(hir_call), typ)
if is_macro_call {
if self.in_comptime_context() {
typ = self.interner.next_type_variable();
} else {
return self
.call_macro(func, comptime_args, location, typ)
.unwrap_or_else(|| (HirExpression::Error, Type::Error));
}
}

(HirExpression::Call(hir_call), typ)
}

fn elaborate_method_call(
Expand Down Expand Up @@ -368,6 +374,7 @@ impl<'context> Elaborator<'context> {
let location = Location::new(span, self.file);
let method = method_call.method_name;
let turbofish_generics = generics.clone();
let is_macro_call = method_call.is_macro_call;
let method_call =
HirMethodCallExpression { method, object, arguments, location, generics };

Expand All @@ -377,6 +384,7 @@ impl<'context> Elaborator<'context> {
let ((function_id, function_name), function_call) = method_call.into_function_call(
&method_ref,
object_type,
is_macro_call,
location,
self.interner,
);
Expand Down Expand Up @@ -721,7 +729,7 @@ impl<'context> Elaborator<'context> {
(id, typ)
}

pub(super) fn inline_comptime_value(
pub fn inline_comptime_value(
&mut self,
value: Result<comptime::Value, InterpreterError>,
span: Span,
Expand Down Expand Up @@ -801,14 +809,14 @@ impl<'context> Elaborator<'context> {
for argument in arguments {
match interpreter.evaluate(argument) {
Ok(arg) => {
let location = interpreter.interner.expr_location(&argument);
let location = interpreter.elaborator.interner.expr_location(&argument);
comptime_args.push((arg, location));
}
Err(error) => errors.push((error.into(), file)),
}
}

let bindings = interpreter.interner.get_instantiation_bindings(func).clone();
let bindings = interpreter.elaborator.interner.get_instantiation_bindings(func).clone();
let result = interpreter.call_function(function, comptime_args, bindings, location);

if !errors.is_empty() {
Expand Down
57 changes: 38 additions & 19 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{
use crate::{
ast::{FunctionKind, UnresolvedTraitConstraint},
hir::{
comptime::{self, Interpreter, InterpreterError, Value},
comptime::{Interpreter, InterpreterError, Value},
def_collector::{
dc_crate::{
filter_literal_globals, CompilationError, ImplMap, UnresolvedGlobal,
Expand Down Expand Up @@ -60,6 +60,7 @@ use crate::{
macros_api::ItemVisibility,
};

mod comptime;
mod expressions;
mod lints;
mod patterns;
Expand Down Expand Up @@ -97,9 +98,9 @@ pub struct LambdaContext {
pub struct Elaborator<'context> {
scopes: ScopeForest,

errors: Vec<(CompilationError, FileId)>,
pub(crate) errors: Vec<(CompilationError, FileId)>,

interner: &'context mut NodeInterner,
pub(crate) interner: &'context mut NodeInterner,

def_maps: &'context mut BTreeMap<CrateId, CrateDefMap>,

Expand Down Expand Up @@ -167,7 +168,7 @@ pub struct Elaborator<'context> {
/// Each value currently in scope in the comptime interpreter.
/// Each element of the Vec represents a scope with every scope together making
/// up all currently visible definitions. The first scope is always the global scope.
comptime_scopes: Vec<HashMap<DefinitionId, comptime::Value>>,
pub(crate) comptime_scopes: Vec<HashMap<DefinitionId, Value>>,

/// The scope of --debug-comptime, or None if unset
debug_comptime_in_file: Option<FileId>,
Expand Down Expand Up @@ -228,6 +229,15 @@ impl<'context> Elaborator<'context> {
items: CollectedItems,
debug_comptime_in_file: Option<FileId>,
) -> Vec<(CompilationError, FileId)> {
Self::elaborate_and_return_self(context, crate_id, items, debug_comptime_in_file).errors
}

pub fn elaborate_and_return_self(
context: &'context mut Context,
crate_id: CrateId,
items: CollectedItems,
debug_comptime_in_file: Option<FileId>,
) -> Self {
let mut this = Self::new(context, crate_id, debug_comptime_in_file);

// Filter out comptime items to execute their functions first if needed.
Expand All @@ -238,7 +248,7 @@ impl<'context> Elaborator<'context> {
let (comptime_items, runtime_items) = Self::filter_comptime_items(items);
this.elaborate_items(comptime_items);
this.elaborate_items(runtime_items);
this.errors
this
}

fn elaborate_items(&mut self, mut items: CollectedItems) {
Expand Down Expand Up @@ -339,6 +349,21 @@ impl<'context> Elaborator<'context> {
self.trait_id = None;
}

fn introduce_generics_into_scope(&mut self, all_generics: Vec<ResolvedGeneric>) {
// Introduce all numeric generics into scope
for generic in &all_generics {
if let Kind::Numeric(typ) = &generic.kind {
let definition = DefinitionKind::GenericType(generic.type_var.clone());
let ident = Ident::new(generic.name.to_string(), generic.span);
let hir_ident =
self.add_variable_decl_inner(ident, false, false, false, definition);
self.interner.push_definition_type(hir_ident.id, *typ.clone());
}
}

self.generics = all_generics;
}

fn elaborate_function(&mut self, id: FuncId) {
let func_meta = self.interner.func_meta.get_mut(&id);
let func_meta =
Expand All @@ -360,16 +385,7 @@ impl<'context> Elaborator<'context> {
self.trait_bounds = func_meta.trait_constraints.clone();
self.function_context.push(FunctionContext::default());

// Introduce all numeric generics into scope
for generic in &func_meta.all_generics {
if let Kind::Numeric(typ) = &generic.kind {
let definition = DefinitionKind::GenericType(generic.type_var.clone());
let ident = Ident::new(generic.name.to_string(), generic.span);
let hir_ident =
self.add_variable_decl_inner(ident, false, false, false, definition);
self.interner.push_definition_type(hir_ident.id, *typ.clone());
}
}
self.introduce_generics_into_scope(func_meta.all_generics.clone());

// The DefinitionIds for each parameter were already created in define_function_meta
// so we need to reintroduce the same IDs into scope here.
Expand All @@ -378,8 +394,6 @@ impl<'context> Elaborator<'context> {
self.add_existing_variable_to_scope(name, parameter.clone(), true);
}

self.generics = func_meta.all_generics.clone();

self.declare_numeric_generics(&func_meta.parameters, func_meta.return_type());
self.add_trait_constraints_to_scope(&func_meta);

Expand Down Expand Up @@ -758,6 +772,7 @@ impl<'context> Elaborator<'context> {
is_trait_function,
has_inline_attribute,
source_crate: self.crate_id,
source_module: self.local_module,
function_body: FunctionBody::Unresolved(func.kind, body, func.def.span),
};

Expand Down Expand Up @@ -1626,8 +1641,12 @@ impl<'context> Elaborator<'context> {
}
}

fn setup_interpreter(&mut self) -> Interpreter {
Interpreter::new(self.interner, &mut self.comptime_scopes, self.crate_id)
pub fn setup_interpreter<'local>(&'local mut self) -> Interpreter<'local, 'context> {
let current_function = match self.current_item {
Some(DependencyId::Function(function)) => Some(function),
_ => None,
};
Interpreter::new(self, self.crate_id, current_function)
}

fn debug_comptime<T: Display, F: FnMut(&mut NodeInterner) -> T>(
Expand Down
4 changes: 1 addition & 3 deletions compiler/noirc_frontend/src/elaborator/patterns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use rustc_hash::FxHashSet as HashSet;
use crate::{
ast::{UnresolvedType, ERROR_IDENT},
hir::{
comptime::Interpreter,
def_collector::dc_crate::CompilationError,
resolution::errors::ResolverError,
type_check::{Source, TypeCheckError},
Expand Down Expand Up @@ -460,8 +459,7 @@ impl<'context> Elaborator<'context> {
// Comptime variables must be replaced with their values
if let Some(definition) = self.interner.try_definition(definition_id) {
if definition.comptime && !self.in_comptime_context() {
let mut interpreter =
Interpreter::new(self.interner, &mut self.comptime_scopes, self.crate_id);
let mut interpreter = self.setup_interpreter();
let value = interpreter.evaluate(id);
return self.inline_comptime_value(value, span);
}
Expand Down
Loading
Loading