Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lightning qubit uses parameter shift if metric tensor applied #5624

Merged
merged 16 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/releases/changelog-0.36.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,10 @@

<h3>Bug fixes 🐛</h3>

* Patches the QNode so that parameter-shift will be considered best with lightning if
`qml.metric_tensor` is in the transform program.
trbromley marked this conversation as resolved.
Show resolved Hide resolved
[(#5624)](https://github.com/PennyLaneAI/pennylane/pull/5624)

* Stopped printing the ID of `qcut.MeasureNode` and `qcut.PrepareNode` in tape drawing.
[(#5613)](https://github.com/PennyLaneAI/pennylane/pull/5613)

Expand Down
50 changes: 20 additions & 30 deletions pennylane/gradients/metric_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,19 +469,14 @@ def _metric_tensor_cov_matrix(tape, argnum, diag_approx): # pylint: disable=too
# Create a quantum tape with all operations
# prior to the parametrized layer, and the rotations
# to measure in the basis of the parametrized layer generators.
with qml.queuing.AnnotatedQueue() as layer_q:
for op in queue:
# TODO: Maybe there are gates that do not affect the
# generators of interest and thus need not be applied.
qml.apply(op)
# TODO: Maybe there are gates that do not affect the
# generators of interest and thus need not be applied.

for o, param_in_argnum in zip(layer_obs, in_argnum_list[-1]):
if param_in_argnum:
o.diagonalizing_gates()

qml.probs(wires=tape.wires)
for o, param_in_argnum in zip(layer_obs, in_argnum_list[-1]):
if param_in_argnum:
queue.extend(o.diagonalizing_gates())

layer_tape = qml.tape.QuantumScript.from_queue(layer_q)
layer_tape = qml.tape.QuantumScript(queue, [qml.probs(wires=tape.wires)], shots=tape.shots)
metric_tensor_tapes.append(layer_tape)

def processing_fn(probs):
Expand Down Expand Up @@ -573,7 +568,7 @@ def _get_gen_op(op, allow_nonunitary, aux_wire):
) from e


def _get_first_term_tapes(layer_i, layer_j, allow_nonunitary, aux_wire):
def _get_first_term_tapes(layer_i, layer_j, allow_nonunitary, aux_wire, shots):
r"""Obtain the tapes for the first term of all tensor entries
belonging to an off-diagonal block.

Expand Down Expand Up @@ -610,23 +605,16 @@ def _get_first_term_tapes(layer_i, layer_j, allow_nonunitary, aux_wire):
for diffed_op_j, par_idx_j in zip(layer_j.ops, layer_j.param_inds):
gen_op_j = _get_gen_op(WrappedObj(diffed_op_j), allow_nonunitary, aux_wire)

with qml.queuing.AnnotatedQueue() as q:
# Initialize auxiliary wire
qml.Hadamard(wires=aux_wire)
# Apply backward cone of first layer
for op in layer_i.pre_ops:
qml.apply(op)
# Controlled-generator operation of first diff'ed op
qml.apply(gen_op_i)
# Apply first layer and operations between layers
for op in ops_between_cgens:
qml.apply(op)
# Controlled-generator operation of second diff'ed op
qml.apply(gen_op_j)
# Measure X on auxiliary wire
qml.expval(qml.X(aux_wire))

tapes.append(qml.tape.QuantumScript.from_queue(q))
ops = [
qml.Hadamard(wires=aux_wire),
*layer_i.pre_ops,
gen_op_i,
*ops_between_cgens,
gen_op_j,
]
new_tape = qml.tape.QuantumScript(ops, [qml.expval(qml.X(aux_wire))], shots=shots)

tapes.append(new_tape)
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
# Memorize to which metric entry this tape belongs
ids.append((par_idx_i, par_idx_j))

Expand Down Expand Up @@ -707,7 +695,9 @@ def _metric_tensor_hadamard(
block_sizes.append(len(layer_i.param_inds))

for layer_j in layers[idx_i + 1 :]:
_tapes, _ids = _get_first_term_tapes(layer_i, layer_j, allow_nonunitary, aux_wire)
_tapes, _ids = _get_first_term_tapes(
layer_i, layer_j, allow_nonunitary, aux_wire, shots=tape.shots
)
first_term_tapes.extend(_tapes)
ids.extend(_ids)

Expand Down
14 changes: 12 additions & 2 deletions pennylane/workflow/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,9 +527,9 @@ def __init__(
self.gradient_kwargs = {}
self._tape_cached = False

self._transform_program = qml.transforms.core.TransformProgram()
self._update_gradient_fn()
functools.update_wrapper(self, func)
self._transform_program = qml.transforms.core.TransformProgram()

def __copy__(self):
copied_qnode = QNode.__new__(QNode)
Expand Down Expand Up @@ -592,8 +592,17 @@ def _update_gradient_fn(self, shots=None, tape=None):
return
if tape is None and shots:
tape = qml.tape.QuantumScript([], [], shots=shots)

diff_method = self.diff_method
if (
self.device.name == "lightning.qubit"
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
and qml.metric_tensor in self.transform_program
and self.diff_method == "best"
):
diff_method = "parameter-shift"

self.gradient_fn, self.gradient_kwargs, self.device = self.get_gradient_fn(
self._original_device, self.interface, self.diff_method, tape=tape
self._original_device, self.interface, diff_method, tape=tape
)
self.gradient_kwargs.update(self._user_gradient_kwargs or {})

Expand Down Expand Up @@ -714,6 +723,7 @@ def get_best_method(device, interface, tape=None):
"""
config = _make_execution_config(None, "best")
if isinstance(device, qml.devices.Device):

if device.supports_derivatives(config, circuit=tape):
new_config = device.preprocess(config)[1]
return new_config.gradient_method, {}, device
Expand Down
1 change: 1 addition & 0 deletions tests/gradients/core/test_jvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def test_dtype_jax(self, dtype1, dtype2):
determined by the dtype of the dy."""
import jax

jax.config.update("jax_enable_x64", True)
dtype = dtype1
dtype1 = getattr(jax.numpy, dtype1)
dtype2 = getattr(jax.numpy, dtype2)
Expand Down
110 changes: 66 additions & 44 deletions tests/gradients/core/test_metric_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ def test_no_trainable_params_tape(self):
mt_tapes, post_processing = qml.metric_tensor(tape)
res = post_processing(qml.execute(mt_tapes, dev, None))

assert mt_tapes == []
assert mt_tapes == [] # pylint: disable=use-implicit-booleaness-not-comparison
assert res == ()


Expand Down Expand Up @@ -1091,8 +1091,13 @@ def qnode(*params):

def mt(*params):
state = qnode(*params)
rqnode = lambda *params: np.real(qnode(*params))
iqnode = lambda *params: np.imag(qnode(*params))

def rqnode(*params):
return np.real(qnode(*params))

def iqnode(*params):
return np.imag(qnode(*params))

rjac = qml.jacobian(rqnode)(*params)
ijac = qml.jacobian(iqnode)(*params)

Expand Down Expand Up @@ -1125,9 +1130,11 @@ class TestFullMetricTensor:
@pytest.mark.autograd
@pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params))
@pytest.mark.parametrize("interface", ["auto", "autograd"])
def test_correct_output_autograd(self, ansatz, params, interface):
@pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit"))
def test_correct_output_autograd(self, dev_name, ansatz, params, interface):

expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params)
dev = qml.device("default.qubit.autograd", wires=self.num_wires + 1)
dev = qml.device(dev_name, wires=self.num_wires + 1)

@qml.qnode(dev, interface=interface)
def circuit(*params):
Expand All @@ -1145,14 +1152,20 @@ def circuit(*params):
@pytest.mark.jax
@pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params))
@pytest.mark.parametrize("interface", ["auto", "jax"])
def test_correct_output_jax(self, ansatz, params, interface):
@pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit"))
def test_correct_output_jax(self, dev_name, ansatz, params, interface):
import jax
from jax import numpy as jnp

if ansatz == fubini_ansatz2:
pytest.xfail("Issue involving trainable indices to be resolved.")
if ansatz == fubini_ansatz3 and dev_name == "lightning.qubit":
pytest.xfail("Issue invovling trainable_params to be resolved.")

jax.config.update("jax_enable_x64", True)

expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params)
dev = qml.device("default.qubit.jax", wires=self.num_wires + 1)
dev = qml.device(dev_name, wires=self.num_wires + 1)

params = tuple(jnp.array(p) for p in params)

Expand All @@ -1176,10 +1189,11 @@ def circuit(*params):
@pytest.mark.jax
@pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params))
@pytest.mark.parametrize("interface", ["auto", "jax"])
def test_jax_argnum_error(self, ansatz, params, interface):
@pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit"))
def test_jax_argnum_error(self, dev_name, ansatz, params, interface):
from jax import numpy as jnp

dev = qml.device("default.qubit.jax", wires=self.num_wires + 1)
dev = qml.device(dev_name, wires=self.num_wires + 1)

params = tuple(jnp.array(p) for p in params)

Expand All @@ -1198,11 +1212,12 @@ def circuit(*params):
@pytest.mark.torch
@pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params))
@pytest.mark.parametrize("interface", ["auto", "torch"])
def test_correct_output_torch(self, ansatz, params, interface):
@pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit"))
def test_correct_output_torch(self, dev_name, ansatz, params, interface):
import torch

expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params)
dev = qml.device("default.qubit.torch", wires=self.num_wires + 1)
dev = qml.device(dev_name, wires=self.num_wires + 1)

params = tuple(torch.tensor(p, dtype=torch.float64, requires_grad=True) for p in params)

Expand All @@ -1222,11 +1237,12 @@ def circuit(*params):
@pytest.mark.tf
@pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params))
@pytest.mark.parametrize("interface", ["auto", "tf"])
def test_correct_output_tf(self, ansatz, params, interface):
@pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit"))
def test_correct_output_tf(self, dev_name, ansatz, params, interface):
import tensorflow as tf

expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params)
dev = qml.device("default.qubit.tf", wires=self.num_wires + 1)
dev = qml.device(dev_name, wires=self.num_wires + 1)

params = tuple(tf.Variable(p, dtype=tf.float64) for p in params)

Expand Down Expand Up @@ -1254,17 +1270,18 @@ def diffability_ansatz_0(weights, wires=None):
qml.RZ(weights[2], wires=1)


expected_diag_jac_0 = lambda weights: np.array(
[
[0, 0, 0],
[0, 0, 0],
def expected_diag_jac_0(weights):
return np.array(
[
np.cos(weights[0] + weights[1]) * np.sin(weights[0] + weights[1]) / 2,
np.cos(weights[0] + weights[1]) * np.sin(weights[0] + weights[1]) / 2,
0,
],
]
)
[0, 0, 0],
[0, 0, 0],
[
np.cos(weights[0] + weights[1]) * np.sin(weights[0] + weights[1]) / 2,
np.cos(weights[0] + weights[1]) * np.sin(weights[0] + weights[1]) / 2,
0,
],
]
)


def diffability_ansatz_1(weights, wires=None):
Expand All @@ -1275,17 +1292,18 @@ def diffability_ansatz_1(weights, wires=None):
qml.RZ(weights[2], wires=1)


expected_diag_jac_1 = lambda weights: np.array(
[
[0, 0, 0],
[-np.sin(2 * weights[0]) / 4, 0, 0],
def expected_diag_jac_1(weights):
return np.array(
[
np.cos(weights[0]) * np.cos(weights[1]) ** 2 * np.sin(weights[0]) / 2,
np.cos(weights[0]) ** 2 * np.sin(2 * weights[1]) / 4,
0,
],
]
)
[0, 0, 0],
[-np.sin(2 * weights[0]) / 4, 0, 0],
[
np.cos(weights[0]) * np.cos(weights[1]) ** 2 * np.sin(weights[0]) / 2,
np.cos(weights[0]) ** 2 * np.sin(2 * weights[1]) / 4,
0,
],
]
)


def diffability_ansatz_2(weights, wires=None):
Expand All @@ -1296,17 +1314,19 @@ def diffability_ansatz_2(weights, wires=None):
qml.RZ(weights[2], wires=1)


expected_diag_jac_2 = lambda weights: np.array(
[
[0, 0, 0],
[0, 0, 0],
def expected_diag_jac_2(weights):
return np.array(
[
np.cos(weights[1]) ** 2 * np.sin(2 * weights[0]) / 4,
np.cos(weights[0]) ** 2 * np.sin(2 * weights[1]) / 4,
0,
],
]
)
[0, 0, 0],
[0, 0, 0],
[
np.cos(weights[1]) ** 2 * np.sin(2 * weights[0]) / 4,
np.cos(weights[0]) ** 2 * np.sin(2 * weights[1]) / 4,
0,
],
]
)


weights_diff = np.array([0.432, 0.12, -0.292], requires_grad=True)

Expand Down Expand Up @@ -1466,7 +1486,9 @@ def test_autograd(self, diff_method, tol, ansatz, weights, interface):
def cost_full(*weights):
return np.array(qml.metric_tensor(qnode, approx=None)(*weights))

_cost_full = lambda *weights: np.array(autodiff_metric_tensor(ansatz, 3)(*weights))
def _cost_full(*weights):
return np.array(autodiff_metric_tensor(ansatz, 3)(*weights))

_c = _cost_full(*weights)
c = cost_full(*weights)
assert all(
Expand Down
Loading