diff --git a/doc/_static/transforms_order.png b/doc/_static/transforms_order.png new file mode 100644 index 00000000000..e7c026865ed Binary files /dev/null and b/doc/_static/transforms_order.png differ diff --git a/doc/_static/transforms_order.svg b/doc/_static/transforms_order.svg new file mode 100644 index 00000000000..09197228d00 --- /dev/null +++ b/doc/_static/transforms_order.svg @@ -0,0 +1,251 @@ + + + + + + + + + + + + + + + + + + tapes + results + + transform 1 + + transform 2 + + gradient + + device + + final + + ML boundary + + + diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 68ecb16137c..f14d7b6ff34 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -139,6 +139,10 @@ * Raise a more informative error when calling `adjoint_jacobian` with trainable state-prep operations. [(#5026)](https://github.com/PennyLaneAI/pennylane/pull/5026) +* Adds `qml.workflow.get_transform_program` and `qml.workflow.construct_batch` to inspect the transform program and batch of tapes + at different stages. + [(#5084)](https://github.com/PennyLaneAI/pennylane/pull/5084) + * `CRX`, `CRY`, `CRZ`, `CROT`, and `ControlledPhaseShift` (i.e. `CPhaseShift`) now inherit from `ControlledOp`, giving them additional properties such as `control_wire` and `control_values`. Calling `qml.ctrl` on `RX`, `RY`, `RZ`, `Rot`, and `PhaseShift` with a single control wire will return gates of types `CRX`, `CRY`, etc. as opposed to a general `Controlled` operator. [(#5069)](https://github.com/PennyLaneAI/pennylane/pull/5069) diff --git a/pennylane/_qubit_device.py b/pennylane/_qubit_device.py index 4d8cd0fe9c8..20a267ad7c3 100644 --- a/pennylane/_qubit_device.py +++ b/pennylane/_qubit_device.py @@ -32,7 +32,6 @@ import pennylane as qml from pennylane import Device, DeviceError -from pennylane.workflow import set_shots from pennylane.math import multiply as qmlmul from pennylane.math import sum as qmlsum from pennylane.measurements import ( @@ -1036,7 +1035,7 @@ def classical_shadow(self, obs, circuit): n_snapshots = self.shots seed = obs.seed - with set_shots(self, shots=1): + with qml.workflow.set_shots(self, shots=1): # slow implementation but works for all devices n_qubits = len(wires) mapped_wires = np.array(self.map_wires(wires)) diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index 740623fa38f..ade6ec28344 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -159,6 +159,16 @@ def adjoint_state_measurements( ) +def adjoint_ops(op: qml.operation.Operator) -> bool: + """Specify whether or not an Operator is supported by adjoint differentiation.""" + return op.num_params == 0 or (op.num_params == 1 and op.has_generator) + + +def adjoint_observables(obs: qml.operation.Operator) -> bool: + """Specifies whether or not an observable is compatible with adjoint differentiation on DefaultQubit.""" + return obs.has_matrix + + def _add_adjoint_transforms(program: TransformProgram, device_vjp=False) -> None: """Private helper function for ``preprocess`` that adds the transforms specific for adjoint differentiation. @@ -171,14 +181,6 @@ def _add_adjoint_transforms(program: TransformProgram, device_vjp=False) -> None """ - def adjoint_ops(op: qml.operation.Operator) -> bool: - """Specify whether or not an Operator is supported by adjoint differentiation.""" - return op.num_params == 0 or op.num_params == 1 and op.has_generator - - def adjoint_observables(obs: qml.operation.Operator) -> bool: - """Specifies whether or not an observable is compatible with adjoint differentiation on DefaultQubit.""" - return obs.has_matrix - name = "adjoint + default.qubit" program.add_transform(no_sampling, name=name) program.add_transform( diff --git a/pennylane/templates/subroutines/permute.py b/pennylane/templates/subroutines/permute.py index 394fa2f6418..569d977f195 100644 --- a/pennylane/templates/subroutines/permute.py +++ b/pennylane/templates/subroutines/permute.py @@ -138,6 +138,9 @@ def circuit(): """ + def __repr__(self): + return f"Permute({self.hyperparameters['permutation']}, wires={self.wires.tolist()})" + num_wires = AnyWires grad_method = None diff --git a/pennylane/transforms/core/transform.py b/pennylane/transforms/core/transform.py index acbd57931f3..f2ec5d8d08c 100644 --- a/pennylane/transforms/core/transform.py +++ b/pennylane/transforms/core/transform.py @@ -23,7 +23,7 @@ def transform( quantum_transform, expand_transform=None, classical_cotransform=None, - is_informative=None, + is_informative=False, final_transform=False, ): """Generalizes a function that transforms tapes to work with additional circuit-like objects such as a @@ -45,14 +45,15 @@ def transform( * The transform must have the following structure (type hinting is optional): ``my_quantum_transform(tape: qml.tape.QuantumTape, ...) -> ( Sequence[qml.tape.QuantumTape], Callable)`` - expand_transform (Callable): An optional expand transform is applied directly before the input + Keyword Args: + expand_transform=None (Optional[Callable]): An optional expand transform is applied directly before the input quantum transform. It must be a function that satisfies the same requirements as ``quantum_transform``. - classical_cotransform (Callable): A classical co-transform is a function to post-process the classical + classical_cotransform=None (Optional[Callable]): A classical co-transform is a function to post-process the classical jacobian and the quantum jacobian and has the signature: ``my_cotransform(qjac, cjac, tape) -> tensor_like`` - is_informative (bool): Whether or not a transform is informative. If true the transform is queued at the end + is_informative=False (bool): Whether or not a transform is informative. If true the transform is queued at the end of the transform program and the tapes or qnode aren't executed. - final_transform (bool): Whether or not the transform is terminal. If true the transform is queued at the end + final_transform=False (bool): Whether or not the transform is terminal. If true the transform is queued at the end of the transform program. ``is_informative`` supersedes ``final_transform``. Returns: @@ -177,15 +178,13 @@ def qnode_circuit(a): ) # 3: CHeck the classical co-transform - if classical_cotransform is not None: - if not callable(classical_cotransform): - raise TransformError("The classical co-transform must be a valid Python function.") + if classical_cotransform is not None and not callable(classical_cotransform): + raise TransformError("The classical co-transform must be a valid Python function.") - dispatcher = TransformDispatcher( + return TransformDispatcher( quantum_transform, expand_transform=expand_transform, classical_cotransform=classical_cotransform, is_informative=is_informative, final_transform=final_transform, ) - return dispatcher diff --git a/pennylane/transforms/core/transform_dispatcher.py b/pennylane/transforms/core/transform_dispatcher.py index 147fddce7e8..99a80098c5e 100644 --- a/pennylane/transforms/core/transform_dispatcher.py +++ b/pennylane/transforms/core/transform_dispatcher.py @@ -343,6 +343,9 @@ def __init__( self._is_informative = is_informative self._final_transform = is_informative or final_transform + def __repr__(self): + return f"<{self._transform.__name__}({self._args}, {self._kwargs})>" + def __iter__(self): return iter( ( diff --git a/pennylane/transforms/core/transform_program.py b/pennylane/transforms/core/transform_program.py index ba175e91d53..857cdc3b8ae 100644 --- a/pennylane/transforms/core/transform_program.py +++ b/pennylane/transforms/core/transform_program.py @@ -15,7 +15,7 @@ This module contains the ``TransformProgram`` class. """ from functools import partial -from typing import Callable, List, Tuple, Optional, Sequence +from typing import Callable, List, Tuple, Optional, Sequence, Union import pennylane as qml from pennylane.typing import Result, ResultBatch @@ -117,6 +117,35 @@ class TransformProgram: .. seealso:: :func:`~.pennylane.transform` + **Implemented Dunder methods** + + Programs have several implemented dunder methods for easy manipulation. + + >>> program = TransformProgram() + >>> program.add_transform(qml.compile) + >>> program.add_transform(qml.transforms.cancel_inverses) + >>> [t for t in program] # Iteration + [, ] + >>> program[0] + + >>> program[::-1] + TransformProgram(cancel_inverses, compile) + >>> len(program) + 2 + >>> True if program else False + True + >>> True if TransformProgram() else False + False + >>> program2 = copy.copy(program) + >>> program2 == program + True + >>> qml.compile in program + True + >>> qml.transforms.hamiltonian_expand in program + False + >>> program + program + TransformProgram(compile, cancel_inverses, compile, cancel_inverses) + """ def __init__(self, initial_program: Optional[Sequence] = None): @@ -132,9 +161,11 @@ def __len__(self): """int: Return the number transforms in the program.""" return len(self._transform_program) - def __getitem__(self, idx): + def __getitem__(self, idx) -> Union["TransformProgram", "TransformContainer"]: """(TransformContainer, List[TransformContainer]): Return the indexed transform container from underlying transform program""" + if isinstance(idx, slice): + return TransformProgram(self._transform_program[idx]) return self._transform_program[idx] def __bool__(self): @@ -155,12 +186,19 @@ def __repr__(self): contents = ", ".join(f"{transform_c.transform.__name__}" for transform_c in self) return f"TransformProgram({contents})" - def __eq__(self, other): + def __eq__(self, other) -> bool: if not isinstance(other, TransformProgram): return False return self._transform_program == other._transform_program + def __contains__(self, obj): + if isinstance(obj, TransformContainer): + return obj in self._transform_program + if isinstance(obj, TransformDispatcher): + return any(obj.transform == t.transform for t in self) + return False + def push_back(self, transform_container: TransformContainer): """Add a transform (container) to the end of the program. @@ -172,7 +210,10 @@ def push_back(self, transform_container: TransformContainer): # Program can only contain one informative transform and at the end of the program if self.has_final_transform: - raise TransformError("The transform program already has a terminal transform.") + if transform_container.final_transform: + raise TransformError("The transform program already has a terminal transform.") + self._transform_program.insert(-1, transform_container) + return self._transform_program.append(transform_container) def insert_front(self, transform_container: TransformContainer): @@ -290,7 +331,7 @@ def is_informative(self) -> bool: @property def has_final_transform(self) -> bool: """``True`` if the transform program has a terminal transform.""" - return self[-1].final_transform if self else False + return self[-1].final_transform if self else False # pylint: disable=no-member def has_classical_cotransform(self) -> bool: """Check if the transform program has some classical cotransforms. diff --git a/pennylane/workflow/__init__.py b/pennylane/workflow/__init__.py index 231206b6182..b859a2e5ba0 100644 --- a/pennylane/workflow/__init__.py +++ b/pennylane/workflow/__init__.py @@ -25,6 +25,8 @@ ~execute ~workflow.cache_execute ~workflow.set_shots + ~workflow.construct_batch + ~workflow.get_transform_program Supported interfaces ~~~~~~~~~~~~~~~~~~~~ @@ -55,3 +57,4 @@ from .set_shots import set_shots from .execution import execute, cache_execute, SUPPORTED_INTERFACES, INTERFACE_MAP from .qnode import QNode, qnode +from .construct_batch import construct_batch, get_transform_program diff --git a/pennylane/workflow/construct_batch.py b/pennylane/workflow/construct_batch.py new file mode 100644 index 00000000000..4a5f0d5066c --- /dev/null +++ b/pennylane/workflow/construct_batch.py @@ -0,0 +1,301 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains a function extracting the tapes at postprocessing at any stage of a transform program. + +""" +from functools import wraps +import inspect +from typing import Union, Callable, Tuple + +import pennylane as qml +from .qnode import QNode, _make_execution_config, _get_device_shots + + +def null_postprocessing(results): + """A postprocessing function with null behavior.""" + return results[0] + + +def expand_fn_transform(expand_fn: Callable) -> "qml.transforms.core.TransformDispatcher": + """Construct a transform from a tape-to-tape function. + + Args: + expand_fn (Callable): a function from a single tape to a single tape + + Returns: + + .TransformDispatcher: Returns a transform dispatcher object that that can transform any + circuit-like object in PennyLane. + + >>> device = qml.device('default.qubit.legacy', wires=2) + >>> my_transform = qml.transforms.core.expand_fn_transform(device.expand_fn) + >>> my_transform + + """ + + @wraps(expand_fn) + def wrapped_expand_fn(tape, *args, **kwargs): + return (expand_fn(tape, *args, **kwargs),), null_postprocessing + + return qml.transform(wrapped_expand_fn) + + +def _get_full_transform_program(qnode: QNode) -> "qml.transforms.core.TransformProgram": + program = qml.transforms.core.TransformProgram(qnode.transform_program) + if getattr(qnode.gradient_fn, "expand_transform", False): + program.add_transform( + qml.transform(qnode.gradient_fn.expand_transform), + **qnode.gradient_kwargs, + ) + if isinstance(qnode.device, qml.devices.Device): + config = _make_execution_config(qnode) + return program + qnode.device.preprocess(config)[0] + program.add_transform(qml.transform(qnode.device.batch_transform)) + program.add_transform(expand_fn_transform(qnode.device.expand_fn)) + return program + + +def get_transform_program(qnode: "QNode", level=None) -> "qml.transforms.core.TransformProgram": + """Extract a transform program at a designated level. + + Args: + qnode (QNode): the qnode to get the transform program for. + level (None, str, int, slice): And indication of what transforms to use from the full program. + + * ``None``: use the full transform program + * ``str``: Acceptable keys are ``"user"``, ``"device"``, ``"top"`` and ``"gradient"`` + * ``int``: How many transforms to include, starting from the front of the program + * ``slice``: a slice to select out components of the transform program. + + Returns: + TransformProgram: the transform program corresponding to the requested level. + + .. details:: + :title: Usage Details + + The transforms are organized as: + + .. image:: ../../_static/transforms_order.png + :align: center + :width: 800px + :target: javascript:void(0); + + where ``transform1`` is first applied to the ``QNode`` followed by ``transform2``. First user transforms are run on the tapes, + followed by the gradient expansion, followed by the device expansion. "Final" transforms, like ``param_shift`` and ``metric_tensor``, + always occur at the end of the program. + + .. code-block:: python + + dev = qml.device('default.qubit') + + @qml.metric_tensor # final transform + @qml.transforms.merge_rotations # transform 2 + @qml.transforms.cancel_inverses # transform 1 + @qml.qnode(dev, diff_method="parameter-shift", shifts=np.pi / 4) + def circuit(): + return qml.expval(qml.PauliZ(0)) + + By default, we get the full transform program. This can be manually specified by ``level=None``. + + >>> qml.workflow.get_transform_program(circuit) + TransformProgram(cancel_inverses, merge_rotations, _expand_metric_tensor, + _expand_transform_param_shift, validate_device_wires, defer_measurements, + decompose, validate_measurements, validate_observables, metric_tensor) + + The ``"user"`` transforms are the ones manually applied to the qnode, :class:`~.cancel_inverses` and + :class:`~.merge_rotations`. + + >>> qml.workflow.get_transform_program(circuit, level="user") + TransformProgram(cancel_inverses, merge_rotations) + + The ``_expand_transform_param_shift`` is the ``"gradient"`` transform. This expands all trainable + operations to a state where the parameter shift transform can operate on them. For example, it will decompose + any parametrized templates into operators that have generators. + + >>> qml.workflow.get_transform_program(circuit, level="gradient") + TransformProgram(cancel_inverses, merge_rotations, _expand_transform_param_shift) + + ``"device"`` includes all transforms except for a ``"final"`` transform, if it exists. This usually + corresponds to the circuits that will be sent to the device to execute. + + >>> qml.workflow.get_transform_program(circuit, level="device") + TransformProgram(cancel_inverses, merge_rotations, _expand_transform_param_shift, + validate_device_wires, defer_measurements, decompose, validate_measurements, + validate_observables) + + ``"top"`` and ``0`` both return empty transform programs. + + >>> qml.workflow.get_transform_program(circuit, level="top") + TransformProgram() + >>> qml.workflow.get_transform_program(circuit, level=0) + TransformProgram() + + The ``level`` can also be any integer, corresponding to a number of transforms in the program. + + >>> qml.workflow.get_transform_program(circuit, level=2) + TransformProgram(cancel_inverses, merge_rotations) + + ``level`` can also accept a ``slice`` object to select out any arbitrary subset of the + transform program. This allows you to select different starting transforms or strides. + For example, you can skip the first transform or reverse the order: + + >>> qml.workflow.get_transform_program(circuit, level=slice(1,3)) + TransformProgram(merge_rotations, _expand_transform_param_shift) + >>> qml.workflow.get_transform_program(circuit, level=slice(None, None, -1)) + TransformProgram(metric_tensor, validate_observables, validate_measurements, + decompose, defer_measurements, validate_device_wires, _expand_transform_param_shift, + _expand_metric_tensor, merge_rotations, cancel_inverses) + + """ + full_transform_program = _get_full_transform_program(qnode) + + num_user = len(qnode.transform_program) + if qnode.transform_program.has_final_transform: + # final transform is placed after device transforms + num_user -= 1 + + if level == "device": + level = -1 if full_transform_program.has_final_transform else None + elif level == "top": + level = 0 + elif level == "user": + level = num_user + elif level == "gradient": + if getattr(qnode.gradient_fn, "expand_transform", False): + level = slice(0, num_user + 1) + else: + level = slice(0, num_user) + elif isinstance(level, str): + raise ValueError( + f"level {level} not recognized. Acceptable strings are 'device', 'top', 'user', and 'gradient'." + ) + if level is None or isinstance(level, int): + level = slice(0, level) + return full_transform_program[level] + + +def construct_batch(qnode: QNode, level: Union[None, str, int, slice] = "user") -> Callable: + """Construct the batch of tapes and post processing for a designated stage in the transform program. + + Args: + qnode (QNode): the qnode we want to get the tapes and post-processing for. + level (None, str, int, slice): And indication of what transforms to use from the full program. + + * ``None``: use the full transform program + * ``str``: Acceptable keys are ``"top"``, ``"user"``, ``"device"``, and ``"gradient"`` + * ``int``: How many transforms to include, starting from the front of the program + * ``slice``: a slice to select out components of the transform program. + + Returns: + Callable: a function with the same call signature as the initial quantum function. This function returns + a batch (tuple) of tapes and postprocessing function. + + .. seealso:: :func:`pennylane.workflow.get_transform_program` to inspect the contents of the transform program for a specified level. + + + .. details:: + :title: Usage Details + + Suppose we have a QNode with several user transforms. + + .. code-block:: python + + @qml.transforms.undo_swaps + @qml.transforms.merge_rotations + @qml.transforms.cancel_inverses + @qml.qnode(qml.device('default.qubit'), diff_method="parameter-shift", shifts=np.pi / 4) + def circuit(x): + qml.RandomLayers(qml.numpy.array([[1.0, 2.0]]), wires=(0,1)) + qml.RX(x, wires=0) + qml.RX(-x, wires=0) + qml.SWAP((0,1)) + qml.PauliX(0) + qml.PauliX(0) + return qml.expval(qml.PauliX(0) + qml.PauliY(0)) + + We can inspect what the device will execute with: + + >>> batch, fn = construct_batch(circuit, level="device")(1.23) + >>> batch[0].circuit + [RY(tensor(1., requires_grad=True), wires=[1]), + RX(tensor(2., requires_grad=True), wires=[0]), + expval( (1) [X0] + + (1) [Y0])] + + These tapes can be natively executed by the device, though with non-backprop devices the parameters + will need to be converted to numpy with :func:`~.convert_to_numpy_parameters`. + + >>> fn(dev.execute(batch)) + (tensor(-0.90929743, requires_grad=True),) + + Or what the parameter shift gradient transform will be applied to: + + >>> batch, fn = construct_batch(circuit, level="gradient")(1.23) + >>> batch[0].circuit + [RY(tensor(1., requires_grad=True), wires=[1]), + RX(tensor(2., requires_grad=True), wires=[0]), + expval( (1) [X0] + + (1) [Y0])] + + We can inspect what was directly captured from the qfunc with ``level=0``. + + >>> batch, fn = construct_batch(circuit, level=0)(1.23) + >>> batch[0].circuit + [RandomLayers(tensor([[1., 2.]], requires_grad=True), wires=[0, 1]), + RX(1.23, wires=[0]), + RX(-1.23, wires=[0]), + SWAP(wires=[0, 1]), + PauliX(wires=[0]), + PauliX(wires=[0]), + expval( (1) [X0] + + (1) [Y0])] + + And iterate though stages in the transform program with different integers. + If we request ``level=1``, the ``cancel_inverses`` transform has been applied. + + >>> batch, fn = construct_batch(circuit, level=1)(1.23) + >>> batch[0].circuit + [RandomLayers(tensor([[1., 2.]], requires_grad=True), wires=[0, 1]), + RX(1.23, wires=[0]), + RX(-1.23, wires=[0]), + SWAP(wires=[0, 1]), + expval( (1) [X0] + + (1) [Y0])] + + We can also slice into a subset of the transform program. ``slice(1, None)`` would skip the first user + transform ``cancel_inverses``: + + >>> batch, fn = construct_batch(circuit, level=slice(1,None))(1.23) + >>> batch[0].circuit + [RY(tensor(1., requires_grad=True), wires=[1]), + RX(tensor(2., requires_grad=True), wires=[0]), + PauliX(wires=[0]), + PauliX(wires=[0]), + expval( (1) [X0] + + (1) [Y0])] + + """ + program = get_transform_program(qnode, level=level) + + def batch_constructor(*args, **kwargs) -> Tuple[Tuple["qml.tape.QuantumTape", Callable]]: + """Create a batch of tapes and a post processing function.""" + if "shots" in inspect.signature(qnode.func).parameters: + shots = _get_device_shots(qnode.device) + else: + shots = kwargs.pop("shots", _get_device_shots(qnode.device)) + + initial_tape = qml.tape.make_qscript(qnode.func, shots=shots)(*args, **kwargs) + return program((initial_tape,)) + + return batch_constructor diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index b1348624e09..3055e920932 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -62,6 +62,26 @@ def _get_device_shots(device) -> Shots: return device.shots +def _make_execution_config(circuit: "QNode") -> "qml.devices.ExecutionConfig": + if circuit.gradient_fn is None: + _gradient_method = None + elif isinstance(circuit.gradient_fn, str): + _gradient_method = circuit.gradient_fn + else: + _gradient_method = "gradient-transform" + grad_on_execution = circuit.execute_kwargs.get("grad_on_execution") + if circuit.interface == "jax": + grad_on_execution = False + elif grad_on_execution == "best": + grad_on_execution = None + return qml.devices.ExecutionConfig( + interface=circuit.interface, + gradient_method=_gradient_method, + grad_on_execution=grad_on_execution, + use_device_jacobian_product=circuit.execute_kwargs["device_vjp"], + ) + + class QNode: """Represents a quantum node in the hybrid computational graph. @@ -990,24 +1010,7 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result: config = None # Add the device program to the QNode program if isinstance(self.device, qml.devices.Device): - if self.gradient_fn is None: - _gradient_method = None - elif isinstance(self.gradient_fn, str): - _gradient_method = self.gradient_fn - else: - _gradient_method = "gradient-transform" - grad_on_execution = self.execute_kwargs.get("grad_on_execution") - if self.interface == "jax": - grad_on_execution = False - elif grad_on_execution == "best": - grad_on_execution = None - - config = qml.devices.ExecutionConfig( - interface=self.interface, - gradient_method=_gradient_method, - grad_on_execution=grad_on_execution, - use_device_jacobian_product=self.execute_kwargs["device_vjp"], - ) + config = _make_execution_config(self) device_transform_program, config = self.device.preprocess(execution_config=config) full_transform_program = self.transform_program + device_transform_program else: diff --git a/tests/devices/default_qubit/test_default_qubit.py b/tests/devices/default_qubit/test_default_qubit.py index d98a20954ea..35a76028f42 100644 --- a/tests/devices/default_qubit/test_default_qubit.py +++ b/tests/devices/default_qubit/test_default_qubit.py @@ -1725,6 +1725,8 @@ def test_postselection_valid_finite_shots( if use_jit and (interface != "jax" or isinstance(shots, tuple)): pytest.skip("Cannot JIT in non-JAX interfaces, or with shot vectors.") + np.random.seed(42) + dev = qml.device("default.qubit") param = qml.math.asarray(param, like=interface) diff --git a/tests/templates/test_subroutines/test_permute.py b/tests/templates/test_subroutines/test_permute.py index 6eefab22851..4bfa0dbc47b 100644 --- a/tests/templates/test_subroutines/test_permute.py +++ b/tests/templates/test_subroutines/test_permute.py @@ -26,6 +26,11 @@ def test_standard_validity(): qml.ops.functions.assert_valid(op) +def test_repr(): + op = qml.Permute([2, 1, 0], wires=(0, 1, 2)) + assert repr(op) == "Permute((2, 1, 0), wires=[0, 1, 2])" + + class TestDecomposition: """Tests that the template defines the correct decomposition.""" diff --git a/tests/test_hermitian_edge_cases.py b/tests/test_hermitian_edge_cases.py index b3f2365ebf2..d628cf268ff 100644 --- a/tests/test_hermitian_edge_cases.py +++ b/tests/test_hermitian_edge_cases.py @@ -91,7 +91,7 @@ def circuit(): @pytest.mark.parametrize("w1, w2", list(itertools.permutations(range(4), 2))) def test_hermitian_two_wires_permuted(self, w1, w2, shots, theta): """Test that an hermitian expectation with various wires permuted works""" - dev = qml.device("default.qubit", wires=4, shots=shots) + dev = qml.device("default.qubit", wires=4, shots=shots, seed=123545) theta = 0.543 A = np.array( diff --git a/tests/transforms/test_experimental/test_transform_dispatcher.py b/tests/transforms/test_experimental/test_transform_dispatcher.py index b7a5160bb52..b1ef83bb65c 100644 --- a/tests/transforms/test_experimental/test_transform_dispatcher.py +++ b/tests/transforms/test_experimental/test_transform_dispatcher.py @@ -150,6 +150,73 @@ def fn(results): return [tape], fn +class TestTransformContainer: + """Tests for the TransformContainer dataclass.""" + + def test_repr(self): + """Tests for the repr of a transform container.""" + t1 = qml.transforms.core.TransformContainer( + qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 1} + ) + assert repr(t1) == "" + + def test_equality(self): + """Tests that we can compare TransformContainer objects with the '==' and '!=' operators.""" + + t1 = TransformContainer( + qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 1} + ) + t2 = TransformContainer( + qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 1} + ) + t3 = TransformContainer( + qml.transforms.transpile.transform, kwargs={"coupling_map": [(0, 1), (1, 2)]} + ) + t4 = TransformContainer( + qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 2} + ) + + t5 = TransformContainer(qml.transforms.merge_rotations.transform, args=(1e-6,)) + t6 = TransformContainer(qml.transforms.merge_rotations.transform, args=(1e-7,)) + + # test for equality of identical transformers + assert t1 == t2 + + # test for inequality of different transformers + assert t1 != t3 + assert t2 != t3 + assert t1 != 2 + assert t1 != t4 + assert t5 != t6 + assert t5 != t1 + + # Test equality with the same args + t5_copy = TransformContainer(qml.transforms.merge_rotations.transform, args=(1e-6,)) + assert t5 == t5_copy + + def test_the_transform_container_attributes(self): + """Test the transform container attributes.""" + container = qml.transforms.core.TransformContainer( + first_valid_transform, args=[0], kwargs={}, classical_cotransform=None + ) + + q_transform, args, kwargs, cotransform, is_informative, final_transform = container + + assert q_transform is first_valid_transform + assert args == [0] + assert kwargs == {} + assert cotransform is None + assert not is_informative + assert not final_transform + + assert container.transform is first_valid_transform + assert container.args == [0] + assert not container.kwargs + assert container.classical_cotransform is None + assert not container.is_informative + assert not container.final_transform + + class TestTransformDispatcher: # pylint: disable=too-many-public-methods """Test the transform function (validate and dispatch).""" @@ -194,7 +261,7 @@ def qnode_circuit(a): assert isinstance( qnode_transformed.transform_program.pop_front(), qml.transforms.core.TransformContainer ) - assert not dispatched_transform.is_informative + assert dispatched_transform.is_informative is False def test_integration_dispatcher_with_informative_transform(self): """Test that no error is raised with the transform function and that the transform dispatcher returns @@ -276,40 +343,6 @@ def qnode_circuit(a): # pylint:disable=unused-variable qml.RZ(a, wires=1) return qml.expval(qml.PauliZ(wires=0)) - def test_equality(self): - """Tests that we can compare TransformContainer objects with the '==' and '!=' operators.""" - - t1 = TransformContainer( - qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 1} - ) - t2 = TransformContainer( - qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 1} - ) - t3 = TransformContainer( - qml.transforms.transpile.transform, kwargs={"coupling_map": [(0, 1), (1, 2)]} - ) - t4 = TransformContainer( - qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 2} - ) - - t5 = TransformContainer(qml.transforms.merge_rotations.transform, args=(1e-6,)) - t6 = TransformContainer(qml.transforms.merge_rotations.transform, args=(1e-7,)) - - # test for equality of identical transformers - assert t1 == t2 - - # test for inequality of different transformers - assert t1 != t3 - assert t2 != t3 - assert t1 != 2 - assert t1 != t4 - assert t5 != t6 - assert t5 != t1 - - # Test equality with the same args - t5_copy = TransformContainer(qml.transforms.merge_rotations.transform, args=(1e-6,)) - assert t5 == t5_copy - def test_queuing_qfunc_transform(self): """Test that queuing works with the transformed quantum function.""" @@ -437,28 +470,6 @@ def test_dispatched_transform_attribute(self): assert dispatched_transform.expand_transform is None assert dispatched_transform.classical_cotransform is None - def test_the_transform_container_attributes(self): - """Test the transform container attributes.""" - container = qml.transforms.core.TransformContainer( - first_valid_transform, args=[0], kwargs={}, classical_cotransform=None - ) - - q_transform, args, kwargs, cotransform, is_informative, final_transform = container - - assert q_transform is first_valid_transform - assert args == [0] - assert kwargs == {} - assert cotransform is None - assert not is_informative - assert not final_transform - - assert container.transform is first_valid_transform - assert container.args == [0] - assert not container.kwargs - assert container.classical_cotransform is None - assert not container.is_informative - assert not container.final_transform - @pytest.mark.parametrize("valid_transform", valid_transforms) def test_custom_qnode_transform(self, valid_transform): """Test that the custom qnode transform is correctly executed""" diff --git a/tests/transforms/test_experimental/test_transform_program.py b/tests/transforms/test_experimental/test_transform_program.py index 4f89fe305d9..c928918bc3c 100644 --- a/tests/transforms/test_experimental/test_transform_program.py +++ b/tests/transforms/test_experimental/test_transform_program.py @@ -133,6 +133,43 @@ def test_iter_program(self): assert isinstance(elem, TransformContainer) assert elem.transform is first_valid_transform + def test_getitem(self): + """Tests for the getitem dunder.""" + + t0 = TransformContainer(transform=first_valid_transform) + t1 = TransformContainer(transform=second_valid_transform) + t2 = TransformContainer(transform=informative_transform) + program = TransformProgram([t0, t1, t2]) + + assert program[0] == t0 + assert program[1] == t1 + assert program[2] == t2 + + assert program[:2] == TransformProgram([t0, t1]) + assert program[::-1] == TransformProgram([t2, t1, t0]) + + def test_contains(self): + """Test that we can check whether a transform or transform container exists in a transform.""" + + t0 = TransformContainer(transform=first_valid_transform) + t1 = TransformContainer(transform=second_valid_transform) + t2 = TransformContainer(transform=informative_transform) + program = TransformProgram([t0, t1, t2]) + + assert t0 in program + assert t1 in program + assert t2 in program + assert qml.compile not in program + + assert t0 in program + assert t1 in program + assert t2 in program + + t_not = TransformContainer(transform=qml.compile) + assert t_not not in program + + assert "a" not in program + def test_add_single_programs(self): """Test adding two transform programs""" transform_program1 = TransformProgram() @@ -272,6 +309,35 @@ def test_repr_program(self): + ")" ) + def test_equality(self): + """Tests that we can compare TransformProgram objects with the '==' and '!=' operators.""" + t1 = TransformContainer( + qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 1} + ) + t2 = TransformContainer( + qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 1} + ) + t3 = TransformContainer( + qml.transforms.transpile.transform, kwargs={"coupling_map": [(0, 1), (1, 2)]} + ) + + p1 = TransformProgram([t1, t3]) + p2 = TransformProgram([t2, t3]) + p3 = TransformProgram([t3, t2]) + + # test for equality of identical objects + assert p1 == p2 + # test for inequality of different objects + assert p1 != p3 + assert p1 != t1 + + # Test inequality with different transforms + t4 = TransformContainer( + qml.transforms.transpile.transform, kwargs={"coupling_map": [(0, 1), (2, 3)]} + ) + p4 = TransformProgram([t1, t4]) + assert p1 != p4 + class TestTransformProgram: """Test the transform program class and its method.""" @@ -465,52 +531,36 @@ def test_insert_transform_with_expand(self): assert transform_program[1].transform is first_valid_transform def test_valid_transforms(self): - """Test that it is only possible to create valid transforms.""" + """Test adding transforms to a program with a terminal transform.""" transform_program = TransformProgram() transform1 = TransformContainer(transform=first_valid_transform, is_informative=True) transform_program.push_back(transform1) + t_normal = TransformContainer(transform=second_valid_transform) + transform_program.push_back(t_normal) + print(transform_program) + assert len(transform_program) == 2 + assert transform_program[0] is t_normal + assert transform_program[1] is transform1 + + t_normal2 = TransformContainer(transform=first_valid_transform) + transform_program.push_back(t_normal2) + assert transform_program[0] is t_normal + assert transform_program[1] is t_normal2 + assert transform_program[2] is transform1 + with pytest.raises( TransformError, match="The transform program already has a terminal transform." ): transform_program.push_back(transform1) - transform2 = TransformContainer(transform=second_valid_transform, is_informative=False) + transform2 = TransformContainer(transform=second_valid_transform, final_transform=True) with pytest.raises( TransformError, match="The transform program already has a terminal transform." ): transform_program.push_back(transform2) - def test_equality(self): - """Tests that we can compare TransformProgram objects with the '==' and '!=' operators.""" - t1 = TransformContainer( - qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 1} - ) - t2 = TransformContainer( - qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 1} - ) - t3 = TransformContainer( - qml.transforms.transpile.transform, kwargs={"coupling_map": [(0, 1), (1, 2)]} - ) - - p1 = TransformProgram([t1, t3]) - p2 = TransformProgram([t2, t3]) - p3 = TransformProgram([t3, t2]) - - # test for equality of identical objects - assert p1 == p2 - # test for inequality of different objects - assert p1 != p3 - assert p1 != t1 - - # Test inequality with different transforms - t4 = TransformContainer( - qml.transforms.transpile.transform, kwargs={"coupling_map": [(0, 1), (2, 3)]} - ) - p4 = TransformProgram([t1, t4]) - assert p1 != p4 - class TestTransformProgramCall: """Tests for calling a TransformProgram on a batch of quantum tapes.""" diff --git a/tests/workflow/test_construct_batch.py b/tests/workflow/test_construct_batch.py new file mode 100644 index 00000000000..bc6113d5b41 --- /dev/null +++ b/tests/workflow/test_construct_batch.py @@ -0,0 +1,436 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Contains tests for the `qml.workflow.get_transform_program` getter and `construct_batch`. + +""" +from functools import partial +import pytest + +import numpy as np + +import pennylane as qml +from pennylane.transforms.core.transform_dispatcher import TransformContainer +from pennylane.transforms.core.transform_program import TransformProgram +from pennylane.workflow import get_transform_program, construct_batch +from pennylane.workflow.construct_batch import expand_fn_transform + + +def test_expand_fn_transform(): + """Tests the expand_fn_transform.""" + + def my_expand_fn(tape, op1, op2=qml.S(0), op3=qml.S(0)): + """my docstring.""" + return qml.tape.QuantumScript( + tape.operations + [op1, op2, op3], tape.measurements, tape.shots + ) + + t = expand_fn_transform(my_expand_fn) + + assert isinstance(t, qml.transforms.core.TransformDispatcher) + tape = qml.tape.QuantumScript([qml.S(0)], [qml.expval(qml.PauliZ(0))], shots=50) + + batch, fn = t(tape, qml.PauliX(0), op3=qml.T(0)) + assert len(batch) == 1 + expected = qml.tape.QuantumScript( + [qml.S(0), qml.PauliX(0), qml.S(0), qml.T(0)], [qml.expval(qml.PauliZ(0))], shots=50 + ) + assert qml.equal(batch[0], expected) + assert fn(("a",)) == "a" + + assert repr(t) == "" + assert t.__doc__ == "my docstring." + + +class TestTransformProgramGetter: + def test_bad_string_key(self): + """Test a value error is raised if a bad string key is provided.""" + + @qml.qnode(qml.device("default.qubit")) + def circuit(): + return qml.state() + + with pytest.raises(ValueError, match=r"level bah not recognized."): + get_transform_program(circuit, level="bah") + + def test_get_transform_program_gradient_fn_transform(self): + """Tests for the transform program when the gradient_fn is a transform.""" + + dev = qml.device("default.qubit", wires=4) + + @partial(qml.transforms.compile, num_passes=2) + @partial(qml.transforms.merge_rotations, atol=1e-5) + @qml.transforms.cancel_inverses + @qml.qnode(dev, diff_method="parameter-shift", shifts=2) + def circuit(): + return qml.expval(qml.PauliZ(0)) + + expected_p0 = TransformContainer(qml.transforms.cancel_inverses.transform) + expected_p1 = TransformContainer( + qml.transforms.merge_rotations.transform, kwargs={"atol": 1e-5} + ) + expected_p2 = TransformContainer(qml.transforms.compile.transform, kwargs={"num_passes": 2}) + + ps_expand_fn = TransformContainer( + qml.gradients.param_shift.expand_transform, kwargs={"shifts": 2} + ) + + p0 = get_transform_program(circuit, level=0) + assert isinstance(p0, TransformProgram) + assert len(p0) == 0 + + p0 = get_transform_program(circuit, level="top") + assert isinstance(p0, TransformProgram) + assert len(p0) == 0 + + p_grad = get_transform_program(circuit, level="gradient") + assert isinstance(p_grad, TransformProgram) + assert len(p_grad) == 4 + assert p_grad == TransformProgram([expected_p0, expected_p1, expected_p2, ps_expand_fn]) + + p_dev = get_transform_program(circuit, level="device") + assert isinstance(p_grad, TransformProgram) + p_default = get_transform_program(circuit) + p_none = get_transform_program(circuit, None) + assert p_dev == p_default + assert p_none == p_dev + assert len(p_dev) == 9 + assert p_dev == p_grad + dev.preprocess()[0] + + # slicing + p_sliced = get_transform_program(circuit, slice(2, 7, 2)) + assert len(p_sliced) == 3 + assert p_sliced[0].transform == qml.compile.transform + assert p_sliced[1].transform == qml.devices.preprocess.validate_device_wires.transform + assert p_sliced[2].transform == qml.devices.preprocess.decompose.transform + + def test_gradient_fn_device_gradient(self): + """Test that if level="gradient" but the gradient does not have preprocessing, the program is strictly user transforms.""" + + @qml.transforms.cancel_inverses + @qml.qnode(qml.device("default.qubit"), diff_method="backprop") + def circuit(): + return qml.state() + + prog = get_transform_program(circuit, level="gradient") + assert len(prog) == 1 + assert qml.transforms.cancel_inverses in prog + + def test_get_transform_program_device_gradient(self): + """Test the trnsform program contents when using a device derivative.""" + + dev = qml.device("default.qubit") + + @qml.transforms.sum_expand + @qml.qnode(dev, diff_method="adjoint", device_vjp=False) + def circuit(x): + qml.RX(x, 0) + return qml.expval(qml.PauliZ(0)) + + full_prog = get_transform_program(circuit) + assert len(full_prog) == 13 + + config = qml.devices.ExecutionConfig( + gradient_method="adjoint", use_device_jacobian_product=False + ) + dev_program = dev.preprocess(config)[0] + + expected = TransformProgram() + expected.add_transform(qml.transforms.sum_expand) + expected += dev_program + assert full_prog == expected + + def test_get_transform_program_legacy_device_interface(self): + """Test the contents of the transform program with the legacy device interface.""" + + dev = qml.device("default.qubit.legacy", wires=5) + + @qml.transforms.merge_rotations + @qml.qnode(dev, diff_method="backprop") + def circuit(x): + qml.RX(x, wires=0) + return qml.expval(qml.PauliZ(0)) + + program = get_transform_program(circuit) + + m1 = TransformContainer(qml.transforms.merge_rotations.transform) + m2 = TransformContainer(dev.batch_transform) + assert program[0:2] == TransformProgram([m1, m2]) + + # a little hard to check the contents of a expand_fn_transform + # this is the best proxy I can find + assert program[2].transform.__wrapped__ == dev.expand_fn + + def test_get_transform_program_final_transform(self): + """Test that gradient preprocessing and device transform occur before a final transform.""" + + @qml.metric_tensor + @qml.compile + @qml.qnode(qml.device("default.qubit"), diff_method="parameter-shift") + def circuit(): + qml.IsingXX(1.234, wires=(0, 1)) + return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliX(0)) + + user_program = get_transform_program(circuit, level="user") + assert len(user_program) == 2 + assert user_program[0].transform == qml.compile.transform + assert user_program[1].transform == qml.metric_tensor.expand_transform + + grad_program = get_transform_program(circuit, level="gradient") + assert len(grad_program) == 3 + assert grad_program[0].transform == qml.compile.transform + assert grad_program[1].transform == qml.metric_tensor.expand_transform + assert grad_program[2].transform == qml.gradients.param_shift.expand_transform + + dev_program = get_transform_program(circuit, level="device") + assert len(dev_program) == 3 + len(circuit.device.preprocess()[0]) # currently 8 + assert qml.metric_tensor not in dev_program + + full = get_transform_program(circuit) + assert full[-1].transform == qml.metric_tensor.transform + + +@qml.transforms.merge_rotations +@qml.transforms.cancel_inverses +@qml.qnode(qml.device("default.qubit"), diff_method="parameter-shift") +def circuit1(weights, order): + qml.RandomLayers(weights, wires=(0, 1)) + qml.Permute(order, wires=(0, 1, 2)) + qml.PauliX(0) + qml.PauliX(0) + qml.RX(0.1, wires=0) + qml.RX(-0.1, wires=0) + return qml.expval(qml.PauliX(0)) + + +class TestConstructBatch: + """Tests for the construct_batch function.""" + + def test_level_zero(self): + """Test that level zero is purely the queued circuit.""" + + order = [2, 1, 0] + weights = np.array([[1.0, 20]]) + batch, fn = construct_batch(circuit1, level=0)(weights, order, shots=10) + + assert len(batch) == 1 + expected_ops = [ + qml.RandomLayers(weights, wires=(0, 1)), + qml.Permute(order, wires=(0, 1, 2)), + qml.PauliX(0), + qml.PauliX(0), + qml.RX(0.1, wires=0), + qml.RX(-0.1, wires=0), + ] + + expected = qml.tape.QuantumScript( + expected_ops, + [qml.expval(qml.PauliX(0))], + shots=10, + ) + assert qml.equal(batch[0], expected) + + assert fn(("a",)) == ("a",) + + def test_first_transform(self): + """Test that the first user transform can be selected by level=1""" + + weights = np.array([[1.0, 2.0]]) + order = [2, 1, 0] + + batch, fn = construct_batch(circuit1, level=1)(weights, order=order, shots=50) + assert len(batch) == 1 + + expected_ops = [ + qml.RandomLayers(weights, wires=(0, 1)), + qml.Permute(order, wires=(0, 1, 2)), + # cancel inverses + qml.RX(0.1, wires=0), + qml.RX(-0.1, wires=0), + ] + + expected = qml.tape.QuantumScript(expected_ops, [qml.expval(qml.PauliX(0))], shots=50) + assert qml.equal(batch[0], expected) + assert fn(("a",)) == ("a",) + + @pytest.mark.parametrize("level", (2, "user")) + def test_all_user_transforms(self, level): + """Test that all user transforms can be selected and run.""" + + weights = np.array([[1.0, 2.0]]) + order = [2, 1, 0] + + batch, fn = construct_batch(circuit1, level=level)(weights, order=order, shots=50) + assert len(batch) == 1 + + expected_ops = [ + qml.RandomLayers(weights, wires=(0, 1)), + qml.Permute(order, wires=(0, 1, 2)), + # cancel inverses + # merge rotations + ] + + expected = qml.tape.QuantumScript(expected_ops, [qml.expval(qml.PauliX(0))], shots=50) + assert qml.equal(batch[0], expected) + assert fn(("a",)) == ("a",) + + @pytest.mark.parametrize("level", (3, "gradient")) + def test_gradient_transforms(self, level): + """Test that the gradient transform can be selected with an integer or keyword.""" + weights = qml.numpy.array([[1.0, 2.0]], requires_grad=True) + order = [2, 1, 0] + batch, fn = construct_batch(circuit1, level=level)(weights=weights, order=order) + + expected = qml.tape.QuantumScript( + [ + qml.RY(qml.numpy.array(1), 0), + qml.RX(qml.numpy.array(2), 1), + qml.Permute(order, (0, 1, 2)), + ], + [qml.expval(qml.PauliX(0))], + ) + assert qml.equal(batch[0], expected) + assert len(batch) == 1 + assert fn(("a",)) == ("a",) + + @pytest.mark.parametrize("level", ("device", None)) + def test_device_transforms(self, level): + """Test that all device transforms can be run with the device keyword.""" + + weights = np.array([[1.0, 2.0]]) + order = [2, 1, 0] + + batch, fn = construct_batch(circuit1, level=level)(weights, order) + + expected = qml.tape.QuantumScript( + [qml.RY(1, 0), qml.RX(2, 1), qml.SWAP((0, 2))], [qml.expval(qml.PauliX(0))] + ) + assert qml.equal(batch[0], expected) + assert len(batch) == 1 + assert fn(("a",)) == ("a",) + + @pytest.mark.parametrize("level", ("device", None)) + def test_device_transforms_legacy_interface(self, level): + """Test that the device transforms can be selected with level=device or None.""" + + @qml.transforms.merge_rotations + @qml.qnode(qml.device("default.qubit.legacy", wires=2, shots=50)) + def circuit(order): + qml.Permute(order, wires=(0, 1, 2)) + qml.RX(0.5, wires=0) + qml.RX(-0.5, wires=0) + return qml.expval(qml.PauliX(0) + qml.PauliY(0)) + + batch, fn = construct_batch(circuit, level=level)((2, 1, 0)) + + expected0 = qml.tape.QuantumScript( + [qml.SWAP((0, 2))], [qml.expval(qml.PauliX(0))], shots=50 + ) + assert qml.equal(expected0, batch[0]) + expected1 = qml.tape.QuantumScript( + [qml.SWAP((0, 2))], [qml.expval(qml.PauliY(0))], shots=50 + ) + assert qml.equal(expected1, batch[1]) + assert len(batch) == 2 + + assert fn((1.0, 2.0)) == (3.0,) + + def test_final_transform(self): + """Test that the final transform is included when level=None.""" + + @qml.gradients.param_shift + @qml.transforms.merge_rotations + @qml.qnode(qml.device("default.qubit")) + def circuit(x): + qml.RX(x, 0) + qml.RX(x, 0) + return qml.expval(qml.PauliZ(0)) + + batch, fn = construct_batch(circuit, level=None)(0.5) + assert len(batch) == 2 + expected0 = qml.tape.QuantumScript( + [qml.RX(1.0 + np.pi / 2, 0)], [qml.expval(qml.PauliZ(0))] + ) + assert qml.equal(batch[0], expected0) + expected1 = qml.tape.QuantumScript( + [qml.RX(1.0 - np.pi / 2, 0)], [qml.expval(qml.PauliZ(0))] + ) + assert qml.equal(batch[1], expected1) + + dummy_res = (1.0, 2.0) + expected_res = (1.0 - 2.0) / 2 + assert qml.numpy.allclose(fn(dummy_res)[0], expected_res) + + def test_user_transform_multiple_tapes(self): + """Test a user transform that creates multiple tapes.""" + + @qml.transforms.split_non_commuting + @qml.qnode(qml.device("default.qubit", shots=10)) + def circuit(): + qml.S(0) + return qml.expval(qml.PauliX(0)), qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliX(1)) + + batch, fn = construct_batch(circuit, level="user")() + + assert len(batch) == 2 + expected0 = qml.tape.QuantumScript( + [qml.S(0)], [qml.expval(qml.PauliX(0)), qml.expval(qml.PauliX(1))], shots=10 + ) + assert qml.equal(expected0, batch[0]) + + expected1 = qml.tape.QuantumScript([qml.S(0)], [qml.expval(qml.PauliZ(0))], shots=10) + assert qml.equal(expected1, batch[1]) + + dummy_res = (("x0", "x1"), "z0") + expected_res = (("x0", "z0", "x1"),) + assert fn(dummy_res) == expected_res + + def test_slicing_level(self): + """Test that the level can be a slice.""" + + @qml.transforms.merge_rotations + @qml.qnode(qml.device("default.qubit")) + def circuit(x): + qml.RX(x, 0) + qml.RX(x, 0) + return qml.expval(qml.PauliZ(0)) + + # by slicing starting at one, we do not run the merge rotations transform + batch, fn = construct_batch(circuit, slice(1, None))(0.5) + + assert len(batch) == 1 + expected = qml.tape.QuantumScript( + [qml.RX(0.5, 0), qml.RX(0.5, 0)], [qml.expval(qml.PauliZ(0))] + ) + assert qml.equal(batch[0], expected) + assert fn(("a",)) == ("a",) + + def test_qfunc_with_shots_arg(self): + """Test that the tape uses device shots only when qfunc has a shots kwarg""" + + dev = qml.device("default.qubit", shots=100) + + @qml.qnode(dev) + def circuit(shots): + for _ in range(shots): + qml.S(0) + return qml.expval(qml.PauliZ(0)) + + batch, fn = construct_batch(circuit, level=None)(shots=2) + assert len(batch) == 1 + expected = qml.tape.QuantumScript( + [qml.S(0), qml.S(0)], [qml.expval(qml.PauliZ(0))], shots=100 + ) + assert qml.equal(batch[0], expected) + assert fn(("a",)) == ("a",)