Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax all devices #1065

Closed
wants to merge 26 commits into from
Closed

Jax all devices #1065

wants to merge 26 commits into from

Conversation

chaserileyroberts
Copy link
Contributor

@chaserileyroberts chaserileyroberts commented Feb 3, 2021

Context:
Previously, only the default.qubit.jax device supported the JAX interface. This PR adds JAX interface support to all devices.

Description of the Change:
The recent addition of jax.experimental.host_callback allows us to finally do jax -> numpy -> numpy -> jax within a jax.jit function! This PR adds the needed scaffolding to do it.

Benefits:
Everyone can start using JAX more and I won't take any more execuses.

Possible Drawbacks:
None.

Related GitHub Issues:
#943

TODOs

  • Gradient support
  • Vmap support

@chaserileyroberts chaserileyroberts added the review-ready 👌 PRs which are ready for review by someone from the core team. label Feb 3, 2021
Copy link
Member

@josh146 josh146 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome, I'm excited to have this in!

Most of my suggestions in the PR itself are minor, but there are some other parts of the codebase that should be modified:

  • We likely need an integration test for this new functionality. Could you make a copy of tests/tape/interfaces/test_qnode_autograd.py as test_qnode_jax.py? Probably the entire file could stay as is, with minor modification to account for the autograd -> JAX change.

  • Small tweaks to the JAX page in the docs need to be made to account for this change

.github/CHANGELOG.md Outdated Show resolved Hide resolved
.github/CHANGELOG.md Show resolved Hide resolved
.github/CHANGELOG.md Outdated Show resolved Hide resolved
Comment on lines +12 to +20
dev = qml.device('cirq.simulator', wires=1)
@qml.qnode(dev, interface="jax")
def circuit(x):
qml.RX(x[1], wires=0)
qml.Rot(x[0], x[1], x[2], wires=0)
return qml.expval(qml.PauliZ(0))

weights = jnp.array([0.2, 0.5, 0.1])
print(circuit(weights)) # DeviceArray(...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woah. This is awesome.

pennylane/tape/interfaces/jax.py Outdated Show resolved Hide resolved
pennylane/tape/interfaces/jax.py Outdated Show resolved Hide resolved
Comment on lines +84 to +91
if len(self.observables) != 1:
raise ValueError("Only one return type is supported currently")
return_type = self.observables[0].return_type
if return_type is not Variance and return_type is not Expectation:
raise ValueError(
f"Only Variance and Expectation returns are support, given {return_type}"
)
exec_fn = partial(self.execute_device, device=device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be simplified using self.output_dim?

return host_callback.call(
    exec_fn, params, result_shape=jax.ShapeDtypeStruct(self.output_dim, JAXInterface.dtype)
)

However, there would still need to be one validation; a validation to ensure that there is no mixture of probs and expval/var. E.g., anything that would result in a different output dim to what the tape estimates.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I leave that to a future PR where I fix the above todo?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Comment on lines +97 to +116
@classmethod
def apply(cls, tape):
"""Apply the JAX interface to an existing tape in-place.

Args:
tape (.JacobianTape): a quantum tape to apply the JAX interface to

**Example**

>>> with JacobianTape() as tape:
... qml.RX(0.5, wires=0)
... expval(qml.PauliZ(0))
>>> JAXInterface.apply(tape)
>>> tape
<JAXQuantumTape: wires=<Wires = [0]>, params=1>
"""
tape_class = getattr(tape, "__bare__", tape.__class__)
tape.__bare__ = tape_class
tape.__class__ = type("JAXQuantumTape", (cls, tape_class), {})
return tape
Copy link
Member

@josh146 josh146 Feb 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tangent: In retrospect, this is a design decision I do not like. Even ignoring the dynamic class change, this feels more like a function than a class method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm usually always against inplace transformations, and this one is pretty intense. tbh I don't really understand how this works.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iirc the in-place transform was solely to placate Autograd. Everything else worked at the time without it being in-place (tf, torch), but autograd would give a really bizarre and long traceback that vanished when it was made in-place instead.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(it could have been torch that complained rather than autograd, I can't remember fully).

Comment on lines -439 to -440
def test_non_backprop_error(self):
"""Test that an error is raised in tape mode if the diff method is not backprop"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

Comment on lines +68 to +69
# Easiest way to test object is a device array instead of np.array
assert "DeviceArray" in res.__repr__()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Chase Roberts and others added 3 commits February 4, 2021 10:58
Co-authored-by: Josh Izaac <josh146@gmail.com>
Co-authored-by: Josh Izaac <josh146@gmail.com>
Co-authored-by: Josh Izaac <josh146@gmail.com>
@codecov
Copy link

codecov bot commented Feb 4, 2021

Codecov Report

Merging #1065 (17364b0) into master (85dd193) will decrease coverage by 0.03%.
The diff coverage is 88.09%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1065      +/-   ##
==========================================
- Coverage   97.74%   97.71%   -0.04%     
==========================================
  Files         153      154       +1     
  Lines       11579    11551      -28     
==========================================
- Hits        11318    11287      -31     
- Misses        261      264       +3     
Impacted Files Coverage Δ
pennylane/devices/autograd_ops.py 100.00% <ø> (+2.38%) ⬆️
pennylane/devices/default_mixed.py 100.00% <ø> (ø)
pennylane/devices/default_qubit.py 100.00% <ø> (ø)
pennylane/devices/default_qubit_autograd.py 100.00% <ø> (ø)
pennylane/devices/default_qubit_jax.py 93.75% <ø> (ø)
pennylane/devices/default_qubit_tf.py 90.14% <ø> (ø)
pennylane/devices/jax_ops.py 97.43% <ø> (+2.31%) ⬆️
pennylane/devices/tf_ops.py 100.00% <ø> (ø)
pennylane/templates/embeddings/amplitude.py 100.00% <ø> (ø)
pennylane/vqe/vqe.py 93.78% <ø> (ø)
... and 4 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 85dd193...594edff. Read the comment docs.

@chaserileyroberts
Copy link
Contributor Author

  • We likely need an integration test for this new functionality. Could you make a copy of tests/tape/interfaces/test_qnode_autograd.py as test_qnode_jax.py? Probably the entire file could stay as is, with minor modification to account for the autograd -> JAX change.

I don't have gradient support yet, so I'll only be able to support some of the tests.

@chaserileyroberts
Copy link
Contributor Author

@josh146 Looks like this is still broken despite check on the build passing. Can you tag your github support on this PR?

return qml.expval(qml.PauliZ(0))

weights = jnp.array([0.2, 0.5, 0.1])
print(circuit(weights)) # DeviceArray(...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't suggest it, but this block is missing an end triple-back-quote

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aw thank you. I was wondering why the doc build kept failing here but worked locally.

@josh146 josh146 deleted the jax_all_devices branch April 29, 2021 09:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
review-ready 👌 PRs which are ready for review by someone from the core team.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants