Skip to content

Commit

Permalink
feat(traits): Implement trait bounds typechecker + monomorphizer pass…
Browse files Browse the repository at this point in the history
…es (#2717)

Co-authored-by: jfecher <jfecher11@gmail.com>
  • Loading branch information
alexvitkov and jfecher authored Sep 25, 2023
1 parent df7b42c commit 5ca99b1
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 48 deletions.
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ impl<'a> Resolver<'a> {
) -> Vec<TraitConstraint> {
vecmap(where_clause, |constraint| TraitConstraint {
typ: self.resolve_type(constraint.typ.clone()),
trait_id: constraint.trait_bound.trait_id,
trait_id: constraint.trait_bound.trait_id.unwrap_or_else(TraitId::dummy_id),
})
}

Expand Down
117 changes: 81 additions & 36 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ use crate::{
hir_def::{
expr::{
self, HirArrayLiteral, HirBinaryOp, HirExpression, HirLiteral, HirMethodCallExpression,
HirPrefixExpression,
HirMethodReference, HirPrefixExpression,
},
types::Type,
},
node_interner::{DefinitionKind, ExprId, FuncId},
node_interner::{DefinitionKind, ExprId, FuncId, TraitMethodId},
Shared, Signedness, TypeBinding, TypeVariableKind, UnaryOp,
};

Expand Down Expand Up @@ -144,7 +144,7 @@ impl<'interner> TypeChecker<'interner> {
let object_type = self.check_expression(&method_call.object).follow_bindings();
let method_name = method_call.method.0.contents.as_str();
match self.lookup_method(&object_type, method_name, expr_id) {
Some(method_id) => {
Some(method_ref) => {
let mut args = vec![(
object_type,
method_call.object,
Expand All @@ -160,22 +160,27 @@ impl<'interner> TypeChecker<'interner> {
// so that the backend doesn't need to worry about methods
let location = method_call.location;

// Automatically add `&mut` if the method expects a mutable reference and
// the object is not already one.
if method_id != FuncId::dummy_id() {
let func_meta = self.interner.function_meta(&method_id);
self.try_add_mutable_reference_to_object(
&mut method_call,
&func_meta.typ,
&mut args,
);
if let HirMethodReference::FuncId(func_id) = method_ref {
// Automatically add `&mut` if the method expects a mutable reference and
// the object is not already one.
if func_id != FuncId::dummy_id() {
let func_meta = self.interner.function_meta(&func_id);
self.try_add_mutable_reference_to_object(
&mut method_call,
&func_meta.typ,
&mut args,
);
}
}

let (function_id, function_call) =
method_call.into_function_call(method_id, location, self.interner);
let (function_id, function_call) = method_call.into_function_call(
method_ref,
location,
self.interner,
);

let span = self.interner.expr_span(expr_id);
let ret = self.check_method_call(&function_id, &method_id, args, span);
let ret = self.check_method_call(&function_id, method_ref, args, span);

self.interner.replace_expr(expr_id, function_call);
ret
Expand Down Expand Up @@ -286,6 +291,7 @@ impl<'interner> TypeChecker<'interner> {

Type::Function(params, Box::new(lambda.return_type), Box::new(env_type))
}
HirExpression::TraitMethodReference(_) => unreachable!("unexpected TraitMethodReference - they should be added after initial type checking"),
};

self.interner.push_expr_type(expr_id, typ.clone());
Expand Down Expand Up @@ -477,34 +483,46 @@ impl<'interner> TypeChecker<'interner> {
fn check_method_call(
&mut self,
function_ident_id: &ExprId,
func_id: &FuncId,
method_ref: HirMethodReference,
arguments: Vec<(Type, ExprId, Span)>,
span: Span,
) -> Type {
if func_id == &FuncId::dummy_id() {
Type::Error
} else {
let func_meta = self.interner.function_meta(func_id);
let (fntyp, param_len) = match method_ref {
HirMethodReference::FuncId(func_id) => {
if func_id == FuncId::dummy_id() {
return Type::Error;
}

// Check function call arity is correct
let param_len = func_meta.parameters.len();
let arg_len = arguments.len();
let func_meta = self.interner.function_meta(&func_id);
let param_len = func_meta.parameters.len();

if param_len != arg_len {
self.errors.push(TypeCheckError::ArityMisMatch {
expected: param_len as u16,
found: arg_len as u16,
span,
});
(func_meta.typ, param_len)
}
HirMethodReference::TraitMethodId(method) => {
let the_trait = self.interner.get_trait(method.trait_id);
let the_trait = the_trait.borrow();
let method = &the_trait.methods[method.method_index];

let (function_type, instantiation_bindings) = func_meta.typ.instantiate(self.interner);
(method.get_type(), method.arguments.len())
}
};

self.interner.store_instantiation_bindings(*function_ident_id, instantiation_bindings);
self.interner.push_expr_type(function_ident_id, function_type.clone());
let arg_len = arguments.len();

self.bind_function_type(function_type, arguments, span)
if param_len != arg_len {
self.errors.push(TypeCheckError::ArityMisMatch {
expected: param_len as u16,
found: arg_len as u16,
span,
});
}

let (function_type, instantiation_bindings) = fntyp.instantiate(self.interner);

self.interner.store_instantiation_bindings(*function_ident_id, instantiation_bindings);
self.interner.push_expr_type(function_ident_id, function_type.clone());

self.bind_function_type(function_type, arguments, span)
}

fn check_if_expr(&mut self, if_expr: &expr::HirIfExpression, expr_id: &ExprId) -> Type {
Expand Down Expand Up @@ -818,11 +836,11 @@ impl<'interner> TypeChecker<'interner> {
object_type: &Type,
method_name: &str,
expr_id: &ExprId,
) -> Option<FuncId> {
) -> Option<HirMethodReference> {
match object_type {
Type::Struct(typ, _args) => {
match self.interner.lookup_method(typ.borrow().id, method_name) {
Some(method_id) => Some(method_id),
Some(method_id) => Some(HirMethodReference::FuncId(method_id)),
None => {
self.errors.push(TypeCheckError::UnresolvedMethodCall {
method_name: method_name.to_string(),
Expand All @@ -833,6 +851,33 @@ impl<'interner> TypeChecker<'interner> {
}
}
}
Type::NamedGeneric(_, _) => {
let func_meta = self.interner.function_meta(
&self.current_function.expect("unexpected method outside a function"),
);

for constraint in func_meta.trait_constraints {
if *object_type == constraint.typ {
let the_trait = self.interner.get_trait(constraint.trait_id);
let the_trait = the_trait.borrow();

for (method_index, method) in the_trait.methods.iter().enumerate() {
if method.name.0.contents == method_name {
let trait_method =
TraitMethodId { trait_id: constraint.trait_id, method_index };
return Some(HirMethodReference::TraitMethodId(trait_method));
}
}
}
}

self.errors.push(TypeCheckError::UnresolvedMethodCall {
method_name: method_name.to_string(),
object_type: object_type.clone(),
span: self.interner.expr_span(expr_id),
});
None
}
// Mutable references to another type should resolve to methods of their element type.
// This may be a struct or a primitive type.
Type::MutableReference(element) => self.lookup_method(element, method_name, expr_id),
Expand All @@ -843,7 +888,7 @@ impl<'interner> TypeChecker<'interner> {
// In the future we could support methods for non-struct types if we have a context
// (in the interner?) essentially resembling HashMap<Type, Methods>
other => match self.interner.lookup_primitive_method(other, method_name) {
Some(method_id) => Some(method_id),
Some(method_id) => Some(HirMethodReference::FuncId(method_id)),
None => {
self.errors.push(TypeCheckError::UnresolvedMethodCall {
method_name: method_name.to_string(),
Expand Down
11 changes: 9 additions & 2 deletions compiler/noirc_frontend/src/hir/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub struct TypeChecker<'interner> {
delayed_type_checks: Vec<TypeCheckFn>,
interner: &'interner mut NodeInterner,
errors: Vec<TypeCheckError>,
current_function: Option<FuncId>,
}

/// Type checks a function and assigns the
Expand All @@ -40,6 +41,7 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec<Type
let function_body_id = function_body.as_expr();

let mut type_checker = TypeChecker::new(interner);
type_checker.current_function = Some(func_id);

// Bind each parameter to its annotated type.
// This is locally obvious, but it must be bound here so that the
Expand Down Expand Up @@ -111,7 +113,7 @@ fn function_info(interner: &NodeInterner, function_body_id: &ExprId) -> (noirc_e

impl<'interner> TypeChecker<'interner> {
fn new(interner: &'interner mut NodeInterner) -> Self {
Self { delayed_type_checks: Vec::new(), interner, errors: vec![] }
Self { delayed_type_checks: Vec::new(), interner, errors: vec![], current_function: None }
}

pub fn push_delayed_type_check(&mut self, f: TypeCheckFn) {
Expand All @@ -127,7 +129,12 @@ impl<'interner> TypeChecker<'interner> {
}

pub fn check_global(id: &StmtId, interner: &'interner mut NodeInterner) -> Vec<TypeCheckError> {
let mut this = Self { delayed_type_checks: Vec::new(), interner, errors: vec![] };
let mut this = Self {
delayed_type_checks: Vec::new(),
interner,
errors: vec![],
current_function: None,
};
this.check_statement(id);
this.errors
}
Expand Down
32 changes: 26 additions & 6 deletions compiler/noirc_frontend/src/hir_def/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use acvm::FieldElement;
use fm::FileId;
use noirc_errors::Location;

use crate::node_interner::{DefinitionId, ExprId, FuncId, NodeInterner, StmtId};
use crate::node_interner::{DefinitionId, ExprId, FuncId, NodeInterner, StmtId, TraitMethodId};
use crate::{BinaryOp, BinaryOpKind, Ident, Shared, UnaryOp};

use super::stmt::HirPattern;
Expand Down Expand Up @@ -30,6 +30,7 @@ pub enum HirExpression {
If(HirIfExpression),
Tuple(Vec<ExprId>),
Lambda(HirLambda),
TraitMethodReference(TraitMethodId),
Error,
}

Expand Down Expand Up @@ -150,20 +151,39 @@ pub struct HirMethodCallExpression {
pub location: Location,
}

#[derive(Debug, Copy, Clone)]
pub enum HirMethodReference {
/// A method can be defined in a regular `impl` block, in which case
/// it's syntax sugar for a normal function call, and can be
/// translated to one during type checking
FuncId(FuncId),

/// Or a method can come from a Trait impl block, in which case
/// the actual function called will depend on the instantiated type,
/// which can be only known during monomorphizaiton.
TraitMethodId(TraitMethodId),
}

impl HirMethodCallExpression {
pub fn into_function_call(
mut self,
func: FuncId,
method: HirMethodReference,
location: Location,
interner: &mut NodeInterner,
) -> (ExprId, HirExpression) {
let mut arguments = vec![self.object];
arguments.append(&mut self.arguments);

let id = interner.function_definition_id(func);
let ident = HirExpression::Ident(HirIdent { location, id });
let func = interner.push_expr(ident);

let expr = match method {
HirMethodReference::FuncId(func_id) => {
let id = interner.function_definition_id(func_id);
HirExpression::Ident(HirIdent { location, id })
}
HirMethodReference::TraitMethodId(method_id) => {
HirExpression::TraitMethodReference(method_id)
}
};
let func = interner.push_expr(expr);
(func, HirExpression::Call(HirCallExpression { func, arguments, location }))
}
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/hir_def/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ pub struct TraitImpl {
#[derive(Debug, Clone)]
pub struct TraitConstraint {
pub typ: Type,
pub trait_id: Option<TraitId>,
pub trait_id: TraitId,
// pub trait_generics: Generics, TODO
}

Expand Down
57 changes: 55 additions & 2 deletions compiler/noirc_frontend/src/monomorphization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
use acvm::FieldElement;
use iter_extended::{btree_map, vecmap};
use noirc_printable_type::PrintableType;
use std::collections::{BTreeMap, HashMap, VecDeque};
use std::{
collections::{BTreeMap, HashMap, VecDeque},
unreachable,
};

use crate::{
hir_def::{
Expand All @@ -20,7 +23,7 @@ use crate::{
stmt::{HirAssignStatement, HirLValue, HirLetStatement, HirPattern, HirStatement},
types,
},
node_interner::{self, DefinitionKind, NodeInterner, StmtId},
node_interner::{self, DefinitionKind, NodeInterner, StmtId, TraitImplKey, TraitMethodId},
token::FunctionAttribute,
ContractFunctionType, FunctionKind, Type, TypeBinding, TypeBindings, TypeVariableKind,
Visibility,
Expand Down Expand Up @@ -375,6 +378,17 @@ impl<'interner> Monomorphizer<'interner> {

HirExpression::Lambda(lambda) => self.lambda(lambda, expr),

HirExpression::TraitMethodReference(method) => {
if let Type::Function(args, _, _) = self.interner.id_type(expr) {
let self_type = args[0].clone();
self.resolve_trait_method_reference(self_type, expr, method)
} else {
unreachable!(
"Calling a non-function, this should've been caught in typechecking"
);
}
}

HirExpression::MethodCall(_) => {
unreachable!("Encountered HirExpression::MethodCall during monomorphization")
}
Expand Down Expand Up @@ -777,6 +791,45 @@ impl<'interner> Monomorphizer<'interner> {
}
}

fn resolve_trait_method_reference(
&mut self,
self_type: HirType,
expr_id: node_interner::ExprId,
method: TraitMethodId,
) -> ast::Expression {
let function_type = self.interner.id_type(expr_id);

// the substitute() here is to replace all internal occurences of the 'Self' typevar
// with whatever 'Self' is currently bound to, so we don't lose type information
// if we need to rebind the trait.
let trait_impl = self
.interner
.get_trait_implementation(&TraitImplKey {
typ: self_type.follow_bindings(),
trait_id: method.trait_id,
})
.expect("ICE: missing trait impl - should be caught during type checking");

let hir_func_id = trait_impl.borrow().methods[method.method_index];

let func_def = self.lookup_function(hir_func_id, expr_id, &function_type);
let func_id = match func_def {
Definition::Function(func_id) => func_id,
_ => unreachable!(),
};

let the_trait = self.interner.get_trait(method.trait_id);
let the_trait = the_trait.borrow();

ast::Expression::Ident(ast::Ident {
definition: Definition::Function(func_id),
mutable: false,
location: None,
name: the_trait.methods[method.method_index].name.0.contents.clone(),
typ: self.convert_type(&function_type),
})
}

fn function_call(
&mut self,
call: HirCallExpression,
Expand Down
Loading

0 comments on commit 5ca99b1

Please sign in to comment.