Skip to content

Commit

Permalink
Update MeasureNode and PrepareNodeoperation identifiers (#2224)
Browse files Browse the repository at this point in the history
* update node ids, update contract tensors test, add unit test

* Add test

* fix copy location in contract tensor test

Co-authored-by: trbromley <brotho02@gmail.com>
  • Loading branch information
anthayes92 and trbromley authored Feb 24, 2022
1 parent 1fdbcb5 commit 60bdfd1
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 11 deletions.
24 changes: 17 additions & 7 deletions pennylane/transforms/qcut.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@

import copy
import string
import uuid
from itertools import product
from typing import List, Sequence, Tuple

from networkx import MultiDiGraph, weakly_connected_components

import pennylane as qml
from networkx import MultiDiGraph, weakly_connected_components
from pennylane import apply, expval
from pennylane.grouping import string_to_pauli_word
from pennylane.measure import MeasurementProcess
Expand All @@ -39,13 +39,23 @@ class MeasureNode(Operation):
num_wires = 1
grad_method = None

def __init__(self, *params, wires=None, do_queue=True, id=None):
id = str(uuid.uuid4())

super().__init__(*params, wires=wires, do_queue=do_queue, id=id)


class PrepareNode(Operation):
"""Placeholder node for state preparations"""

num_wires = 1
grad_method = None

def __init__(self, *params, wires=None, do_queue=True, id=None):
id = str(uuid.uuid4())

super().__init__(*params, wires=wires, do_queue=do_queue, id=id)


def replace_wire_cut_node(node: WireCut, graph: MultiDiGraph):
"""
Expand Down Expand Up @@ -272,7 +282,7 @@ def fragment_graph(graph: MultiDiGraph) -> Tuple[Tuple[MultiDiGraph], MultiDiGra
for node1, node2, wire in graph.edges:
if isinstance(node1, MeasureNode):
assert isinstance(node2, PrepareNode)
cut_edges.append((node1, node2, wire))
cut_edges.append((node1, node2))
graph_copy.remove_edge(node1, node2, key=wire)

subgraph_nodes = weakly_connected_components(graph_copy)
Expand All @@ -281,14 +291,14 @@ def fragment_graph(graph: MultiDiGraph) -> Tuple[Tuple[MultiDiGraph], MultiDiGra
communication_graph = MultiDiGraph()
communication_graph.add_nodes_from(range(len(subgraphs)))

for node1, node2, wire in cut_edges:
for node1, node2 in cut_edges:
for i, subgraph in enumerate(subgraphs):
if subgraph.has_node(node1):
start_fragment = i
if subgraph.has_node(node2):
end_fragment = i

communication_graph.add_edge(start_fragment, end_fragment, pair=(node1, node2, wire))
communication_graph.add_edge(start_fragment, end_fragment, pair=(node1, node2))

return subgraphs, communication_graph

Expand Down Expand Up @@ -658,7 +668,7 @@ def contract_tensors(
for pred_edge in pred_edges.values():
meas_op, prep_op = pred_edge["pair"]

if p is prep_op:
if p.id is prep_op.id:
symb = get_symbol(ctr)
ctr += 1
tensor_indxs[i] += symb
Expand All @@ -672,7 +682,7 @@ def contract_tensors(
for succ_edge in succ_edges.values():
meas_op, _ = succ_edge["pair"]

if m is meas_op:
if m.id is meas_op.id:
symb = meas_map[meas_op]
tensor_indxs[i] += symb

Expand Down
22 changes: 18 additions & 4 deletions tests/transforms/test_qcut.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@
import sys
from itertools import product

import pennylane as qml
import pytest
from networkx import MultiDiGraph
from scipy.stats import unitary_group

import pennylane as qml
from pennylane import numpy as np
from pennylane.transforms import qcut
from pennylane.wires import Wires
from scipy.stats import unitary_group

I, X, Y, Z = (
np.eye(2),
Expand Down Expand Up @@ -162,6 +161,20 @@ def compare_measurements(meas1, meas2):
assert obs1.wires.tolist() == obs2.wires.tolist()


def test_node_ids(monkeypatch):
"""
Tests that the `MeasureNode` and `PrepareNode` return the correct id
"""
with monkeypatch.context() as m:
m.setattr("uuid.uuid4", lambda: "some_string")

mn = qcut.MeasureNode(wires=0)
pn = qcut.PrepareNode(wires=0)

assert mn.id == "some_string"
assert pn.id == "some_string"


class TestTapeToGraph:
"""
Tests conversion of tapes to graph representations that are amenable to
Expand Down Expand Up @@ -1297,9 +1310,10 @@ class TestContractTensors:
"""Tests for the contract_tensors function"""

t = [np.arange(4), np.arange(4, 8)]
# make copies of nodes to ensure id comparisons work correctly
m = [[qcut.MeasureNode(wires=0)], []]
p = [[], [qcut.PrepareNode(wires=0)]]
edge_dict = {"pair": (m[0][0], p[1][0])}
edge_dict = {"pair": (copy.copy(m)[0][0], copy.copy(p)[1][0])}
g = MultiDiGraph([(0, 1, edge_dict)])
expected_result = np.dot(*t)

Expand Down

0 comments on commit 60bdfd1

Please sign in to comment.