-
Notifications
You must be signed in to change notification settings - Fork 586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax all devices #1065
Jax all devices #1065
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome, I'm excited to have this in!
Most of my suggestions in the PR itself are minor, but there are some other parts of the codebase that should be modified:
-
We likely need an integration test for this new functionality. Could you make a copy of
tests/tape/interfaces/test_qnode_autograd.py
astest_qnode_jax.py
? Probably the entire file could stay as is, with minor modification to account for the autograd -> JAX change. -
Small tweaks to the JAX page in the docs need to be made to account for this change
dev = qml.device('cirq.simulator', wires=1) | ||
@qml.qnode(dev, interface="jax") | ||
def circuit(x): | ||
qml.RX(x[1], wires=0) | ||
qml.Rot(x[0], x[1], x[2], wires=0) | ||
return qml.expval(qml.PauliZ(0)) | ||
|
||
weights = jnp.array([0.2, 0.5, 0.1]) | ||
print(circuit(weights)) # DeviceArray(...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Woah. This is awesome.
if len(self.observables) != 1: | ||
raise ValueError("Only one return type is supported currently") | ||
return_type = self.observables[0].return_type | ||
if return_type is not Variance and return_type is not Expectation: | ||
raise ValueError( | ||
f"Only Variance and Expectation returns are support, given {return_type}" | ||
) | ||
exec_fn = partial(self.execute_device, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be simplified using self.output_dim
?
return host_callback.call(
exec_fn, params, result_shape=jax.ShapeDtypeStruct(self.output_dim, JAXInterface.dtype)
)
However, there would still need to be one validation; a validation to ensure that there is no mixture of probs
and expval
/var
. E.g., anything that would result in a different output dim to what the tape estimates.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can I leave that to a future PR where I fix the above todo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
@classmethod | ||
def apply(cls, tape): | ||
"""Apply the JAX interface to an existing tape in-place. | ||
|
||
Args: | ||
tape (.JacobianTape): a quantum tape to apply the JAX interface to | ||
|
||
**Example** | ||
|
||
>>> with JacobianTape() as tape: | ||
... qml.RX(0.5, wires=0) | ||
... expval(qml.PauliZ(0)) | ||
>>> JAXInterface.apply(tape) | ||
>>> tape | ||
<JAXQuantumTape: wires=<Wires = [0]>, params=1> | ||
""" | ||
tape_class = getattr(tape, "__bare__", tape.__class__) | ||
tape.__bare__ = tape_class | ||
tape.__class__ = type("JAXQuantumTape", (cls, tape_class), {}) | ||
return tape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tangent: In retrospect, this is a design decision I do not like. Even ignoring the dynamic class change, this feels more like a function than a class method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm usually always against inplace transformations, and this one is pretty intense. tbh I don't really understand how this works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iirc the in-place transform was solely to placate Autograd. Everything else worked at the time without it being in-place (tf, torch), but autograd would give a really bizarre and long traceback that vanished when it was made in-place instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(it could have been torch that complained rather than autograd, I can't remember fully).
def test_non_backprop_error(self): | ||
"""Test that an error is raised in tape mode if the diff method is not backprop""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️
# Easiest way to test object is a device array instead of np.array | ||
assert "DeviceArray" in res.__repr__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
Co-authored-by: Josh Izaac <josh146@gmail.com>
Co-authored-by: Josh Izaac <josh146@gmail.com>
Co-authored-by: Josh Izaac <josh146@gmail.com>
Codecov Report
@@ Coverage Diff @@
## master #1065 +/- ##
==========================================
- Coverage 97.74% 97.71% -0.04%
==========================================
Files 153 154 +1
Lines 11579 11551 -28
==========================================
- Hits 11318 11287 -31
- Misses 261 264 +3
Continue to review full report at Codecov.
|
I don't have gradient support yet, so I'll only be able to support some of the tests. |
@josh146 Looks like this is still broken despite check on the build passing. Can you tag your github support on this PR? |
Co-authored-by: Josh Izaac <josh146@gmail.com>
Co-authored-by: Josh Izaac <josh146@gmail.com>
Co-authored-by: Josh Izaac <josh146@gmail.com>
return qml.expval(qml.PauliZ(0)) | ||
|
||
weights = jnp.array([0.2, 0.5, 0.1]) | ||
print(circuit(weights)) # DeviceArray(...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't suggest it, but this block is missing an end triple-back-quote
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aw thank you. I was wondering why the doc build kept failing here but worked locally.
Context:
Previously, only the
default.qubit.jax
device supported the JAX interface. This PR adds JAX interface support to all devices.Description of the Change:
The recent addition of
jax.experimental.host_callback
allows us to finally dojax -> numpy -> numpy -> jax
within ajax.jit
function! This PR adds the needed scaffolding to do it.Benefits:
Everyone can start using JAX more and I won't take any more execuses.
Possible Drawbacks:
None.
Related GitHub Issues:
#943
TODOs