Skip to content

Commit

Permalink
quantum_fisher transform on non-default.qubit devices (#5423)
Browse files Browse the repository at this point in the history
**Context:**

The `qml.qinfo.quantum_fisher` transform was failing on
non-`defualt.qubit` devices:

```
import pennylane as qml
from pennylane import numpy as np

from sklearn import datasets as ds

NUM_WIRES = 4

def get_circuit(data, parameters):
    qml.IQPEmbedding(data, wires=range(NUM_WIRES), n_repeats=1)
    for i in range(NUM_WIRES):
        qml.RX(parameters[i], wires=i)
    for j in range(NUM_WIRES-1):
        qml.CNOT(wires=[j, j+1])
    return qml.expval(qml.PauliZ(0))

# use different devices here, i.e. lightning.qubit, default.qubit, lightning.kokkos
dev = qml.device("default.mixed", wires=NUM_WIRES)

def qfim(X_train, parameters):
    circuit = qml.QNode(get_circuit, dev)
    data = np.array(X_train[0], requires_grad=False)
    return qml.qinfo.transforms.quantum_fisher(circuit)(data, parameters)

X = ds.load_iris().data
parameters = np.random.random(size=NUM_WIRES, requires_grad=True)
print(qfim(X, parameters))
```

**Description of the Change:**

Use the `metric_tensor` instead of the `adjoint_metric_tensor` if the
device is not `default.qubit`.

**Benefits:**

`quantum_fisher` works with more devices.

**Possible Drawbacks:**

**Related GitHub Issues:**

Fixes #5381  [sc-58882]
  • Loading branch information
albi3ro authored Mar 20, 2024
1 parent 186faaf commit 23f1a7a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 3 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,9 @@

<h3>Bug fixes 🐛</h3>

* `qml.qinfo.quantum_fisher` now works with non-`default.qubit` devices.
[(#5423)](https://github.com/PennyLaneAI/pennylane/pull/5423)

* We no longer perform unwanted dtype promotion in the `pauli_rep` of `SProd` instances when using tensorflow.
[(#5246)](https://github.com/PennyLaneAI/pennylane/pull/5246)

Expand Down
2 changes: 1 addition & 1 deletion pennylane/qinfo/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ def circ(params):
"""

if device.shots and isinstance(device, (DefaultQubitLegacy, DefaultQubit)):
if device.shots or not isinstance(device, (DefaultQubitLegacy, DefaultQubit)):
tapes, processing_fn = metric_tensor(tape, *args, **kwargs)

def processing_fn_multiply(res):
Expand Down
11 changes: 9 additions & 2 deletions tests/qinfo/test_fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,19 @@ def circ(params):
res = qml.qinfo.classical_fisher(circ)(params)
assert np.allclose(res, n_wires * np.ones((n_params, n_params)), atol=1)

def test_quantum_fisher_info(self):
@pytest.mark.parametrize(
"dev",
(
qml.device("default.qubit"),
qml.device("default.mixed", wires=3),
qml.device("lightning.qubit", wires=3),
),
)
def test_quantum_fisher_info(self, dev):
"""Integration test of quantum fisher information matrix CFIM. This is just calling ``qml.metric_tensor`` or ``qml.adjoint_metric_tensor`` and multiplying by a factor of 4"""

n_wires = 2

dev = qml.device("default.qubit", wires=n_wires)
dev_hard = qml.device("default.qubit", wires=n_wires + 1, shots=1000)

def qfunc(params):
Expand Down

0 comments on commit 23f1a7a

Please sign in to comment.