Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test!: Improve coverage in signature and validate #643

Merged
merged 11 commits into from
Nov 8, 2023
17 changes: 12 additions & 5 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::hugr::hugrmut::InsertionResult;
use crate::hugr::validate::InterGraphEdgeError;
use crate::hugr::views::HugrView;
use crate::hugr::{NodeMetadata, ValidationError};
use crate::ops::{self, LeafOp, OpTrait, OpType};
use crate::ops::{self, LeafOp, OpTag, OpTrait, OpType};
use crate::{IncomingPort, Node, OutgoingPort};

use std::iter;
Expand Down Expand Up @@ -666,6 +666,7 @@ fn wire_up<T: Dataflow + ?Sized>(
let base = data_builder.hugr_mut();

let src_parent = base.get_parent(src);
let src_parent_parent = src_parent.and_then(|src| base.get_parent(src));
let dst_parent = base.get_parent(dst);
let local_source = src_parent == dst_parent;
if let EdgeKind::Value(typ) = base.get_optype(src).port_kind(src_port).unwrap() {
Expand All @@ -687,7 +688,10 @@ fn wire_up<T: Dataflow + ?Sized>(
let Some(src_sibling) = iter::successors(dst_parent, |&p| base.get_parent(p))
.tuple_windows()
.find_map(|(ancestor, ancestor_parent)| {
(ancestor_parent == src_parent).then_some(ancestor)
(ancestor_parent == src_parent ||
// Dom edge - in CFGs
Some(ancestor_parent) == src_parent_parent)
.then_some(ancestor)
})
else {
let val_err: ValidationError = InterGraphEdgeError::NoRelation {
Expand All @@ -700,9 +704,12 @@ fn wire_up<T: Dataflow + ?Sized>(
return Err(val_err.into());
};

// TODO: Avoid adding duplicate edges
// This should be easy with https://github.com/CQCL-DEV/hugr/issues/130
base.add_other_edge(src, src_sibling)?;
if !OpTag::BasicBlock.is_superset(base.get_optype(src).tag())
&& !OpTag::BasicBlock.is_superset(base.get_optype(src_sibling).tag())
{
// Add a state order constraint unless one of the nodes is a CFG BasicBlock
base.add_other_edge(src, src_sibling)?;
}
} else if !typ.copyable() & base.linked_ports(src, src_port).next().is_some() {
// Don't copy linear edges.
return Err(BuildError::NoCopyLinear(typ));
Expand Down
65 changes: 65 additions & 0 deletions src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,8 @@ mod test {
use crate::builder::build_traits::HugrBuilder;
use crate::builder::{DataflowSubContainer, ModuleBuilder};

use crate::hugr::validate::InterGraphEdgeError;
use crate::hugr::ValidationError;
use crate::{builder::test::NAT, type_row};
use cool_asserts::assert_matches;

Expand Down Expand Up @@ -393,4 +395,67 @@ mod test {
cfg_builder.branch(&entry, 1, &exit)?;
Ok(())
}
#[test]
fn test_dom_edge() -> Result<(), BuildError> {
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;
let sum_tuple_const =
cfg_builder.add_constant(ops::Const::unary_unit_sum(), ExtensionSet::new())?;
let sum_variants = vec![type_row![]];

let mut entry_b =
cfg_builder.entry_builder(sum_variants.clone(), type_row![], ExtensionSet::new())?;
let [inw] = entry_b.input_wires_arr();
let entry = {
let sum = entry_b.load_const(&sum_tuple_const)?;

entry_b.finish_with_outputs(sum, [])?
};
let mut middle_b =
cfg_builder.simple_block_builder(FunctionType::new(type_row![], type_row![NAT]), 1)?;
let middle = {
let c = middle_b.load_const(&sum_tuple_const)?;
middle_b.finish_with_outputs(c, [inw])?
};
let exit = cfg_builder.exit_block();
cfg_builder.branch(&entry, 0, &middle)?;
cfg_builder.branch(&middle, 0, &exit)?;
assert_matches!(cfg_builder.finish_prelude_hugr(), Ok(_));

Ok(())
}

#[test]
fn test_non_dom_edge() -> Result<(), BuildError> {
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;
let sum_tuple_const =
cfg_builder.add_constant(ops::Const::unary_unit_sum(), ExtensionSet::new())?;
let sum_variants = vec![type_row![]];
let mut middle_b = cfg_builder
.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?;
let [inw] = middle_b.input_wires_arr();
let middle = {
let c = middle_b.load_const(&sum_tuple_const)?;
middle_b.finish_with_outputs(c, [inw])?
};

let mut entry_b =
cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT], ExtensionSet::new())?;
let entry = {
let sum = entry_b.load_const(&sum_tuple_const)?;
// entry block uses wire from middle block even though middle block
// does not dominate entry
entry_b.finish_with_outputs(sum, [inw])?
};
let exit = cfg_builder.exit_block();
cfg_builder.branch(&entry, 0, &middle)?;
cfg_builder.branch(&middle, 0, &exit)?;
assert_matches!(
cfg_builder.finish_prelude_hugr(),
Err(ValidationError::InterGraphEdgeError(
InterGraphEdgeError::NonDominatedAncestor { .. }
))
);

Ok(())
}
}
52 changes: 52 additions & 0 deletions src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ pub(crate) mod test {

use crate::std_extensions::logic::test::and_op;
use crate::std_extensions::quantum::test::h_gate;
use crate::types::Type;
use crate::{
builder::{
test::{n_identity, BIT, NAT, QB},
Expand Down Expand Up @@ -500,4 +501,55 @@ pub(crate) mod test {

Ok(())
}

#[test]
fn non_cfg_ancestor() -> Result<(), BuildError> {
let unit_sig = FunctionType::new(type_row![Type::UNIT], type_row![Type::UNIT]);
let mut b = DFGBuilder::new(unit_sig.clone())?;
let b_child = b.dfg_builder(unit_sig.clone(), None, [b.input().out_wire(0)])?;
let b_child_in_wire = b_child.input().out_wire(0);
b_child.finish_with_outputs([])?;
let b_child_2 = b.dfg_builder(unit_sig.clone(), None, [])?;

// DFG block has edge coming a sibling block, which is only valid for
// CFGs
let b_child_2_handle = b_child_2.finish_with_outputs([b_child_in_wire])?;

let res = b.finish_prelude_hugr_with_outputs([b_child_2_handle.out_wire(0)]);

assert_matches!(
res,
Err(BuildError::InvalidHUGR(
ValidationError::InterGraphEdgeError(InterGraphEdgeError::NonCFGAncestor { .. })
))
);
Ok(())
}

#[test]
fn no_relation_edge() -> Result<(), BuildError> {
let unit_sig = FunctionType::new(type_row![Type::UNIT], type_row![Type::UNIT]);
let mut b = DFGBuilder::new(unit_sig.clone())?;
let mut b_child = b.dfg_builder(unit_sig.clone(), None, [b.input().out_wire(0)])?;
let b_child_child =
b_child.dfg_builder(unit_sig.clone(), None, [b_child.input().out_wire(0)])?;
let b_child_child_in_wire = b_child_child.input().out_wire(0);

b_child_child.finish_with_outputs([])?;
b_child.finish_with_outputs([])?;

let mut b_child_2 = b.dfg_builder(unit_sig.clone(), None, [])?;
let b_child_2_child =
b_child_2.dfg_builder(unit_sig.clone(), None, [b_child_2.input().out_wire(0)])?;

let res = b_child_2_child.finish_with_outputs([b_child_child_in_wire]);

assert_matches!(
res.map(|h| h.handle().node()), // map to something that implements Debug
Err(BuildError::InvalidHUGR(
ValidationError::InterGraphEdgeError(InterGraphEdgeError::NoRelation { .. })
))
);
Ok(())
}
}
Loading