-
Notifications
You must be signed in to change notification settings - Fork 586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax interface for all devices. #1076
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #1076 +/- ##
==========================================
- Coverage 97.74% 97.71% -0.04%
==========================================
Files 153 154 +1
Lines 11590 11637 +47
==========================================
+ Hits 11329 11371 +42
- Misses 261 266 +5 ☔ View full report in Codecov by Sentry. |
Co-authored-by: Josh Izaac <josh146@gmail.com>
pennylane/tape/interfaces/jax.py
Outdated
return_type = self.observables[0].return_type | ||
if return_type is not Variance and return_type is not Expectation: | ||
raise ValueError( | ||
f"Only Variance and Expectation returns are support, given {return_type}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, maybe mention that this is the JAX interface speaking here, and not general PennyLane?
"""Test that the device provides the correct | ||
result for a simple circuit with a device using a different interface.""" | ||
if not qml.tape_mode_active(): | ||
pytest.skip("Tape mode only test") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I don't include it the tests fails. :(
|
||
weights = jnp.array([0.1, 0.2]) | ||
val = jax.jacrev(circuit)(weights) | ||
assert "DeviceArray" in val.__repr__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also here, what you are testing is not the value, but the return type, right? Maybe mention in test names? test_jacobian
sounds kind of like more!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would one have to test shape/values of the return value as well? I'm just thinking of a situation where the return value is some trivial "zero" or so because something went wrong, but still has the right type...?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, I should add more checks in this test.
def loss(weights, a): | ||
# the following global variable is defined simply for testing | ||
# purposes, so that we can easily extract the transformed QNode | ||
# for verification. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, which global variable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, new_circuit
, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this comment should be deleted. new_circuit was treated as a global variable before and checked outside this method. However, we don't want to do that (the similar checks for the autograd interface don't apply here). I'll delete the comment.
assert grad[1].shape == a.shape | ||
|
||
# compare against the expected values | ||
tol = 1e-5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe not so important, but I remember we tried to use global tol
fixtures for closeness checks...?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just left some comments for now. Awesome addition!
Co-authored-by: Maria Schuld <mariaschuld@gmail.com>
@josh146 this should be ready to be merged |
Overwriting the codecov because untested lines are falsely reported due to how jitting works. |
Ported over from #1065
Context:
Previously, only the
default.qubit.jax
device supported the JAX interface. This PR adds JAX interface support to all devices including built-in gradient support for non-default simulators.Description of the Change:
The recent addition of
jax.experimental.host_callback
allows us to finally dojax -> numpy -> numpy -> jax
within ajax.jit
function! This PR adds the needed scaffolding to do it.Benefits:
Everyone can start using JAX more and I won't take any more excuses.
Possible Drawbacks:
None.
Related GitHub Issues:
#943
TODOs