Skip to content

Commit

Permalink
chore: Bump HUGR dependency (#94)
Browse files Browse the repository at this point in the history
Multiple small changes to the hugr API.

Note that this doesn't update to the current HEAD, #91 will adapt to the
changes from CQCL/hugr#498 and simplify the
circuits.
The goal of this separate PR is to reduce the diff noise there.
  • Loading branch information
aborgna-q authored Sep 8, 2023
1 parent f4ed814 commit a62ce43
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 36 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ members = ["pyrs", "compile-matcher"]

[workspace.dependencies]

quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "abfaba6" }
quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "5a97a635" }
portgraph = { version = "0.8", features = ["serde"] }
pyo3 = { version = "0.19" }
itertools = { version = "0.11.0" }
3 changes: 1 addition & 2 deletions src/circuit/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use std::iter::FusedIterator;
use hugr::hugr::views::HierarchyView;
use hugr::ops::{OpTag, OpTrait};
use petgraph::visit::{GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers};
use portgraph::PortOffset;

use super::Circuit;

Expand Down Expand Up @@ -107,7 +106,7 @@ where
optype
.static_input()
// TODO query optype for this port once it is available in hugr.
.map(|_| PortOffset::new_incoming(sig.input.len()).into()),
.map(|_| Port::new_incoming(sig.input.len())),
)
.filter_map(|port| {
let (from, from_port) = self.circ.linked_ports(node, port).next()?;
Expand Down
10 changes: 6 additions & 4 deletions src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ pub(crate) fn wrap_json_op(op: &JsonOp) -> ExternalOp {
// .into()
let sig = op.signature();
let op = serde_yaml::to_value(op).unwrap();
let payload = TypeArg::Opaque(CustomTypeArg::new(TKET1_OP_PAYLOAD.clone(), op).unwrap());
let payload = TypeArg::Opaque {
arg: CustomTypeArg::new(TKET1_OP_PAYLOAD.clone(), op).unwrap(),
};
OpaqueOp::new(
TKET1_EXTENSION_ID,
JSON_OP_NAME,
Expand All @@ -100,17 +102,17 @@ pub(crate) fn try_unwrap_json_op(ext: &ExternalOp) -> Option<JsonOp> {
if ext.name() != format!("{TKET1_EXTENSION_ID}.{JSON_OP_NAME}") {
return None;
}
let Some(TypeArg::Opaque(op)) = ext.args().get(0) else {
let Some(TypeArg::Opaque { arg }) = ext.args().get(0) else {
// TODO: Throw an error? We should never get here if the name matches.
return None;
};
let op = serde_yaml::from_value(op.value.clone()).ok()?;
let op = serde_yaml::from_value(arg.value.clone()).ok()?;
Some(op)
}

/// Compute the signature of a json-encoded TKET1 operation.
fn json_op_signature(args: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [TypeArg::Opaque(arg)] = args else {
let [TypeArg::Opaque { arg }] = args else {
// This should have already been checked.
panic!("Wrong number of arguments");
};
Expand Down
6 changes: 4 additions & 2 deletions src/json/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,10 @@ impl JsonEncoder {
OpType::Const(const_op) => {
// New constant, register it if it can be interpreted as a parameter.
match const_op.value() {
Value::Prim(PrimValue::Extension((v,))) => {
if let Some(f) = v.downcast_ref::<ConstF64>() {
Value::Prim {
val: PrimValue::Extension { c: (val,) },
} => {
if let Some(f) = val.downcast_ref::<ConstF64>() {
f.to_string()
} else {
return false;
Expand Down
10 changes: 5 additions & 5 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ pub fn symbolic_constant_op(s: &str) -> OpType {
let l: LeafOp = EXTENSION
.instantiate_extension_op(
&SYM_OP_ID,
vec![TypeArg::Opaque(
CustomTypeArg::new(SYM_EXPR_T, value).unwrap(),
)],
vec![TypeArg::Opaque {
arg: CustomTypeArg::new(SYM_EXPR_T, value).unwrap(),
}],
)
.unwrap()
.into();
Expand All @@ -202,11 +202,11 @@ pub(crate) fn match_symb_const_op(op: &OpType) -> Option<&str> {
{
// TODO also check extension name

let Some(TypeArg::Opaque(s)) = e.args().get(0) else {
let Some(TypeArg::Opaque { arg }) = e.args().get(0) else {
panic!("should be an opaque type arg.")
};

let serde_yaml::Value::String(s) = &s.value else {
let serde_yaml::Value::String(s) = &arg.value else {
panic!("unexpected yaml value.")
};

Expand Down
22 changes: 8 additions & 14 deletions src/portmatching/matcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,8 @@ use std::{
};

use super::{CircuitPattern, PEdge, PNode};
use hugr::{
hugr::views::{
sibling::{
ConvexChecker, InvalidReplacement,
InvalidSubgraph::{self},
},
SiblingSubgraph,
},
ops::OpType,
Hugr, Node, Port,
};
use hugr::hugr::views::sibling_subgraph::{ConvexChecker, InvalidReplacement, InvalidSubgraph};
use hugr::{hugr::views::SiblingSubgraph, ops::OpType, Hugr, Node, Port};
use itertools::Itertools;
use portmatching::{
automaton::{LineBuilder, ScopeAutomaton},
Expand Down Expand Up @@ -82,7 +73,7 @@ pub struct PatternMatch<'a, C> {
pub(super) root: Node,
}

impl<'a, C: Circuit<'a>> PatternMatch<'a, C> {
impl<'a, C: Circuit<'a> + Clone> PatternMatch<'a, C> {
/// The matcher's pattern ID of the match.
pub fn pattern_id(&self) -> PatternID {
self.pattern
Expand Down Expand Up @@ -246,7 +237,10 @@ impl PatternMatcher {
}

/// Find all convex pattern matches in a circuit.
pub fn find_matches<'a, C: Circuit<'a>>(&self, circuit: &'a C) -> Vec<PatternMatch<'a, C>> {
pub fn find_matches<'a, C: Circuit<'a> + Clone>(
&self,
circuit: &'a C,
) -> Vec<PatternMatch<'a, C>> {
let mut checker = ConvexChecker::new(circuit);
circuit
.commands()
Expand All @@ -255,7 +249,7 @@ impl PatternMatcher {
}

/// Find all convex pattern matches in a circuit rooted at a given node.
fn find_rooted_matches<'a, C: Circuit<'a>>(
fn find_rooted_matches<'a, C: Circuit<'a> + Clone>(
&self,
circ: &'a C,
root: Node,
Expand Down
2 changes: 1 addition & 1 deletion src/portmatching/pyo3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl PyPatternMatch {
///
/// Requires references to the circuit and pattern to resolve indices
/// into these objects.
pub fn try_from_rust<'circ, C: Circuit<'circ>>(
pub fn try_from_rust<'circ, C: Circuit<'circ> + Clone>(
m: PatternMatch<'circ, C>,
circ: &C,
matcher: &PatternMatcher,
Expand Down
9 changes: 3 additions & 6 deletions src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@ pub use ecc_rewriter::ECCRewriter;

use delegate::delegate;
use derive_more::{From, Into};
use hugr::hugr::views::sibling_subgraph::InvalidReplacement;
use hugr::{
hugr::{
hugrmut::HugrMut,
views::{sibling::InvalidReplacement, SiblingSubgraph},
Rewrite, SimpleReplacementError,
},
hugr::{hugrmut::HugrMut, views::SiblingSubgraph, Rewrite, SimpleReplacementError},
Hugr, HugrView, SimpleReplacement,
};

Expand Down Expand Up @@ -55,5 +52,5 @@ impl CircuitRewrite {
/// Generate rewrite rules for circuits.
pub trait Rewriter {
/// Get the rewrite rules for a circuit.
fn get_rewrites<'a, C: Circuit<'a>>(&'a self, circ: &'a C) -> Vec<CircuitRewrite>;
fn get_rewrites<'a, C: Circuit<'a> + Clone>(&'a self, circ: &'a C) -> Vec<CircuitRewrite>;
}
2 changes: 1 addition & 1 deletion src/rewrite/ecc_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ impl ECCRewriter {
}

impl Rewriter for ECCRewriter {
fn get_rewrites<'a, C: Circuit<'a>>(&'a self, circ: &'a C) -> Vec<CircuitRewrite> {
fn get_rewrites<'a, C: Circuit<'a> + Clone>(&'a self, circ: &'a C) -> Vec<CircuitRewrite> {
let matches = self.matcher.find_matches(circ);
matches
.into_iter()
Expand Down

0 comments on commit a62ce43

Please sign in to comment.