diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 180d773ff8d..bcbcc8ad789 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -1630,8 +1630,18 @@ impl Type { (InfixExpr(lhs_a, op_a, rhs_a), InfixExpr(lhs_b, op_b, rhs_b)) => { if op_a == op_b { - lhs_a.try_unify(lhs_b, bindings)?; - rhs_a.try_unify(rhs_b, bindings) + // We need to preserve the original bindings since if syntactic equality + // fails we fall back to other equality strategies. + let mut new_bindings = bindings.clone(); + let lhs_result = lhs_a.try_unify(lhs_b, &mut new_bindings); + let rhs_result = rhs_a.try_unify(rhs_b, &mut new_bindings); + + if lhs_result.is_ok() && rhs_result.is_ok() { + *bindings = new_bindings; + Ok(()) + } else { + lhs.try_unify_by_moving_constant_terms(&rhs, bindings) + } } else { Err(UnificationError) } diff --git a/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs b/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs index ad07185dff1..44a7526c894 100644 --- a/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs +++ b/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs @@ -1,6 +1,6 @@ use std::collections::BTreeSet; -use crate::{BinaryTypeOperator, Type}; +use crate::{BinaryTypeOperator, Type, TypeBindings, UnificationError}; impl Type { /// Try to canonicalize the representation of this type. @@ -212,4 +212,44 @@ impl Type { _ => None, } } + + /// Try to unify equations like `(..) + 3 = (..) + 1` + /// by transforming them to `(..) + 2 = (..)` + pub(super) fn try_unify_by_moving_constant_terms( + &self, + other: &Type, + bindings: &mut TypeBindings, + ) -> Result<(), UnificationError> { + if let Type::InfixExpr(lhs_a, op_a, rhs_a) = self { + if let Some(inverse) = op_a.inverse() { + if let Some(rhs_a) = rhs_a.evaluate_to_u32() { + let rhs_a = Box::new(Type::Constant(rhs_a)); + let new_other = Type::InfixExpr(Box::new(other.clone()), inverse, rhs_a); + + let mut tmp_bindings = bindings.clone(); + if lhs_a.try_unify(&new_other, &mut tmp_bindings).is_ok() { + *bindings = tmp_bindings; + return Ok(()); + } + } + } + } + + if let Type::InfixExpr(lhs_b, op_b, rhs_b) = other { + if let Some(inverse) = op_b.inverse() { + if let Some(rhs_b) = rhs_b.evaluate_to_u32() { + let rhs_b = Box::new(Type::Constant(rhs_b)); + let new_self = Type::InfixExpr(Box::new(self.clone()), inverse, rhs_b); + + let mut tmp_bindings = bindings.clone(); + if new_self.try_unify(lhs_b, &mut tmp_bindings).is_ok() { + *bindings = tmp_bindings; + return Ok(()); + } + } + } + } + + Err(UnificationError) + } } diff --git a/test_programs/compile_success_empty/arithmetic_generics_move_constant_terms/Nargo.toml b/test_programs/compile_success_empty/arithmetic_generics_move_constant_terms/Nargo.toml new file mode 100644 index 00000000000..8d057a77814 --- /dev/null +++ b/test_programs/compile_success_empty/arithmetic_generics_move_constant_terms/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "arithmetic_generics_move_constant_terms" +type = "bin" +authors = [""] +compiler_version = ">=0.33.0" + +[dependencies] diff --git a/test_programs/compile_success_empty/arithmetic_generics_move_constant_terms/src/main.nr b/test_programs/compile_success_empty/arithmetic_generics_move_constant_terms/src/main.nr new file mode 100644 index 00000000000..87800459284 --- /dev/null +++ b/test_programs/compile_success_empty/arithmetic_generics_move_constant_terms/src/main.nr @@ -0,0 +1,26 @@ +trait FromCallData { + fn from_calldata(calldata: [Field; N]) -> (Self, [Field; M]); +} + +struct Point { x: Field, y: Field } + +impl FromCallData for Field { + fn from_calldata(calldata: [Field; N]) -> (Self, [Field; (N - 1)]) { + let slice = calldata.as_slice(); + let (value, slice) = slice.pop_front(); + (value, slice.as_array()) + } +} + +impl FromCallData for Point { + fn from_calldata(calldata: [Field; N]) -> (Self, [Field; (N - 2)]) { + let (x, calldata) = FromCallData::from_calldata(calldata); + let (y, calldata) = FromCallData::from_calldata(calldata); + (Self { x, y }, calldata) + } +} + +fn main() { + let calldata = [1, 2]; + let _: (Point, _) = FromCallData::from_calldata(calldata); +}