Skip to content

Commit

Permalink
fix: disambiguate field or int static trait method call (#6112)
Browse files Browse the repository at this point in the history
# Description

## Problem

Resolves #6106

## Summary

## Additional Context

Let me know if this isn't the correct way to solve this.

## Documentation

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
asterite authored Sep 20, 2024
1 parent ad83302 commit 5b27ea4
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 17 deletions.
55 changes: 39 additions & 16 deletions compiler/noirc_frontend/src/node_interner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,16 @@ pub enum ImplSearchErrorKind {
#[derive(Default, Debug, Clone)]
pub struct Methods {
pub direct: Vec<FuncId>,
pub trait_impl_methods: Vec<FuncId>,
pub trait_impl_methods: Vec<TraitImplMethod>,
}

#[derive(Debug, Clone)]
pub struct TraitImplMethod {
// This type is only stored for primitive types to be able to
// select the correct static methods between multiple options keyed
// under TypeMethodKey::FieldOrInt
pub typ: Option<Type>,
pub method: FuncId,
}

/// All the information from a function that is filled out during definition collection rather than
Expand Down Expand Up @@ -1383,7 +1392,7 @@ impl NodeInterner {
.or_default()
.entry(method_name)
.or_default()
.add_method(method_id, is_trait_method);
.add_method(method_id, None, is_trait_method);
None
}
Type::Error => None,
Expand All @@ -1395,12 +1404,16 @@ impl NodeInterner {
let key = get_type_method_key(self_type).unwrap_or_else(|| {
unreachable!("Cannot add a method to the unsupported type '{}'", other)
});
// Only remember the actual type if it's FieldOrInt,
// so later we can disambiguate on calls like `u32::call`.
let typ =
if key == TypeMethodKey::FieldOrInt { Some(self_type.clone()) } else { None };
self.primitive_methods
.entry(key)
.or_default()
.entry(method_name)
.or_default()
.add_method(method_id, is_trait_method);
.add_method(method_id, typ, is_trait_method);
None
}
}
Expand Down Expand Up @@ -2246,23 +2259,29 @@ impl Methods {
if self.direct.len() == 1 {
Some(self.direct[0])
} else if self.direct.is_empty() && self.trait_impl_methods.len() == 1 {
Some(self.trait_impl_methods[0])
Some(self.trait_impl_methods[0].method)
} else {
None
}
}

fn add_method(&mut self, method: FuncId, is_trait_method: bool) {
fn add_method(&mut self, method: FuncId, typ: Option<Type>, is_trait_method: bool) {
if is_trait_method {
self.trait_impl_methods.push(method);
let trait_impl_method = TraitImplMethod { typ, method };
self.trait_impl_methods.push(trait_impl_method);
} else {
self.direct.push(method);
}
}

/// Iterate through each method, starting with the direct methods
pub fn iter(&self) -> impl Iterator<Item = FuncId> + '_ {
self.direct.iter().copied().chain(self.trait_impl_methods.iter().copied())
pub fn iter(&self) -> impl Iterator<Item = (FuncId, &Option<Type>)> + '_ {
let trait_impl_methods = self.trait_impl_methods.iter().map(|m| (m.method, &m.typ));
let direct = self.direct.iter().copied().map(|func_id| {
let typ: &Option<Type> = &None;
(func_id, typ)
});
direct.chain(trait_impl_methods)
}

/// Select the 1 matching method with an object type matching `typ`
Expand All @@ -2274,28 +2293,32 @@ impl Methods {
) -> Option<FuncId> {
// When adding methods we always check they do not overlap, so there should be
// at most 1 matching method in this list.
for method in self.iter() {
for (method, method_type) in self.iter() {
match interner.function_meta(&method).typ.instantiate(interner).0 {
Type::Function(args, _, _, _) => {
if has_self_param {
if let Some(object) = args.first() {
let mut bindings = TypeBindings::new();

if object.try_unify(typ, &mut bindings).is_ok() {
Type::apply_type_bindings(bindings);
if object.unify(typ).is_ok() {
return Some(method);
}
}
} else {
// Just return the first method whose name matches since we
// can't match object types on static methods.
return Some(method);
// If we recorded the concrete type this trait impl method belongs to,
// and it matches typ, it's an exact match and we return that.
if let Some(method_type) = method_type {
if method_type.unify(typ).is_ok() {
return Some(method);
}
} else {
return Some(method);
}
}
}
Type::Error => (),
other => unreachable!("Expected function type, found {other}"),
}
}

None
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "field_or_int_static_trait_method"
type = "bin"
authors = [""]
compiler_version = ">=0.32.0"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
trait Read {
fn read(data: [Field; 1]) -> Self;
}

impl Read for Field {
fn read(data: [Field; 1]) -> Self {
data[0] * 10
}
}

impl Read for u32 {
fn read(data: [Field; 1]) -> Self {
data[0] as u32
}
}

fn main() {
let data = [1];

let value: u32 = u32::read(data);
assert_eq(value, 1);

let value: Field = Field::read(data);
assert_eq(value, 10);
}
2 changes: 1 addition & 1 deletion tooling/lsp/src/requests/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ impl<'a> NodeFinder<'a> {
};

for (name, methods) in methods_by_name {
for func_id in methods.iter() {
for (func_id, _method_type) in methods.iter() {
if name_matches(name, prefix) {
let completion_items = self.function_completion_items(
name,
Expand Down

0 comments on commit 5b27ea4

Please sign in to comment.