diff --git a/src/repr/logical_expr.rs b/src/repr/logical_expr.rs index 941f6ca5..57a955e9 100644 --- a/src/repr/logical_expr.rs +++ b/src/repr/logical_expr.rs @@ -122,6 +122,19 @@ impl LogicalExpr { Box::new(helper(l.as_ref(), mapping)), Box::new(helper(r.as_ref(), mapping)), ), + LogicalSExpr::Iff(l, r) => LogicalExpr::Iff( + Box::new(helper(l.as_ref(), mapping)), + Box::new(helper(r.as_ref(), mapping)), + ), + LogicalSExpr::Xor(l, r) => LogicalExpr::Xor( + Box::new(helper(l.as_ref(), mapping)), + Box::new(helper(r.as_ref(), mapping)), + ), + LogicalSExpr::Ite(guard, thn, els) => LogicalExpr::Ite { + guard: Box::new(helper(guard.as_ref(), mapping)), + thn: Box::new(helper(thn.as_ref(), mapping)), + els: Box::new(helper(els.as_ref(), mapping)), + }, } } @@ -253,7 +266,7 @@ impl LogicalExpr { } #[test] -fn from_sexpr_e2e() { +fn from_sexpr_e2e_primitive() { // this string represents X XOR Y. it exercises each branch of the match statement // within the from_sexpr helper let x_xor_y = String::from("(And (Or (Var X) (Var Y)) (Or (Not (Var X)) (Not (Var Y))))"); @@ -272,3 +285,25 @@ fn from_sexpr_e2e() { assert_eq!(LogicalExpr::from_sexpr(&expr), manually_constructed) } + +#[test] +fn from_sexpr_e2e_complex() { + // this string uses the "complex" non-primitive s-expr items: IFF, XOR, ITE + let x_xor_y = + String::from("(Xor (Iff (Var X) (Var Y)) (Ite (Var Z) (Not (Var X)) (Not (Var Y))))"); + let expr = serde_sexpr::from_str::(&x_xor_y).unwrap(); + + let manually_constructed = LogicalExpr::Xor( + Box::new(LogicalExpr::Iff( + Box::new(LogicalExpr::Literal(0, true)), + Box::new(LogicalExpr::Literal(1, true)), + )), + Box::new(LogicalExpr::Ite { + guard: Box::new(LogicalExpr::Literal(2, true)), + thn: Box::new(LogicalExpr::Literal(0, false)), + els: Box::new(LogicalExpr::Literal(1, false)), + }), + ); + + assert_eq!(LogicalExpr::from_sexpr(&expr), manually_constructed) +} diff --git a/src/serialize/ser_logical_expr.rs b/src/serialize/ser_logical_expr.rs index 78673203..81133cf3 100644 --- a/src/serialize/ser_logical_expr.rs +++ b/src/serialize/ser_logical_expr.rs @@ -8,6 +8,9 @@ pub enum LogicalSExpr { Not(Box), Or(Box, Box), And(Box, Box), + Iff(Box, Box), + Xor(Box, Box), + Ite(Box, Box, Box), } impl LogicalSExpr { @@ -27,11 +30,22 @@ impl LogicalSExpr { LogicalSExpr::True | LogicalSExpr::False => HashSet::new(), LogicalSExpr::Var(s) => HashSet::from([s]), LogicalSExpr::Not(l) => l.unique_variables(), - LogicalSExpr::Or(a, b) | LogicalSExpr::And(a, b) => a + LogicalSExpr::Or(a, b) + | LogicalSExpr::And(a, b) + | LogicalSExpr::Iff(a, b) + | LogicalSExpr::Xor(a, b) => a .unique_variables() .union(&b.unique_variables()) .cloned() .collect::>(), + LogicalSExpr::Ite(a, b, c) => a + .unique_variables() + .union(&b.unique_variables()) + .cloned() + .collect::>() + .union(&c.unique_variables()) + .cloned() + .collect::>(), } } @@ -114,6 +128,31 @@ fn logical_expression_deserialization_boxed() { Box::new(LogicalSExpr::Var(String::from("Y"))) ) ); + + assert_eq!( + serde_sexpr::from_str::("(Iff (Var X) (Var Y))").unwrap(), + LogicalSExpr::Iff( + Box::new(LogicalSExpr::Var(String::from("X"))), + Box::new(LogicalSExpr::Var(String::from("Y"))) + ) + ); + + assert_eq!( + serde_sexpr::from_str::("(Xor (Var X) (Var Y))").unwrap(), + LogicalSExpr::Xor( + Box::new(LogicalSExpr::Var(String::from("X"))), + Box::new(LogicalSExpr::Var(String::from("Y"))) + ) + ); + + assert_eq!( + serde_sexpr::from_str::("(Ite (Var X) (Var Y) (Var Z))").unwrap(), + LogicalSExpr::Ite( + Box::new(LogicalSExpr::Var(String::from("X"))), + Box::new(LogicalSExpr::Var(String::from("Y"))), + Box::new(LogicalSExpr::Var(String::from("Z"))) + ) + ); } #[test] @@ -128,7 +167,7 @@ fn logical_expression_unique_variables_trivial() { #[test] fn logical_expression_unique_variables_handles_duplicates_and_nesting() { let expr = - serde_sexpr::from_str::("(Or (Var X) (Or (Not (Var X)) (Var Y)))").unwrap(); + serde_sexpr::from_str::("(Or (Var X) (Or (Not (Var X)) (Xor (Iff (Var X) (Var Y)) (Ite (Var Y) (Not (Var X)) (Not (Var Y))))))").unwrap(); let vars = expr.unique_variables(); assert!(vars.len() == 2);