Skip to content

Commit

Permalink
feat: overload == to call Eq trait implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAFrench committed Dec 21, 2023
1 parent b8c7078 commit 1ffe9f3
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 98 deletions.
93 changes: 0 additions & 93 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,11 +519,6 @@ impl<'a> FunctionContext<'a> {
};
self.builder.insert_shift_right(lhs, rhs, bit_size)
}
BinaryOpKind::Equal | BinaryOpKind::NotEqual
if matches!(result_type, Type::Array(..)) =>
{
return self.insert_array_equality(lhs, operator, rhs, location)
}
_ => {
let op = convert_operator(operator);
if operator_requires_swapped_operands(operator) {
Expand Down Expand Up @@ -551,94 +546,6 @@ impl<'a> FunctionContext<'a> {
result.into()
}

/// The frontend claims to support equality (==) on arrays, so we must support it in SSA here.
/// The actual BinaryOp::Eq in SSA is meant only for primitive numeric types so we encode an
/// entire equality loop on each array element. The generated IR is as follows:
///
/// ...
/// result_alloc = allocate
/// store u1 1 in result_alloc
/// jmp loop_start(0)
/// loop_start(i: Field):
/// v0 = lt i, array_len
/// jmpif v0, then: loop_body, else: loop_end
/// loop_body():
/// v1 = array_get lhs, index i
/// v2 = array_get rhs, index i
/// v3 = eq v1, v2
/// v4 = load result_alloc
/// v5 = and v4, v3
/// store v5 in result_alloc
/// v6 = add i, Field 1
/// jmp loop_start(v6)
/// loop_end():
/// result = load result_alloc
fn insert_array_equality(
&mut self,
lhs: ValueId,
operator: noirc_frontend::BinaryOpKind,
rhs: ValueId,
location: Location,
) -> Values {
let lhs_type = self.builder.type_of_value(lhs);
let rhs_type = self.builder.type_of_value(rhs);

let (array_length, element_type) = match (lhs_type, rhs_type) {
(
Type::Array(lhs_composite_type, lhs_length),
Type::Array(rhs_composite_type, rhs_length),
) => {
assert!(
lhs_composite_type.len() == 1 && rhs_composite_type.len() == 1,
"== is unimplemented for arrays of structs"
);
assert_eq!(lhs_composite_type[0], rhs_composite_type[0]);
assert_eq!(lhs_length, rhs_length, "Expected two arrays of equal length");
(lhs_length, lhs_composite_type[0].clone())
}
_ => unreachable!("Expected two array values"),
};

let loop_start = self.builder.insert_block();
let loop_body = self.builder.insert_block();
let loop_end = self.builder.insert_block();

// pre-loop
let result_alloc = self.builder.set_location(location).insert_allocate(Type::bool());
let true_value = self.builder.numeric_constant(1u128, Type::bool());
self.builder.insert_store(result_alloc, true_value);
let zero = self.builder.field_constant(0u128);
self.builder.terminate_with_jmp(loop_start, vec![zero]);

// loop_start
self.builder.switch_to_block(loop_start);
let i = self.builder.add_block_parameter(loop_start, Type::field());
let array_length = self.builder.field_constant(array_length as u128);
let v0 = self.builder.insert_binary(i, BinaryOp::Lt, array_length);
self.builder.terminate_with_jmpif(v0, loop_body, loop_end);

// loop body
self.builder.switch_to_block(loop_body);
let v1 = self.builder.insert_array_get(lhs, i, element_type.clone());
let v2 = self.builder.insert_array_get(rhs, i, element_type);
let v3 = self.builder.insert_binary(v1, BinaryOp::Eq, v2);
let v4 = self.builder.insert_load(result_alloc, Type::bool());
let v5 = self.builder.insert_binary(v4, BinaryOp::And, v3);
self.builder.insert_store(result_alloc, v5);
let one = self.builder.field_constant(1u128);
let v6 = self.builder.insert_binary(i, BinaryOp::Add, one);
self.builder.terminate_with_jmp(loop_start, vec![v6]);

// loop end
self.builder.switch_to_block(loop_end);
let mut result = self.builder.insert_load(result_alloc, Type::bool());

if operator_requires_not(operator) {
result = self.builder.insert_not(result);
}
result.into()
}

/// Inserts a call instruction at the end of the current block and returns the results
/// of the call.
///
Expand Down
24 changes: 19 additions & 5 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,25 @@ impl<'interner> TypeChecker<'interner> {
let rhs_span = self.interner.expr_span(&infix_expr.rhs);
let span = lhs_span.merge(rhs_span);

self.infix_operand_type_rules(&lhs_type, &infix_expr.operator, &rhs_type, span)
.unwrap_or_else(|error| {
self.errors.push(error);
Type::Error
})
if matches!(lhs_type, Type::Array(_, _) | Type::Struct(_, _)) {
// Replace with call to type's implementation of the `Eq` trait
let method_call = HirExpression::MethodCall(HirMethodCallExpression {
method: "eq".into(),
object: infix_expr.lhs,
arguments: vec![infix_expr.rhs],
location: self.interner.expr_location(expr_id),
});
self.interner.replace_expr(expr_id, method_call);

// We want to convert this method call into a pure function call so need to check again.
self.check_expression(expr_id)
} else {
self.infix_operand_type_rules(&lhs_type, &infix_expr.operator, &rhs_type, span)
.unwrap_or_else(|error| {
self.errors.push(error);
Type::Error
})
}
}
HirExpression::Index(index_expr) => self.check_index_expression(expr_id, index_expr),
HirExpression::Call(call_expr) => {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[package]
name = "trait_equality_overload"
type = "bin"
authors = [""]
[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[x]
inner = "42"
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use dep::std::ops::Eq;

struct FieldWrapper {
inner: Field
}

impl Eq for FieldWrapper {
fn eq(self, other: FieldWrapper) -> bool {
self.inner == other.inner
}
}

fn main(x: FieldWrapper) {
let y = FieldWrapper { inner: 42 };
assert(x == y);
}

0 comments on commit 1ffe9f3

Please sign in to comment.