Skip to content

Commit

Permalink
fix: Try to move constant terms to one side for arithmetic generics (#…
Browse files Browse the repository at this point in the history
…6008)

# Description

## Problem\*

Resolves #6006

## Summary\*

Previously we were failing for constraints like `a + 3 = b + 1` when we
could instead move the constant terms to one side: `a + 2 = b` then
solve with `b := a + 2`.

## Additional Context

## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
jfecher authored Sep 12, 2024
1 parent 21425de commit 4d8fe28
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 3 deletions.
14 changes: 12 additions & 2 deletions compiler/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1630,8 +1630,18 @@ impl Type {

(InfixExpr(lhs_a, op_a, rhs_a), InfixExpr(lhs_b, op_b, rhs_b)) => {
if op_a == op_b {
lhs_a.try_unify(lhs_b, bindings)?;
rhs_a.try_unify(rhs_b, bindings)
// We need to preserve the original bindings since if syntactic equality
// fails we fall back to other equality strategies.
let mut new_bindings = bindings.clone();
let lhs_result = lhs_a.try_unify(lhs_b, &mut new_bindings);
let rhs_result = rhs_a.try_unify(rhs_b, &mut new_bindings);

if lhs_result.is_ok() && rhs_result.is_ok() {
*bindings = new_bindings;
Ok(())
} else {
lhs.try_unify_by_moving_constant_terms(&rhs, bindings)
}
} else {
Err(UnificationError)
}
Expand Down
42 changes: 41 additions & 1 deletion compiler/noirc_frontend/src/hir_def/types/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::BTreeSet;

use crate::{BinaryTypeOperator, Type};
use crate::{BinaryTypeOperator, Type, TypeBindings, UnificationError};

impl Type {
/// Try to canonicalize the representation of this type.
Expand Down Expand Up @@ -212,4 +212,44 @@ impl Type {
_ => None,
}
}

/// Try to unify equations like `(..) + 3 = (..) + 1`
/// by transforming them to `(..) + 2 = (..)`
pub(super) fn try_unify_by_moving_constant_terms(
&self,
other: &Type,
bindings: &mut TypeBindings,
) -> Result<(), UnificationError> {
if let Type::InfixExpr(lhs_a, op_a, rhs_a) = self {
if let Some(inverse) = op_a.inverse() {
if let Some(rhs_a) = rhs_a.evaluate_to_u32() {
let rhs_a = Box::new(Type::Constant(rhs_a));
let new_other = Type::InfixExpr(Box::new(other.clone()), inverse, rhs_a);

let mut tmp_bindings = bindings.clone();
if lhs_a.try_unify(&new_other, &mut tmp_bindings).is_ok() {
*bindings = tmp_bindings;
return Ok(());
}
}
}
}

if let Type::InfixExpr(lhs_b, op_b, rhs_b) = other {
if let Some(inverse) = op_b.inverse() {
if let Some(rhs_b) = rhs_b.evaluate_to_u32() {
let rhs_b = Box::new(Type::Constant(rhs_b));
let new_self = Type::InfixExpr(Box::new(self.clone()), inverse, rhs_b);

let mut tmp_bindings = bindings.clone();
if new_self.try_unify(lhs_b, &mut tmp_bindings).is_ok() {
*bindings = tmp_bindings;
return Ok(());
}
}
}
}

Err(UnificationError)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "arithmetic_generics_move_constant_terms"
type = "bin"
authors = [""]
compiler_version = ">=0.33.0"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
trait FromCallData<let N: u32, let M: u32> {
fn from_calldata(calldata: [Field; N]) -> (Self, [Field; M]);
}

struct Point { x: Field, y: Field }

impl <let N: u32> FromCallData<N, N - 1> for Field {
fn from_calldata(calldata: [Field; N]) -> (Self, [Field; (N - 1)]) {
let slice = calldata.as_slice();
let (value, slice) = slice.pop_front();
(value, slice.as_array())
}
}

impl <let N: u32> FromCallData<N, N - 2> for Point {
fn from_calldata(calldata: [Field; N]) -> (Self, [Field; (N - 2)]) {
let (x, calldata) = FromCallData::from_calldata(calldata);
let (y, calldata) = FromCallData::from_calldata(calldata);
(Self { x, y }, calldata)
}
}

fn main() {
let calldata = [1, 2];
let _: (Point, _) = FromCallData::from_calldata(calldata);
}

0 comments on commit 4d8fe28

Please sign in to comment.