diff --git a/qiskit/primitives/containers/bit_array.py b/qiskit/primitives/containers/bit_array.py index 24d52ca4e85a..8c88de6f12a5 100644 --- a/qiskit/primitives/containers/bit_array.py +++ b/qiskit/primitives/containers/bit_array.py @@ -464,6 +464,97 @@ def slice_shots(self, indices: int | Sequence[int]) -> "BitArray": arr = arr[..., indices, :] return BitArray(arr, self.num_bits) + def postselect( + self, + indices: Sequence[int] | int, + selection: Sequence[bool | int] | bool | int, + ) -> BitArray: + """Post-select this bit array based on sliced equality with a given bitstring. + + .. note:: + If this bit array contains any shape axes, it is first flattened into a long list of shots + before applying post-selection. This is done because :class:`~BitArray` cannot handle + ragged numbers of shots across axes. + + Args: + indices: A list of the indices of the cbits on which to postselect. + If this bit array was produced by a sampler, then an index ``i`` corresponds to the + :class:`~.ClassicalRegister` location ``creg[i]`` (as in :meth:`~slice_bits`). + Negative indices are allowed. + + selection: A list of binary values (will be cast to ``bool``) of length matching + ``indices``, with ``indices[i]`` corresponding to ``selection[i]``. Shots will be + discarded unless all cbits specified by ``indices`` have the values given by + ``selection``. + + Returns: + A new bit array with ``shape=(), num_bits=data.num_bits, num_shots<=data.num_shots``. + + Raises: + IndexError: If ``max(indices)`` is greater than or equal to :attr:`num_bits`. + IndexError: If ``min(indices)`` is less than negative :attr:`num_bits`. + ValueError: If the lengths of ``selection`` and ``indices`` do not match. + """ + if isinstance(indices, int): + indices = (indices,) + if isinstance(selection, (bool, int)): + selection = (selection,) + selection = np.asarray(selection, dtype=bool) + + num_indices = len(indices) + + if len(selection) != num_indices: + raise ValueError("Lengths of indices and selection do not match.") + + num_bytes = self._array.shape[-1] + indices = np.asarray(indices) + + if num_indices > 0: + if indices.max() >= self.num_bits: + raise IndexError( + f"index {int(indices.max())} out of bounds for the number of bits {self.num_bits}." + ) + if indices.min() < -self.num_bits: + raise IndexError( + f"index {int(indices.min())} out of bounds for the number of bits {self.num_bits}." + ) + + flattened = self.reshape((), self.size * self.num_shots) + + # If no conditions, keep all data, but flatten as promised: + if num_indices == 0: + return flattened + + # Make negative bit indices positive: + indices %= self.num_bits + + # Handle special-case of contradictory conditions: + if np.intersect1d(indices[selection], indices[np.logical_not(selection)]).size > 0: + return BitArray(np.empty((0, num_bytes), dtype=np.uint8), num_bits=self.num_bits) + + # Recall that creg[0] is the LSb: + byte_significance, bit_significance = np.divmod(indices, 8) + # least-significant byte is at last position: + byte_idx = (num_bytes - 1) - byte_significance + # least-significant bit is at position 0: + bit_offset = bit_significance.astype(np.uint8) + + # Get bitpacked representation of `indices` (bitmask): + bitmask = np.zeros(num_bytes, dtype=np.uint8) + np.bitwise_or.at(bitmask, byte_idx, np.uint8(1) << bit_offset) + + # Get bitpacked representation of `selection` (desired bitstring): + selection_bytes = np.zeros(num_bytes, dtype=np.uint8) + ## This assumes no contradictions present, since those were already checked for: + np.bitwise_or.at( + selection_bytes, byte_idx, np.asarray(selection, dtype=np.uint8) << bit_offset + ) + + return BitArray( + flattened._array[((flattened._array & bitmask) == selection_bytes).all(axis=-1)], + num_bits=self.num_bits, + ) + def expectation_values(self, observables: ObservablesArrayLike) -> NDArray[np.float64]: """Compute the expectation values of the provided observables, broadcasted against this bit array. diff --git a/releasenotes/notes/bitarray-postselect-659b8f7801ccaa60.yaml b/releasenotes/notes/bitarray-postselect-659b8f7801ccaa60.yaml new file mode 100644 index 000000000000..33ce17bafa8d --- /dev/null +++ b/releasenotes/notes/bitarray-postselect-659b8f7801ccaa60.yaml @@ -0,0 +1,11 @@ +--- +features_primitives: + - | + Added a new method :meth:`.BitArray.postselect` that returns all shots containing specified bit values. + Example usage:: + + from qiskit.primitives.containers import BitArray + + ba = BitArray.from_counts({'110': 2, '100': 4, '000': 3}) + print(ba.postselect([0,2], [0,1]).get_counts()) + # {'110': 2, '100': 4} diff --git a/test/python/primitives/containers/test_bit_array.py b/test/python/primitives/containers/test_bit_array.py index 69f02fd46daa..22fb27e0df43 100644 --- a/test/python/primitives/containers/test_bit_array.py +++ b/test/python/primitives/containers/test_bit_array.py @@ -695,3 +695,82 @@ def test_expectation_values(self): _ = ba.expectation_values("Z") with self.assertRaisesRegex(ValueError, "is not diagonal"): _ = ba.expectation_values("X" * ba.num_bits) + + def test_postselection(self): + """Test the postselection method.""" + + flat_data = np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1], + [0, 1, 0, 1, 0, 1, 0, 1, 0, 1], + ], + dtype=bool, + ) + + shaped_data = np.array( + [ + [ + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1], + [0, 1, 0, 1, 0, 1, 0, 1, 0, 1], + ], + [ + [1, 0, 1, 0, 1, 0, 1, 0, 1, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + ], + ] + ], + dtype=bool, + ) + + for dataname, bool_array in zip(["flat", "shaped"], [flat_data, shaped_data]): + + bit_array = BitArray.from_bool_array(bool_array, order="little") + # indices value of i <-> creg[i] <-> bool_array[..., i] + + num_bits = bool_array.shape[-1] + bool_array = bool_array.reshape(-1, num_bits) + + test_cases = [ + ("basic", [0, 1], [0, 0]), + ("multibyte", [0, 9], [0, 1]), + ("repeated", [5, 5, 5], [0, 0, 0]), + ("contradict", [5, 5, 5], [1, 0, 0]), + ("unsorted", [5, 0, 9, 3], [1, 0, 1, 0]), + ("negative", [-5, 1, -2, -10], [1, 0, 1, 0]), + ("negcontradict", [4, -6], [1, 0]), + ("trivial", [], []), + ("bareindex", 6, 0), + ] + + for name, indices, selection in test_cases: + with self.subTest("_".join([dataname, name])): + postselected_bools = np.unpackbits( + bit_array.postselect(indices, selection).array[:, ::-1], + count=num_bits, + axis=-1, + bitorder="little", + ).astype(bool) + if isinstance(indices, int): + indices = (indices,) + if isinstance(selection, bool): + selection = (selection,) + answer = bool_array[np.all(bool_array[:, indices] == selection, axis=-1)] + if name in ["contradict", "negcontradict"]: + self.assertEqual(len(answer), 0) + else: + self.assertGreater(len(answer), 0) + np.testing.assert_equal(postselected_bools, answer) + + error_cases = [ + ("aboverange", [0, 6, 10], [True, True, False], IndexError), + ("belowrange", [0, 6, -11], [True, True, False], IndexError), + ("mismatch", [0, 1, 2], [False, False], ValueError), + ] + for name, indices, selection, error in error_cases: + with self.subTest(dataname + "_" + name): + with self.assertRaises(error): + bit_array.postselect(indices, selection)