Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only update object's queue metadata if already in the queue #2612

Merged
merged 12 commits into from
May 30, 2022
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@
* Sparse Hamiltonians representation has changed from COOrdinate (COO) to Compressed Sparse Row (CSR) format. The CSR representation is more performant for arithmetic operations and matrix vector products. This change decreases the `expval()` calculation time, for `qml.SparseHamiltonian`, specially for large workflows. Also, the CRS format consumes less memory for the `qml.SparseHamiltonian` storage.
[(#2561)](https://github.com/PennyLaneAI/pennylane/pull/2561)

* A new method `safe_update_info` is added to `qml.QueuingContext`. This method is substituted
for `qml.QueuingContext.update_info` in a variety of places.
[(#2612)](https://github.com/PennyLaneAI/pennylane/pull/2612)

* `BasisEmbedding` can accept an int as argument instead of a list of bits (optionally). Example: `qml.BasisEmbedding(4, wires = range(4))` is now equivalent to `qml.BasisEmbedding([0,1,0,0], wires = range(4))` (because 4=0b100).
[(#2601)](https://github.com/PennyLaneAI/pennylane/pull/2601)

Expand Down
7 changes: 1 addition & 6 deletions pennylane/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,12 +398,7 @@ def expand(self):
def queue(self, context=qml.QueuingContext):
"""Append the measurement process to an annotated queue."""
if self.obs is not None:
try:
context.update_info(self.obs, owner=self)
except qml.queuing.QueuingError:
self.obs.queue(context=context)
context.update_info(self.obs, owner=self)

context.safe_update_info(self.obs, owner=self)
context.append(self, owns=self.obs)
else:
context.append(self)
Expand Down
13 changes: 4 additions & 9 deletions pennylane/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1741,11 +1741,7 @@ def queue(self, context=qml.QueuingContext, init=False): # pylint: disable=argu
else:
raise ValueError("Can only perform tensor products between observables.")

try:
context.update_info(o, owner=self)
except qml.queuing.QueuingError:
o.queue(context=context)
context.update_info(o, owner=self)
context.safe_update_info(o, owner=self)

context.append(self, owns=tuple(constituents))
return self
Expand Down Expand Up @@ -1849,16 +1845,15 @@ def __matmul__(self, other):
owning_info = qml.QueuingContext.get_info(self)["owns"] + (other,)

# update the annotated queue information
qml.QueuingContext.update_info(self, owns=owning_info)
qml.QueuingContext.update_info(other, owner=self)
qml.QueuingContext.safe_update_info(self, owns=owning_info)
qml.QueuingContext.safe_update_info(other, owner=self)

return self

def __rmatmul__(self, other):
if isinstance(other, Observable):
self.obs[:0] = [other]
if qml.QueuingContext.recording():
qml.QueuingContext.update_info(other, owner=self)
qml.QueuingContext.safe_update_info(other, owner=self)
return self

raise ValueError("Can only perform tensor products between observables.")
Expand Down
7 changes: 1 addition & 6 deletions pennylane/ops/qubit/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,11 +633,6 @@ def __isub__(self, H):
def queue(self, context=qml.QueuingContext):
"""Queues a qml.Hamiltonian instance"""
for o in self.ops:
try:
context.update_info(o, owner=self)
except QueuingError:
o.queue(context=context)
context.update_info(o, owner=self)

context.safe_update_info(o, owner=self)
context.append(self, owns=tuple(self.ops))
return self
25 changes: 24 additions & 1 deletion pennylane/queuing.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,27 @@ def update_info(cls, obj, **kwargs):
if cls.recording():
cls.active_context()._update_info(obj, **kwargs) # pylint: disable=protected-access

# pylint: disable=protected-access
@classmethod
def safe_update_info(cls, obj, **kwargs):
"""Updates information of an object in the active queue if it is already in the queue.

Args:
obj: the object with metadata to be updated
"""
if cls.recording():
cls.active_context()._safe_update_info(obj, **kwargs)

@abc.abstractmethod
def _safe_update_info(self, obj, **kwargs):
"""Updates information of an object in the queue instance only if the object is in the queue.
If the object is not in the queue, nothing is done and no errors are raised.
"""

@abc.abstractmethod
def _update_info(self, obj, **kwargs):
"""Updates information of an object in the queue instance."""
"""Updates information of an object in the queue instance. Raises a ``QueuingError`` if the object
is not in the queue."""

@classmethod
def get_info(cls, obj):
Expand Down Expand Up @@ -222,6 +240,10 @@ def _append(self, obj, **kwargs):
def _remove(self, obj):
del self._queue[obj]

def _safe_update_info(self, obj, **kwargs):
if obj in self._queue:
self._queue[obj].update(kwargs)

def _update_info(self, obj, **kwargs):
if obj not in self._queue:
raise QueuingError(f"Object {obj} not in the queue.")
Expand All @@ -240,6 +262,7 @@ def _get_info(self, obj):
append = _append
remove = _remove
update_info = _update_info
safe_update_info = _safe_update_info
get_info = _get_info

@property
Expand Down
22 changes: 10 additions & 12 deletions tests/ops/qubit/test_hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,19 +788,12 @@ def test_arithmetic_errors(self):
with pytest.raises(ValueError, match="Cannot subtract"):
H -= A

def test_hamiltonian_queue(self):
"""Tests that Hamiltonian are queued correctly"""

# Outside of tape
def test_hamiltonian_queue_outside(self):
"""Tests that Hamiltonian are queued correctly when components are defined outside the recording context."""

queue = [
qml.Hadamard(wires=1),
qml.PauliX(wires=0),
qml.PauliZ(0),
qml.PauliZ(2),
qml.PauliZ(0) @ qml.PauliZ(2),
qml.PauliX(1),
qml.PauliZ(1),
qml.Hamiltonian(
[1, 3, 1], [qml.PauliX(1), qml.PauliZ(0) @ qml.PauliZ(2), qml.PauliZ(1)]
),
Expand All @@ -813,9 +806,14 @@ def test_hamiltonian_queue(self):
qml.PauliX(wires=0)
qml.expval(H)

assert np.all([q1.compare(q2) for q1, q2 in zip(tape.queue, queue)])
assert len(tape.queue) == 3
assert isinstance(tape.queue[0], qml.Hadamard)
assert isinstance(tape.queue[1], qml.PauliX)
assert isinstance(tape.queue[2], qml.measurements.MeasurementProcess)
assert H.compare(tape.queue[2].obs)

# Inside of tape
def test_hamiltonian_queue_inside(self):
"""Tests that Hamiltonian are queued correctly when components are instantiated inside the recording context."""

queue = [
qml.Hadamard(wires=1),
Expand Down Expand Up @@ -1278,7 +1276,7 @@ def test_grouping_does_not_alter_queue(self):
with qml.tape.QuantumTape() as tape:
H = qml.Hamiltonian(coeffs, obs, grouping_type="qwc")

assert tape.queue == [a, b, c, H]
assert tape.queue == [H]

def test_grouping_method_can_be_set(self):
r"""Tests that the grouping method can be controlled by kwargs.
Expand Down
8 changes: 4 additions & 4 deletions tests/test_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,21 +656,21 @@ def test_annotating_tensor_return_type(self, op1, op2, stat_func, return_type):
)
def test_queueing_tensor_observable(self, op1, op2, stat_func, return_type):
"""Test that if the constituent components of a tensor operation are not
found in the queue for annotation, that they are queued first and then annotated."""
found in the queue for annotation, they are not queued or annotated."""
A = op1(0)
B = op2(1)

with AnnotatedQueue() as q:
tensor_op = A @ B
stat_func(tensor_op)

assert q.queue[:-1] == [A, B, tensor_op]
assert len(q._queue) == 2

assert q.queue[0] is tensor_op
meas_proc = q.queue[-1]
assert isinstance(meas_proc, MeasurementProcess)
assert meas_proc.return_type == return_type

assert q._get_info(A) == {"owner": tensor_op}
assert q._get_info(B) == {"owner": tensor_op}
assert q._get_info(tensor_op) == {"owns": (A, B), "owner": meas_proc}


Expand Down
8 changes: 2 additions & 6 deletions tests/test_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,13 +861,9 @@ def test_queuing_defined_outside(self):
with qml.tape.QuantumTape() as tape:
T.queue()

assert len(tape.queue) == 3
assert tape.queue[0] is op1
assert tape.queue[1] is op2
assert tape.queue[2] is T
assert len(tape.queue) == 1
assert tape.queue[0] is T

assert tape._queue[op1] == {"owner": T}
assert tape._queue[op2] == {"owner": T}
assert tape._queue[T] == {"owns": (op1, op2)}

def test_queuing(self):
Expand Down
44 changes: 42 additions & 2 deletions tests/test_queuing.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,15 @@ def test_update_info(self):
q.append(A, inv=True)
assert QueuingContext.get_info(A) == {"inv": True}

assert q._get_info(A) == {"inv": True}
qml.QueuingContext.update_info(A, key="value1")

# should pass silently because no longer recording
qml.QueuingContext.update_info(A, key="value2")

assert q._get_info(A) == {"inv": True, "key": "value1"}

q._update_info(A, inv=False, owner=None)
assert q._get_info(A) == {"inv": False, "owner": None}
assert q._get_info(A) == {"inv": False, "owner": None, "key": "value1"}

def test_update_error(self):
"""Test that an exception is raised if get_info is called
Expand All @@ -234,6 +239,41 @@ def test_update_error(self):
with pytest.raises(QueuingError, match="not in the queue"):
q._update_info(B, inv=True)

def test_safe_update_info_queued(self):
"""Test the `safe_update_info` method if the object is already queued."""
op = qml.RX(0.5, wires=1)

with AnnotatedQueue() as q:
q.append(op, key="value1")
assert q.get_info(op) == {"key": "value1"}
qml.QueuingContext.safe_update_info(op, key="value2")

qml.QueuingContext.safe_update_info(op, key="no changes here")
assert q.get_info(op) == {"key": "value2"}

q.safe_update_info(op, key="value3")
assert q.get_info(op) == {"key": "value3"}

q._safe_update_info(op, key="value4")
assert q.get_info(op) == {"key": "value4"}

def test_safe_update_info_not_queued(self):
"""Tests the safe_update_info method passes silently if the object is
not already queued."""
op = qml.RX(0.5, wires=1)

with AnnotatedQueue() as q:
qml.QueuingContext.safe_update_info(op, key="value2")
qml.QueuingContext.safe_update_info(op, key="no changes here")

assert len(q.queue) == 0

q.safe_update_info(op, key="value3")
assert len(q.queue) == 0

q._safe_update_info(op, key="value4")
assert len(q.queue) == 0

def test_append_annotating_object(self):
"""Test appending an object that writes annotations when queuing itself"""

Expand Down