diff --git a/doc/_static/transforms_order.png b/doc/_static/transforms_order.png
new file mode 100644
index 00000000000..e7c026865ed
Binary files /dev/null and b/doc/_static/transforms_order.png differ
diff --git a/doc/_static/transforms_order.svg b/doc/_static/transforms_order.svg
new file mode 100644
index 00000000000..09197228d00
--- /dev/null
+++ b/doc/_static/transforms_order.svg
@@ -0,0 +1,251 @@
+
+
+
+
diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 68ecb16137c..f14d7b6ff34 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -139,6 +139,10 @@
* Raise a more informative error when calling `adjoint_jacobian` with trainable state-prep operations.
[(#5026)](https://github.com/PennyLaneAI/pennylane/pull/5026)
+* Adds `qml.workflow.get_transform_program` and `qml.workflow.construct_batch` to inspect the transform program and batch of tapes
+ at different stages.
+ [(#5084)](https://github.com/PennyLaneAI/pennylane/pull/5084)
+
* `CRX`, `CRY`, `CRZ`, `CROT`, and `ControlledPhaseShift` (i.e. `CPhaseShift`) now inherit from `ControlledOp`, giving them additional properties such as `control_wire` and `control_values`. Calling `qml.ctrl` on `RX`, `RY`, `RZ`, `Rot`, and `PhaseShift` with a single control wire will return gates of types `CRX`, `CRY`, etc. as opposed to a general `Controlled` operator.
[(#5069)](https://github.com/PennyLaneAI/pennylane/pull/5069)
diff --git a/pennylane/_qubit_device.py b/pennylane/_qubit_device.py
index 4d8cd0fe9c8..20a267ad7c3 100644
--- a/pennylane/_qubit_device.py
+++ b/pennylane/_qubit_device.py
@@ -32,7 +32,6 @@
import pennylane as qml
from pennylane import Device, DeviceError
-from pennylane.workflow import set_shots
from pennylane.math import multiply as qmlmul
from pennylane.math import sum as qmlsum
from pennylane.measurements import (
@@ -1036,7 +1035,7 @@ def classical_shadow(self, obs, circuit):
n_snapshots = self.shots
seed = obs.seed
- with set_shots(self, shots=1):
+ with qml.workflow.set_shots(self, shots=1):
# slow implementation but works for all devices
n_qubits = len(wires)
mapped_wires = np.array(self.map_wires(wires))
diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py
index 740623fa38f..ade6ec28344 100644
--- a/pennylane/devices/default_qubit.py
+++ b/pennylane/devices/default_qubit.py
@@ -159,6 +159,16 @@ def adjoint_state_measurements(
)
+def adjoint_ops(op: qml.operation.Operator) -> bool:
+ """Specify whether or not an Operator is supported by adjoint differentiation."""
+ return op.num_params == 0 or (op.num_params == 1 and op.has_generator)
+
+
+def adjoint_observables(obs: qml.operation.Operator) -> bool:
+ """Specifies whether or not an observable is compatible with adjoint differentiation on DefaultQubit."""
+ return obs.has_matrix
+
+
def _add_adjoint_transforms(program: TransformProgram, device_vjp=False) -> None:
"""Private helper function for ``preprocess`` that adds the transforms specific
for adjoint differentiation.
@@ -171,14 +181,6 @@ def _add_adjoint_transforms(program: TransformProgram, device_vjp=False) -> None
"""
- def adjoint_ops(op: qml.operation.Operator) -> bool:
- """Specify whether or not an Operator is supported by adjoint differentiation."""
- return op.num_params == 0 or op.num_params == 1 and op.has_generator
-
- def adjoint_observables(obs: qml.operation.Operator) -> bool:
- """Specifies whether or not an observable is compatible with adjoint differentiation on DefaultQubit."""
- return obs.has_matrix
-
name = "adjoint + default.qubit"
program.add_transform(no_sampling, name=name)
program.add_transform(
diff --git a/pennylane/templates/subroutines/permute.py b/pennylane/templates/subroutines/permute.py
index 394fa2f6418..569d977f195 100644
--- a/pennylane/templates/subroutines/permute.py
+++ b/pennylane/templates/subroutines/permute.py
@@ -138,6 +138,9 @@ def circuit():
"""
+ def __repr__(self):
+ return f"Permute({self.hyperparameters['permutation']}, wires={self.wires.tolist()})"
+
num_wires = AnyWires
grad_method = None
diff --git a/pennylane/transforms/core/transform.py b/pennylane/transforms/core/transform.py
index acbd57931f3..f2ec5d8d08c 100644
--- a/pennylane/transforms/core/transform.py
+++ b/pennylane/transforms/core/transform.py
@@ -23,7 +23,7 @@ def transform(
quantum_transform,
expand_transform=None,
classical_cotransform=None,
- is_informative=None,
+ is_informative=False,
final_transform=False,
):
"""Generalizes a function that transforms tapes to work with additional circuit-like objects such as a
@@ -45,14 +45,15 @@ def transform(
* The transform must have the following structure (type hinting is optional): ``my_quantum_transform(tape:
qml.tape.QuantumTape, ...) -> ( Sequence[qml.tape.QuantumTape], Callable)``
- expand_transform (Callable): An optional expand transform is applied directly before the input
+ Keyword Args:
+ expand_transform=None (Optional[Callable]): An optional expand transform is applied directly before the input
quantum transform. It must be a function that satisfies the same requirements as
``quantum_transform``.
- classical_cotransform (Callable): A classical co-transform is a function to post-process the classical
+ classical_cotransform=None (Optional[Callable]): A classical co-transform is a function to post-process the classical
jacobian and the quantum jacobian and has the signature: ``my_cotransform(qjac, cjac, tape) -> tensor_like``
- is_informative (bool): Whether or not a transform is informative. If true the transform is queued at the end
+ is_informative=False (bool): Whether or not a transform is informative. If true the transform is queued at the end
of the transform program and the tapes or qnode aren't executed.
- final_transform (bool): Whether or not the transform is terminal. If true the transform is queued at the end
+ final_transform=False (bool): Whether or not the transform is terminal. If true the transform is queued at the end
of the transform program. ``is_informative`` supersedes ``final_transform``.
Returns:
@@ -177,15 +178,13 @@ def qnode_circuit(a):
)
# 3: CHeck the classical co-transform
- if classical_cotransform is not None:
- if not callable(classical_cotransform):
- raise TransformError("The classical co-transform must be a valid Python function.")
+ if classical_cotransform is not None and not callable(classical_cotransform):
+ raise TransformError("The classical co-transform must be a valid Python function.")
- dispatcher = TransformDispatcher(
+ return TransformDispatcher(
quantum_transform,
expand_transform=expand_transform,
classical_cotransform=classical_cotransform,
is_informative=is_informative,
final_transform=final_transform,
)
- return dispatcher
diff --git a/pennylane/transforms/core/transform_dispatcher.py b/pennylane/transforms/core/transform_dispatcher.py
index 147fddce7e8..99a80098c5e 100644
--- a/pennylane/transforms/core/transform_dispatcher.py
+++ b/pennylane/transforms/core/transform_dispatcher.py
@@ -343,6 +343,9 @@ def __init__(
self._is_informative = is_informative
self._final_transform = is_informative or final_transform
+ def __repr__(self):
+ return f"<{self._transform.__name__}({self._args}, {self._kwargs})>"
+
def __iter__(self):
return iter(
(
diff --git a/pennylane/transforms/core/transform_program.py b/pennylane/transforms/core/transform_program.py
index ba175e91d53..857cdc3b8ae 100644
--- a/pennylane/transforms/core/transform_program.py
+++ b/pennylane/transforms/core/transform_program.py
@@ -15,7 +15,7 @@
This module contains the ``TransformProgram`` class.
"""
from functools import partial
-from typing import Callable, List, Tuple, Optional, Sequence
+from typing import Callable, List, Tuple, Optional, Sequence, Union
import pennylane as qml
from pennylane.typing import Result, ResultBatch
@@ -117,6 +117,35 @@ class TransformProgram:
.. seealso:: :func:`~.pennylane.transform`
+ **Implemented Dunder methods**
+
+ Programs have several implemented dunder methods for easy manipulation.
+
+ >>> program = TransformProgram()
+ >>> program.add_transform(qml.compile)
+ >>> program.add_transform(qml.transforms.cancel_inverses)
+ >>> [t for t in program] # Iteration
+ [, ]
+ >>> program[0]
+
+ >>> program[::-1]
+ TransformProgram(cancel_inverses, compile)
+ >>> len(program)
+ 2
+ >>> True if program else False
+ True
+ >>> True if TransformProgram() else False
+ False
+ >>> program2 = copy.copy(program)
+ >>> program2 == program
+ True
+ >>> qml.compile in program
+ True
+ >>> qml.transforms.hamiltonian_expand in program
+ False
+ >>> program + program
+ TransformProgram(compile, cancel_inverses, compile, cancel_inverses)
+
"""
def __init__(self, initial_program: Optional[Sequence] = None):
@@ -132,9 +161,11 @@ def __len__(self):
"""int: Return the number transforms in the program."""
return len(self._transform_program)
- def __getitem__(self, idx):
+ def __getitem__(self, idx) -> Union["TransformProgram", "TransformContainer"]:
"""(TransformContainer, List[TransformContainer]): Return the indexed transform container from underlying
transform program"""
+ if isinstance(idx, slice):
+ return TransformProgram(self._transform_program[idx])
return self._transform_program[idx]
def __bool__(self):
@@ -155,12 +186,19 @@ def __repr__(self):
contents = ", ".join(f"{transform_c.transform.__name__}" for transform_c in self)
return f"TransformProgram({contents})"
- def __eq__(self, other):
+ def __eq__(self, other) -> bool:
if not isinstance(other, TransformProgram):
return False
return self._transform_program == other._transform_program
+ def __contains__(self, obj):
+ if isinstance(obj, TransformContainer):
+ return obj in self._transform_program
+ if isinstance(obj, TransformDispatcher):
+ return any(obj.transform == t.transform for t in self)
+ return False
+
def push_back(self, transform_container: TransformContainer):
"""Add a transform (container) to the end of the program.
@@ -172,7 +210,10 @@ def push_back(self, transform_container: TransformContainer):
# Program can only contain one informative transform and at the end of the program
if self.has_final_transform:
- raise TransformError("The transform program already has a terminal transform.")
+ if transform_container.final_transform:
+ raise TransformError("The transform program already has a terminal transform.")
+ self._transform_program.insert(-1, transform_container)
+ return
self._transform_program.append(transform_container)
def insert_front(self, transform_container: TransformContainer):
@@ -290,7 +331,7 @@ def is_informative(self) -> bool:
@property
def has_final_transform(self) -> bool:
"""``True`` if the transform program has a terminal transform."""
- return self[-1].final_transform if self else False
+ return self[-1].final_transform if self else False # pylint: disable=no-member
def has_classical_cotransform(self) -> bool:
"""Check if the transform program has some classical cotransforms.
diff --git a/pennylane/workflow/__init__.py b/pennylane/workflow/__init__.py
index 231206b6182..b859a2e5ba0 100644
--- a/pennylane/workflow/__init__.py
+++ b/pennylane/workflow/__init__.py
@@ -25,6 +25,8 @@
~execute
~workflow.cache_execute
~workflow.set_shots
+ ~workflow.construct_batch
+ ~workflow.get_transform_program
Supported interfaces
~~~~~~~~~~~~~~~~~~~~
@@ -55,3 +57,4 @@
from .set_shots import set_shots
from .execution import execute, cache_execute, SUPPORTED_INTERFACES, INTERFACE_MAP
from .qnode import QNode, qnode
+from .construct_batch import construct_batch, get_transform_program
diff --git a/pennylane/workflow/construct_batch.py b/pennylane/workflow/construct_batch.py
new file mode 100644
index 00000000000..4a5f0d5066c
--- /dev/null
+++ b/pennylane/workflow/construct_batch.py
@@ -0,0 +1,301 @@
+# Copyright 2018-2024 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Contains a function extracting the tapes at postprocessing at any stage of a transform program.
+
+"""
+from functools import wraps
+import inspect
+from typing import Union, Callable, Tuple
+
+import pennylane as qml
+from .qnode import QNode, _make_execution_config, _get_device_shots
+
+
+def null_postprocessing(results):
+ """A postprocessing function with null behavior."""
+ return results[0]
+
+
+def expand_fn_transform(expand_fn: Callable) -> "qml.transforms.core.TransformDispatcher":
+ """Construct a transform from a tape-to-tape function.
+
+ Args:
+ expand_fn (Callable): a function from a single tape to a single tape
+
+ Returns:
+
+ .TransformDispatcher: Returns a transform dispatcher object that that can transform any
+ circuit-like object in PennyLane.
+
+ >>> device = qml.device('default.qubit.legacy', wires=2)
+ >>> my_transform = qml.transforms.core.expand_fn_transform(device.expand_fn)
+ >>> my_transform
+
+ """
+
+ @wraps(expand_fn)
+ def wrapped_expand_fn(tape, *args, **kwargs):
+ return (expand_fn(tape, *args, **kwargs),), null_postprocessing
+
+ return qml.transform(wrapped_expand_fn)
+
+
+def _get_full_transform_program(qnode: QNode) -> "qml.transforms.core.TransformProgram":
+ program = qml.transforms.core.TransformProgram(qnode.transform_program)
+ if getattr(qnode.gradient_fn, "expand_transform", False):
+ program.add_transform(
+ qml.transform(qnode.gradient_fn.expand_transform),
+ **qnode.gradient_kwargs,
+ )
+ if isinstance(qnode.device, qml.devices.Device):
+ config = _make_execution_config(qnode)
+ return program + qnode.device.preprocess(config)[0]
+ program.add_transform(qml.transform(qnode.device.batch_transform))
+ program.add_transform(expand_fn_transform(qnode.device.expand_fn))
+ return program
+
+
+def get_transform_program(qnode: "QNode", level=None) -> "qml.transforms.core.TransformProgram":
+ """Extract a transform program at a designated level.
+
+ Args:
+ qnode (QNode): the qnode to get the transform program for.
+ level (None, str, int, slice): And indication of what transforms to use from the full program.
+
+ * ``None``: use the full transform program
+ * ``str``: Acceptable keys are ``"user"``, ``"device"``, ``"top"`` and ``"gradient"``
+ * ``int``: How many transforms to include, starting from the front of the program
+ * ``slice``: a slice to select out components of the transform program.
+
+ Returns:
+ TransformProgram: the transform program corresponding to the requested level.
+
+ .. details::
+ :title: Usage Details
+
+ The transforms are organized as:
+
+ .. image:: ../../_static/transforms_order.png
+ :align: center
+ :width: 800px
+ :target: javascript:void(0);
+
+ where ``transform1`` is first applied to the ``QNode`` followed by ``transform2``. First user transforms are run on the tapes,
+ followed by the gradient expansion, followed by the device expansion. "Final" transforms, like ``param_shift`` and ``metric_tensor``,
+ always occur at the end of the program.
+
+ .. code-block:: python
+
+ dev = qml.device('default.qubit')
+
+ @qml.metric_tensor # final transform
+ @qml.transforms.merge_rotations # transform 2
+ @qml.transforms.cancel_inverses # transform 1
+ @qml.qnode(dev, diff_method="parameter-shift", shifts=np.pi / 4)
+ def circuit():
+ return qml.expval(qml.PauliZ(0))
+
+ By default, we get the full transform program. This can be manually specified by ``level=None``.
+
+ >>> qml.workflow.get_transform_program(circuit)
+ TransformProgram(cancel_inverses, merge_rotations, _expand_metric_tensor,
+ _expand_transform_param_shift, validate_device_wires, defer_measurements,
+ decompose, validate_measurements, validate_observables, metric_tensor)
+
+ The ``"user"`` transforms are the ones manually applied to the qnode, :class:`~.cancel_inverses` and
+ :class:`~.merge_rotations`.
+
+ >>> qml.workflow.get_transform_program(circuit, level="user")
+ TransformProgram(cancel_inverses, merge_rotations)
+
+ The ``_expand_transform_param_shift`` is the ``"gradient"`` transform. This expands all trainable
+ operations to a state where the parameter shift transform can operate on them. For example, it will decompose
+ any parametrized templates into operators that have generators.
+
+ >>> qml.workflow.get_transform_program(circuit, level="gradient")
+ TransformProgram(cancel_inverses, merge_rotations, _expand_transform_param_shift)
+
+ ``"device"`` includes all transforms except for a ``"final"`` transform, if it exists. This usually
+ corresponds to the circuits that will be sent to the device to execute.
+
+ >>> qml.workflow.get_transform_program(circuit, level="device")
+ TransformProgram(cancel_inverses, merge_rotations, _expand_transform_param_shift,
+ validate_device_wires, defer_measurements, decompose, validate_measurements,
+ validate_observables)
+
+ ``"top"`` and ``0`` both return empty transform programs.
+
+ >>> qml.workflow.get_transform_program(circuit, level="top")
+ TransformProgram()
+ >>> qml.workflow.get_transform_program(circuit, level=0)
+ TransformProgram()
+
+ The ``level`` can also be any integer, corresponding to a number of transforms in the program.
+
+ >>> qml.workflow.get_transform_program(circuit, level=2)
+ TransformProgram(cancel_inverses, merge_rotations)
+
+ ``level`` can also accept a ``slice`` object to select out any arbitrary subset of the
+ transform program. This allows you to select different starting transforms or strides.
+ For example, you can skip the first transform or reverse the order:
+
+ >>> qml.workflow.get_transform_program(circuit, level=slice(1,3))
+ TransformProgram(merge_rotations, _expand_transform_param_shift)
+ >>> qml.workflow.get_transform_program(circuit, level=slice(None, None, -1))
+ TransformProgram(metric_tensor, validate_observables, validate_measurements,
+ decompose, defer_measurements, validate_device_wires, _expand_transform_param_shift,
+ _expand_metric_tensor, merge_rotations, cancel_inverses)
+
+ """
+ full_transform_program = _get_full_transform_program(qnode)
+
+ num_user = len(qnode.transform_program)
+ if qnode.transform_program.has_final_transform:
+ # final transform is placed after device transforms
+ num_user -= 1
+
+ if level == "device":
+ level = -1 if full_transform_program.has_final_transform else None
+ elif level == "top":
+ level = 0
+ elif level == "user":
+ level = num_user
+ elif level == "gradient":
+ if getattr(qnode.gradient_fn, "expand_transform", False):
+ level = slice(0, num_user + 1)
+ else:
+ level = slice(0, num_user)
+ elif isinstance(level, str):
+ raise ValueError(
+ f"level {level} not recognized. Acceptable strings are 'device', 'top', 'user', and 'gradient'."
+ )
+ if level is None or isinstance(level, int):
+ level = slice(0, level)
+ return full_transform_program[level]
+
+
+def construct_batch(qnode: QNode, level: Union[None, str, int, slice] = "user") -> Callable:
+ """Construct the batch of tapes and post processing for a designated stage in the transform program.
+
+ Args:
+ qnode (QNode): the qnode we want to get the tapes and post-processing for.
+ level (None, str, int, slice): And indication of what transforms to use from the full program.
+
+ * ``None``: use the full transform program
+ * ``str``: Acceptable keys are ``"top"``, ``"user"``, ``"device"``, and ``"gradient"``
+ * ``int``: How many transforms to include, starting from the front of the program
+ * ``slice``: a slice to select out components of the transform program.
+
+ Returns:
+ Callable: a function with the same call signature as the initial quantum function. This function returns
+ a batch (tuple) of tapes and postprocessing function.
+
+ .. seealso:: :func:`pennylane.workflow.get_transform_program` to inspect the contents of the transform program for a specified level.
+
+
+ .. details::
+ :title: Usage Details
+
+ Suppose we have a QNode with several user transforms.
+
+ .. code-block:: python
+
+ @qml.transforms.undo_swaps
+ @qml.transforms.merge_rotations
+ @qml.transforms.cancel_inverses
+ @qml.qnode(qml.device('default.qubit'), diff_method="parameter-shift", shifts=np.pi / 4)
+ def circuit(x):
+ qml.RandomLayers(qml.numpy.array([[1.0, 2.0]]), wires=(0,1))
+ qml.RX(x, wires=0)
+ qml.RX(-x, wires=0)
+ qml.SWAP((0,1))
+ qml.PauliX(0)
+ qml.PauliX(0)
+ return qml.expval(qml.PauliX(0) + qml.PauliY(0))
+
+ We can inspect what the device will execute with:
+
+ >>> batch, fn = construct_batch(circuit, level="device")(1.23)
+ >>> batch[0].circuit
+ [RY(tensor(1., requires_grad=True), wires=[1]),
+ RX(tensor(2., requires_grad=True), wires=[0]),
+ expval( (1) [X0]
+ + (1) [Y0])]
+
+ These tapes can be natively executed by the device, though with non-backprop devices the parameters
+ will need to be converted to numpy with :func:`~.convert_to_numpy_parameters`.
+
+ >>> fn(dev.execute(batch))
+ (tensor(-0.90929743, requires_grad=True),)
+
+ Or what the parameter shift gradient transform will be applied to:
+
+ >>> batch, fn = construct_batch(circuit, level="gradient")(1.23)
+ >>> batch[0].circuit
+ [RY(tensor(1., requires_grad=True), wires=[1]),
+ RX(tensor(2., requires_grad=True), wires=[0]),
+ expval( (1) [X0]
+ + (1) [Y0])]
+
+ We can inspect what was directly captured from the qfunc with ``level=0``.
+
+ >>> batch, fn = construct_batch(circuit, level=0)(1.23)
+ >>> batch[0].circuit
+ [RandomLayers(tensor([[1., 2.]], requires_grad=True), wires=[0, 1]),
+ RX(1.23, wires=[0]),
+ RX(-1.23, wires=[0]),
+ SWAP(wires=[0, 1]),
+ PauliX(wires=[0]),
+ PauliX(wires=[0]),
+ expval( (1) [X0]
+ + (1) [Y0])]
+
+ And iterate though stages in the transform program with different integers.
+ If we request ``level=1``, the ``cancel_inverses`` transform has been applied.
+
+ >>> batch, fn = construct_batch(circuit, level=1)(1.23)
+ >>> batch[0].circuit
+ [RandomLayers(tensor([[1., 2.]], requires_grad=True), wires=[0, 1]),
+ RX(1.23, wires=[0]),
+ RX(-1.23, wires=[0]),
+ SWAP(wires=[0, 1]),
+ expval( (1) [X0]
+ + (1) [Y0])]
+
+ We can also slice into a subset of the transform program. ``slice(1, None)`` would skip the first user
+ transform ``cancel_inverses``:
+
+ >>> batch, fn = construct_batch(circuit, level=slice(1,None))(1.23)
+ >>> batch[0].circuit
+ [RY(tensor(1., requires_grad=True), wires=[1]),
+ RX(tensor(2., requires_grad=True), wires=[0]),
+ PauliX(wires=[0]),
+ PauliX(wires=[0]),
+ expval( (1) [X0]
+ + (1) [Y0])]
+
+ """
+ program = get_transform_program(qnode, level=level)
+
+ def batch_constructor(*args, **kwargs) -> Tuple[Tuple["qml.tape.QuantumTape", Callable]]:
+ """Create a batch of tapes and a post processing function."""
+ if "shots" in inspect.signature(qnode.func).parameters:
+ shots = _get_device_shots(qnode.device)
+ else:
+ shots = kwargs.pop("shots", _get_device_shots(qnode.device))
+
+ initial_tape = qml.tape.make_qscript(qnode.func, shots=shots)(*args, **kwargs)
+ return program((initial_tape,))
+
+ return batch_constructor
diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py
index b1348624e09..3055e920932 100644
--- a/pennylane/workflow/qnode.py
+++ b/pennylane/workflow/qnode.py
@@ -62,6 +62,26 @@ def _get_device_shots(device) -> Shots:
return device.shots
+def _make_execution_config(circuit: "QNode") -> "qml.devices.ExecutionConfig":
+ if circuit.gradient_fn is None:
+ _gradient_method = None
+ elif isinstance(circuit.gradient_fn, str):
+ _gradient_method = circuit.gradient_fn
+ else:
+ _gradient_method = "gradient-transform"
+ grad_on_execution = circuit.execute_kwargs.get("grad_on_execution")
+ if circuit.interface == "jax":
+ grad_on_execution = False
+ elif grad_on_execution == "best":
+ grad_on_execution = None
+ return qml.devices.ExecutionConfig(
+ interface=circuit.interface,
+ gradient_method=_gradient_method,
+ grad_on_execution=grad_on_execution,
+ use_device_jacobian_product=circuit.execute_kwargs["device_vjp"],
+ )
+
+
class QNode:
"""Represents a quantum node in the hybrid computational graph.
@@ -990,24 +1010,7 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result:
config = None
# Add the device program to the QNode program
if isinstance(self.device, qml.devices.Device):
- if self.gradient_fn is None:
- _gradient_method = None
- elif isinstance(self.gradient_fn, str):
- _gradient_method = self.gradient_fn
- else:
- _gradient_method = "gradient-transform"
- grad_on_execution = self.execute_kwargs.get("grad_on_execution")
- if self.interface == "jax":
- grad_on_execution = False
- elif grad_on_execution == "best":
- grad_on_execution = None
-
- config = qml.devices.ExecutionConfig(
- interface=self.interface,
- gradient_method=_gradient_method,
- grad_on_execution=grad_on_execution,
- use_device_jacobian_product=self.execute_kwargs["device_vjp"],
- )
+ config = _make_execution_config(self)
device_transform_program, config = self.device.preprocess(execution_config=config)
full_transform_program = self.transform_program + device_transform_program
else:
diff --git a/tests/devices/default_qubit/test_default_qubit.py b/tests/devices/default_qubit/test_default_qubit.py
index d98a20954ea..35a76028f42 100644
--- a/tests/devices/default_qubit/test_default_qubit.py
+++ b/tests/devices/default_qubit/test_default_qubit.py
@@ -1725,6 +1725,8 @@ def test_postselection_valid_finite_shots(
if use_jit and (interface != "jax" or isinstance(shots, tuple)):
pytest.skip("Cannot JIT in non-JAX interfaces, or with shot vectors.")
+ np.random.seed(42)
+
dev = qml.device("default.qubit")
param = qml.math.asarray(param, like=interface)
diff --git a/tests/templates/test_subroutines/test_permute.py b/tests/templates/test_subroutines/test_permute.py
index 6eefab22851..4bfa0dbc47b 100644
--- a/tests/templates/test_subroutines/test_permute.py
+++ b/tests/templates/test_subroutines/test_permute.py
@@ -26,6 +26,11 @@ def test_standard_validity():
qml.ops.functions.assert_valid(op)
+def test_repr():
+ op = qml.Permute([2, 1, 0], wires=(0, 1, 2))
+ assert repr(op) == "Permute((2, 1, 0), wires=[0, 1, 2])"
+
+
class TestDecomposition:
"""Tests that the template defines the correct decomposition."""
diff --git a/tests/test_hermitian_edge_cases.py b/tests/test_hermitian_edge_cases.py
index b3f2365ebf2..d628cf268ff 100644
--- a/tests/test_hermitian_edge_cases.py
+++ b/tests/test_hermitian_edge_cases.py
@@ -91,7 +91,7 @@ def circuit():
@pytest.mark.parametrize("w1, w2", list(itertools.permutations(range(4), 2)))
def test_hermitian_two_wires_permuted(self, w1, w2, shots, theta):
"""Test that an hermitian expectation with various wires permuted works"""
- dev = qml.device("default.qubit", wires=4, shots=shots)
+ dev = qml.device("default.qubit", wires=4, shots=shots, seed=123545)
theta = 0.543
A = np.array(
diff --git a/tests/transforms/test_experimental/test_transform_dispatcher.py b/tests/transforms/test_experimental/test_transform_dispatcher.py
index b7a5160bb52..b1ef83bb65c 100644
--- a/tests/transforms/test_experimental/test_transform_dispatcher.py
+++ b/tests/transforms/test_experimental/test_transform_dispatcher.py
@@ -150,6 +150,73 @@ def fn(results):
return [tape], fn
+class TestTransformContainer:
+ """Tests for the TransformContainer dataclass."""
+
+ def test_repr(self):
+ """Tests for the repr of a transform container."""
+ t1 = qml.transforms.core.TransformContainer(
+ qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 1}
+ )
+ assert repr(t1) == ""
+
+ def test_equality(self):
+ """Tests that we can compare TransformContainer objects with the '==' and '!=' operators."""
+
+ t1 = TransformContainer(
+ qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 1}
+ )
+ t2 = TransformContainer(
+ qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 1}
+ )
+ t3 = TransformContainer(
+ qml.transforms.transpile.transform, kwargs={"coupling_map": [(0, 1), (1, 2)]}
+ )
+ t4 = TransformContainer(
+ qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 2}
+ )
+
+ t5 = TransformContainer(qml.transforms.merge_rotations.transform, args=(1e-6,))
+ t6 = TransformContainer(qml.transforms.merge_rotations.transform, args=(1e-7,))
+
+ # test for equality of identical transformers
+ assert t1 == t2
+
+ # test for inequality of different transformers
+ assert t1 != t3
+ assert t2 != t3
+ assert t1 != 2
+ assert t1 != t4
+ assert t5 != t6
+ assert t5 != t1
+
+ # Test equality with the same args
+ t5_copy = TransformContainer(qml.transforms.merge_rotations.transform, args=(1e-6,))
+ assert t5 == t5_copy
+
+ def test_the_transform_container_attributes(self):
+ """Test the transform container attributes."""
+ container = qml.transforms.core.TransformContainer(
+ first_valid_transform, args=[0], kwargs={}, classical_cotransform=None
+ )
+
+ q_transform, args, kwargs, cotransform, is_informative, final_transform = container
+
+ assert q_transform is first_valid_transform
+ assert args == [0]
+ assert kwargs == {}
+ assert cotransform is None
+ assert not is_informative
+ assert not final_transform
+
+ assert container.transform is first_valid_transform
+ assert container.args == [0]
+ assert not container.kwargs
+ assert container.classical_cotransform is None
+ assert not container.is_informative
+ assert not container.final_transform
+
+
class TestTransformDispatcher: # pylint: disable=too-many-public-methods
"""Test the transform function (validate and dispatch)."""
@@ -194,7 +261,7 @@ def qnode_circuit(a):
assert isinstance(
qnode_transformed.transform_program.pop_front(), qml.transforms.core.TransformContainer
)
- assert not dispatched_transform.is_informative
+ assert dispatched_transform.is_informative is False
def test_integration_dispatcher_with_informative_transform(self):
"""Test that no error is raised with the transform function and that the transform dispatcher returns
@@ -276,40 +343,6 @@ def qnode_circuit(a): # pylint:disable=unused-variable
qml.RZ(a, wires=1)
return qml.expval(qml.PauliZ(wires=0))
- def test_equality(self):
- """Tests that we can compare TransformContainer objects with the '==' and '!=' operators."""
-
- t1 = TransformContainer(
- qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 1}
- )
- t2 = TransformContainer(
- qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 1}
- )
- t3 = TransformContainer(
- qml.transforms.transpile.transform, kwargs={"coupling_map": [(0, 1), (1, 2)]}
- )
- t4 = TransformContainer(
- qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 2}
- )
-
- t5 = TransformContainer(qml.transforms.merge_rotations.transform, args=(1e-6,))
- t6 = TransformContainer(qml.transforms.merge_rotations.transform, args=(1e-7,))
-
- # test for equality of identical transformers
- assert t1 == t2
-
- # test for inequality of different transformers
- assert t1 != t3
- assert t2 != t3
- assert t1 != 2
- assert t1 != t4
- assert t5 != t6
- assert t5 != t1
-
- # Test equality with the same args
- t5_copy = TransformContainer(qml.transforms.merge_rotations.transform, args=(1e-6,))
- assert t5 == t5_copy
-
def test_queuing_qfunc_transform(self):
"""Test that queuing works with the transformed quantum function."""
@@ -437,28 +470,6 @@ def test_dispatched_transform_attribute(self):
assert dispatched_transform.expand_transform is None
assert dispatched_transform.classical_cotransform is None
- def test_the_transform_container_attributes(self):
- """Test the transform container attributes."""
- container = qml.transforms.core.TransformContainer(
- first_valid_transform, args=[0], kwargs={}, classical_cotransform=None
- )
-
- q_transform, args, kwargs, cotransform, is_informative, final_transform = container
-
- assert q_transform is first_valid_transform
- assert args == [0]
- assert kwargs == {}
- assert cotransform is None
- assert not is_informative
- assert not final_transform
-
- assert container.transform is first_valid_transform
- assert container.args == [0]
- assert not container.kwargs
- assert container.classical_cotransform is None
- assert not container.is_informative
- assert not container.final_transform
-
@pytest.mark.parametrize("valid_transform", valid_transforms)
def test_custom_qnode_transform(self, valid_transform):
"""Test that the custom qnode transform is correctly executed"""
diff --git a/tests/transforms/test_experimental/test_transform_program.py b/tests/transforms/test_experimental/test_transform_program.py
index 4f89fe305d9..c928918bc3c 100644
--- a/tests/transforms/test_experimental/test_transform_program.py
+++ b/tests/transforms/test_experimental/test_transform_program.py
@@ -133,6 +133,43 @@ def test_iter_program(self):
assert isinstance(elem, TransformContainer)
assert elem.transform is first_valid_transform
+ def test_getitem(self):
+ """Tests for the getitem dunder."""
+
+ t0 = TransformContainer(transform=first_valid_transform)
+ t1 = TransformContainer(transform=second_valid_transform)
+ t2 = TransformContainer(transform=informative_transform)
+ program = TransformProgram([t0, t1, t2])
+
+ assert program[0] == t0
+ assert program[1] == t1
+ assert program[2] == t2
+
+ assert program[:2] == TransformProgram([t0, t1])
+ assert program[::-1] == TransformProgram([t2, t1, t0])
+
+ def test_contains(self):
+ """Test that we can check whether a transform or transform container exists in a transform."""
+
+ t0 = TransformContainer(transform=first_valid_transform)
+ t1 = TransformContainer(transform=second_valid_transform)
+ t2 = TransformContainer(transform=informative_transform)
+ program = TransformProgram([t0, t1, t2])
+
+ assert t0 in program
+ assert t1 in program
+ assert t2 in program
+ assert qml.compile not in program
+
+ assert t0 in program
+ assert t1 in program
+ assert t2 in program
+
+ t_not = TransformContainer(transform=qml.compile)
+ assert t_not not in program
+
+ assert "a" not in program
+
def test_add_single_programs(self):
"""Test adding two transform programs"""
transform_program1 = TransformProgram()
@@ -272,6 +309,35 @@ def test_repr_program(self):
+ ")"
)
+ def test_equality(self):
+ """Tests that we can compare TransformProgram objects with the '==' and '!=' operators."""
+ t1 = TransformContainer(
+ qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 1}
+ )
+ t2 = TransformContainer(
+ qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 1}
+ )
+ t3 = TransformContainer(
+ qml.transforms.transpile.transform, kwargs={"coupling_map": [(0, 1), (1, 2)]}
+ )
+
+ p1 = TransformProgram([t1, t3])
+ p2 = TransformProgram([t2, t3])
+ p3 = TransformProgram([t3, t2])
+
+ # test for equality of identical objects
+ assert p1 == p2
+ # test for inequality of different objects
+ assert p1 != p3
+ assert p1 != t1
+
+ # Test inequality with different transforms
+ t4 = TransformContainer(
+ qml.transforms.transpile.transform, kwargs={"coupling_map": [(0, 1), (2, 3)]}
+ )
+ p4 = TransformProgram([t1, t4])
+ assert p1 != p4
+
class TestTransformProgram:
"""Test the transform program class and its method."""
@@ -465,52 +531,36 @@ def test_insert_transform_with_expand(self):
assert transform_program[1].transform is first_valid_transform
def test_valid_transforms(self):
- """Test that it is only possible to create valid transforms."""
+ """Test adding transforms to a program with a terminal transform."""
transform_program = TransformProgram()
transform1 = TransformContainer(transform=first_valid_transform, is_informative=True)
transform_program.push_back(transform1)
+ t_normal = TransformContainer(transform=second_valid_transform)
+ transform_program.push_back(t_normal)
+ print(transform_program)
+ assert len(transform_program) == 2
+ assert transform_program[0] is t_normal
+ assert transform_program[1] is transform1
+
+ t_normal2 = TransformContainer(transform=first_valid_transform)
+ transform_program.push_back(t_normal2)
+ assert transform_program[0] is t_normal
+ assert transform_program[1] is t_normal2
+ assert transform_program[2] is transform1
+
with pytest.raises(
TransformError, match="The transform program already has a terminal transform."
):
transform_program.push_back(transform1)
- transform2 = TransformContainer(transform=second_valid_transform, is_informative=False)
+ transform2 = TransformContainer(transform=second_valid_transform, final_transform=True)
with pytest.raises(
TransformError, match="The transform program already has a terminal transform."
):
transform_program.push_back(transform2)
- def test_equality(self):
- """Tests that we can compare TransformProgram objects with the '==' and '!=' operators."""
- t1 = TransformContainer(
- qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 1}
- )
- t2 = TransformContainer(
- qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 1}
- )
- t3 = TransformContainer(
- qml.transforms.transpile.transform, kwargs={"coupling_map": [(0, 1), (1, 2)]}
- )
-
- p1 = TransformProgram([t1, t3])
- p2 = TransformProgram([t2, t3])
- p3 = TransformProgram([t3, t2])
-
- # test for equality of identical objects
- assert p1 == p2
- # test for inequality of different objects
- assert p1 != p3
- assert p1 != t1
-
- # Test inequality with different transforms
- t4 = TransformContainer(
- qml.transforms.transpile.transform, kwargs={"coupling_map": [(0, 1), (2, 3)]}
- )
- p4 = TransformProgram([t1, t4])
- assert p1 != p4
-
class TestTransformProgramCall:
"""Tests for calling a TransformProgram on a batch of quantum tapes."""
diff --git a/tests/workflow/test_construct_batch.py b/tests/workflow/test_construct_batch.py
new file mode 100644
index 00000000000..bc6113d5b41
--- /dev/null
+++ b/tests/workflow/test_construct_batch.py
@@ -0,0 +1,436 @@
+# Copyright 2018-2024 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Contains tests for the `qml.workflow.get_transform_program` getter and `construct_batch`.
+
+"""
+from functools import partial
+import pytest
+
+import numpy as np
+
+import pennylane as qml
+from pennylane.transforms.core.transform_dispatcher import TransformContainer
+from pennylane.transforms.core.transform_program import TransformProgram
+from pennylane.workflow import get_transform_program, construct_batch
+from pennylane.workflow.construct_batch import expand_fn_transform
+
+
+def test_expand_fn_transform():
+ """Tests the expand_fn_transform."""
+
+ def my_expand_fn(tape, op1, op2=qml.S(0), op3=qml.S(0)):
+ """my docstring."""
+ return qml.tape.QuantumScript(
+ tape.operations + [op1, op2, op3], tape.measurements, tape.shots
+ )
+
+ t = expand_fn_transform(my_expand_fn)
+
+ assert isinstance(t, qml.transforms.core.TransformDispatcher)
+ tape = qml.tape.QuantumScript([qml.S(0)], [qml.expval(qml.PauliZ(0))], shots=50)
+
+ batch, fn = t(tape, qml.PauliX(0), op3=qml.T(0))
+ assert len(batch) == 1
+ expected = qml.tape.QuantumScript(
+ [qml.S(0), qml.PauliX(0), qml.S(0), qml.T(0)], [qml.expval(qml.PauliZ(0))], shots=50
+ )
+ assert qml.equal(batch[0], expected)
+ assert fn(("a",)) == "a"
+
+ assert repr(t) == ""
+ assert t.__doc__ == "my docstring."
+
+
+class TestTransformProgramGetter:
+ def test_bad_string_key(self):
+ """Test a value error is raised if a bad string key is provided."""
+
+ @qml.qnode(qml.device("default.qubit"))
+ def circuit():
+ return qml.state()
+
+ with pytest.raises(ValueError, match=r"level bah not recognized."):
+ get_transform_program(circuit, level="bah")
+
+ def test_get_transform_program_gradient_fn_transform(self):
+ """Tests for the transform program when the gradient_fn is a transform."""
+
+ dev = qml.device("default.qubit", wires=4)
+
+ @partial(qml.transforms.compile, num_passes=2)
+ @partial(qml.transforms.merge_rotations, atol=1e-5)
+ @qml.transforms.cancel_inverses
+ @qml.qnode(dev, diff_method="parameter-shift", shifts=2)
+ def circuit():
+ return qml.expval(qml.PauliZ(0))
+
+ expected_p0 = TransformContainer(qml.transforms.cancel_inverses.transform)
+ expected_p1 = TransformContainer(
+ qml.transforms.merge_rotations.transform, kwargs={"atol": 1e-5}
+ )
+ expected_p2 = TransformContainer(qml.transforms.compile.transform, kwargs={"num_passes": 2})
+
+ ps_expand_fn = TransformContainer(
+ qml.gradients.param_shift.expand_transform, kwargs={"shifts": 2}
+ )
+
+ p0 = get_transform_program(circuit, level=0)
+ assert isinstance(p0, TransformProgram)
+ assert len(p0) == 0
+
+ p0 = get_transform_program(circuit, level="top")
+ assert isinstance(p0, TransformProgram)
+ assert len(p0) == 0
+
+ p_grad = get_transform_program(circuit, level="gradient")
+ assert isinstance(p_grad, TransformProgram)
+ assert len(p_grad) == 4
+ assert p_grad == TransformProgram([expected_p0, expected_p1, expected_p2, ps_expand_fn])
+
+ p_dev = get_transform_program(circuit, level="device")
+ assert isinstance(p_grad, TransformProgram)
+ p_default = get_transform_program(circuit)
+ p_none = get_transform_program(circuit, None)
+ assert p_dev == p_default
+ assert p_none == p_dev
+ assert len(p_dev) == 9
+ assert p_dev == p_grad + dev.preprocess()[0]
+
+ # slicing
+ p_sliced = get_transform_program(circuit, slice(2, 7, 2))
+ assert len(p_sliced) == 3
+ assert p_sliced[0].transform == qml.compile.transform
+ assert p_sliced[1].transform == qml.devices.preprocess.validate_device_wires.transform
+ assert p_sliced[2].transform == qml.devices.preprocess.decompose.transform
+
+ def test_gradient_fn_device_gradient(self):
+ """Test that if level="gradient" but the gradient does not have preprocessing, the program is strictly user transforms."""
+
+ @qml.transforms.cancel_inverses
+ @qml.qnode(qml.device("default.qubit"), diff_method="backprop")
+ def circuit():
+ return qml.state()
+
+ prog = get_transform_program(circuit, level="gradient")
+ assert len(prog) == 1
+ assert qml.transforms.cancel_inverses in prog
+
+ def test_get_transform_program_device_gradient(self):
+ """Test the trnsform program contents when using a device derivative."""
+
+ dev = qml.device("default.qubit")
+
+ @qml.transforms.sum_expand
+ @qml.qnode(dev, diff_method="adjoint", device_vjp=False)
+ def circuit(x):
+ qml.RX(x, 0)
+ return qml.expval(qml.PauliZ(0))
+
+ full_prog = get_transform_program(circuit)
+ assert len(full_prog) == 13
+
+ config = qml.devices.ExecutionConfig(
+ gradient_method="adjoint", use_device_jacobian_product=False
+ )
+ dev_program = dev.preprocess(config)[0]
+
+ expected = TransformProgram()
+ expected.add_transform(qml.transforms.sum_expand)
+ expected += dev_program
+ assert full_prog == expected
+
+ def test_get_transform_program_legacy_device_interface(self):
+ """Test the contents of the transform program with the legacy device interface."""
+
+ dev = qml.device("default.qubit.legacy", wires=5)
+
+ @qml.transforms.merge_rotations
+ @qml.qnode(dev, diff_method="backprop")
+ def circuit(x):
+ qml.RX(x, wires=0)
+ return qml.expval(qml.PauliZ(0))
+
+ program = get_transform_program(circuit)
+
+ m1 = TransformContainer(qml.transforms.merge_rotations.transform)
+ m2 = TransformContainer(dev.batch_transform)
+ assert program[0:2] == TransformProgram([m1, m2])
+
+ # a little hard to check the contents of a expand_fn_transform
+ # this is the best proxy I can find
+ assert program[2].transform.__wrapped__ == dev.expand_fn
+
+ def test_get_transform_program_final_transform(self):
+ """Test that gradient preprocessing and device transform occur before a final transform."""
+
+ @qml.metric_tensor
+ @qml.compile
+ @qml.qnode(qml.device("default.qubit"), diff_method="parameter-shift")
+ def circuit():
+ qml.IsingXX(1.234, wires=(0, 1))
+ return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliX(0))
+
+ user_program = get_transform_program(circuit, level="user")
+ assert len(user_program) == 2
+ assert user_program[0].transform == qml.compile.transform
+ assert user_program[1].transform == qml.metric_tensor.expand_transform
+
+ grad_program = get_transform_program(circuit, level="gradient")
+ assert len(grad_program) == 3
+ assert grad_program[0].transform == qml.compile.transform
+ assert grad_program[1].transform == qml.metric_tensor.expand_transform
+ assert grad_program[2].transform == qml.gradients.param_shift.expand_transform
+
+ dev_program = get_transform_program(circuit, level="device")
+ assert len(dev_program) == 3 + len(circuit.device.preprocess()[0]) # currently 8
+ assert qml.metric_tensor not in dev_program
+
+ full = get_transform_program(circuit)
+ assert full[-1].transform == qml.metric_tensor.transform
+
+
+@qml.transforms.merge_rotations
+@qml.transforms.cancel_inverses
+@qml.qnode(qml.device("default.qubit"), diff_method="parameter-shift")
+def circuit1(weights, order):
+ qml.RandomLayers(weights, wires=(0, 1))
+ qml.Permute(order, wires=(0, 1, 2))
+ qml.PauliX(0)
+ qml.PauliX(0)
+ qml.RX(0.1, wires=0)
+ qml.RX(-0.1, wires=0)
+ return qml.expval(qml.PauliX(0))
+
+
+class TestConstructBatch:
+ """Tests for the construct_batch function."""
+
+ def test_level_zero(self):
+ """Test that level zero is purely the queued circuit."""
+
+ order = [2, 1, 0]
+ weights = np.array([[1.0, 20]])
+ batch, fn = construct_batch(circuit1, level=0)(weights, order, shots=10)
+
+ assert len(batch) == 1
+ expected_ops = [
+ qml.RandomLayers(weights, wires=(0, 1)),
+ qml.Permute(order, wires=(0, 1, 2)),
+ qml.PauliX(0),
+ qml.PauliX(0),
+ qml.RX(0.1, wires=0),
+ qml.RX(-0.1, wires=0),
+ ]
+
+ expected = qml.tape.QuantumScript(
+ expected_ops,
+ [qml.expval(qml.PauliX(0))],
+ shots=10,
+ )
+ assert qml.equal(batch[0], expected)
+
+ assert fn(("a",)) == ("a",)
+
+ def test_first_transform(self):
+ """Test that the first user transform can be selected by level=1"""
+
+ weights = np.array([[1.0, 2.0]])
+ order = [2, 1, 0]
+
+ batch, fn = construct_batch(circuit1, level=1)(weights, order=order, shots=50)
+ assert len(batch) == 1
+
+ expected_ops = [
+ qml.RandomLayers(weights, wires=(0, 1)),
+ qml.Permute(order, wires=(0, 1, 2)),
+ # cancel inverses
+ qml.RX(0.1, wires=0),
+ qml.RX(-0.1, wires=0),
+ ]
+
+ expected = qml.tape.QuantumScript(expected_ops, [qml.expval(qml.PauliX(0))], shots=50)
+ assert qml.equal(batch[0], expected)
+ assert fn(("a",)) == ("a",)
+
+ @pytest.mark.parametrize("level", (2, "user"))
+ def test_all_user_transforms(self, level):
+ """Test that all user transforms can be selected and run."""
+
+ weights = np.array([[1.0, 2.0]])
+ order = [2, 1, 0]
+
+ batch, fn = construct_batch(circuit1, level=level)(weights, order=order, shots=50)
+ assert len(batch) == 1
+
+ expected_ops = [
+ qml.RandomLayers(weights, wires=(0, 1)),
+ qml.Permute(order, wires=(0, 1, 2)),
+ # cancel inverses
+ # merge rotations
+ ]
+
+ expected = qml.tape.QuantumScript(expected_ops, [qml.expval(qml.PauliX(0))], shots=50)
+ assert qml.equal(batch[0], expected)
+ assert fn(("a",)) == ("a",)
+
+ @pytest.mark.parametrize("level", (3, "gradient"))
+ def test_gradient_transforms(self, level):
+ """Test that the gradient transform can be selected with an integer or keyword."""
+ weights = qml.numpy.array([[1.0, 2.0]], requires_grad=True)
+ order = [2, 1, 0]
+ batch, fn = construct_batch(circuit1, level=level)(weights=weights, order=order)
+
+ expected = qml.tape.QuantumScript(
+ [
+ qml.RY(qml.numpy.array(1), 0),
+ qml.RX(qml.numpy.array(2), 1),
+ qml.Permute(order, (0, 1, 2)),
+ ],
+ [qml.expval(qml.PauliX(0))],
+ )
+ assert qml.equal(batch[0], expected)
+ assert len(batch) == 1
+ assert fn(("a",)) == ("a",)
+
+ @pytest.mark.parametrize("level", ("device", None))
+ def test_device_transforms(self, level):
+ """Test that all device transforms can be run with the device keyword."""
+
+ weights = np.array([[1.0, 2.0]])
+ order = [2, 1, 0]
+
+ batch, fn = construct_batch(circuit1, level=level)(weights, order)
+
+ expected = qml.tape.QuantumScript(
+ [qml.RY(1, 0), qml.RX(2, 1), qml.SWAP((0, 2))], [qml.expval(qml.PauliX(0))]
+ )
+ assert qml.equal(batch[0], expected)
+ assert len(batch) == 1
+ assert fn(("a",)) == ("a",)
+
+ @pytest.mark.parametrize("level", ("device", None))
+ def test_device_transforms_legacy_interface(self, level):
+ """Test that the device transforms can be selected with level=device or None."""
+
+ @qml.transforms.merge_rotations
+ @qml.qnode(qml.device("default.qubit.legacy", wires=2, shots=50))
+ def circuit(order):
+ qml.Permute(order, wires=(0, 1, 2))
+ qml.RX(0.5, wires=0)
+ qml.RX(-0.5, wires=0)
+ return qml.expval(qml.PauliX(0) + qml.PauliY(0))
+
+ batch, fn = construct_batch(circuit, level=level)((2, 1, 0))
+
+ expected0 = qml.tape.QuantumScript(
+ [qml.SWAP((0, 2))], [qml.expval(qml.PauliX(0))], shots=50
+ )
+ assert qml.equal(expected0, batch[0])
+ expected1 = qml.tape.QuantumScript(
+ [qml.SWAP((0, 2))], [qml.expval(qml.PauliY(0))], shots=50
+ )
+ assert qml.equal(expected1, batch[1])
+ assert len(batch) == 2
+
+ assert fn((1.0, 2.0)) == (3.0,)
+
+ def test_final_transform(self):
+ """Test that the final transform is included when level=None."""
+
+ @qml.gradients.param_shift
+ @qml.transforms.merge_rotations
+ @qml.qnode(qml.device("default.qubit"))
+ def circuit(x):
+ qml.RX(x, 0)
+ qml.RX(x, 0)
+ return qml.expval(qml.PauliZ(0))
+
+ batch, fn = construct_batch(circuit, level=None)(0.5)
+ assert len(batch) == 2
+ expected0 = qml.tape.QuantumScript(
+ [qml.RX(1.0 + np.pi / 2, 0)], [qml.expval(qml.PauliZ(0))]
+ )
+ assert qml.equal(batch[0], expected0)
+ expected1 = qml.tape.QuantumScript(
+ [qml.RX(1.0 - np.pi / 2, 0)], [qml.expval(qml.PauliZ(0))]
+ )
+ assert qml.equal(batch[1], expected1)
+
+ dummy_res = (1.0, 2.0)
+ expected_res = (1.0 - 2.0) / 2
+ assert qml.numpy.allclose(fn(dummy_res)[0], expected_res)
+
+ def test_user_transform_multiple_tapes(self):
+ """Test a user transform that creates multiple tapes."""
+
+ @qml.transforms.split_non_commuting
+ @qml.qnode(qml.device("default.qubit", shots=10))
+ def circuit():
+ qml.S(0)
+ return qml.expval(qml.PauliX(0)), qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliX(1))
+
+ batch, fn = construct_batch(circuit, level="user")()
+
+ assert len(batch) == 2
+ expected0 = qml.tape.QuantumScript(
+ [qml.S(0)], [qml.expval(qml.PauliX(0)), qml.expval(qml.PauliX(1))], shots=10
+ )
+ assert qml.equal(expected0, batch[0])
+
+ expected1 = qml.tape.QuantumScript([qml.S(0)], [qml.expval(qml.PauliZ(0))], shots=10)
+ assert qml.equal(expected1, batch[1])
+
+ dummy_res = (("x0", "x1"), "z0")
+ expected_res = (("x0", "z0", "x1"),)
+ assert fn(dummy_res) == expected_res
+
+ def test_slicing_level(self):
+ """Test that the level can be a slice."""
+
+ @qml.transforms.merge_rotations
+ @qml.qnode(qml.device("default.qubit"))
+ def circuit(x):
+ qml.RX(x, 0)
+ qml.RX(x, 0)
+ return qml.expval(qml.PauliZ(0))
+
+ # by slicing starting at one, we do not run the merge rotations transform
+ batch, fn = construct_batch(circuit, slice(1, None))(0.5)
+
+ assert len(batch) == 1
+ expected = qml.tape.QuantumScript(
+ [qml.RX(0.5, 0), qml.RX(0.5, 0)], [qml.expval(qml.PauliZ(0))]
+ )
+ assert qml.equal(batch[0], expected)
+ assert fn(("a",)) == ("a",)
+
+ def test_qfunc_with_shots_arg(self):
+ """Test that the tape uses device shots only when qfunc has a shots kwarg"""
+
+ dev = qml.device("default.qubit", shots=100)
+
+ @qml.qnode(dev)
+ def circuit(shots):
+ for _ in range(shots):
+ qml.S(0)
+ return qml.expval(qml.PauliZ(0))
+
+ batch, fn = construct_batch(circuit, level=None)(shots=2)
+ assert len(batch) == 1
+ expected = qml.tape.QuantumScript(
+ [qml.S(0), qml.S(0)], [qml.expval(qml.PauliZ(0))], shots=100
+ )
+ assert qml.equal(batch[0], expected)
+ assert fn(("a",)) == ("a",)