diff --git a/doc/releases/changelog-0.36.0.md b/doc/releases/changelog-0.36.0.md index 17d8775df55..2798cba12aa 100644 --- a/doc/releases/changelog-0.36.0.md +++ b/doc/releases/changelog-0.36.0.md @@ -722,6 +722,10 @@

Bug fixes 🐛

+* Patches the QNode so that parameter-shift will be considered best with lightning if + `qml.metric_tensor` is in the transform program. + [(#5624)](https://github.com/PennyLaneAI/pennylane/pull/5624) + * Stopped printing the ID of `qcut.MeasureNode` and `qcut.PrepareNode` in tape drawing. [(#5613)](https://github.com/PennyLaneAI/pennylane/pull/5613) diff --git a/pennylane/gradients/metric_tensor.py b/pennylane/gradients/metric_tensor.py index 62e16562475..fbde615c029 100644 --- a/pennylane/gradients/metric_tensor.py +++ b/pennylane/gradients/metric_tensor.py @@ -469,19 +469,14 @@ def _metric_tensor_cov_matrix(tape, argnum, diag_approx): # pylint: disable=too # Create a quantum tape with all operations # prior to the parametrized layer, and the rotations # to measure in the basis of the parametrized layer generators. - with qml.queuing.AnnotatedQueue() as layer_q: - for op in queue: - # TODO: Maybe there are gates that do not affect the - # generators of interest and thus need not be applied. - qml.apply(op) + # TODO: Maybe there are gates that do not affect the + # generators of interest and thus need not be applied. - for o, param_in_argnum in zip(layer_obs, in_argnum_list[-1]): - if param_in_argnum: - o.diagonalizing_gates() - - qml.probs(wires=tape.wires) + for o, param_in_argnum in zip(layer_obs, in_argnum_list[-1]): + if param_in_argnum: + queue.extend(o.diagonalizing_gates()) - layer_tape = qml.tape.QuantumScript.from_queue(layer_q) + layer_tape = qml.tape.QuantumScript(queue, [qml.probs(wires=tape.wires)], shots=tape.shots) metric_tensor_tapes.append(layer_tape) def processing_fn(probs): @@ -573,7 +568,7 @@ def _get_gen_op(op, allow_nonunitary, aux_wire): ) from e -def _get_first_term_tapes(layer_i, layer_j, allow_nonunitary, aux_wire): +def _get_first_term_tapes(layer_i, layer_j, allow_nonunitary, aux_wire, shots): r"""Obtain the tapes for the first term of all tensor entries belonging to an off-diagonal block. @@ -610,23 +605,16 @@ def _get_first_term_tapes(layer_i, layer_j, allow_nonunitary, aux_wire): for diffed_op_j, par_idx_j in zip(layer_j.ops, layer_j.param_inds): gen_op_j = _get_gen_op(WrappedObj(diffed_op_j), allow_nonunitary, aux_wire) - with qml.queuing.AnnotatedQueue() as q: - # Initialize auxiliary wire - qml.Hadamard(wires=aux_wire) - # Apply backward cone of first layer - for op in layer_i.pre_ops: - qml.apply(op) - # Controlled-generator operation of first diff'ed op - qml.apply(gen_op_i) - # Apply first layer and operations between layers - for op in ops_between_cgens: - qml.apply(op) - # Controlled-generator operation of second diff'ed op - qml.apply(gen_op_j) - # Measure X on auxiliary wire - qml.expval(qml.X(aux_wire)) - - tapes.append(qml.tape.QuantumScript.from_queue(q)) + ops = [ + qml.Hadamard(wires=aux_wire), + *layer_i.pre_ops, + gen_op_i, + *ops_between_cgens, + gen_op_j, + ] + new_tape = qml.tape.QuantumScript(ops, [qml.expval(qml.X(aux_wire))], shots=shots) + + tapes.append(new_tape) # Memorize to which metric entry this tape belongs ids.append((par_idx_i, par_idx_j)) @@ -707,7 +695,9 @@ def _metric_tensor_hadamard( block_sizes.append(len(layer_i.param_inds)) for layer_j in layers[idx_i + 1 :]: - _tapes, _ids = _get_first_term_tapes(layer_i, layer_j, allow_nonunitary, aux_wire) + _tapes, _ids = _get_first_term_tapes( + layer_i, layer_j, allow_nonunitary, aux_wire, shots=tape.shots + ) first_term_tapes.extend(_tapes) ids.extend(_ids) diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 7941fa89fe3..2cf3f89ec42 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -527,9 +527,9 @@ def __init__( self.gradient_kwargs = {} self._tape_cached = False + self._transform_program = qml.transforms.core.TransformProgram() self._update_gradient_fn() functools.update_wrapper(self, func) - self._transform_program = qml.transforms.core.TransformProgram() def __copy__(self): copied_qnode = QNode.__new__(QNode) @@ -592,8 +592,17 @@ def _update_gradient_fn(self, shots=None, tape=None): return if tape is None and shots: tape = qml.tape.QuantumScript([], [], shots=shots) + + diff_method = self.diff_method + if ( + self.device.name == "lightning.qubit" + and qml.metric_tensor in self.transform_program + and self.diff_method == "best" + ): + diff_method = "parameter-shift" + self.gradient_fn, self.gradient_kwargs, self.device = self.get_gradient_fn( - self._original_device, self.interface, self.diff_method, tape=tape + self._original_device, self.interface, diff_method, tape=tape ) self.gradient_kwargs.update(self._user_gradient_kwargs or {}) @@ -714,6 +723,7 @@ def get_best_method(device, interface, tape=None): """ config = _make_execution_config(None, "best") if isinstance(device, qml.devices.Device): + if device.supports_derivatives(config, circuit=tape): new_config = device.preprocess(config)[1] return new_config.gradient_method, {}, device diff --git a/tests/gradients/core/test_jvp.py b/tests/gradients/core/test_jvp.py index 7445b5f8a2b..54bdb051572 100644 --- a/tests/gradients/core/test_jvp.py +++ b/tests/gradients/core/test_jvp.py @@ -284,6 +284,7 @@ def test_dtype_jax(self, dtype1, dtype2): determined by the dtype of the dy.""" import jax + jax.config.update("jax_enable_x64", True) dtype = dtype1 dtype1 = getattr(jax.numpy, dtype1) dtype2 = getattr(jax.numpy, dtype2) diff --git a/tests/gradients/core/test_metric_tensor.py b/tests/gradients/core/test_metric_tensor.py index bd234062736..01b43bc8177 100644 --- a/tests/gradients/core/test_metric_tensor.py +++ b/tests/gradients/core/test_metric_tensor.py @@ -913,7 +913,7 @@ def test_no_trainable_params_tape(self): mt_tapes, post_processing = qml.metric_tensor(tape) res = post_processing(qml.execute(mt_tapes, dev, None)) - assert mt_tapes == [] + assert mt_tapes == [] # pylint: disable=use-implicit-booleaness-not-comparison assert res == () @@ -1091,8 +1091,13 @@ def qnode(*params): def mt(*params): state = qnode(*params) - rqnode = lambda *params: np.real(qnode(*params)) - iqnode = lambda *params: np.imag(qnode(*params)) + + def rqnode(*params): + return np.real(qnode(*params)) + + def iqnode(*params): + return np.imag(qnode(*params)) + rjac = qml.jacobian(rqnode)(*params) ijac = qml.jacobian(iqnode)(*params) @@ -1125,9 +1130,11 @@ class TestFullMetricTensor: @pytest.mark.autograd @pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params)) @pytest.mark.parametrize("interface", ["auto", "autograd"]) - def test_correct_output_autograd(self, ansatz, params, interface): + @pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit")) + def test_correct_output_autograd(self, dev_name, ansatz, params, interface): + expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params) - dev = qml.device("default.qubit.autograd", wires=self.num_wires + 1) + dev = qml.device(dev_name, wires=self.num_wires + 1) @qml.qnode(dev, interface=interface) def circuit(*params): @@ -1145,14 +1152,20 @@ def circuit(*params): @pytest.mark.jax @pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params)) @pytest.mark.parametrize("interface", ["auto", "jax"]) - def test_correct_output_jax(self, ansatz, params, interface): + @pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit")) + def test_correct_output_jax(self, dev_name, ansatz, params, interface): import jax from jax import numpy as jnp + if ansatz == fubini_ansatz2: + pytest.xfail("Issue involving trainable indices to be resolved.") + if ansatz == fubini_ansatz3 and dev_name == "lightning.qubit": + pytest.xfail("Issue invovling trainable_params to be resolved.") + jax.config.update("jax_enable_x64", True) expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params) - dev = qml.device("default.qubit.jax", wires=self.num_wires + 1) + dev = qml.device(dev_name, wires=self.num_wires + 1) params = tuple(jnp.array(p) for p in params) @@ -1176,10 +1189,11 @@ def circuit(*params): @pytest.mark.jax @pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params)) @pytest.mark.parametrize("interface", ["auto", "jax"]) - def test_jax_argnum_error(self, ansatz, params, interface): + @pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit")) + def test_jax_argnum_error(self, dev_name, ansatz, params, interface): from jax import numpy as jnp - dev = qml.device("default.qubit.jax", wires=self.num_wires + 1) + dev = qml.device(dev_name, wires=self.num_wires + 1) params = tuple(jnp.array(p) for p in params) @@ -1198,11 +1212,12 @@ def circuit(*params): @pytest.mark.torch @pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params)) @pytest.mark.parametrize("interface", ["auto", "torch"]) - def test_correct_output_torch(self, ansatz, params, interface): + @pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit")) + def test_correct_output_torch(self, dev_name, ansatz, params, interface): import torch expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params) - dev = qml.device("default.qubit.torch", wires=self.num_wires + 1) + dev = qml.device(dev_name, wires=self.num_wires + 1) params = tuple(torch.tensor(p, dtype=torch.float64, requires_grad=True) for p in params) @@ -1222,11 +1237,12 @@ def circuit(*params): @pytest.mark.tf @pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params)) @pytest.mark.parametrize("interface", ["auto", "tf"]) - def test_correct_output_tf(self, ansatz, params, interface): + @pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit")) + def test_correct_output_tf(self, dev_name, ansatz, params, interface): import tensorflow as tf expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params) - dev = qml.device("default.qubit.tf", wires=self.num_wires + 1) + dev = qml.device(dev_name, wires=self.num_wires + 1) params = tuple(tf.Variable(p, dtype=tf.float64) for p in params) @@ -1254,17 +1270,18 @@ def diffability_ansatz_0(weights, wires=None): qml.RZ(weights[2], wires=1) -expected_diag_jac_0 = lambda weights: np.array( - [ - [0, 0, 0], - [0, 0, 0], +def expected_diag_jac_0(weights): + return np.array( [ - np.cos(weights[0] + weights[1]) * np.sin(weights[0] + weights[1]) / 2, - np.cos(weights[0] + weights[1]) * np.sin(weights[0] + weights[1]) / 2, - 0, - ], - ] -) + [0, 0, 0], + [0, 0, 0], + [ + np.cos(weights[0] + weights[1]) * np.sin(weights[0] + weights[1]) / 2, + np.cos(weights[0] + weights[1]) * np.sin(weights[0] + weights[1]) / 2, + 0, + ], + ] + ) def diffability_ansatz_1(weights, wires=None): @@ -1275,17 +1292,18 @@ def diffability_ansatz_1(weights, wires=None): qml.RZ(weights[2], wires=1) -expected_diag_jac_1 = lambda weights: np.array( - [ - [0, 0, 0], - [-np.sin(2 * weights[0]) / 4, 0, 0], +def expected_diag_jac_1(weights): + return np.array( [ - np.cos(weights[0]) * np.cos(weights[1]) ** 2 * np.sin(weights[0]) / 2, - np.cos(weights[0]) ** 2 * np.sin(2 * weights[1]) / 4, - 0, - ], - ] -) + [0, 0, 0], + [-np.sin(2 * weights[0]) / 4, 0, 0], + [ + np.cos(weights[0]) * np.cos(weights[1]) ** 2 * np.sin(weights[0]) / 2, + np.cos(weights[0]) ** 2 * np.sin(2 * weights[1]) / 4, + 0, + ], + ] + ) def diffability_ansatz_2(weights, wires=None): @@ -1296,17 +1314,19 @@ def diffability_ansatz_2(weights, wires=None): qml.RZ(weights[2], wires=1) -expected_diag_jac_2 = lambda weights: np.array( - [ - [0, 0, 0], - [0, 0, 0], +def expected_diag_jac_2(weights): + return np.array( [ - np.cos(weights[1]) ** 2 * np.sin(2 * weights[0]) / 4, - np.cos(weights[0]) ** 2 * np.sin(2 * weights[1]) / 4, - 0, - ], - ] -) + [0, 0, 0], + [0, 0, 0], + [ + np.cos(weights[1]) ** 2 * np.sin(2 * weights[0]) / 4, + np.cos(weights[0]) ** 2 * np.sin(2 * weights[1]) / 4, + 0, + ], + ] + ) + weights_diff = np.array([0.432, 0.12, -0.292], requires_grad=True) @@ -1466,7 +1486,9 @@ def test_autograd(self, diff_method, tol, ansatz, weights, interface): def cost_full(*weights): return np.array(qml.metric_tensor(qnode, approx=None)(*weights)) - _cost_full = lambda *weights: np.array(autodiff_metric_tensor(ansatz, 3)(*weights)) + def _cost_full(*weights): + return np.array(autodiff_metric_tensor(ansatz, 3)(*weights)) + _c = _cost_full(*weights) c = cost_full(*weights) assert all(