diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 81b0a21d144..a67925c4ad6 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -27,6 +27,9 @@
* The `qml.Adder` and `qml.PhaseAdder` templates are added to perform in-place modular addition.
[(#6109)](https://github.com/PennyLaneAI/pennylane/pull/6109)
+* The `qml.Multiplier` and `qml.OutMultiplier` templates are added to perform modular multiplication.
+ [(#6112)](https://github.com/PennyLaneAI/pennylane/pull/6112)
+
Creating spin Hamiltonians 🧑🎨
* The function ``transverse_ising`` is added to generate transverse-field Ising Hamiltonian.
diff --git a/pennylane/templates/subroutines/__init__.py b/pennylane/templates/subroutines/__init__.py
index 2c60d72bbf4..5b3ec3fc887 100644
--- a/pennylane/templates/subroutines/__init__.py
+++ b/pennylane/templates/subroutines/__init__.py
@@ -47,3 +47,5 @@
from .qrom import QROM
from .phase_adder import PhaseAdder
from .adder import Adder
+from .multiplier import Multiplier
+from .out_multiplier import OutMultiplier
diff --git a/pennylane/templates/subroutines/adder.py b/pennylane/templates/subroutines/adder.py
index 0d4a776dc7c..32a621c4985 100644
--- a/pennylane/templates/subroutines/adder.py
+++ b/pennylane/templates/subroutines/adder.py
@@ -34,9 +34,8 @@ class Adder(Operation):
.. note::
- Note that :math:`x` must be smaller than :math:`mod` to get the correct result. Also, when
- :math:`mod \neq 2^{\text{len(x\_wires)}}` we need :math:`x < 2^{\text{len(x\_wires)}}/2`,
- which means that we need one extra wire in ``x_wires``.
+ Note that :math:`x` must be smaller than :math:`mod` to get the correct result.
+
Args:
k (int): the number that needs to be added
diff --git a/pennylane/templates/subroutines/multiplier.py b/pennylane/templates/subroutines/multiplier.py
new file mode 100644
index 00000000000..dbc3ff96325
--- /dev/null
+++ b/pennylane/templates/subroutines/multiplier.py
@@ -0,0 +1,197 @@
+# Copyright 2018-2024 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Contains the Multiplier template.
+"""
+
+import numpy as np
+
+import pennylane as qml
+from pennylane.operation import Operation
+
+
+def _mul_out_k_mod(k, x_wires, mod, work_wire_aux, wires_aux):
+ """Performs :math:`x \times k` in the registers wires wires_aux"""
+ op_list = []
+
+ op_list.append(qml.QFT(wires=wires_aux))
+ op_list.append(
+ qml.ControlledSequence(qml.PhaseAdder(k, wires_aux, mod, work_wire_aux), control=x_wires)
+ )
+ op_list.append(qml.adjoint(qml.QFT(wires=wires_aux)))
+ return op_list
+
+
+class Multiplier(Operation):
+ r"""Performs the in-place modular multiplication operation.
+
+ This operator performs the modular multiplication by an integer :math:`k` modulo :math:`mod` in
+ the computational basis:
+
+ .. math::
+
+ \text{Multiplier}(k,mod) |x \rangle = | x \cdot k \; \text{modulo} \; \text{mod} \rangle.
+
+ The implementation is based on the quantum Fourier transform method presented in
+ `arXiv:2311.08555 `_.
+
+ .. note::
+
+ Note that :math:`x` must be smaller than :math:`mod` to get the correct result. Also, it
+ is required that :math:`k` has inverse, :math:`k^-1`, modulo :math:`mod`. That means
+ :math:`k*k^-1 modulo mod is equal to 1`, which will only be possible if :math:`k` and
+ :math:`mod` are coprime. Furthermore, if :math:`mod \neq 2^{len(x\_wires)}`, two more
+ auxiliaries must be added.
+
+ Args:
+ k (int): the number that needs to be multiplied
+ x_wires (Sequence[int]): the wires the operation acts on
+ mod (int): the modulus for performing the multiplication, default value is :math:`2^{len(x\_wires)}`
+ work_wires (Sequence[int]): the auxiliary wires to be used for performing the multiplication
+
+ **Example**
+
+ This example performs the multiplication of two integers :math:`x=3` and :math:`k=4` modulo :math:`mod=7`.
+
+ .. code-block::
+
+ x = 3
+ k = 4
+ mod = 7
+
+ x_wires =[0,1,2]
+ work_wires=[3,4,5,6,7]
+
+ dev = qml.device("default.qubit", shots=1)
+ @qml.qnode(dev)
+ def circuit(x, k, mod, wires_m, work_wires):
+ qml.BasisEmbedding(x, wires=wires_m)
+ qml.Multiplier(k, x_wires, mod, work_wires)
+ return qml.sample(wires=wires_m)
+
+ .. code-block:: pycon
+
+ >>> print(circuit(x, k, mod, x_wires, work_wires))
+ [1 0 1]
+
+ The result :math:`[1 0 1]`, is the ket representation of
+ :math:`3 \cdot 4 \, \text{modulo} \, 12 = 5`.
+ """
+
+ grad_method = None
+
+ def __init__(
+ self, k, x_wires, mod=None, work_wires=None, id=None
+ ): # pylint: disable=too-many-arguments
+ if any(wire in work_wires for wire in x_wires):
+ raise ValueError("None of the wire in work_wires should be included in x_wires.")
+
+ if mod is None:
+ mod = 2 ** len(x_wires)
+ if mod != 2 ** len(x_wires) and len(work_wires) < (len(x_wires) + 2):
+ raise ValueError("Multiplier needs as many work_wires as x_wires plus two.")
+ if len(work_wires) < len(x_wires):
+ raise ValueError("Multiplier needs as many work_wires as x_wires.")
+ if (not hasattr(x_wires, "__len__")) or (mod > 2 ** len(x_wires)):
+ raise ValueError("Multiplier must have enough wires to represent mod.")
+
+ k = k % mod
+ if np.gcd(k, mod) != 1:
+ raise ValueError("The operator cannot be built because k has no inverse modulo mod.")
+
+ self.hyperparameters["k"] = k
+ self.hyperparameters["mod"] = mod
+ self.hyperparameters["work_wires"] = qml.wires.Wires(work_wires)
+ self.hyperparameters["x_wires"] = qml.wires.Wires(x_wires)
+ all_wires = qml.wires.Wires(x_wires) + qml.wires.Wires(work_wires)
+ super().__init__(wires=all_wires, id=id)
+
+ @property
+ def num_params(self):
+ return 0
+
+ def _flatten(self):
+ metadata = tuple((key, value) for key, value in self.hyperparameters.items())
+ return tuple(), metadata
+
+ @classmethod
+ def _unflatten(cls, data, metadata):
+ hyperparams_dict = dict(metadata)
+ return cls(**hyperparams_dict)
+
+ def map_wires(self, wire_map: dict):
+ new_dict = {
+ key: [wire_map.get(w, w) for w in self.hyperparameters[key]]
+ for key in ["x_wires", "work_wires"]
+ }
+
+ return Multiplier(
+ self.hyperparameters["k"],
+ new_dict["x_wires"],
+ self.hyperparameters["mod"],
+ new_dict["work_wires"],
+ )
+
+ @property
+ def wires(self):
+ """All wires involved in the operation."""
+ return self.hyperparameters["x_wires"] + self.hyperparameters["work_wires"]
+
+ def decomposition(self): # pylint: disable=arguments-differ
+ return self.compute_decomposition(**self.hyperparameters)
+
+ @classmethod
+ def _primitive_bind_call(cls, *args, **kwargs):
+ return cls._primitive.bind(*args, **kwargs)
+
+ @staticmethod
+ def compute_decomposition(k, x_wires, mod, work_wires): # pylint: disable=arguments-differ
+ r"""Representation of the operator as a product of other operators.
+ Args:
+ k (int): the number that needs to be multiplied
+ x_wires (Sequence[int]): the wires the operation acts on
+ mod (int): the modulus for performing the multiplication, default value is :math:`2^{len(x\_wires)}`
+ work_wires (Sequence[int]): the auxiliary wires to be used for performing the multiplication
+ Returns:
+ list[.Operator]: Decomposition of the operator
+
+ **Example**
+
+ >>> qml.Multiplier.compute_decomposition(k=3, mod=8, x_wires=[0,1,2], work_wires=[3,4,5])
+ [QFT(wires=[3, 4, 5]),
+ ControlledSequence(PhaseAdder(wires=[3, 4 , 5 , None]), control=[0, 1, 2]),
+ Adjoint(QFT(wires=[3, 4, 5])),
+ SWAP(wires=[0, 3]),
+ SWAP(wires=[1, 4]),
+ SWAP(wires=[2, 5]),
+ Adjoint(Adjoint(QFT(wires=[3, 4, 5]))),
+ Adjoint(ControlledSequence(PhaseAdder(wires=[3, 4, 5, None]), control=[0, 1, 2])),
+ Adjoint(QFT(wires=[3, 4, 5]))]
+ """
+
+ op_list = []
+ if mod != 2 ** len(x_wires):
+ work_wire_aux = work_wires[:1]
+ wires_aux = work_wires[1:]
+ wires_aux_swap = wires_aux[1:]
+ else:
+ work_wire_aux = None
+ wires_aux = work_wires[: len(x_wires)]
+ wires_aux_swap = wires_aux
+ op_list.extend(_mul_out_k_mod(k, x_wires, mod, work_wire_aux, wires_aux))
+ for x_wire, aux_wire in zip(x_wires, wires_aux_swap):
+ op_list.append(qml.SWAP(wires=[x_wire, aux_wire]))
+ inv_k = pow(k, -1, mod)
+ op_list.extend(qml.adjoint(_mul_out_k_mod)(inv_k, x_wires, mod, work_wire_aux, wires_aux))
+ return op_list
diff --git a/pennylane/templates/subroutines/out_multiplier.py b/pennylane/templates/subroutines/out_multiplier.py
new file mode 100644
index 00000000000..009763e51e2
--- /dev/null
+++ b/pennylane/templates/subroutines/out_multiplier.py
@@ -0,0 +1,196 @@
+# Copyright 2018-2024 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Contains the OutMultiplier template.
+"""
+
+import pennylane as qml
+from pennylane.operation import Operation
+
+
+class OutMultiplier(Operation):
+ r"""Performs the out-place modular multiplication operation.
+
+ This operator performs the modular multiplication of integers :math:`x` and :math:`y` modulo
+ :math:`mod` in the computational basis:
+
+ .. math::
+ \text{OutMultiplier}(mod) |x \rangle |y \rangle |b \rangle = |x \rangle |y \rangle |b + x \cdot y \; \text{modulo} \; mod \rangle,
+
+ The implementation is based on the quantum Fourier transform method presented in
+ `arXiv:2311.08555 `_.
+
+ .. note::
+
+ Note that :math:`x` and :math:`y` must be smaller than :math:`mod` to get the correct result.
+
+ Args:
+ x_wires (Sequence[int]): the wires that store the integer :math:`x`
+ y_wires (Sequence[int]): the wires that store the integer :math:`y`
+ output_wires (Sequence[int]): the wires that store the multiplication result
+ mod (int): the modulus for performing the multiplication, default value is :math:`2^{len(output\_wires)}`
+ work_wires (Sequence[int]): the auxiliary wires to use for the multiplication modulo
+
+ **Example**
+
+ This example performs the multiplication of two integers :math:`x=2` and :math:`y=7` modulo :math:`mod=12`.
+
+ .. code-block::
+
+ x = 2
+ y = 7
+ mod = 12
+
+ x_wires = [0, 1]
+ y_wires = [2, 3, 4]
+ output_wires = [6, 7, 8, 9]
+ work_wires = [5, 10]
+
+ dev = qml.device("default.qubit", shots=1)
+ @qml.qnode(dev)
+ def circuit():
+ qml.BasisEmbedding(x, wires=x_wires)
+ qml.BasisEmbedding(y, wires=y_wires)
+ qml.OutMultiplier(x_wires, y_wires, output_wires, mod, work_wires)
+ return qml.sample(wires=output_wires)
+
+ .. code-block:: pycon
+
+ >>> print(circuit())
+ [0 0 1 0]
+
+ The result :math:`[0 0 1 0]`, is the ket representation of
+ :math:`2 \cdot 7 \, \text{modulo} \, 12 = 2`.
+ """
+
+ grad_method = None
+
+ def __init__(
+ self, x_wires, y_wires, output_wires, mod=None, work_wires=None, id=None
+ ): # pylint: disable=too-many-arguments
+
+ if mod is None:
+ mod = 2 ** len(output_wires)
+ if mod != 2 ** len(output_wires) and work_wires is None:
+ raise ValueError(
+ f"If mod is not 2^{len(output_wires)}, two work wires should be provided."
+ )
+ if (not hasattr(output_wires, "__len__")) or (mod > 2 ** (len(output_wires))):
+ raise ValueError("OutMultiplier must have enough wires to represent mod.")
+
+ if work_wires is not None:
+ if any(wire in work_wires for wire in x_wires):
+ raise ValueError("None of the wires in work_wires should be included in x_wires.")
+ if any(wire in work_wires for wire in y_wires):
+ raise ValueError("None of the wires in work_wires should be included in y_wires.")
+
+ if any(wire in y_wires for wire in x_wires):
+ raise ValueError("None of the wires in y_wires should be included in x_wires.")
+ if any(wire in x_wires for wire in output_wires):
+ raise ValueError("None of the wires in x_wires should be included in output_wires.")
+ if any(wire in y_wires for wire in output_wires):
+ raise ValueError("None of the wires in y_wires should be included in output_wires.")
+
+ wires_list = ["x_wires", "y_wires", "output_wires", "work_wires"]
+
+ for key in wires_list:
+ self.hyperparameters[key] = qml.wires.Wires(locals()[key])
+ self.hyperparameters["mod"] = mod
+ all_wires = sum(self.hyperparameters[key] for key in wires_list)
+ super().__init__(wires=all_wires, id=id)
+
+ @property
+ def num_params(self):
+ return 0
+
+ def _flatten(self):
+ metadata = tuple((key, value) for key, value in self.hyperparameters.items())
+ return tuple(), metadata
+
+ @classmethod
+ def _unflatten(cls, data, metadata):
+ hyperparams_dict = dict(metadata)
+ return cls(**hyperparams_dict)
+
+ def map_wires(self, wire_map: dict):
+ new_dict = {
+ key: [wire_map.get(w, w) for w in self.hyperparameters[key]]
+ for key in ["x_wires", "y_wires", "output_wires", "work_wires"]
+ }
+
+ return OutMultiplier(
+ new_dict["x_wires"],
+ new_dict["y_wires"],
+ new_dict["output_wires"],
+ self.hyperparameters["mod"],
+ new_dict["work_wires"],
+ )
+
+ @property
+ def wires(self):
+ """All wires involved in the operation."""
+ return (
+ self.hyperparameters["x_wires"]
+ + self.hyperparameters["y_wires"]
+ + self.hyperparameters["output_wires"]
+ + self.hyperparameters["work_wires"]
+ )
+
+ def decomposition(self): # pylint: disable=arguments-differ
+ return self.compute_decomposition(**self.hyperparameters)
+
+ @classmethod
+ def _primitive_bind_call(cls, *args, **kwargs):
+ return cls._primitive.bind(*args, **kwargs)
+
+ @staticmethod
+ def compute_decomposition(
+ x_wires, y_wires, output_wires, mod, work_wires
+ ): # pylint: disable=arguments-differ
+ r"""Representation of the operator as a product of other operators.
+ Args:
+ x_wires (Sequence[int]): the wires that store the integer :math:`x`
+ y_wires (Sequence[int]): the wires that store the integer :math:`y`
+ output_wires (Sequence[int]): the wires that store the multiplication result
+ mod (int): the modulus for performing the multiplication, default value is :math:`2^{len(output\_wires)}`
+ work_wires (Sequence[int]): the auxiliary wires to use for the multiplication modulo
+ Returns:
+ list[.Operator]: Decomposition of the operator
+
+ **Example**
+
+ >>> qml.OutMultiplier.compute_decomposition(x_wires=[0,1], y_wires=[2,3], output_wires=[5,6], mod=4, work_wires=[4,7])
+ [QFT(wires=[5, 6]),
+ ControlledSequence(ControlledSequence(PhaseAdder(wires=[5, 6]), control=[0, 1]), control=[2, 3]),
+ Adjoint(QFT(wires=[5, 6]))]
+ """
+ op_list = []
+ if mod != 2 ** len(output_wires):
+ qft_output_wires = work_wires[:1] + output_wires
+ work_wire = work_wires[1:]
+ else:
+ qft_output_wires = output_wires
+ work_wire = None
+ op_list.append(qml.QFT(wires=qft_output_wires))
+ op_list.append(
+ qml.ControlledSequence(
+ qml.ControlledSequence(
+ qml.PhaseAdder(1, qft_output_wires, mod, work_wire), control=x_wires
+ ),
+ control=y_wires,
+ )
+ )
+ op_list.append(qml.adjoint(qml.QFT)(wires=qft_output_wires))
+
+ return op_list
diff --git a/tests/capture/test_templates.py b/tests/capture/test_templates.py
index 787ad612281..cf202b3b087 100644
--- a/tests/capture/test_templates.py
+++ b/tests/capture/test_templates.py
@@ -259,6 +259,8 @@ def fn(*args):
qml.QROM,
qml.PhaseAdder,
qml.Adder,
+ qml.Multiplier,
+ qml.OutMultiplier,
]
@@ -758,6 +760,77 @@ def qfunc():
assert len(q) == 1
qml.assert_equal(q.queue[0], qml.Adder(**kwargs))
+ @pytest.mark.usefixtures("new_opmath_only")
+ def test_multiplier(self):
+ """Test the primitive bind call of Multiplier."""
+
+ kwargs = {
+ "k": 3,
+ "x_wires": [0, 1],
+ "mod": None,
+ "work_wires": [2, 3],
+ }
+
+ def qfunc():
+ qml.Multiplier(**kwargs)
+
+ # Validate inputs
+ qfunc()
+
+ # Actually test primitive bind
+ jaxpr = jax.make_jaxpr(qfunc)()
+
+ assert len(jaxpr.eqns) == 1
+
+ eqn = jaxpr.eqns[0]
+ assert eqn.primitive == qml.Multiplier._primitive
+ assert eqn.invars == jaxpr.jaxpr.invars
+ assert eqn.params == kwargs
+ assert len(eqn.outvars) == 1
+ assert isinstance(eqn.outvars[0], jax.core.DropVar)
+
+ with qml.queuing.AnnotatedQueue() as q:
+ jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts)
+
+ assert len(q) == 1
+ qml.assert_equal(q.queue[0], qml.Multiplier(**kwargs))
+
+ @pytest.mark.usefixtures("new_opmath_only")
+ def test_out_multiplier(self):
+ """Test the primitive bind call of OutMultiplier."""
+
+ kwargs = {
+ "x_wires": [0, 1],
+ "y_wires": [2, 3],
+ "output_wires": [4, 5],
+ "mod": None,
+ "work_wires": None,
+ }
+
+ def qfunc():
+ qml.OutMultiplier(**kwargs)
+
+ # Validate inputs
+ qfunc()
+
+ # Actually test primitive bind
+ jaxpr = jax.make_jaxpr(qfunc)()
+
+ assert len(jaxpr.eqns) == 1
+
+ eqn = jaxpr.eqns[0]
+ assert eqn.primitive == qml.OutMultiplier._primitive
+ assert eqn.invars == jaxpr.jaxpr.invars
+ assert eqn.params == kwargs
+ assert len(eqn.outvars) == 1
+ assert isinstance(eqn.outvars[0], jax.core.DropVar)
+
+ with qml.queuing.AnnotatedQueue() as q:
+ jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts)
+
+ assert len(q) == 1
+ qml.assert_equal(q.queue[0], qml.OutMultiplier(**kwargs))
+
@pytest.mark.parametrize(
"template, kwargs",
[
diff --git a/tests/templates/test_subroutines/test_multiplier.py b/tests/templates/test_subroutines/test_multiplier.py
new file mode 100644
index 00000000000..4cb9cadee5f
--- /dev/null
+++ b/tests/templates/test_subroutines/test_multiplier.py
@@ -0,0 +1,202 @@
+# Copyright 2018-2024 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Tests for the Multiplier template.
+"""
+
+import numpy as np
+import pytest
+
+import pennylane as qml
+from pennylane.templates.subroutines.multiplier import _mul_out_k_mod
+
+
+def test_standard_validity_Multiplier():
+ """Check the operation using the assert_valid function."""
+ k = 6
+ mod = 11
+ x_wires = [0, 1, 2, 3]
+ work_wires = [4, 5, 6, 7, 8, 9]
+ op = qml.Multiplier(k, x_wires=x_wires, mod=mod, work_wires=work_wires)
+ qml.ops.functions.assert_valid(op)
+
+
+def test_mul_out_k_mod():
+ """Test the _mul_out_k_mod function."""
+
+ op = _mul_out_k_mod(2, [0, 1], 4, None, [4, 5])
+ assert op[0].name == "QFT"
+ assert op[1].name == "ControlledSequence"
+ assert op[2].name == "Adjoint(QFT)"
+ print(op[1].base)
+ assert qml.equal(op[1].base, qml.PhaseAdder(2, x_wires=[4, 5]))
+
+
+class TestMultiplier:
+ """Test the qml.Multiplier template."""
+
+ @pytest.mark.parametrize(
+ ("k", "x_wires", "mod", "work_wires", "x"),
+ [
+ (
+ 5,
+ [0, 1, 2],
+ 8,
+ [4, 5, 6, 7, 8],
+ 3,
+ ),
+ (
+ 1,
+ [0, 1, 2],
+ 3,
+ [3, 4, 5, 6, 7],
+ 2,
+ ),
+ (
+ -12,
+ [0, 1, 2, 3, 4],
+ 23,
+ [5, 6, 7, 8, 9, 10, 11],
+ 1,
+ ),
+ (
+ 5,
+ [0, 1, 2, 3, 4],
+ None,
+ [5, 6, 7, 8, 9, 10, 11],
+ 0,
+ ),
+ (
+ 5,
+ [0, 1, 2, 3, 4],
+ None,
+ [5, 6, 7, 8, 9],
+ 1,
+ ),
+ ],
+ )
+ def test_operation_result(
+ self, k, x_wires, mod, work_wires, x
+ ): # pylint: disable=too-many-arguments
+ """Test the correctness of the Multiplier template output."""
+ dev = qml.device("default.qubit", shots=1)
+
+ @qml.qnode(dev)
+ def circuit(x):
+ qml.BasisEmbedding(x, wires=x_wires)
+ qml.Multiplier(k, x_wires, mod, work_wires)
+ return qml.sample(wires=x_wires)
+
+ if mod is None:
+ mod = 2 ** len(x_wires)
+
+ assert np.allclose(
+ sum(bit * (2**i) for i, bit in enumerate(reversed(circuit(x)))), (x * k) % mod
+ )
+
+ @pytest.mark.parametrize(
+ ("k", "x_wires", "mod", "work_wires", "msg_match"),
+ [
+ (
+ 6,
+ [0, 1],
+ 7,
+ [3, 4, 5, 6],
+ "Multiplier must have enough wires to represent mod.",
+ ),
+ (
+ 2,
+ [0, 1, 2],
+ 6,
+ [3, 4, 5, 6, 7],
+ "The operator cannot be built because k has no inverse modulo mod",
+ ),
+ (
+ 3,
+ [0, 1, 2, 3, 4],
+ 11,
+ [4, 5],
+ "None of the wire in work_wires should be included in x_wires.",
+ ),
+ (
+ 3,
+ [0, 1, 2, 3, 4],
+ 11,
+ [5, 6, 7, 8, 9, 10],
+ "Multiplier needs as many work_wires as x_wires plus two.",
+ ),
+ (
+ 3,
+ [0, 1, 2, 3],
+ 16,
+ [5, 6, 7],
+ "Multiplier needs as many work_wires as x_wires.",
+ ),
+ ],
+ )
+ def test_operation_and_wires_error(
+ self, k, x_wires, mod, work_wires, msg_match
+ ): # pylint: disable=too-many-arguments
+ """Test an error is raised when k or mod don't meet the requirements"""
+ with pytest.raises(ValueError, match=msg_match):
+ qml.Multiplier(k, x_wires, mod, work_wires)
+
+ def test_decomposition(self):
+ """Test that compute_decomposition and decomposition work as expected."""
+ k, x_wires, mod, work_wires = 4, [0, 1, 2], 7, [3, 4, 5, 6, 7]
+ multiplier_decomposition = qml.Multiplier(
+ k, x_wires, mod, work_wires
+ ).compute_decomposition(k, x_wires, mod, work_wires)
+ op_list = []
+ if mod != 2 ** len(x_wires):
+ work_wire_aux = work_wires[:1]
+ wires_aux = work_wires[1:]
+ wires_aux_swap = wires_aux[1:]
+ else:
+ work_wire_aux = None
+ wires_aux = work_wires[:3]
+ wires_aux_swap = wires_aux
+ op_list.extend(_mul_out_k_mod(k, x_wires, mod, work_wire_aux, wires_aux))
+ for x_wire, aux_wire in zip(x_wires, wires_aux_swap):
+ op_list.append(qml.SWAP(wires=[x_wire, aux_wire]))
+ inv_k = pow(k, -1, mod)
+ op_list.extend(qml.adjoint(_mul_out_k_mod)(inv_k, x_wires, mod, work_wire_aux, wires_aux))
+
+ for op1, op2 in zip(multiplier_decomposition, op_list):
+ qml.assert_equal(op1, op2)
+
+ @pytest.mark.jax
+ def test_jit_compatible(self):
+ """Test that the template is compatible with the JIT compiler."""
+
+ import jax
+
+ jax.config.update("jax_enable_x64", True)
+ x = 2
+ k = 6
+ mod = 7
+ x_wires = [0, 1, 2]
+ work_wires = [4, 5, 6, 7, 8]
+ dev = qml.device("default.qubit", shots=1)
+
+ @jax.jit
+ @qml.qnode(dev)
+ def circuit():
+ qml.BasisEmbedding(x, wires=x_wires)
+ qml.Multiplier(k, x_wires, mod, work_wires)
+ return qml.sample(wires=x_wires)
+
+ assert jax.numpy.allclose(
+ sum(bit * (2**i) for i, bit in enumerate(reversed(circuit()))), (x * k) % mod
+ )
diff --git a/tests/templates/test_subroutines/test_out_multiplier.py b/tests/templates/test_subroutines/test_out_multiplier.py
new file mode 100644
index 00000000000..9541474b7f3
--- /dev/null
+++ b/tests/templates/test_subroutines/test_out_multiplier.py
@@ -0,0 +1,247 @@
+# Copyright 2018-2024 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Tests for the OutMultiplier template.
+"""
+
+import pytest
+
+import pennylane as qml
+from pennylane import numpy as np
+from pennylane.templates.subroutines.out_multiplier import OutMultiplier
+
+
+def test_standard_validity_OutMultiplier():
+ """Check the operation using the assert_valid function."""
+ mod = 12
+ x_wires = [0, 1]
+ y_wires = [2, 3, 4]
+ output_wires = [6, 7, 8, 9]
+ work_wires = [5, 10]
+ op = OutMultiplier(x_wires, y_wires, output_wires, mod, work_wires)
+ qml.ops.functions.assert_valid(op)
+
+
+class TestOutMultiplier:
+ """Test the qml.OutMultiplier template."""
+
+ @pytest.mark.parametrize(
+ ("x_wires", "y_wires", "output_wires", "mod", "work_wires", "x", "y"),
+ [
+ (
+ [0, 1, 2],
+ [3, 4, 5],
+ [6, 7, 8],
+ 7,
+ [9, 10],
+ 2,
+ 3,
+ ),
+ (
+ [0, 1],
+ [3, 4, 5],
+ [6, 7, 8, 2],
+ 14,
+ [9, 10],
+ 1,
+ 2,
+ ),
+ (
+ [0, 1, 2],
+ [3, 4],
+ [5, 6, 7, 8],
+ 8,
+ [9, 10],
+ 3,
+ 3,
+ ),
+ (
+ [0, 1, 2, 3],
+ [4, 5],
+ [6, 7, 8, 9, 10],
+ 22,
+ [11, 12],
+ 0,
+ 0,
+ ),
+ (
+ [0, 1, 2],
+ [3, 4, 5],
+ [6, 7, 8],
+ None,
+ [9, 10],
+ 1,
+ 3,
+ ),
+ (
+ [0, 1],
+ [3, 4, 5],
+ [6, 7, 8],
+ None,
+ None,
+ 3,
+ 3,
+ ),
+ ],
+ )
+ def test_operation_result(
+ self, x_wires, y_wires, output_wires, mod, work_wires, x, y
+ ): # pylint: disable=too-many-arguments
+ """Test the correctness of the OutMultiplier template output."""
+ dev = qml.device("default.qubit", shots=1)
+
+ @qml.qnode(dev)
+ def circuit(x, y):
+ qml.BasisEmbedding(x, wires=x_wires)
+ qml.BasisEmbedding(y, wires=y_wires)
+ OutMultiplier(x_wires, y_wires, output_wires, mod, work_wires)
+ return qml.sample(wires=output_wires)
+
+ if mod is None:
+ mod = 2 ** len(output_wires)
+
+ assert np.allclose(
+ sum(bit * (2**i) for i, bit in enumerate(reversed(circuit(x, y)))), (x * y) % mod
+ )
+
+ @pytest.mark.parametrize(
+ ("x_wires", "y_wires", "output_wires", "mod", "work_wires", "msg_match"),
+ [
+ (
+ [0, 1, 2],
+ [3, 4, 5],
+ [6, 7, 8],
+ 7,
+ [1, 10],
+ "None of the wires in work_wires should be included in x_wires.",
+ ),
+ (
+ [0, 1, 2],
+ [3, 4, 5],
+ [6, 7, 8],
+ 7,
+ [3, 10],
+ "None of the wires in work_wires should be included in y_wires.",
+ ),
+ (
+ [0, 1, 2],
+ [2, 4, 5],
+ [6, 7, 8],
+ 7,
+ [9, 10],
+ "None of the wires in y_wires should be included in x_wires.",
+ ),
+ (
+ [0, 1, 2],
+ [3, 7, 5],
+ [6, 7, 8],
+ 7,
+ [9, 10],
+ "None of the wires in y_wires should be included in output_wires.",
+ ),
+ (
+ [0, 1, 7],
+ [3, 4, 5],
+ [6, 7, 8],
+ 7,
+ [9, 10],
+ "None of the wires in x_wires should be included in output_wires.",
+ ),
+ (
+ [0, 1, 2],
+ [3, 4, 5],
+ [6, 7, 8],
+ 9,
+ [9, 10],
+ "OutMultiplier must have enough wires to represent mod.",
+ ),
+ (
+ [0, 1, 2],
+ [3, 4, 5],
+ [6, 7, 8],
+ 9,
+ None,
+ "If mod is not",
+ ),
+ ],
+ )
+ def test_wires_error(
+ self, x_wires, y_wires, output_wires, mod, work_wires, msg_match
+ ): # pylint: disable=too-many-arguments
+ """Test an error is raised when some work_wires don't meet the requirements"""
+ with pytest.raises(ValueError, match=msg_match):
+ OutMultiplier(x_wires, y_wires, output_wires, mod, work_wires)
+
+ def test_decomposition(self):
+ """Test that compute_decomposition and decomposition work as expected."""
+ x_wires, y_wires, output_wires, mod, work_wires = (
+ [0, 1, 2],
+ [3, 5],
+ [6, 8],
+ 3,
+ [9, 10],
+ )
+ multiplier_decomposition = OutMultiplier(
+ x_wires, y_wires, output_wires, mod, work_wires
+ ).compute_decomposition(x_wires, y_wires, output_wires, mod, work_wires)
+ op_list = []
+ if mod != 2 ** len(output_wires):
+ qft_output_wires = work_wires[:1] + output_wires
+ work_wire = work_wires[1:]
+ else:
+ qft_output_wires = output_wires
+ work_wire = None
+ op_list.append(qml.QFT(wires=qft_output_wires))
+ op_list.append(
+ qml.ControlledSequence(
+ qml.ControlledSequence(
+ qml.PhaseAdder(1, qft_output_wires, mod, work_wire), control=x_wires
+ ),
+ control=y_wires,
+ )
+ )
+ op_list.append(qml.adjoint(qml.QFT)(wires=qft_output_wires))
+
+ for op1, op2 in zip(multiplier_decomposition, op_list):
+ qml.assert_equal(op1, op2)
+
+ @pytest.mark.jax
+ def test_jit_compatible(self):
+ """Test that the template is compatible with the JIT compiler."""
+
+ import jax
+
+ jax.config.update("jax_enable_x64", True)
+
+ x, y = 2, 3
+ x_list = [1, 0]
+ y_list = [1, 1]
+ mod = 12
+ x_wires = [0, 1]
+ y_wires = [2, 3]
+ output_wires = [6, 7, 8, 9]
+ work_wires = [5, 10]
+ dev = qml.device("default.qubit", shots=1)
+
+ @jax.jit
+ @qml.qnode(dev)
+ def circuit():
+ qml.BasisEmbedding(x_list, wires=x_wires)
+ qml.BasisEmbedding(y_list, wires=y_wires)
+ OutMultiplier(x_wires, y_wires, output_wires, mod, work_wires)
+ return qml.sample(wires=output_wires)
+
+ assert jax.numpy.allclose(
+ sum(bit * (2**i) for i, bit in enumerate(reversed(circuit()))), (x * y) % mod
+ )