Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix BasisEmbedding to allow only-1 bitstrings #1114

Merged
merged 15 commits into from
Mar 2, 2021
Merged
4 changes: 4 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@

<h3>Bug fixes</h3>

* Fixes a bug where `BasisEmbedding` would not accept inputs whose bits are all ones
or all zeros.
[(#1114)](https://github.com/PennyLaneAI/pennylane/pull/1114)

* The `ExpvalCost` class raises an error if instantiated
with non-expectation measurement statistics.
[(#1106)](https://github.com/PennyLaneAI/pennylane/pull/1106)
Expand Down
2 changes: 1 addition & 1 deletion pennylane/templates/embeddings/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _preprocess(features, wires):

features = list(qml.math.toarray(features))

if set(features) != {0, 1}:
if not set(features).issubset({0, 1}):
raise ValueError(f"Basis state must only consist of 0s and 1s; got {features}")

return features
Expand Down
12 changes: 8 additions & 4 deletions tests/templates/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,10 +332,13 @@ def circuit(x=None):
class TestBasisEmbedding:
""" Tests the BasisEmbedding method."""

def test_state(self):
"""Checks the state."""
@pytest.mark.parametrize("state", [[0, 1],
[1, 1],
[1, 0],
[0, 0]])
def test_state(self, state):
"""Checks that the correct state is prepared."""

state = np.array([0, 1])
n_qubits = 2
dev = qml.device('default.qubit', wires=n_qubits)

Expand All @@ -345,7 +348,8 @@ def circuit(x=None):
return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]

res = circuit(x=state)
assert np.allclose(res, [1, -1])
expected = [1 if s == 0 else -1 for s in state]
assert np.allclose(res, expected)

def test_too_many_input_bits_exception(self):
"""Verifies that exception thrown if there are more features than qubits."""
Expand Down