-
Notifications
You must be signed in to change notification settings - Fork 603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make Measurements Pytrees #4607
Conversation
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## master #4607 +/- ##
=======================================
Coverage 99.62% 99.62%
=======================================
Files 375 375
Lines 33404 33443 +39
=======================================
+ Hits 33279 33318 +39
Misses 125 125
☔ View full report in Codecov by Sentry. |
Considering the size of this PR, I'd assume it will get merged before #4544 . So I will update the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, cheers @albi3ro .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks awesome, also helped highlight some little things to touch up!
quick question: (un/)flattening should help with copying and replacing operators, right? Curious because we replace measurements in-place in device preprocessing, and I'm wondering if that code can benefit from this change
Not necessarily in the scope of the PR, but I was wondering whether we would like to leverage |
Potentially we could rewrite a lot of |
PennyLane measurements are automatically registered as `Pytrees <https://jax.readthedocs.io/en/latest/pytrees.html>`_ . | ||
|
||
The :class:`~.MeasurementProcess` definitions are sufficient for all PL measurements. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So does this mean that users creating custom measurements do not need to do anything extra?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops. Actually that comment is wrong. They need to be overridden if the measurement process has extra metadata.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll fix that in the tape pytree PR.
This PR registers all `MeasurementProcess` objects as jax pytrees. [sc-40588]
**Context:** PennyLane 0.33.0 will introduce measurements as pytrees PennyLaneAI/pennylane#4607 Most measurements have therefore no leaves, this breaks a Catalyst assumptions for capturing the program. **Description of the Change:** - Unflatten the return to get it - Flatten the return again but this time with `is_leaf` true for measurement processes. **Benefits:** Catalyst is up to date with PennyLane master **Possible Drawbacks:** Potential slow down Benchmark: ``` import pennylane as qml from jax import numpy as jnp from catalyst import qjit dev = qml.device("lightning.qubit", wires=3) import timeit def my_function_v1(): @qjit @qml.qnode(device=dev) def circuit(x: float, y: float): qml.RX(x, wires=0) qml.RY(y, wires=1) qml.CNOT(wires=[0, 1]) return [tuple([qml.expval(qml.PauliZ(wires=0))]), jnp.sin(y)], {"expval": qml.expval(qml.PauliZ(wires=1))}, tuple([qml.probs(wires=[0,1]), qml.expval(qml.PauliZ(wires=2))]) def benchmark_function(func): setup_code = f"from __main__ import {func.__name__}" stmt = f"{func.__name__}()" execution_time = timeit.timeit(stmt, setup_code, number=100) return execution_time if __name__ == "__main__": time_v1 = benchmark_function(my_function_v1) print(f"Version execution time: {time_v1:.6f} seconds") ``` PL master: For 100 runs: Version execution time: 6.361672 seconds PL 0.32.0: For 100 runs: Version execution time: 6.301733 seconds Diff: around 0.06s for hundred runs
This PR registers all
MeasurementProcess
objects as jax pytrees. [sc-40588]