Skip to content

Commit

Permalink
feat: Add TraitConstraint type (#5499)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves #5480
Resolves #5481

## Summary\*

Adds:
- The `TraitConstraint` type
- `impl Eq for TraitConstraint`
- `impl Hash for TraitConstraint`
- `Quoted::as_trait_constraint`

## Additional Context

Ran into the type error when calling trait impls issue again while
working on this. Hence why it is a draft.

## Documentation\*

Check one:
- [ ] No documentation needed.
- [ ] Documentation included in this PR.
- [x] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
jfecher authored Jul 17, 2024
1 parent 6a7f593 commit 30cb65a
Show file tree
Hide file tree
Showing 23 changed files with 323 additions and 120 deletions.
1 change: 1 addition & 0 deletions aztec_macros/src/transforms/note_interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ pub fn generate_note_interface_impl(module: &mut SortedModule) -> Result<(), Azt
generics: vec![],
methods: vec![],
where_clause: vec![],
is_comptime: false,
};
module.impls.push(default_impl.clone());
module.impls.last_mut().unwrap()
Expand Down
1 change: 1 addition & 0 deletions aztec_macros/src/transforms/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ pub fn generate_storage_implementation(
methods: vec![(init, Span::default())],

where_clause: vec![],
is_comptime: false,
};
module.impls.push(storage_impl);

Expand Down
13 changes: 1 addition & 12 deletions compiler/noirc_frontend/src/ast/structure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,7 @@ pub struct NoirStruct {
pub generics: UnresolvedGenerics,
pub fields: Vec<(Ident, UnresolvedType)>,
pub span: Span,
}

impl NoirStruct {
pub fn new(
name: Ident,
attributes: Vec<SecondaryAttribute>,
generics: UnresolvedGenerics,
fields: Vec<(Ident, UnresolvedType)>,
span: Span,
) -> NoirStruct {
NoirStruct { name, attributes, generics, fields, span }
}
pub is_comptime: bool,
}

impl Display for NoirStruct {
Expand Down
5 changes: 4 additions & 1 deletion compiler/noirc_frontend/src/ast/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ pub struct TypeImpl {
pub generics: UnresolvedGenerics,
pub where_clause: Vec<UnresolvedTraitConstraint>,
pub methods: Vec<(NoirFunction, Span)>,
pub is_comptime: bool,
}

/// Ast node for an implementation of a trait for a particular type
Expand All @@ -69,6 +70,8 @@ pub struct NoirTraitImpl {
pub where_clause: Vec<UnresolvedTraitConstraint>,

pub items: Vec<TraitImplItem>,

pub is_comptime: bool,
}

/// Represents a simple trait constraint such as `where Foo: TraitY<U, V>`
Expand All @@ -84,7 +87,7 @@ pub struct UnresolvedTraitConstraint {
}

/// Represents a single trait bound, such as `TraitX` or `TraitY<U, V>`
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct TraitBound {
pub trait_path: Path,
pub trait_id: Option<TraitId>, // initially None, gets assigned during DC
Expand Down
152 changes: 85 additions & 67 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1636,17 +1636,25 @@ impl<'context> Elaborator<'context> {
function_sets.push(UnresolvedFunctions { functions, file_id, trait_id, self_type });
}

let (comptime_trait_impls, trait_impls) =
items.trait_impls.into_iter().partition(|trait_impl| trait_impl.is_comptime);

let (comptime_structs, structs) =
items.types.into_iter().partition(|typ| typ.1.struct_def.is_comptime);

let comptime = CollectedItems {
functions: comptime_function_sets,
types: BTreeMap::new(),
types: comptime_structs,
type_aliases: BTreeMap::new(),
traits: BTreeMap::new(),
trait_impls: Vec::new(),
trait_impls: comptime_trait_impls,
globals: Vec::new(),
impls: rustc_hash::FxHashMap::default(),
};

items.functions = function_sets;
items.trait_impls = trait_impls;
items.types = structs;
(comptime, items)
}

Expand All @@ -1657,75 +1665,85 @@ impl<'context> Elaborator<'context> {
location: Location,
) {
for item in items {
match item {
TopLevelStatement::Function(function) => {
let id = self.interner.push_empty_fn();
let module = self.module_id();
self.interner.push_function(id, &function.def, module, location);
let functions = vec![(self.local_module, id, function)];
generated_items.functions.push(UnresolvedFunctions {
file_id: self.file,
functions,
trait_id: None,
self_type: None,
});
}
TopLevelStatement::TraitImpl(mut trait_impl) => {
let methods = dc_mod::collect_trait_impl_functions(
self.interner,
&mut trait_impl,
self.crate_id,
self.file,
self.local_module,
);
self.add_item(item, generated_items, location);
}
}

generated_items.trait_impls.push(UnresolvedTraitImpl {
file_id: self.file,
module_id: self.local_module,
trait_generics: trait_impl.trait_generics,
trait_path: trait_impl.trait_name,
object_type: trait_impl.object_type,
methods,
generics: trait_impl.impl_generics,
where_clause: trait_impl.where_clause,

// These last fields are filled in later
trait_id: None,
impl_id: None,
resolved_object_type: None,
resolved_generics: Vec::new(),
resolved_trait_generics: Vec::new(),
});
}
TopLevelStatement::Global(global) => {
let (global, error) = dc_mod::collect_global(
self.interner,
self.def_maps.get_mut(&self.crate_id).unwrap(),
global,
self.file,
self.local_module,
);
fn add_item(
&mut self,
item: TopLevelStatement,
generated_items: &mut CollectedItems,
location: Location,
) {
match item {
TopLevelStatement::Function(function) => {
let id = self.interner.push_empty_fn();
let module = self.module_id();
self.interner.push_function(id, &function.def, module, location);
let functions = vec![(self.local_module, id, function)];
generated_items.functions.push(UnresolvedFunctions {
file_id: self.file,
functions,
trait_id: None,
self_type: None,
});
}
TopLevelStatement::TraitImpl(mut trait_impl) => {
let methods = dc_mod::collect_trait_impl_functions(
self.interner,
&mut trait_impl,
self.crate_id,
self.file,
self.local_module,
);

generated_items.globals.push(global);
if let Some(error) = error {
self.errors.push(error);
}
}
// Assume that an error has already been issued
TopLevelStatement::Error => (),

TopLevelStatement::Module(_)
| TopLevelStatement::Import(_)
| TopLevelStatement::Struct(_)
| TopLevelStatement::Trait(_)
| TopLevelStatement::Impl(_)
| TopLevelStatement::TypeAlias(_)
| TopLevelStatement::SubModule(_) => {
let item = item.to_string();
let error = InterpreterError::UnsupportedTopLevelItemUnquote { item, location };
self.errors.push(error.into_compilation_error_pair());
generated_items.trait_impls.push(UnresolvedTraitImpl {
file_id: self.file,
module_id: self.local_module,
trait_generics: trait_impl.trait_generics,
trait_path: trait_impl.trait_name,
object_type: trait_impl.object_type,
methods,
generics: trait_impl.impl_generics,
where_clause: trait_impl.where_clause,
is_comptime: trait_impl.is_comptime,

// These last fields are filled in later
trait_id: None,
impl_id: None,
resolved_object_type: None,
resolved_generics: Vec::new(),
resolved_trait_generics: Vec::new(),
});
}
TopLevelStatement::Global(global) => {
let (global, error) = dc_mod::collect_global(
self.interner,
self.def_maps.get_mut(&self.crate_id).unwrap(),
global,
self.file,
self.local_module,
);

generated_items.globals.push(global);
if let Some(error) = error {
self.errors.push(error);
}
}
// Assume that an error has already been issued
TopLevelStatement::Error => (),

TopLevelStatement::Module(_)
| TopLevelStatement::Import(_)
| TopLevelStatement::Struct(_)
| TopLevelStatement::Trait(_)
| TopLevelStatement::Impl(_)
| TopLevelStatement::TypeAlias(_)
| TopLevelStatement::SubModule(_) => {
let item = item.to_string();
let error = InterpreterError::UnsupportedTopLevelItemUnquote { item, location };
self.errors.push(error.into_compilation_error_pair());
}
}
}

Expand Down
93 changes: 86 additions & 7 deletions compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
use std::rc::Rc;
use std::{
hash::{Hash, Hasher},
rc::Rc,
};

use chumsky::Parser;
use noirc_errors::Location;

use crate::{
ast::IntegerBitSize,
ast::{IntegerBitSize, TraitBound},
hir::comptime::{errors::IResult, InterpreterError, Value},
macros_api::{NodeInterner, Signedness},
parser,
token::{SpannedToken, Token, Tokens},
QuotedType, Type,
};
Expand All @@ -29,6 +34,9 @@ pub(super) fn call_builtin(
"struct_def_as_type" => struct_def_as_type(interner, arguments, location),
"struct_def_fields" => struct_def_fields(interner, arguments, location),
"struct_def_generics" => struct_def_generics(interner, arguments, location),
"trait_constraint_eq" => trait_constraint_eq(interner, arguments, location),
"trait_constraint_hash" => trait_constraint_hash(interner, arguments, location),
"quoted_as_trait_constraint" => quoted_as_trait_constraint(interner, arguments, location),
_ => {
let item = format!("Comptime evaluation for builtin function {name}");
Err(InterpreterError::Unimplemented { item, location })
Expand Down Expand Up @@ -79,6 +87,26 @@ fn get_u32(value: Value, location: Location) -> IResult<u32> {
}
}

fn get_trait_constraint(value: Value, location: Location) -> IResult<TraitBound> {
match value {
Value::TraitConstraint(bound) => Ok(bound),
value => {
let expected = Type::Quoted(QuotedType::TraitConstraint);
Err(InterpreterError::TypeMismatch { expected, value, location })
}
}
}

fn get_quoted(value: Value, location: Location) -> IResult<Rc<Tokens>> {
match value {
Value::Code(tokens) => Ok(tokens),
value => {
let expected = Type::Quoted(QuotedType::Quoted);
Err(InterpreterError::TypeMismatch { expected, value, location })
}
}
}

fn array_len(
interner: &NodeInterner,
mut arguments: Vec<(Value, Location)>,
Expand Down Expand Up @@ -231,7 +259,7 @@ fn slice_remove(
interner: &mut NodeInterner,
mut arguments: Vec<(Value, Location)>,
location: Location,
) -> Result<Value, InterpreterError> {
) -> IResult<Value> {
check_argument_count(2, &arguments, location)?;

let index = get_u32(arguments.pop().unwrap().0, location)? as usize;
Expand All @@ -257,7 +285,7 @@ fn slice_push_front(
interner: &mut NodeInterner,
mut arguments: Vec<(Value, Location)>,
location: Location,
) -> Result<Value, InterpreterError> {
) -> IResult<Value> {
check_argument_count(2, &arguments, location)?;

let (element, _) = arguments.pop().unwrap();
Expand All @@ -270,7 +298,7 @@ fn slice_pop_front(
interner: &mut NodeInterner,
mut arguments: Vec<(Value, Location)>,
location: Location,
) -> Result<Value, InterpreterError> {
) -> IResult<Value> {
check_argument_count(1, &arguments, location)?;

let (mut values, typ) = get_slice(interner, arguments.pop().unwrap().0, location)?;
Expand All @@ -284,7 +312,7 @@ fn slice_pop_back(
interner: &mut NodeInterner,
mut arguments: Vec<(Value, Location)>,
location: Location,
) -> Result<Value, InterpreterError> {
) -> IResult<Value> {
check_argument_count(1, &arguments, location)?;

let (mut values, typ) = get_slice(interner, arguments.pop().unwrap().0, location)?;
Expand All @@ -298,7 +326,7 @@ fn slice_insert(
interner: &mut NodeInterner,
mut arguments: Vec<(Value, Location)>,
location: Location,
) -> Result<Value, InterpreterError> {
) -> IResult<Value> {
check_argument_count(3, &arguments, location)?;

let (element, _) = arguments.pop().unwrap();
Expand All @@ -307,3 +335,54 @@ fn slice_insert(
values.insert(index as usize, element);
Ok(Value::Slice(values, typ))
}

// fn as_trait_constraint(quoted: Quoted) -> TraitConstraint
fn quoted_as_trait_constraint(
_interner: &mut NodeInterner,
mut arguments: Vec<(Value, Location)>,
location: Location,
) -> IResult<Value> {
check_argument_count(1, &arguments, location)?;

let tokens = get_quoted(arguments.pop().unwrap().0, location)?;
let quoted = tokens.as_ref().clone();

let trait_bound = parser::trait_bound().parse(quoted).map_err(|mut errors| {
let error = errors.swap_remove(0);
let rule = "a trait constraint";
InterpreterError::FailedToParseMacro { error, tokens, rule, file: location.file }
})?;

Ok(Value::TraitConstraint(trait_bound))
}

// fn constraint_hash(constraint: TraitConstraint) -> Field
fn trait_constraint_hash(
_interner: &mut NodeInterner,
mut arguments: Vec<(Value, Location)>,
location: Location,
) -> IResult<Value> {
check_argument_count(1, &arguments, location)?;

let bound = get_trait_constraint(arguments.pop().unwrap().0, location)?;

let mut hasher = std::collections::hash_map::DefaultHasher::new();
bound.hash(&mut hasher);
let hash = hasher.finish();

Ok(Value::Field((hash as u128).into()))
}

// fn constraint_eq(constraint_a: TraitConstraint, constraint_b: TraitConstraint) -> bool
fn trait_constraint_eq(
_interner: &mut NodeInterner,
mut arguments: Vec<(Value, Location)>,
location: Location,
) -> IResult<Value> {
check_argument_count(2, &arguments, location)?;

let constraint_b = get_trait_constraint(arguments.pop().unwrap().0, location)?;
let constraint_a = get_trait_constraint(arguments.pop().unwrap().0, location)?;

Ok(Value::Bool(constraint_a == constraint_b))
}
Loading

0 comments on commit 30cb65a

Please sign in to comment.