Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: "Types in a binary operation should match, but found T and T" #4648

Merged
merged 11 commits into from
Mar 29, 2024
7 changes: 2 additions & 5 deletions compiler/noirc_frontend/src/hir/type_check/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ pub enum TypeCheckError {
VariableMustBeMutable { name: String, span: Span },
#[error("No method named '{method_name}' found for type '{object_type}'")]
UnresolvedMethodCall { method_name: String, object_type: Type, span: Span },
#[error("Comparisons are invalid on Field types. Try casting the operands to a sized integer type first")]
InvalidComparisonOnField { span: Span },
#[error("Integers must have the same signedness LHS is {sign_x:?}, RHS is {sign_y:?}")]
IntegerSignedness { sign_x: Signedness, sign_y: Signedness, span: Span },
#[error("Integers must have the same bit width LHS is {bit_width_x}, RHS is {bit_width_y}")]
Expand All @@ -76,7 +74,7 @@ pub enum TypeCheckError {
#[error("{kind} cannot be used in a unary operation")]
InvalidUnaryOp { kind: String, span: Span },
#[error("Bitwise operations are invalid on Field types. Try casting the operands to a sized integer type first.")]
InvalidBitwiseOperationOnField { span: Span },
FieldBitwiseOp { span: Span },
#[error("Integer cannot be used with type {typ}")]
IntegerTypeMismatch { typ: Type, span: Span },
#[error("Cannot use an integer and a Field in a binary operation, try converting the Field into an integer first")]
Expand Down Expand Up @@ -224,12 +222,11 @@ impl From<TypeCheckError> for Diagnostic {
| TypeCheckError::TupleIndexOutOfBounds { span, .. }
| TypeCheckError::VariableMustBeMutable { span, .. }
| TypeCheckError::UnresolvedMethodCall { span, .. }
| TypeCheckError::InvalidComparisonOnField { span }
| TypeCheckError::IntegerSignedness { span, .. }
| TypeCheckError::IntegerBitWidth { span, .. }
| TypeCheckError::InvalidInfixOp { span, .. }
| TypeCheckError::InvalidUnaryOp { span, .. }
| TypeCheckError::InvalidBitwiseOperationOnField { span, .. }
| TypeCheckError::FieldBitwiseOp { span, .. }
| TypeCheckError::IntegerTypeMismatch { span, .. }
| TypeCheckError::FieldComparison { span, .. }
| TypeCheckError::AmbiguousBitWidth { span, .. }
Expand Down
85 changes: 46 additions & 39 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,16 @@ impl<'interner> TypeChecker<'interner> {
Ok((typ, use_impl)) => {
if use_impl {
let id = infix_expr.trait_method_id;
// Assume operators have no trait generics
self.verify_trait_constraint(
&lhs_type,
id.trait_id,
&[],
*expr_id,
span,
);

// Delay checking the trait constraint until the end of the function.
// Checking it now could bind an unbound type variable to any type
// that implements the trait.
let constraint = crate::hir_def::traits::TraitConstraint {
typ: lhs_type.clone(),
trait_id: id.trait_id,
trait_generics: Vec::new(),
};
self.trait_constraints.push((constraint, *expr_id));
self.typecheck_operator_method(*expr_id, id, &lhs_type, span);
}
typ
Expand Down Expand Up @@ -836,6 +838,10 @@ impl<'interner> TypeChecker<'interner> {
match (lhs_type, rhs_type) {
// Avoid reporting errors multiple times
(Error, _) | (_, Error) => Ok((Bool, false)),
(Alias(alias, args), other) | (other, Alias(alias, args)) => {
let alias = alias.borrow().get_type(args);
self.comparator_operand_type_rules(&alias, other, op, span)
}

// Matches on TypeVariable must be first to follow any type
// bindings.
Expand All @@ -844,12 +850,8 @@ impl<'interner> TypeChecker<'interner> {
return self.comparator_operand_type_rules(other, binding, op, span);
}

self.bind_type_variables_for_infix(lhs_type, op, rhs_type, span);
Ok((Bool, false))
}
(Alias(alias, args), other) | (other, Alias(alias, args)) => {
let alias = alias.borrow().get_type(args);
self.comparator_operand_type_rules(&alias, other, op, span)
let use_impl = self.bind_type_variables_for_infix(lhs_type, op, rhs_type, span);
Ok((Bool, use_impl))
}
(Integer(sign_x, bit_width_x), Integer(sign_y, bit_width_y)) => {
if sign_x != sign_y {
Expand Down Expand Up @@ -1079,36 +1081,43 @@ impl<'interner> TypeChecker<'interner> {
}
}

/// Handles the TypeVariable case for checking binary operators.
/// Returns true if we should use the impl for the operator instead of the primitive
/// version of it.
fn bind_type_variables_for_infix(
&mut self,
lhs_type: &Type,
op: &HirBinaryOp,
rhs_type: &Type,
span: Span,
) {
) -> bool {
self.unify(lhs_type, rhs_type, || TypeCheckError::TypeMismatchWithSource {
expected: lhs_type.clone(),
actual: rhs_type.clone(),
source: Source::Binary,
span,
});

// In addition to unifying both types, we also have to bind either
// the lhs or rhs to an integer type variable. This ensures if both lhs
// and rhs are type variables, that they will have the correct integer
// type variable kind instead of TypeVariableKind::Normal.
let target = if op.kind.is_valid_for_field_type() {
Type::polymorphic_integer_or_field(self.interner)
} else {
Type::polymorphic_integer(self.interner)
};
let use_impl = !lhs_type.is_numeric();

// If this operator isn't valid for fields we have to possibly narrow
// TypeVariableKind::IntegerOrField to TypeVariableKind::Integer.
// Doing so also ensures a type error if Field is used.
// The is_numeric check is to allow impls for custom types to bypass this.
if !op.kind.is_valid_for_field_type() && lhs_type.is_numeric() {
let target = Type::polymorphic_integer(self.interner);

use BinaryOpKind::*;
use TypeCheckError::*;
self.unify(lhs_type, &target, || match op.kind {
Less | LessEqual | Greater | GreaterEqual => FieldComparison { span },
And | Or | Xor | ShiftRight | ShiftLeft => FieldBitwiseOp { span },
Modulo => FieldModulo { span },
other => unreachable!("Operator {other:?} should be valid for Field"),
});
}

self.unify(lhs_type, &target, || TypeCheckError::TypeMismatchWithSource {
expected: lhs_type.clone(),
actual: rhs_type.clone(),
source: Source::Binary,
span,
});
use_impl
}

// Given a binary operator and another type. This method will produce the output type
Expand All @@ -1130,6 +1139,10 @@ impl<'interner> TypeChecker<'interner> {
match (lhs_type, rhs_type) {
// An error type on either side will always return an error
(Error, _) | (_, Error) => Ok((Error, false)),
(Alias(alias, args), other) | (other, Alias(alias, args)) => {
let alias = alias.borrow().get_type(args);
self.infix_operand_type_rules(&alias, op, other, span)
}

// Matches on TypeVariable must be first so that we follow any type
// bindings.
Expand All @@ -1138,14 +1151,8 @@ impl<'interner> TypeChecker<'interner> {
return self.infix_operand_type_rules(binding, op, other, span);
}

self.bind_type_variables_for_infix(lhs_type, op, rhs_type, span);

// Both types are unified so the choice of which to return is arbitrary
Ok((other.clone(), false))
}
(Alias(alias, args), other) | (other, Alias(alias, args)) => {
let alias = alias.borrow().get_type(args);
self.infix_operand_type_rules(&alias, op, other, span)
let use_impl = self.bind_type_variables_for_infix(lhs_type, op, rhs_type, span);
Ok((other.clone(), use_impl))
}
(Integer(sign_x, bit_width_x), Integer(sign_y, bit_width_y)) => {
if sign_x != sign_y {
Expand All @@ -1170,7 +1177,7 @@ impl<'interner> TypeChecker<'interner> {
if op.kind == BinaryOpKind::Modulo {
return Err(TypeCheckError::FieldModulo { span });
} else {
return Err(TypeCheckError::InvalidBitwiseOperationOnField { span });
return Err(TypeCheckError::FieldBitwiseOp { span });
}
}
Ok((FieldElement, false))
Expand Down
50 changes: 24 additions & 26 deletions compiler/noirc_frontend/src/hir/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,31 +86,13 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec<Type

let function_last_type = type_checker.check_function_body(function_body_id);

// Verify any remaining trait constraints arising from the function body
for (constraint, expr_id) in std::mem::take(&mut type_checker.trait_constraints) {
let span = type_checker.interner.expr_span(&expr_id);
type_checker.verify_trait_constraint(
&constraint.typ,
constraint.trait_id,
&constraint.trait_generics,
expr_id,
span,
);
}

errors.append(&mut type_checker.errors);

// Now remove all the `where` clause constraints we added
for constraint in &expected_trait_constraints {
interner.remove_assumed_trait_implementations_for_trait(constraint.trait_id);
}

// Check declared return type and actual return type
if !can_ignore_ret {
let (expr_span, empty_function) = function_info(interner, function_body_id);
let func_span = interner.expr_span(function_body_id); // XXX: We could be more specific and return the span of the last stmt, however stmts do not have spans yet
let (expr_span, empty_function) = function_info(type_checker.interner, function_body_id);
let func_span = type_checker.interner.expr_span(function_body_id); // XXX: We could be more specific and return the span of the last stmt, however stmts do not have spans yet
if let Type::TraitAsType(trait_id, _, generics) = &declared_return_type {
if interner
if type_checker
.interner
.lookup_trait_implementation(&function_last_type, *trait_id, generics)
.is_err()
{
Expand All @@ -126,7 +108,7 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec<Type
function_last_type.unify_with_coercions(
&declared_return_type,
*function_body_id,
interner,
type_checker.interner,
&mut errors,
|| {
let mut error = TypeCheckError::TypeMismatchWithSource {
Expand All @@ -137,16 +119,32 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec<Type
};

if empty_function {
error = error.add_context(
"implicitly returns `()` as its body has no tail or `return` expression",
);
error = error.add_context("implicitly returns `()` as its body has no tail or `return` expression");
}
error
},
);
}
}

// Verify any remaining trait constraints arising from the function body
for (constraint, expr_id) in std::mem::take(&mut type_checker.trait_constraints) {
let span = type_checker.interner.expr_span(&expr_id);
type_checker.verify_trait_constraint(
&constraint.typ,
constraint.trait_id,
&constraint.trait_generics,
expr_id,
span,
);
}

// Now remove all the `where` clause constraints we added
for constraint in &expected_trait_constraints {
type_checker.interner.remove_assumed_trait_implementations_for_trait(constraint.trait_id);
}

errors.append(&mut type_checker.errors);
errors
}

Expand Down
10 changes: 10 additions & 0 deletions compiler/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,7 @@
TypeBinding::Bound(binding) => binding.is_bindable(),
TypeBinding::Unbound(_) => true,
},
Type::Alias(alias, args) => alias.borrow().get_type(args).is_bindable(),
_ => false,
}
}
Expand All @@ -605,6 +606,15 @@
matches!(self.follow_bindings(), Type::Integer(Signedness::Unsigned, _))
}

pub fn is_numeric(&self) -> bool {
use Type::*;
use TypeVariableKind as K;
matches!(
self.follow_bindings(),
FieldElement | Integer(..) | Bool | TypeVariable(_, K::Integer | K::IntegerOrField)
)
}

fn contains_numeric_typevar(&self, target_id: TypeVariableId) -> bool {
// True if the given type is a NamedGeneric with the target_id
let named_generic_id_matches_target = |typ: &Type| {
Expand Down Expand Up @@ -1510,7 +1520,7 @@
Type::Tuple(fields)
}
Type::Forall(typevars, typ) => {
// Trying to substitute_helper a variable de, substitute_bound_typevarsfined within a nested Forall

Check warning on line 1523 in compiler/noirc_frontend/src/hir_def/types.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (typevarsfined)
// is usually impossible and indicative of an error in the type checker somewhere.
for var in typevars {
assert!(!type_bindings.contains_key(&var.id()));
Expand Down
12 changes: 6 additions & 6 deletions compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1033,19 +1033,19 @@ mod test {
fn resolve_complex_closures() {
let src = r#"
fn main(x: Field) -> pub Field {
let closure_without_captures = |x| x + x;
let closure_without_captures = |x: Field| -> Field { x + x };
let a = closure_without_captures(1);

let closure_capturing_a_param = |y| y + x;
let closure_capturing_a_param = |y: Field| -> Field { y + x };
let b = closure_capturing_a_param(2);

let closure_capturing_a_local_var = |y| y + b;
let closure_capturing_a_local_var = |y: Field| -> Field { y + b };
let c = closure_capturing_a_local_var(3);

let closure_with_transitive_captures = |y| {
let closure_with_transitive_captures = |y: Field| -> Field {
let d = 5;
let nested_closure = |z| {
let doubly_nested_closure = |w| w + x + b;
let nested_closure = |z: Field| -> Field {
let doubly_nested_closure = |w: Field| -> Field { w + x + b };
a + z + y + d + x + doubly_nested_closure(4) + x + y
};
let res = nested_closure(5);
Expand Down
14 changes: 7 additions & 7 deletions noir_stdlib/src/cmp.nr
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ trait Eq {

impl Eq for Field { fn eq(self, other: Field) -> bool { self == other } }

impl Eq for u1 { fn eq(self, other: u1) -> bool { self == other } }
impl Eq for u8 { fn eq(self, other: u8) -> bool { self == other } }
impl Eq for u32 { fn eq(self, other: u32) -> bool { self == other } }
impl Eq for u64 { fn eq(self, other: u64) -> bool { self == other } }
impl Eq for u32 { fn eq(self, other: u32) -> bool { self == other } }
impl Eq for u8 { fn eq(self, other: u8) -> bool { self == other } }
impl Eq for u1 { fn eq(self, other: u1) -> bool { self == other } }

impl Eq for i8 { fn eq(self, other: i8) -> bool { self == other } }
impl Eq for i32 { fn eq(self, other: i32) -> bool { self == other } }
Expand Down Expand Up @@ -107,8 +107,8 @@ trait Ord {

// Note: Field deliberately does not implement Ord

impl Ord for u8 {
fn cmp(self, other: u8) -> Ordering {
impl Ord for u64 {
fn cmp(self, other: u64) -> Ordering {
if self < other {
Ordering::less()
} else if self > other {
Expand All @@ -131,8 +131,8 @@ impl Ord for u32 {
}
}

impl Ord for u64 {
fn cmp(self, other: u64) -> Ordering {
impl Ord for u8 {
fn cmp(self, other: u8) -> Ordering {
if self < other {
Ordering::less()
} else if self > other {
Expand Down
Loading
Loading