Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jit with sampling multi-wire observable #5422

Merged
merged 5 commits into from
Mar 21, 2024
Merged

jit with sampling multi-wire observable #5422

merged 5 commits into from
Mar 21, 2024

Conversation

albi3ro
Copy link
Contributor

@albi3ro albi3ro commented Mar 20, 2024

Context:

Jitting was failing with qml.sample with an observable on more than one wire.

import jax

jax.config.update("jax_enable_x64", True)

dev = qml.device("default.qubit", wires=2, shots=100)

@qml.qnode(dev, interface="jax")
def circuit(x):
    qml.RX(x, wires=0)
    return qml.sample(qml.PauliX(0) @ qml.PauliY(1))

results = jax.jit(circuit)(jax.numpy.array(0.123, dtype=jax.numpy.float64))

Description of the Change:

Update SampleMP.shape so that if there's an observable, there's only one dimension output.

Benefits:

Possible Drawbacks:

Related GitHub Issues:

Fixes #5372 [sc-58779]

Copy link

codecov bot commented Mar 20, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.65%. Comparing base (54d43b8) to head (ff81b9f).
Report is 2 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5422      +/-   ##
==========================================
+ Coverage   99.63%   99.65%   +0.01%     
==========================================
  Files         399      401       +2     
  Lines       37125    36946     -179     
==========================================
- Hits        36990    36817     -173     
+ Misses        135      129       -6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@astralcai
Copy link
Contributor

I'm not sure if I understand the issue. Why would the shape always be 1?

@albi3ro
Copy link
Contributor Author

albi3ro commented Mar 20, 2024

I'm not sure if I understand the issue. Why would the shape always be 1?

Because with an observable, we are sampling from the observables eigenvalues. So just one single eigenvalue, even if there's multiple wires.

Copy link
Contributor

@astralcai astralcai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good 🐱

albi3ro and others added 2 commits March 21, 2024 10:55
Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com>
@albi3ro albi3ro requested a review from mudit2812 March 21, 2024 15:06
Copy link
Contributor

@mudit2812 mudit2812 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Just one changelog comment.

doc/releases/changelog-dev.md Outdated Show resolved Hide resolved
@albi3ro albi3ro enabled auto-merge (squash) March 21, 2024 17:04
@albi3ro albi3ro merged commit 45f8e75 into master Mar 21, 2024
40 checks passed
@albi3ro albi3ro deleted the jit-multiwire-obs branch March 21, 2024 17:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] Sampling an observable with more than one wire is not compatible with JAX-JIT
4 participants