Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bitarray postselect #12693

Merged
merged 51 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
7495d59
define BitArray.postselect()
aeddins-ibm Jun 28, 2024
757eda5
add test for BitArray.postselect()
aeddins-ibm Jun 28, 2024
cb0bb69
lint
aeddins-ibm Jun 29, 2024
6079906
remove redundant docstring text
aeddins-ibm Jun 29, 2024
2e79728
Update qiskit/primitives/containers/bit_array.py
aeddins-ibm Jul 2, 2024
033aa63
docstring ticks (BitArray.postselect())
aeddins-ibm Jul 2, 2024
efdad57
Simpler tests for BitArray.postselect
aeddins-ibm Jul 2, 2024
8cb4920
lint
aeddins-ibm Jul 2, 2024
b2b0a54
add release note
aeddins-ibm Jul 2, 2024
e080cf5
check postselect() arg lengths match
aeddins-ibm Jul 2, 2024
9760e1d
fix postselect tests
aeddins-ibm Jul 2, 2024
07a2ede
lint
aeddins-ibm Jul 2, 2024
ac720dc
Merge branch 'main' into bitarray_postselect
aeddins-ibm Jul 2, 2024
fd317bf
Fix type-hint
aeddins-ibm Jul 2, 2024
b114315
Merge branch 'bitarray_postselect' of https://github.com/aeddins-ibm/…
aeddins-ibm Jul 2, 2024
124e567
remove spurious print()
aeddins-ibm Jul 2, 2024
82c5011
lint
aeddins-ibm Jul 2, 2024
ab87d66
lint
aeddins-ibm Jul 3, 2024
d03cbb0
Merge branch 'main' into bitarray_postselect
aeddins-ibm Jul 3, 2024
09ec38a
use bitwise operations for faster postselect
aeddins-ibm Jul 3, 2024
d688a2a
remove spurious print()
aeddins-ibm Jul 3, 2024
ae959d0
Merge branch 'main' into bitarray_postselect
aeddins-ibm Jul 3, 2024
ff580df
end final line of release note
aeddins-ibm Jul 3, 2024
17ac5ec
try to fix docstring formatting
aeddins-ibm Jul 3, 2024
2961544
fix bitarray test assertion
aeddins-ibm Jul 5, 2024
745130f
disallow postselect positional kwarg
aeddins-ibm Jul 5, 2024
803765b
fix numpy dtype args
aeddins-ibm Jul 5, 2024
99824f3
Simpler kwarg: "assume_unique"
aeddins-ibm Jul 5, 2024
bdd4ede
lint (line too long)
aeddins-ibm Jul 5, 2024
6950687
simplification: remove assume_unique kwarg
aeddins-ibm Jul 5, 2024
88be7a6
improve misleading comment
aeddins-ibm Jul 5, 2024
5243067
Merge branch 'main' into bitarray_postselect
aeddins-ibm Jul 7, 2024
036fdf8
raise IndexError if indices out of range
aeddins-ibm Jul 9, 2024
3e02934
lint
aeddins-ibm Jul 9, 2024
da959ff
add negative-contradiction test
aeddins-ibm Jul 9, 2024
8f76f97
Update docstring with IndexErrors
aeddins-ibm Jul 9, 2024
bfca1bd
lint
aeddins-ibm Jul 9, 2024
8f32178
change slice_bits error from ValueError to IndexError
aeddins-ibm Jul 10, 2024
c2b0039
update slice_bits test to use IndexError
aeddins-ibm Jul 10, 2024
c4becd9
change ValueError to IndexError in slice_shots
aeddins-ibm Jul 11, 2024
50545ef
update error type in slice_shots docstring
aeddins-ibm Jul 11, 2024
95bb321
Revert ValueError to IndexError changes
aeddins-ibm Jul 11, 2024
c140f8d
fix docstring formatting
aeddins-ibm Jul 12, 2024
38155e4
allow selection to be int instead of bool
aeddins-ibm Jul 12, 2024
af369c9
Merge branch 'bitarray_postselect' of https://github.com/aeddins-ibm/…
aeddins-ibm Jul 12, 2024
c129642
In tests, give selection as type int
aeddins-ibm Jul 12, 2024
af39354
lint
aeddins-ibm Jul 12, 2024
9c5edcd
add example to release note
aeddins-ibm Jul 12, 2024
3e9684e
fix typo in test case
aeddins-ibm Jul 22, 2024
4750250
add check of test
aeddins-ibm Jul 23, 2024
d5d0009
lint
aeddins-ibm Jul 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions qiskit/primitives/containers/bit_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,97 @@ def slice_shots(self, indices: int | Sequence[int]) -> "BitArray":
arr = arr[..., indices, :]
return BitArray(arr, self.num_bits)

def postselect(
self,
indices: Sequence[int] | int,
selection: Sequence[bool | int] | bool | int,
) -> BitArray:
"""Post-select this bit array based on sliced equality with a given bitstring.

.. note::
If this bit array contains any shape axes, it is first flattened into a long list of shots
before applying post-selection. This is done because :class:`~BitArray` cannot handle
ragged numbers of shots across axes.

Args:
indices: A list of the indices of the cbits on which to postselect.
If this bit array was produced by a sampler, then an index ``i`` corresponds to the
:class:`~.ClassicalRegister` location ``creg[i]`` (as in :meth:`~slice_bits`).
Negative indices are allowed.

selection: A list of binary values (will be cast to ``bool``) of length matching
``indices``, with ``indices[i]`` corresponding to ``selection[i]``. Shots will be
discarded unless all cbits specified by ``indices`` have the values given by
``selection``.

Returns:
A new bit array with ``shape=(), num_bits=data.num_bits, num_shots<=data.num_shots``.

Raises:
IndexError: If ``max(indices)`` is greater than or equal to :attr:`num_bits`.
IndexError: If ``min(indices)`` is less than negative :attr:`num_bits`.
ValueError: If the lengths of ``selection`` and ``indices`` do not match.
"""
if isinstance(indices, int):
indices = (indices,)
if isinstance(selection, (bool, int)):
selection = (selection,)
selection = np.asarray(selection, dtype=bool)

num_indices = len(indices)

if len(selection) != num_indices:
raise ValueError("Lengths of indices and selection do not match.")

num_bytes = self._array.shape[-1]
indices = np.asarray(indices)

if num_indices > 0:
if indices.max() >= self.num_bits:
raise IndexError(
f"index {int(indices.max())} out of bounds for the number of bits {self.num_bits}."
)
if indices.min() < -self.num_bits:
raise IndexError(
f"index {int(indices.min())} out of bounds for the number of bits {self.num_bits}."
)

flattened = self.reshape((), self.size * self.num_shots)

# If no conditions, keep all data, but flatten as promised:
if num_indices == 0:
return flattened

# Make negative bit indices positive:
indices %= self.num_bits

# Handle special-case of contradictory conditions:
if np.intersect1d(indices[selection], indices[np.logical_not(selection)]).size > 0:
return BitArray(np.empty((0, num_bytes), dtype=np.uint8), num_bits=self.num_bits)

# Recall that creg[0] is the LSb:
byte_significance, bit_significance = np.divmod(indices, 8)
# least-significant byte is at last position:
byte_idx = (num_bytes - 1) - byte_significance
# least-significant bit is at position 0:
bit_offset = bit_significance.astype(np.uint8)

# Get bitpacked representation of `indices` (bitmask):
bitmask = np.zeros(num_bytes, dtype=np.uint8)
np.bitwise_or.at(bitmask, byte_idx, np.uint8(1) << bit_offset)

# Get bitpacked representation of `selection` (desired bitstring):
selection_bytes = np.zeros(num_bytes, dtype=np.uint8)
## This assumes no contradictions present, since those were already checked for:
np.bitwise_or.at(
selection_bytes, byte_idx, np.asarray(selection, dtype=np.uint8) << bit_offset
)

return BitArray(
flattened._array[((flattened._array & bitmask) == selection_bytes).all(axis=-1)],
num_bits=self.num_bits,
)

def expectation_values(self, observables: ObservablesArrayLike) -> NDArray[np.float64]:
"""Compute the expectation values of the provided observables, broadcasted against
this bit array.
Expand Down
11 changes: 11 additions & 0 deletions releasenotes/notes/bitarray-postselect-659b8f7801ccaa60.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
features_primitives:
- |
Added a new method :meth:`.BitArray.postselect` that returns all shots containing specified bit values.
t-imamichi marked this conversation as resolved.
Show resolved Hide resolved
Example usage::

from qiskit.primitives.containers import BitArray

ba = BitArray.from_counts({'110': 2, '100': 4, '000': 3})
print(ba.postselect([0,2], [0,1]).get_counts())
# {'110': 2, '100': 4}
79 changes: 79 additions & 0 deletions test/python/primitives/containers/test_bit_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,3 +695,82 @@ def test_expectation_values(self):
_ = ba.expectation_values("Z")
with self.assertRaisesRegex(ValueError, "is not diagonal"):
_ = ba.expectation_values("X" * ba.num_bits)

def test_postselection(self):
ihincks marked this conversation as resolved.
Show resolved Hide resolved
"""Test the postselection method."""

flat_data = np.array(
[
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
[0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
],
dtype=bool,
)

shaped_data = np.array(
[
[
[
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
[0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
],
[
[1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
],
]
],
dtype=bool,
)

for dataname, bool_array in zip(["flat", "shaped"], [flat_data, shaped_data]):

bit_array = BitArray.from_bool_array(bool_array, order="little")
# indices value of i <-> creg[i] <-> bool_array[..., i]

num_bits = bool_array.shape[-1]
bool_array = bool_array.reshape(-1, num_bits)

test_cases = [
("basic", [0, 1], [0, 0]),
("multibyte", [0, 9], [0, 1]),
("repeated", [5, 5, 5], [0, 0, 0]),
("contradict", [5, 5, 5], [1, 0, 0]),
("unsorted", [5, 0, 9, 3], [1, 0, 1, 0]),
("negative", [-5, 1, -2, -10], [1, 0, 1, 0]),
("negcontradict", [4, -6], [1, 0]),
("trivial", [], []),
("bareindex", 6, 0),
]

for name, indices, selection in test_cases:
with self.subTest("_".join([dataname, name])):
postselected_bools = np.unpackbits(
bit_array.postselect(indices, selection).array[:, ::-1],
count=num_bits,
axis=-1,
bitorder="little",
).astype(bool)
if isinstance(indices, int):
indices = (indices,)
if isinstance(selection, bool):
selection = (selection,)
answer = bool_array[np.all(bool_array[:, indices] == selection, axis=-1)]
if name in ["contradict", "negcontradict"]:
self.assertEqual(len(answer), 0)
else:
self.assertGreater(len(answer), 0)
np.testing.assert_equal(postselected_bools, answer)

error_cases = [
("aboverange", [0, 6, 10], [True, True, False], IndexError),
("belowrange", [0, 6, -11], [True, True, False], IndexError),
("mismatch", [0, 1, 2], [False, False], ValueError),
]
for name, indices, selection, error in error_cases:
with self.subTest(dataname + "_" + name):
with self.assertRaises(error):
bit_array.postselect(indices, selection)
Loading