diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 81b0a21d144..a67925c4ad6 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -27,6 +27,9 @@ * The `qml.Adder` and `qml.PhaseAdder` templates are added to perform in-place modular addition. [(#6109)](https://github.com/PennyLaneAI/pennylane/pull/6109) +* The `qml.Multiplier` and `qml.OutMultiplier` templates are added to perform modular multiplication. + [(#6112)](https://github.com/PennyLaneAI/pennylane/pull/6112) +

Creating spin Hamiltonians 🧑‍🎨

* The function ``transverse_ising`` is added to generate transverse-field Ising Hamiltonian. diff --git a/pennylane/templates/subroutines/__init__.py b/pennylane/templates/subroutines/__init__.py index 2c60d72bbf4..5b3ec3fc887 100644 --- a/pennylane/templates/subroutines/__init__.py +++ b/pennylane/templates/subroutines/__init__.py @@ -47,3 +47,5 @@ from .qrom import QROM from .phase_adder import PhaseAdder from .adder import Adder +from .multiplier import Multiplier +from .out_multiplier import OutMultiplier diff --git a/pennylane/templates/subroutines/adder.py b/pennylane/templates/subroutines/adder.py index 0d4a776dc7c..32a621c4985 100644 --- a/pennylane/templates/subroutines/adder.py +++ b/pennylane/templates/subroutines/adder.py @@ -34,9 +34,8 @@ class Adder(Operation): .. note:: - Note that :math:`x` must be smaller than :math:`mod` to get the correct result. Also, when - :math:`mod \neq 2^{\text{len(x\_wires)}}` we need :math:`x < 2^{\text{len(x\_wires)}}/2`, - which means that we need one extra wire in ``x_wires``. + Note that :math:`x` must be smaller than :math:`mod` to get the correct result. + Args: k (int): the number that needs to be added diff --git a/pennylane/templates/subroutines/multiplier.py b/pennylane/templates/subroutines/multiplier.py new file mode 100644 index 00000000000..dbc3ff96325 --- /dev/null +++ b/pennylane/templates/subroutines/multiplier.py @@ -0,0 +1,197 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Contains the Multiplier template. +""" + +import numpy as np + +import pennylane as qml +from pennylane.operation import Operation + + +def _mul_out_k_mod(k, x_wires, mod, work_wire_aux, wires_aux): + """Performs :math:`x \times k` in the registers wires wires_aux""" + op_list = [] + + op_list.append(qml.QFT(wires=wires_aux)) + op_list.append( + qml.ControlledSequence(qml.PhaseAdder(k, wires_aux, mod, work_wire_aux), control=x_wires) + ) + op_list.append(qml.adjoint(qml.QFT(wires=wires_aux))) + return op_list + + +class Multiplier(Operation): + r"""Performs the in-place modular multiplication operation. + + This operator performs the modular multiplication by an integer :math:`k` modulo :math:`mod` in + the computational basis: + + .. math:: + + \text{Multiplier}(k,mod) |x \rangle = | x \cdot k \; \text{modulo} \; \text{mod} \rangle. + + The implementation is based on the quantum Fourier transform method presented in + `arXiv:2311.08555 `_. + + .. note:: + + Note that :math:`x` must be smaller than :math:`mod` to get the correct result. Also, it + is required that :math:`k` has inverse, :math:`k^-1`, modulo :math:`mod`. That means + :math:`k*k^-1 modulo mod is equal to 1`, which will only be possible if :math:`k` and + :math:`mod` are coprime. Furthermore, if :math:`mod \neq 2^{len(x\_wires)}`, two more + auxiliaries must be added. + + Args: + k (int): the number that needs to be multiplied + x_wires (Sequence[int]): the wires the operation acts on + mod (int): the modulus for performing the multiplication, default value is :math:`2^{len(x\_wires)}` + work_wires (Sequence[int]): the auxiliary wires to be used for performing the multiplication + + **Example** + + This example performs the multiplication of two integers :math:`x=3` and :math:`k=4` modulo :math:`mod=7`. + + .. code-block:: + + x = 3 + k = 4 + mod = 7 + + x_wires =[0,1,2] + work_wires=[3,4,5,6,7] + + dev = qml.device("default.qubit", shots=1) + @qml.qnode(dev) + def circuit(x, k, mod, wires_m, work_wires): + qml.BasisEmbedding(x, wires=wires_m) + qml.Multiplier(k, x_wires, mod, work_wires) + return qml.sample(wires=wires_m) + + .. code-block:: pycon + + >>> print(circuit(x, k, mod, x_wires, work_wires)) + [1 0 1] + + The result :math:`[1 0 1]`, is the ket representation of + :math:`3 \cdot 4 \, \text{modulo} \, 12 = 5`. + """ + + grad_method = None + + def __init__( + self, k, x_wires, mod=None, work_wires=None, id=None + ): # pylint: disable=too-many-arguments + if any(wire in work_wires for wire in x_wires): + raise ValueError("None of the wire in work_wires should be included in x_wires.") + + if mod is None: + mod = 2 ** len(x_wires) + if mod != 2 ** len(x_wires) and len(work_wires) < (len(x_wires) + 2): + raise ValueError("Multiplier needs as many work_wires as x_wires plus two.") + if len(work_wires) < len(x_wires): + raise ValueError("Multiplier needs as many work_wires as x_wires.") + if (not hasattr(x_wires, "__len__")) or (mod > 2 ** len(x_wires)): + raise ValueError("Multiplier must have enough wires to represent mod.") + + k = k % mod + if np.gcd(k, mod) != 1: + raise ValueError("The operator cannot be built because k has no inverse modulo mod.") + + self.hyperparameters["k"] = k + self.hyperparameters["mod"] = mod + self.hyperparameters["work_wires"] = qml.wires.Wires(work_wires) + self.hyperparameters["x_wires"] = qml.wires.Wires(x_wires) + all_wires = qml.wires.Wires(x_wires) + qml.wires.Wires(work_wires) + super().__init__(wires=all_wires, id=id) + + @property + def num_params(self): + return 0 + + def _flatten(self): + metadata = tuple((key, value) for key, value in self.hyperparameters.items()) + return tuple(), metadata + + @classmethod + def _unflatten(cls, data, metadata): + hyperparams_dict = dict(metadata) + return cls(**hyperparams_dict) + + def map_wires(self, wire_map: dict): + new_dict = { + key: [wire_map.get(w, w) for w in self.hyperparameters[key]] + for key in ["x_wires", "work_wires"] + } + + return Multiplier( + self.hyperparameters["k"], + new_dict["x_wires"], + self.hyperparameters["mod"], + new_dict["work_wires"], + ) + + @property + def wires(self): + """All wires involved in the operation.""" + return self.hyperparameters["x_wires"] + self.hyperparameters["work_wires"] + + def decomposition(self): # pylint: disable=arguments-differ + return self.compute_decomposition(**self.hyperparameters) + + @classmethod + def _primitive_bind_call(cls, *args, **kwargs): + return cls._primitive.bind(*args, **kwargs) + + @staticmethod + def compute_decomposition(k, x_wires, mod, work_wires): # pylint: disable=arguments-differ + r"""Representation of the operator as a product of other operators. + Args: + k (int): the number that needs to be multiplied + x_wires (Sequence[int]): the wires the operation acts on + mod (int): the modulus for performing the multiplication, default value is :math:`2^{len(x\_wires)}` + work_wires (Sequence[int]): the auxiliary wires to be used for performing the multiplication + Returns: + list[.Operator]: Decomposition of the operator + + **Example** + + >>> qml.Multiplier.compute_decomposition(k=3, mod=8, x_wires=[0,1,2], work_wires=[3,4,5]) + [QFT(wires=[3, 4, 5]), + ControlledSequence(PhaseAdder(wires=[3, 4 , 5 , None]), control=[0, 1, 2]), + Adjoint(QFT(wires=[3, 4, 5])), + SWAP(wires=[0, 3]), + SWAP(wires=[1, 4]), + SWAP(wires=[2, 5]), + Adjoint(Adjoint(QFT(wires=[3, 4, 5]))), + Adjoint(ControlledSequence(PhaseAdder(wires=[3, 4, 5, None]), control=[0, 1, 2])), + Adjoint(QFT(wires=[3, 4, 5]))] + """ + + op_list = [] + if mod != 2 ** len(x_wires): + work_wire_aux = work_wires[:1] + wires_aux = work_wires[1:] + wires_aux_swap = wires_aux[1:] + else: + work_wire_aux = None + wires_aux = work_wires[: len(x_wires)] + wires_aux_swap = wires_aux + op_list.extend(_mul_out_k_mod(k, x_wires, mod, work_wire_aux, wires_aux)) + for x_wire, aux_wire in zip(x_wires, wires_aux_swap): + op_list.append(qml.SWAP(wires=[x_wire, aux_wire])) + inv_k = pow(k, -1, mod) + op_list.extend(qml.adjoint(_mul_out_k_mod)(inv_k, x_wires, mod, work_wire_aux, wires_aux)) + return op_list diff --git a/pennylane/templates/subroutines/out_multiplier.py b/pennylane/templates/subroutines/out_multiplier.py new file mode 100644 index 00000000000..009763e51e2 --- /dev/null +++ b/pennylane/templates/subroutines/out_multiplier.py @@ -0,0 +1,196 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Contains the OutMultiplier template. +""" + +import pennylane as qml +from pennylane.operation import Operation + + +class OutMultiplier(Operation): + r"""Performs the out-place modular multiplication operation. + + This operator performs the modular multiplication of integers :math:`x` and :math:`y` modulo + :math:`mod` in the computational basis: + + .. math:: + \text{OutMultiplier}(mod) |x \rangle |y \rangle |b \rangle = |x \rangle |y \rangle |b + x \cdot y \; \text{modulo} \; mod \rangle, + + The implementation is based on the quantum Fourier transform method presented in + `arXiv:2311.08555 `_. + + .. note:: + + Note that :math:`x` and :math:`y` must be smaller than :math:`mod` to get the correct result. + + Args: + x_wires (Sequence[int]): the wires that store the integer :math:`x` + y_wires (Sequence[int]): the wires that store the integer :math:`y` + output_wires (Sequence[int]): the wires that store the multiplication result + mod (int): the modulus for performing the multiplication, default value is :math:`2^{len(output\_wires)}` + work_wires (Sequence[int]): the auxiliary wires to use for the multiplication modulo + + **Example** + + This example performs the multiplication of two integers :math:`x=2` and :math:`y=7` modulo :math:`mod=12`. + + .. code-block:: + + x = 2 + y = 7 + mod = 12 + + x_wires = [0, 1] + y_wires = [2, 3, 4] + output_wires = [6, 7, 8, 9] + work_wires = [5, 10] + + dev = qml.device("default.qubit", shots=1) + @qml.qnode(dev) + def circuit(): + qml.BasisEmbedding(x, wires=x_wires) + qml.BasisEmbedding(y, wires=y_wires) + qml.OutMultiplier(x_wires, y_wires, output_wires, mod, work_wires) + return qml.sample(wires=output_wires) + + .. code-block:: pycon + + >>> print(circuit()) + [0 0 1 0] + + The result :math:`[0 0 1 0]`, is the ket representation of + :math:`2 \cdot 7 \, \text{modulo} \, 12 = 2`. + """ + + grad_method = None + + def __init__( + self, x_wires, y_wires, output_wires, mod=None, work_wires=None, id=None + ): # pylint: disable=too-many-arguments + + if mod is None: + mod = 2 ** len(output_wires) + if mod != 2 ** len(output_wires) and work_wires is None: + raise ValueError( + f"If mod is not 2^{len(output_wires)}, two work wires should be provided." + ) + if (not hasattr(output_wires, "__len__")) or (mod > 2 ** (len(output_wires))): + raise ValueError("OutMultiplier must have enough wires to represent mod.") + + if work_wires is not None: + if any(wire in work_wires for wire in x_wires): + raise ValueError("None of the wires in work_wires should be included in x_wires.") + if any(wire in work_wires for wire in y_wires): + raise ValueError("None of the wires in work_wires should be included in y_wires.") + + if any(wire in y_wires for wire in x_wires): + raise ValueError("None of the wires in y_wires should be included in x_wires.") + if any(wire in x_wires for wire in output_wires): + raise ValueError("None of the wires in x_wires should be included in output_wires.") + if any(wire in y_wires for wire in output_wires): + raise ValueError("None of the wires in y_wires should be included in output_wires.") + + wires_list = ["x_wires", "y_wires", "output_wires", "work_wires"] + + for key in wires_list: + self.hyperparameters[key] = qml.wires.Wires(locals()[key]) + self.hyperparameters["mod"] = mod + all_wires = sum(self.hyperparameters[key] for key in wires_list) + super().__init__(wires=all_wires, id=id) + + @property + def num_params(self): + return 0 + + def _flatten(self): + metadata = tuple((key, value) for key, value in self.hyperparameters.items()) + return tuple(), metadata + + @classmethod + def _unflatten(cls, data, metadata): + hyperparams_dict = dict(metadata) + return cls(**hyperparams_dict) + + def map_wires(self, wire_map: dict): + new_dict = { + key: [wire_map.get(w, w) for w in self.hyperparameters[key]] + for key in ["x_wires", "y_wires", "output_wires", "work_wires"] + } + + return OutMultiplier( + new_dict["x_wires"], + new_dict["y_wires"], + new_dict["output_wires"], + self.hyperparameters["mod"], + new_dict["work_wires"], + ) + + @property + def wires(self): + """All wires involved in the operation.""" + return ( + self.hyperparameters["x_wires"] + + self.hyperparameters["y_wires"] + + self.hyperparameters["output_wires"] + + self.hyperparameters["work_wires"] + ) + + def decomposition(self): # pylint: disable=arguments-differ + return self.compute_decomposition(**self.hyperparameters) + + @classmethod + def _primitive_bind_call(cls, *args, **kwargs): + return cls._primitive.bind(*args, **kwargs) + + @staticmethod + def compute_decomposition( + x_wires, y_wires, output_wires, mod, work_wires + ): # pylint: disable=arguments-differ + r"""Representation of the operator as a product of other operators. + Args: + x_wires (Sequence[int]): the wires that store the integer :math:`x` + y_wires (Sequence[int]): the wires that store the integer :math:`y` + output_wires (Sequence[int]): the wires that store the multiplication result + mod (int): the modulus for performing the multiplication, default value is :math:`2^{len(output\_wires)}` + work_wires (Sequence[int]): the auxiliary wires to use for the multiplication modulo + Returns: + list[.Operator]: Decomposition of the operator + + **Example** + + >>> qml.OutMultiplier.compute_decomposition(x_wires=[0,1], y_wires=[2,3], output_wires=[5,6], mod=4, work_wires=[4,7]) + [QFT(wires=[5, 6]), + ControlledSequence(ControlledSequence(PhaseAdder(wires=[5, 6]), control=[0, 1]), control=[2, 3]), + Adjoint(QFT(wires=[5, 6]))] + """ + op_list = [] + if mod != 2 ** len(output_wires): + qft_output_wires = work_wires[:1] + output_wires + work_wire = work_wires[1:] + else: + qft_output_wires = output_wires + work_wire = None + op_list.append(qml.QFT(wires=qft_output_wires)) + op_list.append( + qml.ControlledSequence( + qml.ControlledSequence( + qml.PhaseAdder(1, qft_output_wires, mod, work_wire), control=x_wires + ), + control=y_wires, + ) + ) + op_list.append(qml.adjoint(qml.QFT)(wires=qft_output_wires)) + + return op_list diff --git a/tests/capture/test_templates.py b/tests/capture/test_templates.py index 787ad612281..cf202b3b087 100644 --- a/tests/capture/test_templates.py +++ b/tests/capture/test_templates.py @@ -259,6 +259,8 @@ def fn(*args): qml.QROM, qml.PhaseAdder, qml.Adder, + qml.Multiplier, + qml.OutMultiplier, ] @@ -758,6 +760,77 @@ def qfunc(): assert len(q) == 1 qml.assert_equal(q.queue[0], qml.Adder(**kwargs)) + @pytest.mark.usefixtures("new_opmath_only") + def test_multiplier(self): + """Test the primitive bind call of Multiplier.""" + + kwargs = { + "k": 3, + "x_wires": [0, 1], + "mod": None, + "work_wires": [2, 3], + } + + def qfunc(): + qml.Multiplier(**kwargs) + + # Validate inputs + qfunc() + + # Actually test primitive bind + jaxpr = jax.make_jaxpr(qfunc)() + + assert len(jaxpr.eqns) == 1 + + eqn = jaxpr.eqns[0] + assert eqn.primitive == qml.Multiplier._primitive + assert eqn.invars == jaxpr.jaxpr.invars + assert eqn.params == kwargs + assert len(eqn.outvars) == 1 + assert isinstance(eqn.outvars[0], jax.core.DropVar) + + with qml.queuing.AnnotatedQueue() as q: + jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts) + + assert len(q) == 1 + qml.assert_equal(q.queue[0], qml.Multiplier(**kwargs)) + + @pytest.mark.usefixtures("new_opmath_only") + def test_out_multiplier(self): + """Test the primitive bind call of OutMultiplier.""" + + kwargs = { + "x_wires": [0, 1], + "y_wires": [2, 3], + "output_wires": [4, 5], + "mod": None, + "work_wires": None, + } + + def qfunc(): + qml.OutMultiplier(**kwargs) + + # Validate inputs + qfunc() + + # Actually test primitive bind + jaxpr = jax.make_jaxpr(qfunc)() + + assert len(jaxpr.eqns) == 1 + + eqn = jaxpr.eqns[0] + assert eqn.primitive == qml.OutMultiplier._primitive + assert eqn.invars == jaxpr.jaxpr.invars + assert eqn.params == kwargs + assert len(eqn.outvars) == 1 + assert isinstance(eqn.outvars[0], jax.core.DropVar) + + with qml.queuing.AnnotatedQueue() as q: + jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts) + + assert len(q) == 1 + qml.assert_equal(q.queue[0], qml.OutMultiplier(**kwargs)) + @pytest.mark.parametrize( "template, kwargs", [ diff --git a/tests/templates/test_subroutines/test_multiplier.py b/tests/templates/test_subroutines/test_multiplier.py new file mode 100644 index 00000000000..4cb9cadee5f --- /dev/null +++ b/tests/templates/test_subroutines/test_multiplier.py @@ -0,0 +1,202 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for the Multiplier template. +""" + +import numpy as np +import pytest + +import pennylane as qml +from pennylane.templates.subroutines.multiplier import _mul_out_k_mod + + +def test_standard_validity_Multiplier(): + """Check the operation using the assert_valid function.""" + k = 6 + mod = 11 + x_wires = [0, 1, 2, 3] + work_wires = [4, 5, 6, 7, 8, 9] + op = qml.Multiplier(k, x_wires=x_wires, mod=mod, work_wires=work_wires) + qml.ops.functions.assert_valid(op) + + +def test_mul_out_k_mod(): + """Test the _mul_out_k_mod function.""" + + op = _mul_out_k_mod(2, [0, 1], 4, None, [4, 5]) + assert op[0].name == "QFT" + assert op[1].name == "ControlledSequence" + assert op[2].name == "Adjoint(QFT)" + print(op[1].base) + assert qml.equal(op[1].base, qml.PhaseAdder(2, x_wires=[4, 5])) + + +class TestMultiplier: + """Test the qml.Multiplier template.""" + + @pytest.mark.parametrize( + ("k", "x_wires", "mod", "work_wires", "x"), + [ + ( + 5, + [0, 1, 2], + 8, + [4, 5, 6, 7, 8], + 3, + ), + ( + 1, + [0, 1, 2], + 3, + [3, 4, 5, 6, 7], + 2, + ), + ( + -12, + [0, 1, 2, 3, 4], + 23, + [5, 6, 7, 8, 9, 10, 11], + 1, + ), + ( + 5, + [0, 1, 2, 3, 4], + None, + [5, 6, 7, 8, 9, 10, 11], + 0, + ), + ( + 5, + [0, 1, 2, 3, 4], + None, + [5, 6, 7, 8, 9], + 1, + ), + ], + ) + def test_operation_result( + self, k, x_wires, mod, work_wires, x + ): # pylint: disable=too-many-arguments + """Test the correctness of the Multiplier template output.""" + dev = qml.device("default.qubit", shots=1) + + @qml.qnode(dev) + def circuit(x): + qml.BasisEmbedding(x, wires=x_wires) + qml.Multiplier(k, x_wires, mod, work_wires) + return qml.sample(wires=x_wires) + + if mod is None: + mod = 2 ** len(x_wires) + + assert np.allclose( + sum(bit * (2**i) for i, bit in enumerate(reversed(circuit(x)))), (x * k) % mod + ) + + @pytest.mark.parametrize( + ("k", "x_wires", "mod", "work_wires", "msg_match"), + [ + ( + 6, + [0, 1], + 7, + [3, 4, 5, 6], + "Multiplier must have enough wires to represent mod.", + ), + ( + 2, + [0, 1, 2], + 6, + [3, 4, 5, 6, 7], + "The operator cannot be built because k has no inverse modulo mod", + ), + ( + 3, + [0, 1, 2, 3, 4], + 11, + [4, 5], + "None of the wire in work_wires should be included in x_wires.", + ), + ( + 3, + [0, 1, 2, 3, 4], + 11, + [5, 6, 7, 8, 9, 10], + "Multiplier needs as many work_wires as x_wires plus two.", + ), + ( + 3, + [0, 1, 2, 3], + 16, + [5, 6, 7], + "Multiplier needs as many work_wires as x_wires.", + ), + ], + ) + def test_operation_and_wires_error( + self, k, x_wires, mod, work_wires, msg_match + ): # pylint: disable=too-many-arguments + """Test an error is raised when k or mod don't meet the requirements""" + with pytest.raises(ValueError, match=msg_match): + qml.Multiplier(k, x_wires, mod, work_wires) + + def test_decomposition(self): + """Test that compute_decomposition and decomposition work as expected.""" + k, x_wires, mod, work_wires = 4, [0, 1, 2], 7, [3, 4, 5, 6, 7] + multiplier_decomposition = qml.Multiplier( + k, x_wires, mod, work_wires + ).compute_decomposition(k, x_wires, mod, work_wires) + op_list = [] + if mod != 2 ** len(x_wires): + work_wire_aux = work_wires[:1] + wires_aux = work_wires[1:] + wires_aux_swap = wires_aux[1:] + else: + work_wire_aux = None + wires_aux = work_wires[:3] + wires_aux_swap = wires_aux + op_list.extend(_mul_out_k_mod(k, x_wires, mod, work_wire_aux, wires_aux)) + for x_wire, aux_wire in zip(x_wires, wires_aux_swap): + op_list.append(qml.SWAP(wires=[x_wire, aux_wire])) + inv_k = pow(k, -1, mod) + op_list.extend(qml.adjoint(_mul_out_k_mod)(inv_k, x_wires, mod, work_wire_aux, wires_aux)) + + for op1, op2 in zip(multiplier_decomposition, op_list): + qml.assert_equal(op1, op2) + + @pytest.mark.jax + def test_jit_compatible(self): + """Test that the template is compatible with the JIT compiler.""" + + import jax + + jax.config.update("jax_enable_x64", True) + x = 2 + k = 6 + mod = 7 + x_wires = [0, 1, 2] + work_wires = [4, 5, 6, 7, 8] + dev = qml.device("default.qubit", shots=1) + + @jax.jit + @qml.qnode(dev) + def circuit(): + qml.BasisEmbedding(x, wires=x_wires) + qml.Multiplier(k, x_wires, mod, work_wires) + return qml.sample(wires=x_wires) + + assert jax.numpy.allclose( + sum(bit * (2**i) for i, bit in enumerate(reversed(circuit()))), (x * k) % mod + ) diff --git a/tests/templates/test_subroutines/test_out_multiplier.py b/tests/templates/test_subroutines/test_out_multiplier.py new file mode 100644 index 00000000000..9541474b7f3 --- /dev/null +++ b/tests/templates/test_subroutines/test_out_multiplier.py @@ -0,0 +1,247 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for the OutMultiplier template. +""" + +import pytest + +import pennylane as qml +from pennylane import numpy as np +from pennylane.templates.subroutines.out_multiplier import OutMultiplier + + +def test_standard_validity_OutMultiplier(): + """Check the operation using the assert_valid function.""" + mod = 12 + x_wires = [0, 1] + y_wires = [2, 3, 4] + output_wires = [6, 7, 8, 9] + work_wires = [5, 10] + op = OutMultiplier(x_wires, y_wires, output_wires, mod, work_wires) + qml.ops.functions.assert_valid(op) + + +class TestOutMultiplier: + """Test the qml.OutMultiplier template.""" + + @pytest.mark.parametrize( + ("x_wires", "y_wires", "output_wires", "mod", "work_wires", "x", "y"), + [ + ( + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + 7, + [9, 10], + 2, + 3, + ), + ( + [0, 1], + [3, 4, 5], + [6, 7, 8, 2], + 14, + [9, 10], + 1, + 2, + ), + ( + [0, 1, 2], + [3, 4], + [5, 6, 7, 8], + 8, + [9, 10], + 3, + 3, + ), + ( + [0, 1, 2, 3], + [4, 5], + [6, 7, 8, 9, 10], + 22, + [11, 12], + 0, + 0, + ), + ( + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + None, + [9, 10], + 1, + 3, + ), + ( + [0, 1], + [3, 4, 5], + [6, 7, 8], + None, + None, + 3, + 3, + ), + ], + ) + def test_operation_result( + self, x_wires, y_wires, output_wires, mod, work_wires, x, y + ): # pylint: disable=too-many-arguments + """Test the correctness of the OutMultiplier template output.""" + dev = qml.device("default.qubit", shots=1) + + @qml.qnode(dev) + def circuit(x, y): + qml.BasisEmbedding(x, wires=x_wires) + qml.BasisEmbedding(y, wires=y_wires) + OutMultiplier(x_wires, y_wires, output_wires, mod, work_wires) + return qml.sample(wires=output_wires) + + if mod is None: + mod = 2 ** len(output_wires) + + assert np.allclose( + sum(bit * (2**i) for i, bit in enumerate(reversed(circuit(x, y)))), (x * y) % mod + ) + + @pytest.mark.parametrize( + ("x_wires", "y_wires", "output_wires", "mod", "work_wires", "msg_match"), + [ + ( + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + 7, + [1, 10], + "None of the wires in work_wires should be included in x_wires.", + ), + ( + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + 7, + [3, 10], + "None of the wires in work_wires should be included in y_wires.", + ), + ( + [0, 1, 2], + [2, 4, 5], + [6, 7, 8], + 7, + [9, 10], + "None of the wires in y_wires should be included in x_wires.", + ), + ( + [0, 1, 2], + [3, 7, 5], + [6, 7, 8], + 7, + [9, 10], + "None of the wires in y_wires should be included in output_wires.", + ), + ( + [0, 1, 7], + [3, 4, 5], + [6, 7, 8], + 7, + [9, 10], + "None of the wires in x_wires should be included in output_wires.", + ), + ( + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + 9, + [9, 10], + "OutMultiplier must have enough wires to represent mod.", + ), + ( + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + 9, + None, + "If mod is not", + ), + ], + ) + def test_wires_error( + self, x_wires, y_wires, output_wires, mod, work_wires, msg_match + ): # pylint: disable=too-many-arguments + """Test an error is raised when some work_wires don't meet the requirements""" + with pytest.raises(ValueError, match=msg_match): + OutMultiplier(x_wires, y_wires, output_wires, mod, work_wires) + + def test_decomposition(self): + """Test that compute_decomposition and decomposition work as expected.""" + x_wires, y_wires, output_wires, mod, work_wires = ( + [0, 1, 2], + [3, 5], + [6, 8], + 3, + [9, 10], + ) + multiplier_decomposition = OutMultiplier( + x_wires, y_wires, output_wires, mod, work_wires + ).compute_decomposition(x_wires, y_wires, output_wires, mod, work_wires) + op_list = [] + if mod != 2 ** len(output_wires): + qft_output_wires = work_wires[:1] + output_wires + work_wire = work_wires[1:] + else: + qft_output_wires = output_wires + work_wire = None + op_list.append(qml.QFT(wires=qft_output_wires)) + op_list.append( + qml.ControlledSequence( + qml.ControlledSequence( + qml.PhaseAdder(1, qft_output_wires, mod, work_wire), control=x_wires + ), + control=y_wires, + ) + ) + op_list.append(qml.adjoint(qml.QFT)(wires=qft_output_wires)) + + for op1, op2 in zip(multiplier_decomposition, op_list): + qml.assert_equal(op1, op2) + + @pytest.mark.jax + def test_jit_compatible(self): + """Test that the template is compatible with the JIT compiler.""" + + import jax + + jax.config.update("jax_enable_x64", True) + + x, y = 2, 3 + x_list = [1, 0] + y_list = [1, 1] + mod = 12 + x_wires = [0, 1] + y_wires = [2, 3] + output_wires = [6, 7, 8, 9] + work_wires = [5, 10] + dev = qml.device("default.qubit", shots=1) + + @jax.jit + @qml.qnode(dev) + def circuit(): + qml.BasisEmbedding(x_list, wires=x_wires) + qml.BasisEmbedding(y_list, wires=y_wires) + OutMultiplier(x_wires, y_wires, output_wires, mod, work_wires) + return qml.sample(wires=output_wires) + + assert jax.numpy.allclose( + sum(bit * (2**i) for i, bit in enumerate(reversed(circuit()))), (x * y) % mod + )