Skip to content

Commit

Permalink
feat: Implement generic traits (#4000)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves #3471
Resolves #3474

## Summary\*

Implements support for generics on the trait directly. E.g. `trait
Into<T> { ... }`

## Additional Context

The old `trait_generics` test has been renamed to `trait_impl_generics`
- I think this is more accurate. There is a new test in `trait_generics`
now which tests the generic traits added by this PR.

I've discovered two new bugs in writing this PR, which are commented in
the `trait_generics` test. I'll make issues for them now.

## Documentation\*

Check one:
- [ ] No documentation needed.
- [x] Documentation included in this PR.
- [ ] **[Exceptional Case]** Documentation to be submitted in a separate
PR.

# PR Checklist\*

- [ ] I have tested the changes locally.
- [ ] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
jfecher authored Jan 16, 2024
1 parent b3aed17 commit 916fd15
Show file tree
Hide file tree
Showing 22 changed files with 754 additions and 370 deletions.
132 changes: 88 additions & 44 deletions compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use crate::node_interner::{FuncId, NodeInterner, StmtId, StructId, TraitId, Type
use crate::parser::{ParserError, SortedModule};
use crate::{
ExpressionKind, Ident, LetStatement, Literal, NoirFunction, NoirStruct, NoirTrait,
NoirTypeAlias, Path, PathKind, Type, UnresolvedGenerics, UnresolvedTraitConstraint,
UnresolvedType,
NoirTypeAlias, Path, PathKind, Type, TypeBindings, UnresolvedGenerics,
UnresolvedTraitConstraint, UnresolvedType,
};
use fm::FileId;
use iter_extended::vecmap;
Expand Down Expand Up @@ -90,6 +90,7 @@ pub struct UnresolvedTraitImpl {
pub file_id: FileId,
pub module_id: LocalModuleId,
pub trait_id: Option<TraitId>,
pub trait_generics: Vec<UnresolvedType>,
pub trait_path: Path,
pub object_type: UnresolvedType,
pub methods: UnresolvedFunctions,
Expand Down Expand Up @@ -456,19 +457,44 @@ fn type_check_functions(
}

// TODO(vitkov): Move this out of here and into type_check
#[allow(clippy::too_many_arguments)]
pub(crate) fn check_methods_signatures(
resolver: &mut Resolver,
impl_methods: &Vec<(FileId, FuncId)>,
trait_id: TraitId,
trait_name_span: Span,
// These are the generics on the trait itself from the impl.
// E.g. in `impl Foo<A, B> for Bar<B, C>`, this is `vec![A, B]`.
trait_generics: Vec<UnresolvedType>,
trait_impl_generic_count: usize,
file_id: FileId,
errors: &mut Vec<(CompilationError, FileId)>,
) {
let self_type = resolver.get_self_type().expect("trait impl must have a Self type").clone();
let trait_generics = vecmap(trait_generics, |typ| resolver.resolve_type(typ));

// Temporarily bind the trait's Self type to self_type so we can type check
let the_trait = resolver.interner.get_trait_mut(trait_id);
the_trait.self_type_typevar.bind(self_type);

if trait_generics.len() != the_trait.generics.len() {
let error = DefCollectorErrorKind::MismatchGenericCount {
actual_generic_count: trait_generics.len(),
expected_generic_count: the_trait.generics.len(),
// Preferring to use 'here' over a more precise term like 'this reference'
// to try to make the error easier to understand for newer users.
location: "here it",
origin: the_trait.name.to_string(),
span: trait_name_span,
};
errors.push((error.into(), file_id));
}

// We also need to bind the traits generics to the trait's generics on the impl
for ((_, generic), binding) in the_trait.generics.iter().zip(trait_generics) {
generic.bind(binding);
}

// Temporarily take the trait's methods so we can use both them and a mutable reference
// to the interner within the loop.
let trait_methods = std::mem::take(&mut the_trait.methods);
Expand All @@ -482,49 +508,44 @@ pub(crate) fn check_methods_signatures(
if let Some(trait_method) =
trait_methods.iter().find(|method| method.name.0.contents == func_name)
{
let mut typecheck_errors = Vec::new();
let impl_method = resolver.interner.function_meta(func_id);

let (impl_function_type, _) = impl_method.typ.instantiate(resolver.interner);

let impl_method_generic_count =
impl_method.typ.generic_count() - trait_impl_generic_count;

// We subtract 1 here to account for the implicit generic `Self` type that is on all
// traits (and thus trait methods) but is not required (or allowed) for users to specify.
let trait_method_generic_count = trait_method.generics().len() - 1;
let the_trait = resolver.interner.get_trait(trait_id);
let trait_method_generic_count =
trait_method.generics().len() - 1 - the_trait.generics.len();

if impl_method_generic_count != trait_method_generic_count {
let error = DefCollectorErrorKind::MismatchTraitImplementationNumGenerics {
impl_method_generic_count,
trait_method_generic_count,
trait_name: resolver.interner.get_trait(trait_id).name.to_string(),
method_name: func_name.to_string(),
let trait_name = resolver.interner.get_trait(trait_id).name.clone();

let error = DefCollectorErrorKind::MismatchGenericCount {
actual_generic_count: impl_method_generic_count,
expected_generic_count: trait_method_generic_count,
origin: format!("{}::{}", trait_name, func_name),
location: "this method",
span: impl_method.location.span,
};
errors.push((error.into(), *file_id));
}

if let Type::Function(impl_params, _, _) = impl_function_type {
if trait_method.arguments().len() == impl_params.len() {
// Check the parameters of the impl method against the parameters of the trait method
let args = trait_method.arguments().iter();
let args_and_params = args.zip(&impl_params).zip(&impl_method.parameters.0);

for (parameter_index, ((expected, actual), (hir_pattern, _, _))) in
args_and_params.enumerate()
{
expected.unify(actual, &mut typecheck_errors, || {
TypeCheckError::TraitMethodParameterTypeMismatch {
method_name: func_name.to_string(),
expected_typ: expected.to_string(),
actual_typ: actual.to_string(),
parameter_span: hir_pattern.span(),
parameter_index: parameter_index + 1,
}
});
}
} else {
// This instantiation is technically not needed. We could bind each generic in the
// trait function to the impl's corresponding generic but to do so we'd have to rely
// on the trait function's generics being first in the generic list, since the same
// list also contains the generic `Self` variable, and any generics on the trait itself.
//
// Instantiating the impl method's generics here instead is a bit less precise but
// doesn't rely on any orderings that may be changed.
let impl_function_type = impl_method.typ.instantiate(resolver.interner).0;

let mut bindings = TypeBindings::new();
let mut typecheck_errors = Vec::new();

if let Type::Function(impl_params, impl_return, _) = impl_function_type.as_monotype() {
if trait_method.arguments().len() != impl_params.len() {
let error = DefCollectorErrorKind::MismatchTraitImplementationNumParameters {
actual_num_parameters: impl_method.parameters.0.len(),
expected_num_parameters: trait_method.arguments().len(),
Expand All @@ -534,28 +555,51 @@ pub(crate) fn check_methods_signatures(
};
errors.push((error.into(), *file_id));
}
}

// Check that impl method return type matches trait return type:
let resolved_return_type =
resolver.resolve_type(impl_method.return_type.get_type().into_owned());
// Check the parameters of the impl method against the parameters of the trait method
let args = trait_method.arguments().iter();
let args_and_params = args.zip(impl_params).zip(&impl_method.parameters.0);

// TODO: This is not right since it may bind generic return types
trait_method.return_type().unify(&resolved_return_type, &mut typecheck_errors, || {
let impl_method = resolver.interner.function_meta(func_id);
let ret_type_span = impl_method.return_type.get_type().span;
let expr_span = ret_type_span.expect("return type must always have a span");
for (parameter_index, ((expected, actual), (hir_pattern, _, _))) in
args_and_params.enumerate()
{
if expected.try_unify(actual, &mut bindings).is_err() {
typecheck_errors.push(TypeCheckError::TraitMethodParameterTypeMismatch {
method_name: func_name.to_string(),
expected_typ: expected.to_string(),
actual_typ: actual.to_string(),
parameter_span: hir_pattern.span(),
parameter_index: parameter_index + 1,
});
}
}

let expected_typ = trait_method.return_type().to_string();
let expr_typ = impl_method.return_type().to_string();
TypeCheckError::TypeMismatch { expr_typ, expected_typ, expr_span }
});
if trait_method.return_type().try_unify(impl_return, &mut bindings).is_err() {
let impl_method = resolver.interner.function_meta(func_id);
let ret_type_span = impl_method.return_type.get_type().span;
let expr_span = ret_type_span.expect("return type must always have a span");

let expected_typ = trait_method.return_type().to_string();
let expr_typ = impl_method.return_type().to_string();
let error = TypeCheckError::TypeMismatch { expr_typ, expected_typ, expr_span };
typecheck_errors.push(error);
}
} else {
unreachable!(
"impl_function_type is not a function type, it is: {impl_function_type}"
);
}

errors.extend(typecheck_errors.iter().cloned().map(|e| (e.into(), *file_id)));
}
}

// Now unbind `Self` and the trait's generics
let the_trait = resolver.interner.get_trait_mut(trait_id);
the_trait.set_methods(trait_methods);
the_trait.self_type_typevar.unbind(the_trait.self_type_typevar_id);

for (old_id, generic) in &the_trait.generics {
generic.unbind(*old_id);
}
}
6 changes: 3 additions & 3 deletions compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ pub fn collect_defs(

errors.extend(collector.collect_functions(context, ast.functions, crate_id));

errors.extend(collector.collect_trait_impls(context, ast.trait_impls, crate_id));
collector.collect_trait_impls(context, ast.trait_impls, crate_id);

collector.collect_impls(context, ast.impls, crate_id);

Expand Down Expand Up @@ -144,7 +144,7 @@ impl<'a> ModCollector<'a> {
context: &mut Context,
impls: Vec<NoirTraitImpl>,
krate: CrateId,
) -> Vec<(CompilationError, fm::FileId)> {
) {
for trait_impl in impls {
let trait_name = trait_impl.trait_name.clone();

Expand All @@ -168,11 +168,11 @@ impl<'a> ModCollector<'a> {
generics: trait_impl.impl_generics,
where_clause: trait_impl.where_clause,
trait_id: None, // will be filled later
trait_generics: trait_impl.trait_generics,
};

self.def_collector.collected_traits_impls.push(unresolved_trait_impl);
}
vec![]
}

fn collect_trait_impl_function_overrides(
Expand Down
26 changes: 13 additions & 13 deletions compiler/noirc_frontend/src/hir/def_collector/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ pub enum DefCollectorErrorKind {
method_name: String,
span: Span,
},
#[error("Mismatched number of generics in impl method")]
MismatchTraitImplementationNumGenerics {
impl_method_generic_count: usize,
trait_method_generic_count: usize,
trait_name: String,
method_name: String,
#[error("Mismatched number of generics in {location}")]
MismatchGenericCount {
actual_generic_count: usize,
expected_generic_count: usize,
location: &'static str,
origin: String,
span: Span,
},
#[error("Method is not defined in trait")]
Expand Down Expand Up @@ -188,16 +188,16 @@ impl From<DefCollectorErrorKind> for Diagnostic {
"`{trait_name}::{method_name}` expects {expected_num_parameters} parameter{plural}, but this method has {actual_num_parameters}");
Diagnostic::simple_error(primary_message, "".to_string(), span)
}
DefCollectorErrorKind::MismatchTraitImplementationNumGenerics {
impl_method_generic_count,
trait_method_generic_count,
trait_name,
method_name,
DefCollectorErrorKind::MismatchGenericCount {
actual_generic_count,
expected_generic_count,
location,
origin,
span,
} => {
let plural = if trait_method_generic_count == 1 { "" } else { "s" };
let plural = if expected_generic_count == 1 { "" } else { "s" };
let primary_message = format!(
"`{trait_name}::{method_name}` expects {trait_method_generic_count} generic{plural}, but this method has {impl_method_generic_count}");
"`{origin}` expects {expected_generic_count} generic{plural}, but {location} has {actual_generic_count}");
Diagnostic::simple_error(primary_message, "".to_string(), span)
}
DefCollectorErrorKind::MethodNotInTrait { trait_name, impl_method } => {
Expand Down
8 changes: 4 additions & 4 deletions compiler/noirc_frontend/src/hir/resolution/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ pub enum ResolverError {
NonStructWithGenerics { span: Span },
#[error("Cannot apply generics on Self type")]
GenericsOnSelfType { span: Span },
#[error("Incorrect amount of arguments to generic type constructor")]
IncorrectGenericCount { span: Span, struct_type: String, actual: usize, expected: usize },
#[error("Incorrect amount of arguments to {item_name}")]
IncorrectGenericCount { span: Span, item_name: String, actual: usize, expected: usize },
#[error("{0}")]
ParserError(Box<ParserError>),
#[error("Function is not defined in a contract yet sets its contract visibility")]
Expand Down Expand Up @@ -259,12 +259,12 @@ impl From<ResolverError> for Diagnostic {
"Use an explicit type name or apply the generics at the start of the impl instead".into(),
span,
),
ResolverError::IncorrectGenericCount { span, struct_type, actual, expected } => {
ResolverError::IncorrectGenericCount { span, item_name, actual, expected } => {
let expected_plural = if expected == 1 { "" } else { "s" };
let actual_plural = if actual == 1 { "is" } else { "are" };

Diagnostic::simple_error(
format!("The struct type {struct_type} has {expected} generic{expected_plural} but {actual} {actual_plural} given here"),
format!("`{item_name}` has {expected} generic argument{expected_plural} but {actual} {actual_plural} given here"),
"Incorrect number of generic arguments".into(),
span,
)
Expand Down
Loading

0 comments on commit 916fd15

Please sign in to comment.