Skip to content

Commit

Permalink
feat: Add Eq op to logic extension (#1398)
Browse files Browse the repository at this point in the history
Co-authored-by: Alan Lawrence <alan.lawrence@quantinuum.com>
  • Loading branch information
croyzor and acl-cqc authored Aug 8, 2024
1 parent c134584 commit cd0c906
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
17 changes: 15 additions & 2 deletions hugr-core/src/std_extensions/logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ impl ConstFold for NaryLogic {
(res || inps.len() as u64 == num_args)
.then_some(vec![(0.into(), ops::Value::from_bool(res))])
}
Self::Eq => {
let inps = read_inputs(consts)?;
let res = inps.iter().copied().reduce(|a, b| a == b)?;
// If we have only some inputs, we can still fold to false, but not to true
(!res || inps.len() as u64 == num_args)
.then_some(vec![(0.into(), ops::Value::from_bool(res))])
}
}
}
}
Expand All @@ -57,6 +64,7 @@ impl ConstFold for NaryLogic {
pub enum NaryLogic {
And,
Or,
Eq,
}

impl MakeOpDef for NaryLogic {
Expand All @@ -68,6 +76,7 @@ impl MakeOpDef for NaryLogic {
match self {
NaryLogic::And => "logical 'and'",
NaryLogic::Or => "logical 'or'",
NaryLogic::Eq => "test if bools are equal",
}
.to_string()
}
Expand Down Expand Up @@ -275,7 +284,7 @@ pub(crate) mod test {
fn test_logic_extension() {
let r: Extension = extension();
assert_eq!(r.name() as &str, "logic");
assert_eq!(r.operations().count(), 3);
assert_eq!(r.operations().count(), 4);

for op in NaryLogic::iter() {
assert_eq!(
Expand All @@ -287,7 +296,7 @@ pub(crate) mod test {

#[test]
fn test_conversions() {
for def in [NaryLogic::And, NaryLogic::Or] {
for def in [NaryLogic::And, NaryLogic::Or, NaryLogic::Eq] {
let o = def.with_n_inputs(3);
let ext_op = o.clone().to_extension_op().unwrap();
let custom_op: CustomOp = ext_op.into();
Expand Down Expand Up @@ -331,6 +340,8 @@ pub(crate) mod test {
#[case(NaryLogic::Or, [], false)]
#[case(NaryLogic::Or, [false, false, true], true)]
#[case(NaryLogic::Or, [false, false, false], false)]
#[case(NaryLogic::Eq, [true, true, false, true], false)]
#[case(NaryLogic::Eq, [false, false], true)]
fn nary_const_fold(
#[case] op: NaryLogic,
#[case] ins: impl IntoIterator<Item = bool>,
Expand All @@ -355,6 +366,8 @@ pub(crate) mod test {
#[case(NaryLogic::And, [Some(false), None], Some(false))]
#[case(NaryLogic::Or, [None, Some(false)], None)]
#[case(NaryLogic::Or, [None, Some(true)], Some(true))]
#[case(NaryLogic::Eq, [None, Some(true), Some(true)], None)]
#[case(NaryLogic::Eq, [None, Some(false), Some(true)], Some(false))]
fn nary_partial_const_fold(
#[case] op: NaryLogic,
#[case] ins: impl IntoIterator<Item = Option<bool>>,
Expand Down
7 changes: 7 additions & 0 deletions specification/std_extensions/logic.json
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@
"signature": null,
"binary": true
},
"Eq": {
"extension": "logic",
"name": "Eq",
"description": "test if bools are equal",
"signature": null,
"binary": true
},
"Not": {
"extension": "logic",
"name": "Not",
Expand Down

0 comments on commit cd0c906

Please sign in to comment.