Skip to content

Commit

Permalink
fix(frontend): Resolve object types from method calls a single time (#…
Browse files Browse the repository at this point in the history
…5131)

# Description

## Problem\*

Resolves #5065 

Probably resolves #4732 but need to test it or have aztec team test it.

## Summary\*

When working with a program like such where `MyType` implements
`MyTrait`:
```rust
fn foo<T>() -> T where T: MyTrait {
    MyTrait::new()
}
fn concise_regression() -> MyType {
    Wrapper::new(foo()).unwrap()
}
```
We should be able to infer the return type of `foo`. We currently always
push trait constraints onto a a `Vec<(TraitConstraint, ExprId)>`. We
need to do this as we can have multiple trait constraints that need to
be handled during monomorphization due to generics. However, when
working with a method call this can cause us to store an old trait
constraint that does not necessarily apply to the expression.

The nested function call in `concise_regression` initially adds a trait
constraint simply for `foo` due to the call to `Wrapper::new(foo())` and
then another constraint for `Wrapper::new(foo()).unwrap()`. The call to
`Wrapper::new(foo())` cannot be bound to anything unless we introduce an
intermediate variable. This felt like it would be overly complex and we
just need to follow the accurate trait constraint for a function call
expression.

Taking the test in the issue and this PR we have the following trait
constraints on master for the `foo` expression:
```
TraitConstraint {
        typ: '23646 -> '23647,
        trait_id: TraitId(
            ModuleId {
                krate: Root(
                    1,
                ),
                local_id: LocalModuleId(
                    Index(
                        1,
                    ),
                ),
            },
        ),
        trait_generics: [],
    },
    TraitConstraint {
        typ: '23648 -> '23649 -> '23650 -> MyType,
        trait_id: TraitId(
            ModuleId {
                krate: Root(
                    1,
                ),
                local_id: LocalModuleId(
                    Index(
                        1,
                    ),
                ),
            },
        ),
        trait_generics: [],
    }
```
This is occurring due to an unnecessary type check on a method call's
object type. This is cause a repeated trait constraint where one has
incorrect type variables that cannot be resolved.

I have altered how MethodCall's and Call's are resolved as to avoid
repeated type checks on the object type.

## Additional Context



## Documentation\*

Check one:
- [X] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [X] I have tested the changes locally.
- [X] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: jfecher <jake@aztecprotocol.com>
  • Loading branch information
vezenovm and jfecher authored May 29, 2024
1 parent 69eca9b commit 3afe023
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 46 deletions.
118 changes: 72 additions & 46 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use iter_extended::vecmap;
use noirc_errors::Span;

use crate::ast::{BinaryOpKind, IntegerBitSize, UnaryOp};
use crate::hir_def::expr::HirCallExpression;
use crate::macros_api::Signedness;
use crate::{
hir::{resolution::resolver::verify_mutable_reference, type_check::errors::Source},
Expand Down Expand Up @@ -176,56 +177,20 @@ impl<'interner> TypeChecker<'interner> {
}
HirExpression::Index(index_expr) => self.check_index_expression(expr_id, index_expr),
HirExpression::Call(call_expr) => {
// Need to setup these flags here as `self` is borrowed mutably to type check the rest of the call expression
// These flags are later used to type check calls to unconstrained functions from constrained functions
let current_func = self.current_function;
let func_mod = current_func.map(|func| self.interner.function_modifiers(&func));
let is_current_func_constrained =
func_mod.map_or(true, |func_mod| !func_mod.is_unconstrained);
let is_unconstrained_call = self.is_unconstrained_call(&call_expr.func);

self.check_if_deprecated(&call_expr.func);

let function = self.check_expression(&call_expr.func);

let args = vecmap(&call_expr.arguments, |arg| {
let typ = self.check_expression(arg);
(typ, *arg, self.interner.expr_span(arg))
});

// Check that we are not passing a mutable reference from a constrained runtime to an unconstrained runtime
if is_current_func_constrained && is_unconstrained_call {
for (typ, _, _) in args.iter() {
if matches!(&typ.follow_bindings(), Type::MutableReference(_)) {
self.errors.push(TypeCheckError::ConstrainedReferenceToUnconstrained {
span: self.interner.expr_span(expr_id),
});
return Type::Error;
}
}
}

let span = self.interner.expr_span(expr_id);
let return_type = self.bind_function_type(function, args, span);

// Check that we are not passing a slice from an unconstrained runtime to a constrained runtime
if is_current_func_constrained && is_unconstrained_call {
if return_type.contains_slice() {
self.errors.push(TypeCheckError::UnconstrainedSliceReturnToConstrained {
span: self.interner.expr_span(expr_id),
});
return Type::Error;
} else if matches!(&return_type.follow_bindings(), Type::MutableReference(_)) {
self.errors.push(TypeCheckError::UnconstrainedReferenceToConstrained {
span: self.interner.expr_span(expr_id),
});
return Type::Error;
}
};

return_type
self.check_call(&call_expr, function, args, span)
}
HirExpression::MethodCall(mut method_call) => {
let method_call_span = self.interner.expr_span(expr_id);
let object = method_call.object;
let object_span = self.interner.expr_span(&method_call.object);
let mut object_type = self.check_expression(&method_call.object).follow_bindings();
let method_name = method_call.method.0.contents.as_str();
match self.lookup_method(&object_type, method_name, expr_id) {
Expand Down Expand Up @@ -259,19 +224,42 @@ impl<'interner> TypeChecker<'interner> {
);
}

// These arguments will be given to the desugared function call.
// Compared to the method arguments, they also contain the object.
let mut function_args = Vec::with_capacity(method_call.arguments.len() + 1);

function_args.push((object_type.clone(), object, object_span));

for arg in method_call.arguments.iter() {
let span = self.interner.expr_span(arg);
let typ = self.check_expression(arg);
function_args.push((typ, *arg, span));
}

// TODO: update object_type here?
let (_, function_call) = method_call.into_function_call(
let ((function_id, _), function_call) = method_call.into_function_call(
&method_ref,
object_type,
location,
self.interner,
);

self.interner.replace_expr(expr_id, HirExpression::Call(function_call));
let func_type = self.check_expression(&function_id);

// Type check the new call now that it has been changed from a method call
// to a function call. This way we avoid duplicating code.
self.check_expression(expr_id)
// We call `check_call` rather than `check_expression` directly as we want to avoid
// resolving the object type again once it is part of the arguments.
let typ = self.check_call(
&function_call,
func_type,
function_args,
method_call_span,
);

self.interner.replace_expr(expr_id, HirExpression::Call(function_call));

typ
}
None => Type::Error,
}
Expand Down Expand Up @@ -333,6 +321,45 @@ impl<'interner> TypeChecker<'interner> {
typ
}

fn check_call(
&mut self,
call: &HirCallExpression,
func_type: Type,
args: Vec<(Type, ExprId, Span)>,
span: Span,
) -> Type {
// Need to setup these flags here as `self` is borrowed mutably to type check the rest of the call expression
// These flags are later used to type check calls to unconstrained functions from constrained functions
let func_mod = self.current_function.map(|func| self.interner.function_modifiers(&func));
let is_current_func_constrained =
func_mod.map_or(true, |func_mod| !func_mod.is_unconstrained);

let is_unconstrained_call = self.is_unconstrained_call(&call.func);
self.check_if_deprecated(&call.func);

// Check that we are not passing a mutable reference from a constrained runtime to an unconstrained runtime
if is_current_func_constrained && is_unconstrained_call {
for (typ, _, _) in args.iter() {
if matches!(&typ.follow_bindings(), Type::MutableReference(_)) {
self.errors.push(TypeCheckError::ConstrainedReferenceToUnconstrained { span });
}
}
}

let return_type = self.bind_function_type(func_type, args, span);

// Check that we are not passing a slice from an unconstrained runtime to a constrained runtime
if is_current_func_constrained && is_unconstrained_call {
if return_type.contains_slice() {
self.errors.push(TypeCheckError::UnconstrainedSliceReturnToConstrained { span });
} else if matches!(&return_type.follow_bindings(), Type::MutableReference(_)) {
self.errors.push(TypeCheckError::UnconstrainedReferenceToConstrained { span });
}
};

return_type
}

fn check_block(&mut self, block: HirBlockExpression) -> Type {
let mut block_type = Type::Unit;

Expand Down Expand Up @@ -416,9 +443,8 @@ impl<'interner> TypeChecker<'interner> {
// Push any trait constraints required by this definition to the context
// to be checked later when the type of this variable is further constrained.
if let Some(definition) = self.interner.try_definition(ident.id) {
if let DefinitionKind::Function(function) = definition.kind {
let function = self.interner.function_meta(&function);

if let DefinitionKind::Function(func_id) = definition.kind {
let function = self.interner.function_meta(&func_id);
for mut constraint in function.trait_constraints.clone() {
constraint.apply_bindings(&bindings);
self.trait_constraints.push((constraint, *expr_id));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "regression_5065_failure"
type = "bin"
authors = [""]
compiler_version = ">=0.30.0"

[dependencies]
40 changes: 40 additions & 0 deletions test_programs/compile_failure/regression_5065_failure/src/main.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
struct Wrapper<T> {
_value: T,
}

impl<T> Wrapper<T> {
fn new(value: T) -> Self {
Self { _value: value }
}

fn unwrap(self) -> T {
self._value
}
}

trait MyTrait {
fn new() -> Self;
}

struct MyType {}

impl MyTrait for MyType {
fn new() -> Self {
MyType {}
}
}

fn foo<T>() -> T where T: MyTrait {
MyTrait::new()
}

struct BadType {}

// Check that we get "No matching impl found for `BadType: MyTrait`"
fn concise_regression() -> BadType {
Wrapper::new(foo()).unwrap()
}

fn main() {
let _ = concise_regression();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "regression_5065"
type = "bin"
authors = [""]
compiler_version = ">=0.30.0"

[dependencies]
45 changes: 45 additions & 0 deletions test_programs/compile_success_empty/regression_5065/src/main.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
struct Wrapper<T> {
_value: T,
}

impl<T> Wrapper<T> {
fn new_wrapper(value: T) -> Self {
Self { _value: value }
}

fn unwrap(self) -> T {
self._value
}
}

trait MyTrait {
fn new() -> Self;
}

struct MyType {}

impl MyTrait for MyType {
fn new() -> Self {
MyType {}
}
}

fn foo<T>() -> T where T: MyTrait {
MyTrait::new()
}

// fn verbose_but_compiles() -> MyType {
// let a = Wrapper::new_wrapper(foo());
// a.unwrap()
// }

// Check that are able to infer the return type of the call to `foo`
fn concise_regression() -> MyType {
Wrapper::new_wrapper(foo()).unwrap()
// Wrapper::unwrap(Wrapper::new_wrapper(foo()))
}

fn main() {
// let _ = verbose_but_compiles();
let _ = concise_regression();
}

0 comments on commit 3afe023

Please sign in to comment.