Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ssa refactor): Implement array equality in SSA-gen #1704

Merged
merged 2 commits into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ impl<'function> PerFunctionContext<'function> {
_ => {
self.context.failed_to_inline_a_call = true;
None
},
}
}
}

Expand Down
91 changes: 91 additions & 0 deletions crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@ impl<'a> FunctionContext<'a> {
) -> Values {
let op = convert_operator(operator);

if op == BinaryOp::Eq && matches!(self.builder.type_of_value(lhs), Type::Array(..)) {
return self.insert_array_equality(lhs, operator, rhs);
}

if operator_requires_swapped_operands(operator) {
std::mem::swap(&mut lhs, &mut rhs);
}
Expand Down Expand Up @@ -255,6 +259,93 @@ impl<'a> FunctionContext<'a> {
result.into()
}

/// The frontend claims to support equality (==) on arrays, so we must support it in SSA here.
/// The actual BinaryOp::Eq in SSA is meant only for primitive numeric types so we encode an
/// entire equality loop on each array element. The generated IR is as follows:
///
/// ...
/// result_alloc = allocate
/// store u1 1 in result_alloc
/// jmp loop_start(0)
/// loop_start(i: Field):
/// v0 = lt i, array_len
/// jmpif v0, then: loop_body, else: loop_end
/// loop_body():
/// v1 = array_get lhs, index i
/// v2 = array_get rhs, index i
/// v3 = eq v1, v2
/// v4 = load result_alloc
/// v5 = and v4, v3
/// store v5 in result_alloc
/// v6 = add i, Field 1
/// jmp loop_start(v6)
/// loop_end():
/// result = load result_alloc
fn insert_array_equality(
&mut self,
lhs: ValueId,
operator: noirc_frontend::BinaryOpKind,
rhs: ValueId,
) -> Values {
let lhs_type = self.builder.type_of_value(lhs);
let rhs_type = self.builder.type_of_value(rhs);

let (array_length, element_type) = match (lhs_type, rhs_type) {
(
Type::Array(lhs_composite_type, lhs_length),
Type::Array(rhs_composite_type, rhs_length),
) => {
assert!(
lhs_composite_type.len() == 1 && rhs_composite_type.len() == 1,
"== is unimplemented for arrays of structs"
);
assert_eq!(lhs_composite_type[0], rhs_composite_type[0]);
assert_eq!(lhs_length, rhs_length, "Expected two arrays of equal length");
(lhs_length, lhs_composite_type[0].clone())
}
_ => unreachable!("Expected two array values"),
};

let loop_start = self.builder.insert_block();
let loop_body = self.builder.insert_block();
let loop_end = self.builder.insert_block();

// pre-loop
let result_alloc = self.builder.insert_allocate();
let true_value = self.builder.numeric_constant(1u128, Type::bool());
self.builder.insert_store(result_alloc, true_value);
let zero = self.builder.field_constant(0u128);
self.builder.terminate_with_jmp(loop_start, vec![zero]);

// loop_start
self.builder.switch_to_block(loop_start);
let i = self.builder.add_block_parameter(loop_start, Type::field());
let array_length = self.builder.field_constant(array_length as u128);
let v0 = self.builder.insert_binary(i, BinaryOp::Lt, array_length);
self.builder.terminate_with_jmpif(v0, loop_body, loop_end);

// loop body
self.builder.switch_to_block(loop_body);
let v1 = self.builder.insert_array_get(lhs, i, element_type.clone());
let v2 = self.builder.insert_array_get(rhs, i, element_type);
let v3 = self.builder.insert_binary(v1, BinaryOp::Eq, v2);
let v4 = self.builder.insert_load(result_alloc, Type::bool());
let v5 = self.builder.insert_binary(v4, BinaryOp::And, v3);
self.builder.insert_store(result_alloc, v5);
let one = self.builder.field_constant(1u128);
let v6 = self.builder.insert_binary(i, BinaryOp::Add, one);
self.builder.terminate_with_jmp(loop_start, vec![v6]);

// loop end
self.builder.switch_to_block(loop_end);
let mut result = self.builder.insert_load(result_alloc, Type::bool());

if operator_requires_not(operator) {
result = self.builder.insert_not(result);
}
result.into()
}

/// Inserts a call instruction at the end of the current block and returns the results
/// of the call.
///
Expand Down