Skip to content

Commit

Permalink
feat: prefix operator overload trait dispatch (#5423)
Browse files Browse the repository at this point in the history
# Description

## Problem

Resolves #5000

## Summary

Allows `impl Not` and `impl Neg` to be picked up at runtime.

Note that this doesn't work yet at comptime: I'd prefer to do it in a
separate PR (mainly to keep PR smaller, but also because I'm not sure
how to test it).

## Additional Context

None.

## Documentation

Is documentation needed for this?

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: jfecher <jake@aztecprotocol.com>
  • Loading branch information
asterite and jfecher authored Jul 10, 2024
1 parent 0b74a18 commit a3bb09e
Show file tree
Hide file tree
Showing 19 changed files with 303 additions and 106 deletions.
51 changes: 39 additions & 12 deletions compiler/noirc_frontend/src/elaborator/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::{
HirStatement, Ident, IndexExpression, Literal, MemberAccessExpression,
MethodCallExpression, PrefixExpression,
},
node_interner::{DefinitionKind, ExprId, FuncId, ReferenceId},
node_interner::{DefinitionKind, ExprId, FuncId, ReferenceId, TraitMethodId},
token::Tokens,
QuotedType, Shared, StructType, Type,
};
Expand All @@ -39,7 +39,7 @@ impl<'context> Elaborator<'context> {
let (hir_expr, typ) = match expr.kind {
ExpressionKind::Literal(literal) => self.elaborate_literal(literal, expr.span),
ExpressionKind::Block(block) => self.elaborate_block(block),
ExpressionKind::Prefix(prefix) => self.elaborate_prefix(*prefix),
ExpressionKind::Prefix(prefix) => return self.elaborate_prefix(*prefix),
ExpressionKind::Index(index) => self.elaborate_index(*index),
ExpressionKind::Call(call) => self.elaborate_call(*call, expr.span),
ExpressionKind::MethodCall(call) => self.elaborate_method_call(*call, expr.span),
Expand Down Expand Up @@ -227,11 +227,22 @@ impl<'context> Elaborator<'context> {
(HirExpression::Literal(HirLiteral::FmtStr(str, fmt_str_idents)), typ)
}

fn elaborate_prefix(&mut self, prefix: PrefixExpression) -> (HirExpression, Type) {
fn elaborate_prefix(&mut self, prefix: PrefixExpression) -> (ExprId, Type) {
let span = prefix.rhs.span;
let (rhs, rhs_type) = self.elaborate_expression(prefix.rhs);
let ret_type = self.type_check_prefix_operand(&prefix.operator, &rhs_type, span);
(HirExpression::Prefix(HirPrefixExpression { operator: prefix.operator, rhs }), ret_type)
let trait_id = self.interner.get_prefix_operator_trait_method(&prefix.operator);

let operator = prefix.operator;
let expr =
HirExpression::Prefix(HirPrefixExpression { operator, rhs, trait_method_id: trait_id });
let expr_id = self.interner.push_expr(expr);
self.interner.push_expr_location(expr_id, span, self.file);

let result = self.prefix_operand_type_rules(&operator, &rhs_type, span);
let typ = self.handle_operand_type_rules_result(result, &rhs_type, trait_id, expr_id, span);

self.interner.push_expr_type(expr_id, typ.clone());
(expr_id, typ)
}

fn elaborate_index(&mut self, index_expr: IndexExpression) -> (HirExpression, Type) {
Expand Down Expand Up @@ -541,30 +552,46 @@ impl<'context> Elaborator<'context> {
let expr_id = self.interner.push_expr(expr);
self.interner.push_expr_location(expr_id, span, self.file);

let typ = match self.infix_operand_type_rules(&lhs_type, &operator, &rhs_type, span) {
let result = self.infix_operand_type_rules(&lhs_type, &operator, &rhs_type, span);
let typ =
self.handle_operand_type_rules_result(result, &lhs_type, Some(trait_id), expr_id, span);

self.interner.push_expr_type(expr_id, typ.clone());
(expr_id, typ)
}

fn handle_operand_type_rules_result(
&mut self,
result: Result<(Type, bool), TypeCheckError>,
operand_type: &Type,
trait_id: Option<TraitMethodId>,
expr_id: ExprId,
span: Span,
) -> Type {
match result {
Ok((typ, use_impl)) => {
if use_impl {
let trait_id =
trait_id.expect("ice: expected some trait_id when use_impl is true");

// Delay checking the trait constraint until the end of the function.
// Checking it now could bind an unbound type variable to any type
// that implements the trait.
let constraint = TraitConstraint {
typ: lhs_type.clone(),
typ: operand_type.clone(),
trait_id: trait_id.trait_id,
trait_generics: Vec::new(),
};
self.push_trait_constraint(constraint, expr_id);
self.type_check_operator_method(expr_id, trait_id, &lhs_type, span);
self.type_check_operator_method(expr_id, trait_id, operand_type, span);
}
typ
}
Err(error) => {
self.push_err(error);
Type::Error
}
};

self.interner.push_expr_type(expr_id, typ.clone());
(expr_id, typ)
}
}

fn elaborate_if(&mut self, if_expr: IfExpression) -> (HirExpression, Type) {
Expand Down
3 changes: 2 additions & 1 deletion compiler/noirc_frontend/src/elaborator/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ impl<'context> Elaborator<'context> {
// the interner may set `interner.ordering_type` based on the result type
// of the Cmp trait, if this is it.
if self.crate_id.is_stdlib() {
self.interner.try_add_operator_trait(trait_id);
self.interner.try_add_infix_operator_trait(trait_id);
self.interner.try_add_prefix_operator_trait(trait_id);
}
}
}
Expand Down
132 changes: 81 additions & 51 deletions compiler/noirc_frontend/src/elaborator/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -670,57 +670,6 @@ impl<'context> Elaborator<'context> {
}
}

pub(super) fn type_check_prefix_operand(
&mut self,
op: &crate::ast::UnaryOp,
rhs_type: &Type,
span: Span,
) -> Type {
let unify = |this: &mut Self, expected| {
this.unify(rhs_type, &expected, || TypeCheckError::TypeMismatch {
expr_typ: rhs_type.to_string(),
expected_typ: expected.to_string(),
expr_span: span,
});
expected
};

match op {
crate::ast::UnaryOp::Minus => {
if rhs_type.is_unsigned() {
self.push_err(TypeCheckError::InvalidUnaryOp {
kind: rhs_type.to_string(),
span,
});
}
let expected = self.polymorphic_integer_or_field();
self.unify(rhs_type, &expected, || TypeCheckError::InvalidUnaryOp {
kind: rhs_type.to_string(),
span,
});
expected
}
crate::ast::UnaryOp::Not => {
let rhs_type = rhs_type.follow_bindings();

// `!` can work on booleans or integers
if matches!(rhs_type, Type::Integer(..)) {
return rhs_type;
}

unify(self, Type::Bool)
}
crate::ast::UnaryOp::MutableReference => {
Type::MutableReference(Box::new(rhs_type.follow_bindings()))
}
crate::ast::UnaryOp::Dereference { implicitly_added: _ } => {
let element_type = self.interner.next_type_variable();
unify(self, Type::MutableReference(Box::new(element_type.clone())));
element_type
}
}
}

/// Insert as many dereference operations as necessary to automatically dereference a method
/// call object to its base value type T.
pub(super) fn insert_auto_dereferences(&mut self, object: ExprId, typ: Type) -> (ExprId, Type) {
Expand All @@ -730,6 +679,7 @@ impl<'context> Elaborator<'context> {
let object = self.interner.push_expr(HirExpression::Prefix(HirPrefixExpression {
operator: UnaryOp::Dereference { implicitly_added: true },
rhs: object,
trait_method_id: None,
}));
self.interner.push_expr_type(object, element.as_ref().clone());
self.interner.push_expr_location(object, location.span, location.file);
Expand Down Expand Up @@ -1073,6 +1023,84 @@ impl<'context> Elaborator<'context> {
}
}

// Given a unary operator and a type, this method will produce the output type
// and a boolean indicating whether to use the trait impl corresponding to the operator
// or not. A value of false indicates the caller to use a primitive operation for this
// operator, while a true value indicates a user-provided trait impl is required.
pub(super) fn prefix_operand_type_rules(
&mut self,
op: &UnaryOp,
rhs_type: &Type,
span: Span,
) -> Result<(Type, bool), TypeCheckError> {
use Type::*;

match op {
crate::ast::UnaryOp::Minus | crate::ast::UnaryOp::Not => {
match rhs_type {
// An error type will always return an error
Error => Ok((Error, false)),
Alias(alias, args) => {
let alias = alias.borrow().get_type(args);
self.prefix_operand_type_rules(op, &alias, span)
}

// Matches on TypeVariable must be first so that we follow any type
// bindings.
TypeVariable(int, _) => {
if let TypeBinding::Bound(binding) = &*int.borrow() {
return self.prefix_operand_type_rules(op, binding, span);
}

// The `!` prefix operator is not valid for Field, so if this is a numeric
// type we constrain it to just (non-Field) integer types.
if matches!(op, crate::ast::UnaryOp::Not) && rhs_type.is_numeric() {
let integer_type = Type::polymorphic_integer(self.interner);
self.unify(rhs_type, &integer_type, || {
TypeCheckError::InvalidUnaryOp { kind: rhs_type.to_string(), span }
});
}

Ok((rhs_type.clone(), !rhs_type.is_numeric()))
}
Integer(sign_x, bit_width_x) => {
if *op == UnaryOp::Minus && *sign_x == Signedness::Unsigned {
return Err(TypeCheckError::InvalidUnaryOp {
kind: rhs_type.to_string(),
span,
});
}
Ok((Integer(*sign_x, *bit_width_x), false))
}
// The result of a Field is always a witness
FieldElement => {
if *op == UnaryOp::Not {
return Err(TypeCheckError::FieldNot { span });
}
Ok((FieldElement, false))
}

Bool => Ok((Bool, false)),

_ => Ok((rhs_type.clone(), true)),
}
}
crate::ast::UnaryOp::MutableReference => {
Ok((Type::MutableReference(Box::new(rhs_type.follow_bindings())), false))
}
crate::ast::UnaryOp::Dereference { implicitly_added: _ } => {
let element_type = self.interner.next_type_variable();
let expected = Type::MutableReference(Box::new(element_type.clone()));
self.unify(rhs_type, &expected, || TypeCheckError::TypeMismatch {
expr_typ: rhs_type.to_string(),
expected_typ: expected.to_string(),
expr_span: span,
});
Ok((element_type, false))
}
}
}

/// Prerequisite: verify_trait_constraint of the operator's trait constraint.
///
/// Although by this point the operator is expected to already have a trait impl,
Expand Down Expand Up @@ -1140,6 +1168,7 @@ impl<'context> Elaborator<'context> {
*access_lhs = this.interner.push_expr(HirExpression::Prefix(HirPrefixExpression {
operator: crate::ast::UnaryOp::Dereference { implicitly_added: true },
rhs: old_lhs,
trait_method_id: None,
}));
this.interner.push_expr_type(old_lhs, lhs_type);
this.interner.push_expr_type(*access_lhs, element);
Expand Down Expand Up @@ -1362,6 +1391,7 @@ impl<'context> Elaborator<'context> {
self.interner.push_expr(HirExpression::Prefix(HirPrefixExpression {
operator: UnaryOp::MutableReference,
rhs: *object,
trait_method_id: None,
}));
self.interner.push_expr_type(new_object, new_type);
self.interner.push_expr_location(new_object, location.span, location.file);
Expand Down
3 changes: 2 additions & 1 deletion compiler/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1548,14 +1548,15 @@ impl<'a> Resolver<'a> {
ExpressionKind::Prefix(prefix) => {
let operator = prefix.operator;
let rhs = self.resolve_expression(prefix.rhs);
let trait_method_id = self.interner.get_prefix_operator_trait_method(&operator);

if operator == UnaryOp::MutableReference {
if let Err(error) = verify_mutable_reference(self.interner, rhs) {
self.errors.push(error);
}
}

HirExpression::Prefix(HirPrefixExpression { operator, rhs })
HirExpression::Prefix(HirPrefixExpression { operator, rhs, trait_method_id })
}
ExpressionKind::Infix(infix) => {
let lhs = self.resolve_expression(infix.lhs);
Expand Down
3 changes: 2 additions & 1 deletion compiler/noirc_frontend/src/hir/resolution/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ pub(crate) fn resolve_traits(
// the interner may set `interner.ordering_type` based on the result type
// of the Cmp trait, if this is it.
if crate_id.is_stdlib() {
context.def_interner.try_add_operator_trait(trait_id);
context.def_interner.try_add_infix_operator_trait(trait_id);
context.def_interner.try_add_prefix_operator_trait(trait_id);
}
}
all_errors
Expand Down
3 changes: 3 additions & 0 deletions compiler/noirc_frontend/src/hir/type_check/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ pub enum TypeCheckError {
IntegerAndFieldBinaryOperation { span: Span },
#[error("Cannot do modulo on Fields, try casting to an integer first")]
FieldModulo { span: Span },
#[error("Cannot do not (`!`) on Fields, try casting to an integer first")]
FieldNot { span: Span },
#[error("Fields cannot be compared, try casting to an integer first")]
FieldComparison { span: Span },
#[error("The bit count in a bit-shift operation must fit in a u8, try casting the right hand side into a u8 first")]
Expand Down Expand Up @@ -256,6 +258,7 @@ impl<'a> From<&'a TypeCheckError> for Diagnostic {
| TypeCheckError::IntegerAndFieldBinaryOperation { span }
| TypeCheckError::OverflowingAssignment { span, .. }
| TypeCheckError::FieldModulo { span }
| TypeCheckError::FieldNot { span }
| TypeCheckError::ConstrainedReferenceToUnconstrained { span }
| TypeCheckError::UnconstrainedReferenceToConstrained { span }
| TypeCheckError::UnconstrainedSliceReturnToConstrained { span }
Expand Down
4 changes: 4 additions & 0 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,7 @@ impl<'interner> TypeChecker<'interner> {
self.interner.push_expr(HirExpression::Prefix(HirPrefixExpression {
operator: UnaryOp::MutableReference,
rhs: method_call.object,
trait_method_id: None,
}));
self.interner.push_expr_type(new_object, new_type);
self.interner.push_expr_location(new_object, location.span, location.file);
Expand All @@ -604,6 +605,7 @@ impl<'interner> TypeChecker<'interner> {
let object = self.interner.push_expr(HirExpression::Prefix(HirPrefixExpression {
operator: UnaryOp::Dereference { implicitly_added: true },
rhs: object,
trait_method_id: None,
}));
self.interner.push_expr_type(object, element.as_ref().clone());
self.interner.push_expr_location(object, location.span, location.file);
Expand Down Expand Up @@ -799,9 +801,11 @@ impl<'interner> TypeChecker<'interner> {

let dereference_lhs = |this: &mut Self, lhs_type, element| {
let old_lhs = *access_lhs;

*access_lhs = this.interner.push_expr(HirExpression::Prefix(HirPrefixExpression {
operator: crate::ast::UnaryOp::Dereference { implicitly_added: true },
rhs: old_lhs,
trait_method_id: None,
}));
this.interner.push_expr_type(old_lhs, lhs_type);
this.interner.push_expr_type(*access_lhs, element);
Expand Down
4 changes: 4 additions & 0 deletions compiler/noirc_frontend/src/hir_def/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ pub enum HirArrayLiteral {
pub struct HirPrefixExpression {
pub operator: UnaryOp,
pub rhs: ExprId,

/// The trait method id for the operator trait method that corresponds to this operator,
/// if such a trait exists (for example, there's no trait for the dereference operator).
pub trait_method_id: Option<TraitMethodId>,
}

#[derive(Debug, Clone)]
Expand Down
Loading

0 comments on commit a3bb09e

Please sign in to comment.