Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] Probabilities do not sum to one with Torch #5462

Merged
merged 21 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
1a9a5d2
first possible fix (checking tests)
PietropaoloFrisoni Apr 2, 2024
1b515bc
changing import order
PietropaoloFrisoni Apr 2, 2024
d74f5eb
removing torch import
PietropaoloFrisoni Apr 2, 2024
2c1dcdb
moving change to `sampling.py`
PietropaoloFrisoni Apr 2, 2024
a903435
moving change to `sampling.py`
PietropaoloFrisoni Apr 2, 2024
4cb8b23
cleaning up
PietropaoloFrisoni Apr 2, 2024
4f3c925
updating changelog
PietropaoloFrisoni Apr 2, 2024
27c537e
Merge branch 'master' into bugfix/probs_do_not_sum_to_one_torch
PietropaoloFrisoni Apr 3, 2024
c2dc97d
removing unnecessary brackets
PietropaoloFrisoni Apr 3, 2024
a431763
Merge branch 'master' into bugfix/probs_do_not_sum_to_one_torch
PietropaoloFrisoni Apr 4, 2024
4e1bc12
renormalizing non-batched states as well for safety
PietropaoloFrisoni Apr 4, 2024
d11f2ba
Merge branch 'bugfix/probs_do_not_sum_to_one_torch' of https://github…
PietropaoloFrisoni Apr 4, 2024
d555bae
Merge branch 'master' into bugfix/probs_do_not_sum_to_one_torch
PietropaoloFrisoni Apr 8, 2024
a1e8453
adding unit test
PietropaoloFrisoni Apr 8, 2024
50a156a
cycling just once over `abs_diff`
PietropaoloFrisoni Apr 8, 2024
8e19148
Triggering CI
PietropaoloFrisoni Apr 8, 2024
a8f6cd2
Merge branch 'master' into bugfix/probs_do_not_sum_to_one_torch
PietropaoloFrisoni Apr 8, 2024
9767129
Merge branch 'master' into bugfix/probs_do_not_sum_to_one_torch
PietropaoloFrisoni Apr 8, 2024
424c93f
breaking the error case and the normal case up into two tests.
PietropaoloFrisoni Apr 8, 2024
9d4afe4
Merge branch 'master' into bugfix/probs_do_not_sum_to_one_torch
PietropaoloFrisoni Apr 8, 2024
a8116bc
Triggering CI
PietropaoloFrisoni Apr 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,9 @@

<h3>Bug fixes 🐛</h3>

* The probabilities now sum to one using the `torch` interface with `default_dtype` set to `torch.float32`.
[(#5462)](https://github.com/PennyLaneAI/pennylane/pull/5462)

* Tensorflow can now handle devices with float32 results but float64 input parameters.
[(#5446)](https://github.com/PennyLaneAI/pennylane/pull/5446)

Expand Down
24 changes: 24 additions & 0 deletions pennylane/devices/qubit/sampling.py
PietropaoloFrisoni marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -448,10 +448,34 @@ def sample_state(
with qml.queuing.QueuingManager.stop_recording():
probs = qml.probs(wires=wires_to_sample).process_state(flat_state, state_wires)

# when using the torch interface with float32 as default dtype,
# probabilities must be renormalized as they may not sum to one
# see https://github.com/PennyLaneAI/pennylane/issues/5444
norm = qml.math.sum(probs, axis=-1)
abs_diff = np.abs(norm - 1.0)
cutoff = 1e-07

if is_state_batched:

normalize_condition = False

for s in abs_diff:
if s != 0:
normalize_condition = True
if s > cutoff:
normalize_condition = False
break

if normalize_condition:
probs = probs / norm[:, np.newaxis] if norm.shape else probs / norm

# rng.choice doesn't support broadcasting
samples = np.stack([rng.choice(basis_states, shots, p=p) for p in probs])
else:

if 0 < abs_diff < cutoff:
probs /= norm

samples = rng.choice(basis_states, shots, p=probs)

powers_of_two = 1 << np.arange(num_wires, dtype=np.int64)[::-1]
Expand Down
57 changes: 57 additions & 0 deletions tests/devices/qubit/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,63 @@ def test_nan_shadow_expval(self, H, interface, shots):
assert qml.math.all(qml.math.isnan(r))


two_qubit_state_to_be_normalized = np.array([[0, 1.0000000005j], [-1, 0]]) / np.sqrt(2)
two_qubit_state_not_normalized = np.array([[0, 1.0000005j], [-1.00000001, 0]]) / np.sqrt(2)

batched_state_to_be_normalized = np.stack(
[
np.array([[0, 0], [0, 1.000000000009]]),
np.array([[1.00000004, 0], [1, 0]]) / np.sqrt(2),
np.array([[1, 1], [1, 0.99999995]]) / 2,
]
)
batched_state_not_normalized = np.stack(
[
np.array([[0, 0], [0, 1]]),
np.array([[1.0000004, 0], [1, 0]]) / np.sqrt(2),
np.array([[1, 1], [1, 0.9999995]]) / 2,
]
)


class TestRenormalization:
"""Test suite for renormalization functionality."""

@pytest.mark.all_interfaces
@pytest.mark.parametrize("interface", ["numpy", "jax", "torch", "tensorflow"])
def test_sample_state_renorm(self, interface):
"""Test renormalization for a non-batched state."""

state = qml.math.array(two_qubit_state_to_be_normalized, like=interface)
_ = sample_state(state, 10)

@pytest.mark.all_interfaces
@pytest.mark.parametrize("interface", ["numpy", "jax", "torch", "tensorflow"])
def test_sample_state_renorm_error(self, interface):
"""Test that renormalization does not occur if the error is too large."""

state = qml.math.array(two_qubit_state_not_normalized, like=interface)
with pytest.raises(ValueError, match="probabilities do not sum to 1"):
_ = sample_state(state, 10)

@pytest.mark.all_interfaces
@pytest.mark.parametrize("interface", ["numpy", "jax", "torch", "tensorflow"])
def test_sample_batched_state_renorm(self, interface):
"""Test renormalization for a batched state."""

state = qml.math.array(batched_state_to_be_normalized, like=interface)
_ = sample_state(state, 10, is_state_batched=True)

@pytest.mark.all_interfaces
@pytest.mark.parametrize("interface", ["numpy", "jax", "torch", "tensorflow"])
def test_sample_batched_state_renorm_error(self, interface):
"""Test that renormalization does not occur if the error is too large."""

state = qml.math.array(batched_state_not_normalized, like=interface)
with pytest.raises(ValueError, match="probabilities do not sum to 1"):
_ = sample_state(state, 10, is_state_batched=True)


class TestBroadcasting:
"""Test that measurements work when the state has a batch dim"""

Expand Down
Loading