Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Micro-optimize basis state conversion #1088

Merged
merged 4 commits into from
Feb 17, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@
(5, 4, 4)
```

* An improvement has been made to how `QubitDevice` generates and post-processess samples,
allowing QNode measurement statistics to work on devices with more than 32 qubits.
[(#)](https://github.com/PennyLaneAI/pennylane/pull/)

<h3>Breaking changes</h3>

* If creating a QNode from a quantum function with an argument named `shots`,
Expand Down
15 changes: 8 additions & 7 deletions pennylane/_qubit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,8 +564,8 @@ def estimate_probability(self, wires=None):
samples = self._samples[:, device_wires]

# convert samples from a list of 0, 1 integers, to base 10 representation
unraveled_indices = [2] * len(device_wires)
indices = np.ravel_multi_index(samples.T, unraveled_indices)
powers_of_two = 2 ** np.arange(len(device_wires))[::-1]
chaserileyroberts marked this conversation as resolved.
Show resolved Hide resolved
indices = samples @ powers_of_two
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very clever!


# count the basis state occurrences, and construct the probability vector
basis_states, counts = np.unique(indices, return_counts=True)
Expand Down Expand Up @@ -656,9 +656,10 @@ def marginal_prob(self, prob, wires=None):
# it corresponds to the orders of the wires passed.
num_wires = len(device_wires)
basis_states = self.generate_basis_states(num_wires)
perm = np.ravel_multi_index(
basis_states[:, np.argsort(np.argsort(device_wires))].T, [2] * len(device_wires)
)
basis_states = basis_states[:, np.argsort(np.argsort(device_wires))].T
chaserileyroberts marked this conversation as resolved.
Show resolved Hide resolved

powers_of_two = 2 ** np.arange(len(device_wires))[::-1]
perm = basis_states.T @ powers_of_two
josh146 marked this conversation as resolved.
Show resolved Hide resolved
return self._gather(prob, perm)

def expval(self, observable):
Expand Down Expand Up @@ -696,8 +697,8 @@ def sample(self, observable):
# Replace the basis state in the computational basis with the correct eigenvalue.
# Extract only the columns of the basis samples required based on ``wires``.
samples = self._samples[:, np.array(device_wires)] # Add np.array here for Jax support.
unraveled_indices = [2] * len(device_wires)
indices = np.ravel_multi_index(samples.T, unraveled_indices)
powers_of_two = 2 ** np.arange(samples.shape[-1])[::-1]
indices = samples @ powers_of_two
return observable.eigvals[indices]

def adjoint_jacobian(self, tape):
Expand Down