diff --git a/doc/_static/templates/subroutines/ampamp.png b/doc/_static/templates/subroutines/ampamp.png
new file mode 100644
index 00000000000..cda07abc358
Binary files /dev/null and b/doc/_static/templates/subroutines/ampamp.png differ
diff --git a/doc/introduction/templates.rst b/doc/introduction/templates.rst
index 7cd8625f3d4..fbc088ef98c 100644
--- a/doc/introduction/templates.rst
+++ b/doc/introduction/templates.rst
@@ -215,6 +215,10 @@ Other useful templates which do not belong to the previous categories can be fou
:description: :doc:`Reflection Operator <../code/api/pennylane.Reflection>`
:figure: _static/templates/subroutines/reflection.png
+.. gallery-item::
+ :description: :doc:`Amplitude Amplification <../code/api/pennylane.AmplitudeAmplification>`
+ :figure: _static/templates/subroutines/ampamp.png
+
.. gallery-item::
:description: :doc:`Interferometer <../code/api/pennylane.Interferometer>`
:figure: _static/templates/subroutines/interferometer.png
diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index d8e73a0bee4..5482a23b335 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -58,16 +58,8 @@
* Added new function `qml.operation.convert_to_legacy_H` to convert `Sum`, `SProd`, and `Prod` to `Hamiltonian` instances.
[(#5309)](https://github.com/PennyLaneAI/pennylane/pull/5309)
-
Improvements ðŸ›
-
-* The `qml.is_commuting` function now accepts `Sum`, `SProd`, and `Prod` instances.
- [(#5351)](https://github.com/PennyLaneAI/pennylane/pull/5351)
-
-* Operators can now be left multiplied `x * op` by numpy arrays.
- [(#5361)](https://github.com/PennyLaneAI/pennylane/pull/5361)
-
* Create the `qml.Reflection` operator, useful for amplitude amplification and its variants.
- [(##5159)](https://github.com/PennyLaneAI/pennylane/pull/5159)
+ [(#5159)](https://github.com/PennyLaneAI/pennylane/pull/5159)
```python
@qml.prod
@@ -94,6 +86,44 @@
>>> circuit()
tensor([1.+6.123234e-17j, 0.-6.123234e-17j], requires_grad=True)
```
+
+* The `qml.AmplitudeAmplification` operator is introduced, which is a high-level interface for amplitude amplification and its variants.
+ [(#5160)](https://github.com/PennyLaneAI/pennylane/pull/5160)
+
+ ```python
+ @qml.prod
+ def generator(wires):
+ for wire in wires:
+ qml.Hadamard(wires=wire)
+
+ U = generator(wires=range(3))
+ O = qml.FlipSign(2, wires=range(3))
+
+ dev = qml.device("default.qubit")
+
+ @qml.qnode(dev)
+ def circuit():
+
+ generator(wires=range(3))
+ qml.AmplitudeAmplification(U, O, iters=5, fixed_point=True, work_wire=3)
+
+ return qml.probs(wires=range(3))
+
+ ```
+
+ ```pycon
+ >>> print(np.round(circuit(), 3))
+ [0.013, 0.013, 0.91, 0.013, 0.013, 0.013, 0.013, 0.013]
+
+ ```
+
+Improvements ðŸ›
+
+* The `qml.is_commuting` function now accepts `Sum`, `SProd`, and `Prod` instances.
+ [(#5351)](https://github.com/PennyLaneAI/pennylane/pull/5351)
+
+* Operators can now be left multiplied `x * op` by numpy arrays.
+ [(#5361)](https://github.com/PennyLaneAI/pennylane/pull/5361)
* The `molecular_hamiltonian` function calls `PySCF` directly when `method='pyscf'` is selected.
[(#5118)](https://github.com/PennyLaneAI/pennylane/pull/5118)
diff --git a/pennylane/templates/subroutines/__init__.py b/pennylane/templates/subroutines/__init__.py
index 5f9c21e0dd5..2a9b6612567 100644
--- a/pennylane/templates/subroutines/__init__.py
+++ b/pennylane/templates/subroutines/__init__.py
@@ -40,3 +40,4 @@
from .trotter import TrotterProduct
from .aqft import AQFT
from .reflection import Reflection
+from .amplitude_amplification import AmplitudeAmplification
diff --git a/pennylane/templates/subroutines/amplitude_amplification.py b/pennylane/templates/subroutines/amplitude_amplification.py
new file mode 100644
index 00000000000..afbdedffe00
--- /dev/null
+++ b/pennylane/templates/subroutines/amplitude_amplification.py
@@ -0,0 +1,179 @@
+# Copyright 2018-2024 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This submodule contains the template for Amplitude Amplification.
+"""
+
+# pylint: disable-msg=too-many-arguments
+import numpy as np
+from pennylane.operation import Operation
+import pennylane as qml
+
+
+def _get_fixed_point_angles(iters, p_min):
+ """
+ Returns the angles needed for the fixed-point amplitude amplification algorithm.
+ The angles are computed using equation (11) of `arXiv:1409.3305v2 `__.
+ """
+
+ delta = np.sqrt(1 - p_min)
+ gamma = np.cos(np.arccos(1 / delta, dtype=np.complex128) / iters, dtype=np.complex128) ** -1
+
+ alphas = [
+ 2 * np.arctan(1 / (np.tan(2 * np.pi * j / iters) * np.sqrt(1 - gamma**2)))
+ for j in range(1, iters // 2 + 1)
+ ]
+ betas = [-alphas[-j] for j in range(1, iters // 2 + 1)]
+ return alphas[: iters // 2], betas[: iters // 2]
+
+
+class AmplitudeAmplification(Operation):
+ r"""Applies amplitude amplification.
+
+ Given a state :math:`|\Psi\rangle = \alpha |\phi\rangle + \beta|\phi^{\perp}\rangle`, this
+ subroutine amplifies the amplitude of the state :math:`|\phi\rangle` such that
+
+ .. math::
+
+ \text{A}(U, O)|\Psi\rangle \sim |\phi\rangle.
+
+ The implementation of the algorithm is based on [`arXiv:quant-ph/0005055 `__].
+ The template also unlocks advanced techniques such as fixed-point quantum search
+ [`arXiv:1409.3305 `__] and oblivious amplitude amplification
+ [`arXiv:1312.1414 `__], by reflecting on a subset of wires.
+
+ Args:
+ U (Operator): the operator that prepares the state :math:`|\Psi\rangle`
+ O (Operator): the oracle that flips the sign of the state :math:`|\phi\rangle` and does nothing to the state :math:`|\phi^{\perp}\rangle`
+ iters (int): the number of iterations of the amplitude amplification subroutine, default is ``1``
+ fixed_point (bool): whether to use the fixed-point amplitude amplification algorithm, default is ``False``
+ work_wire (int): the auxiliary wire to use for the fixed-point amplitude amplification algorithm, default is ``None``
+ reflection_wires (Wires): the wires to reflect on, default is the wires of ``U``
+ p_min (int): the lower bound for the probability of success in fixed-point amplitude amplification, default is ``0.9``
+
+ Raises:
+ ValueError: ``work_wire`` must be specified if ``fixed_point == True``.
+ ValueError: ``work_wire`` must be different from the wires of the oracle ``O``.
+
+ **Example**
+
+ Amplification of state :math:`|2\rangle` using Grover's algorithm with 3 qubits.
+ The state :math:`|\Psi\rangle` is constructed as a uniform superposition of basis states.
+
+ .. code-block::
+
+ @qml.prod
+ def generator(wires):
+ for wire in wires:
+ qml.Hadamard(wires=wire)
+
+ U = generator(wires=range(3))
+ O = qml.FlipSign(2, wires=range(3))
+
+ dev = qml.device("default.qubit")
+
+ @qml.qnode(dev)
+ def circuit():
+
+ generator(wires=range(3))
+ qml.AmplitudeAmplification(U, O, iters=5, fixed_point=True, work_wire=3)
+
+ return qml.probs(wires=range(3))
+
+ .. code-block:: pycon
+
+ >>> print(np.round(circuit(),3))
+ [0.013, 0.013, 0.91, 0.013, 0.013, 0.013, 0.013, 0.013]
+ """
+
+ def _flatten(self):
+ data = (self.hyperparameters["U"], self.hyperparameters["O"])
+ metadata = tuple(
+ (key, value) for key, value in self.hyperparameters.items() if key not in ["O", "U"]
+ )
+ return data, metadata
+
+ @classmethod
+ def _unflatten(cls, data, metadata):
+ U, O = (data[0], data[1])
+ hyperparams_dict = dict(metadata)
+ return cls(U, O, **hyperparams_dict)
+
+ def __init__(
+ self, U, O, iters=1, fixed_point=False, work_wire=None, p_min=0.9, reflection_wires=None
+ ):
+ self._name = "AmplitudeAmplification"
+ if reflection_wires is None:
+ reflection_wires = U.wires
+
+ if fixed_point and work_wire is None:
+ raise qml.wires.WireError("work_wire must be specified if fixed_point == True.")
+
+ if fixed_point and len(O.wires + qml.wires.Wires(work_wire)) == len(O.wires):
+ raise ValueError("work_wire must be different from the wires of O.")
+
+ if fixed_point:
+ wires = U.wires + qml.wires.Wires(work_wire)
+ else:
+ wires = U.wires
+
+ self.hyperparameters["U"] = U
+ self.hyperparameters["O"] = O
+ self.hyperparameters["iters"] = iters
+ self.hyperparameters["fixed_point"] = fixed_point
+ self.hyperparameters["work_wire"] = work_wire
+ self.hyperparameters["p_min"] = p_min
+ self.hyperparameters["reflection_wires"] = qml.wires.Wires(reflection_wires)
+
+ super().__init__(wires=wires)
+
+ # pylint:disable=arguments-differ
+ @staticmethod
+ def compute_decomposition(**kwargs):
+ U = kwargs["U"]
+ O = kwargs["O"]
+ iters = kwargs["iters"]
+ fixed_point = kwargs["fixed_point"]
+ work_wire = kwargs["work_wire"]
+ p_min = kwargs["p_min"]
+ reflection_wires = kwargs["reflection_wires"]
+
+ ops = []
+
+ if fixed_point:
+ alphas, betas = _get_fixed_point_angles(iters, p_min)
+
+ for iter in range(iters // 2):
+ ops.append(qml.Hadamard(wires=work_wire))
+ ops.append(qml.ctrl(O, control=work_wire))
+ ops.append(qml.Hadamard(wires=work_wire))
+ ops.append(qml.PhaseShift(betas[iter], wires=work_wire))
+ ops.append(qml.Hadamard(wires=work_wire))
+ ops.append(qml.ctrl(O, control=work_wire))
+ ops.append(qml.Hadamard(wires=work_wire))
+
+ ops.append(qml.Reflection(U, -alphas[iter], reflection_wires=reflection_wires))
+ else:
+ for _ in range(iters):
+ ops.append(O)
+ ops.append(qml.Reflection(U, np.pi, reflection_wires=reflection_wires))
+
+ return ops
+
+ def queue(self, context=qml.QueuingManager):
+ for op in [self.hyperparameters["U"], self.hyperparameters["O"]]:
+ context.remove(op)
+ context.append(self)
+ return self
diff --git a/tests/templates/test_subroutines/test_amplitude_amplification.py b/tests/templates/test_subroutines/test_amplitude_amplification.py
new file mode 100644
index 00000000000..aa46bd8b4be
--- /dev/null
+++ b/tests/templates/test_subroutines/test_amplitude_amplification.py
@@ -0,0 +1,335 @@
+# Copyright 2018-2024 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Tests for the AmplitudeAmplification template.
+"""
+
+import pytest
+import numpy as np
+import pennylane as qml
+from pennylane.templates.subroutines.amplitude_amplification import _get_fixed_point_angles
+
+
+@qml.prod
+def generator(wires):
+ for wire in wires:
+ qml.Hadamard(wire)
+
+
+@qml.prod
+def oracle(items, wires):
+ for item in items:
+ qml.FlipSign(item, wires=wires)
+
+
+class TestInitialization:
+ """Test that AmplitudeAmplification initializes correctly."""
+
+ def test_error_none_wire(self):
+ """Test that an error is raised if work_wire is None and fixed_point is True."""
+
+ U = generator(wires=range(3))
+ O = oracle([0, 2], wires=range(3))
+
+ with pytest.raises(
+ qml.wires.WireError, match="work_wire must be specified if fixed_point == True."
+ ):
+ qml.AmplitudeAmplification(U, O, iters=3, fixed_point=True)
+
+ @pytest.mark.parametrize(
+ "wires, fixed_point, work_wire",
+ (
+ ([0, 1, 2], True, 2),
+ (["a", "b"], True, "a"),
+ ),
+ )
+ def test_error_wrong_work_wire(self, wires, fixed_point, work_wire):
+ """Test that an error is raised if work_wire is part of the O wires."""
+
+ U = generator(wires=wires)
+ O = oracle([0], wires=wires)
+
+ with pytest.raises(ValueError, match="work_wire must be different from the wires of O."):
+ qml.AmplitudeAmplification(U, O, iters=3, fixed_point=fixed_point, work_wire=work_wire)
+
+
+@pytest.mark.parametrize(
+ "n_wires, items, iters",
+ (
+ (3, [0, 2], 1),
+ (3, [1, 2], 2),
+ (5, [4, 5, 7, 12], 3),
+ (5, [0, 1, 2, 3, 4], 4),
+ ),
+)
+def test_compare_grover(n_wires, items, iters):
+ """Test that Grover's algorithm gives the same result with GroverOperator and AmplitudeAmplification."""
+ U = generator(wires=range(n_wires))
+ O = oracle(items, wires=range(n_wires))
+
+ dev = qml.device("default.qubit", wires=n_wires)
+
+ @qml.qnode(dev)
+ def circuit_amplitude_amplification():
+ generator(wires=range(n_wires))
+ qml.AmplitudeAmplification(U, O, iters)
+ return qml.probs(wires=range(n_wires))
+
+ @qml.qnode(dev)
+ def circuit_grover():
+ generator(wires=range(n_wires))
+
+ for _ in range(iters):
+ oracle(items, wires=range(n_wires))
+ qml.GroverOperator(wires=range(n_wires))
+
+ return qml.probs(wires=range(n_wires))
+
+ assert np.allclose(circuit_amplitude_amplification(), circuit_grover(), atol=1e-5)
+
+
+def test_default_lightning_devices():
+ """Test that AmplitudeAmplification executes with the default.qubit and lightning.qubit simulators."""
+
+ def circuit():
+ """Test circuit"""
+ qml.Hadamard(wires=0)
+ qml.Hadamard(wires=1)
+ qml.Hadamard(wires=2)
+
+ qml.AmplitudeAmplification(
+ generator(range(3)), oracle([0], range(3)), fixed_point=True, iters=3, work_wire=3
+ )
+ return qml.probs(wires=range(3))
+
+ dev1 = qml.device("default.qubit")
+ qnode1 = qml.QNode(circuit, dev1, interface=None)
+
+ res1 = qnode1()
+
+ dev2 = qml.device("lightning.qubit", wires=4)
+ qnode2 = qml.QNode(circuit, dev2)
+
+ res2 = qnode2()
+
+ assert np.allclose(res1, res2, atol=1e-5)
+
+
+class TestDifferentiability:
+ """Test that AmplitudeAmplification is differentiable"""
+
+ @staticmethod
+ def circuit(params):
+ qml.RY(params[0], wires=0)
+ qml.AmplitudeAmplification(
+ qml.RY(params[0], wires=0),
+ qml.RZ(params[1], wires=0),
+ iters=3,
+ fixed_point=True,
+ work_wire=3,
+ )
+
+ return qml.expval(qml.PauliZ(0))
+
+ # calculated numerically with finite diff method (h = 1e-5)
+ exp_grad = np.array([-0.88109663, -0.66429297])
+
+ params = np.array([0.9, 0.1])
+
+ @pytest.mark.autograd
+ def test_qnode_autograd(self):
+ """Test that the QNode executes with Autograd."""
+
+ dev = qml.device("default.qubit")
+ qnode = qml.QNode(self.circuit, dev, interface="autograd")
+
+ params = qml.numpy.array(self.params, requires_grad=True)
+ res = qml.grad(qnode)(params)
+ print(res)
+ assert qml.math.shape(res) == (2,)
+ assert np.allclose(res, self.exp_grad, atol=1e-5)
+
+ @pytest.mark.jax
+ @pytest.mark.parametrize("use_jit", [False, True])
+ @pytest.mark.parametrize("shots", [None, 50000])
+ def test_qnode_jax(self, shots, use_jit):
+ """Test that the QNode executes and is differentiable with JAX. The shots
+ argument controls whether autodiff or parameter-shift gradients are used."""
+ import jax
+
+ jax.config.update("jax_enable_x64", True)
+
+ dev = qml.device("default.qubit", shots=shots, seed=10)
+ diff_method = "backprop" if shots is None else "parameter-shift"
+ qnode = qml.QNode(self.circuit, dev, interface="jax", diff_method=diff_method)
+ if use_jit:
+ qnode = jax.jit(qnode)
+
+ params = jax.numpy.array(self.params)
+
+ jac_fn = jax.jacobian(qnode)
+ if use_jit:
+ jac_fn = jax.jit(jac_fn)
+
+ jac = jac_fn(params)
+ assert jac.shape == (2,)
+ assert np.allclose(jac, self.exp_grad, atol=0.01)
+
+ @pytest.mark.torch
+ @pytest.mark.parametrize("shots", [None, 50000])
+ def test_qnode_torch(self, shots):
+ """Test that the QNode executes and is differentiable with Torch. The shots
+ argument controls whether autodiff or parameter-shift gradients are used."""
+ import torch
+
+ dev = qml.device("default.qubit", shots=shots, seed=10)
+ diff_method = "backprop" if shots is None else "parameter-shift"
+ qnode = qml.QNode(self.circuit, dev, interface="torch", diff_method=diff_method)
+
+ params = torch.tensor(self.params, requires_grad=True)
+ jac = torch.autograd.functional.jacobian(qnode, params)
+ assert qml.math.shape(jac) == (2,)
+ assert qml.math.allclose(jac, self.exp_grad, atol=0.01)
+
+ @pytest.mark.tf
+ @pytest.mark.parametrize("shots", [None, 50000])
+ @pytest.mark.xfail(reason="tf gradient doesn't seem to be working, returns ()")
+ def test_qnode_tf(self, shots):
+ """Test that the QNode executes and is differentiable with TensorFlow. The shots
+ argument controls whether autodiff or parameter-shift gradients are used."""
+ import tensorflow as tf
+
+ dev = qml.device("default.qubit", shots=shots, seed=10)
+ diff_method = "backprop" if shots is None else "parameter-shift"
+ qnode = qml.QNode(self.circuit, dev, interface="tf", diff_method=diff_method)
+
+ params = tf.Variable(self.params)
+ with tf.GradientTape() as tape:
+ res = qnode(params)
+
+ jac = tape.gradient(res, params)
+ assert qml.math.shape(jac) == (8,)
+ assert qml.math.allclose(res, self.exp_grad, atol=0.001)
+
+
+def test_correct_queueing():
+ """Test that operations in a circuit containing AmplitudeAmplification are correctly queued"""
+ dev = qml.device("default.qubit")
+
+ @qml.qnode(dev)
+ def circuit1():
+ qml.Hadamard(wires=0)
+ qml.Hadamard(wires=1)
+ qml.Hadamard(wires=2)
+
+ qml.AmplitudeAmplification(generator(range(3)), oracle([0], range(3)))
+ return qml.state()
+
+ @qml.qnode(dev)
+ def circuit2():
+ generator(wires=[0, 1, 2])
+
+ qml.AmplitudeAmplification(generator(range(3)), oracle([0], range(3)))
+ return qml.state()
+
+ U = generator(wires=[0, 1, 2])
+ O = oracle([0], wires=[0, 1, 2])
+
+ @qml.qnode(dev)
+ def circuit3():
+ generator(wires=[0, 1, 2])
+
+ qml.AmplitudeAmplification(U=U, O=O)
+ return qml.state()
+
+ assert np.allclose(circuit1(), circuit2())
+ assert np.allclose(circuit1(), circuit3())
+
+
+# pylint: disable=protected-access
+def test_flatten_and_unflatten():
+ """Test the _flatten and _unflatten methods for AmplitudeAmplification."""
+
+ op = qml.AmplitudeAmplification(qml.RX(0.25, wires=0), qml.PauliZ(0))
+ data, metadata = op._flatten()
+
+ assert len(data) == 2
+ assert len(metadata) == 5
+
+ new_op = type(op)._unflatten(*op._flatten())
+ assert qml.equal(op, new_op)
+ assert op is not new_op
+
+ assert hash(metadata)
+
+
+def test_amplification():
+ """Test that AmplitudeAmplification amplifies a marked element."""
+
+ U = generator(wires=range(3))
+ O = oracle([2], wires=range(3))
+
+ dev = qml.device("default.qubit")
+
+ @qml.qnode(dev)
+ def circuit():
+ generator(wires=range(3))
+ qml.AmplitudeAmplification(U, O, iters=5, fixed_point=True, work_wire=3)
+
+ return qml.probs(wires=range(3))
+
+ res = np.round(circuit(), 3)
+
+ expected = np.array([0.013, 0.013, 0.91, 0.013, 0.013, 0.013, 0.013, 0.013])
+
+ assert np.allclose(res, expected)
+
+
+@pytest.mark.parametrize(("p_min"), [0.7, 0.8, 0.9])
+def test_p_min(p_min):
+ """Test that the p_min parameter works correctly."""
+
+ dev = qml.device("default.qubit")
+
+ U = generator(wires=range(4))
+ O = oracle([0], wires=range(4))
+
+ @qml.qnode(dev)
+ def circuit():
+ generator(wires=range(4))
+
+ qml.AmplitudeAmplification(U, O, fixed_point=True, work_wire=4, p_min=p_min, iters=11)
+
+ return qml.probs(wires=range(4))
+
+ assert circuit()[0] >= p_min
+
+
+@pytest.mark.parametrize(
+ "iters, p_min",
+ (
+ (4, 0.8),
+ (5, 0.9),
+ (6, 0.95),
+ ),
+)
+def test_fixed_point_angles_function(iters, p_min):
+ """Test that the _get_fixed_point_angles function works correctly."""
+
+ alphas, betas = _get_fixed_point_angles(iters, p_min)
+
+ assert np.all(alphas[:-1] > alphas[1:])
+ assert np.all(betas[:-1] > betas[1:])
+
+ assert np.allclose(betas, np.array([-alpha for alpha in reversed(alphas)]))