-
Notifications
You must be signed in to change notification settings - Fork 586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adjoint differentation improvements #1341
Conversation
Hello. You may have forgotten to update the changelog!
|
Codecov Report
@@ Coverage Diff @@
## master #1341 +/- ##
=======================================
Coverage 98.18% 98.19%
=======================================
Files 157 157
Lines 11687 11713 +26
=======================================
+ Hits 11475 11501 +26
Misses 212 212
Continue to review full report at Codecov.
|
🎉 |
As of now, this branch is seeing decent time improvements. I changed the benchmark repo's The results for current master:
And for the most recent commit to this PR:
|
Co-authored-by: Tom Bromley <49409390+trbromley@users.noreply.github.com>
@trbromley For the autograd and jax derivatives, we know that we just performed a forward pass, so the device is in the correct state. For the tensorflow and torch interfaces, we can perform the forward pass, the go do other stuff, maybe perform a forward pass with different parameters, and then perform the backward pass where we perform Caching the device state means keeping around a copy of the statevector, which may come with some memory overhead. At some point, we may switch to having this as default, but I don't want to do so just yet given the memory considerations. |
Co-authored-by: Tom Bromley <49409390+trbromley@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool improvements, just left very few comments.
Just to make sure, as my experience with the interfaces is quite limited still: Is there absolutely no way one could break the caching with intermediate calls / changes of the device/qnode state in the sense of recycling the wrong state?
if starting_state is not None: | ||
ket = self._reshape(starting_state, [2] * self.num_wires) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also check whether the array is flat and do
if starting_state is not None:
if starting_state.shape != (2,) * self.num_wires:
ket = self._reshape(starting_state, [2] * self.num_wires)
ket = starting_state
or is this unnecessary detail optimization?
pennylane/_qubit_device.py
Outdated
if starting_state is not None: | ||
ket = self._reshape(starting_state, [2] * self.num_wires) | ||
elif use_device_state: | ||
ket = self._reshape(self.state, [2] * self.num_wires) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we use self._state
here and skip the reshaping?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Substituting for self._prerotated_state
pennylane/_qubit_device.py
Outdated
else: | ||
self.reset() | ||
self.execute(tape) | ||
ket = self._reshape(self.state, [2] * self.num_wires) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here as well.
One could "reduce code length" slightly by doing
if starting_state is not None:
...
else:
if not use_device_state:
self.reset()
self.execute(tape)
ket = self._reshape(self.state, [2] * self.num_wires)
pennylane/_qubit_device.py
Outdated
for kk, bra_ in enumerate(bras): | ||
jac[kk, trainable_param_number] = 2 * dot_product_real(bra_, ket_temp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we do this with a single dot product across the 0th axis
of bras
?
jac[:, trainable_param_number] = 2 * np.real(np.tensordot(np.conj(bras), ket_temp, [2] * self.num_wires))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bras
is a list of arrays, rather than an array with an extra dimension. I thought about making it an array with an additional dimension. Think I'll go ahead and proceed with trying to make that work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, yes, sorry I overlooked that! The multiplication should be saving more than the conversion to array costs, I'd hope :)
@@ -174,6 +179,9 @@ def _evaluate_grad_matrix(grad_matrix_fn): | |||
if grad_matrix_fn in saved_grad_matrices: | |||
return saved_grad_matrices[grad_matrix_fn] | |||
|
|||
if self.jacobian_options.get("cache_state", False): | |||
self.jacobian_options["device_pd_options"] = {"starting_state": state} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this is fantasizing about very strange use cases, but could it make sense to also cache the output of self.jacobian_options.get("cache_state", False)
in order to prevent a later change of the keyword argument for a new forward pass from making the device forget that caching was actually activated for an earlier execution and the state is lying around locally?
Specified |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @albi3ro! The code is on point and am happy to give a ✔️, but I just wanted to discuss more the question of cache_state
. Concretely, my suggestion would be to:
- Make
cache_state
only cache when using the adjoint diff method - Have it be on by default
- Rename to
adjoint_cache
or similar
Would be great to get your thoughts though!
pennylane/_qubit_device.py
Outdated
use_device_state (bool): use current device state to initialize. A forward pass of the same | ||
circuit should be the last thing the device has executed. If a ``starting_state`` is | ||
provided, that takes precedence. | ||
return_obs (bool): return the expectation values alongside the jacobian as a tuple (jac, obs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return_obs (bool): return the expectation values alongside the jacobian as a tuple (jac, obs) | |
return_obs (bool): return the expectation values alongside the jacobian as a tuple ``(jac, obs)`` |
pennylane/_qubit_device.py
Outdated
Dimensions are ``(len(observables), len(trainable_params))``. | ||
Union[array, tuple(array)]: the derivative of the tape with respect to trainable parameters. | ||
Dimensions are ``(len(observables), len(trainable_params))``. If ``return_obs`` keyword is True, | ||
then returns ``jacobian, expectation_values`` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
then returns ``jacobian, expectation_values`` | |
then returns ``jacobian, expectation_values``. |
pennylane/_qubit_device.py
Outdated
n_obs = len(tape.observables) | ||
bras = np.empty([n_obs] + [2] * self.num_wires, dtype=np.complex128) | ||
for kk in range(n_obs): | ||
bras[kk, ...] = self._apply_operation(ket, tape.observables[kk]) | ||
|
||
lambdas = [self._apply_operation(phi, obs) for obs in tape.observables] | ||
# this can probably be more optimized, but at least now it works... | ||
if return_obs: | ||
expectation = dot_product_real(bras, ket) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the changes from phi and lambdas to ket and bras for readability?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep
pennylane/qnode.py
Outdated
cache_state=False (bool): for tensorflow and torch interfaces and adjoint differentiation, | ||
this indicates whether to save the device state after the forward pass. Doing so saves a | ||
forward execution. Device state automatically reused with autograd and jax interfaces. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤔 I was on the fence about adding cache_state
as a keyword argument, feeling that it might pollute the keyword arguments of the QNode for a fairly specific use case. However, can agree that it might be nice to have the option to turn off/on this feature due to memory limitations. One suggestion would be to rename to something more specific, e.g., adjoint_cache
since perhaps cache_state
might give the impression of something more general than it is.
I'd also suggest making cache_state=True
by default. If it's off by default, I think most users will not notice and not benefit from the feature, so all the hard work in this PR will not be quite so beneficial. Any users experiencing memory issues can always turn off caching if they have a problem.
pennylane/qnode.py
Outdated
cache_state=False (bool): for tensorflow and torch interfaces and adjoint differentiation, | ||
this indicates whether to save the device state after the forward pass. Doing so saves a | ||
forward execution. Device state automatically reused with autograd and jax interfaces. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, now I look again, when cache_state=True
, we cache the state in the Torch and TF interfaces regardless of the diff_method
? 🤔
Shouldn't we have it so that caching only occurs if diff_method="adjoint"
? We can expand to further use cases going forward.
In other words,
Currently: off by default, but when on caching occurs in TF and Torch interfaces regardless of diff method
My suggestion: on by default, and caching occurs in TF and Torch interfaces only when diff method is adjoint
I understand the caution with off-by-default in the former, but I think it's worth following a smart defaults philosophy.
Also makes sure executing a second circuit before backward pass does not interfere | ||
with answer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice to test this edge case!
@@ -192,6 +192,64 @@ def test_gradient_gate_with_multiple_parameters(self, tol, dev): | |||
# the different methods agree | |||
assert np.allclose(grad_D, grad_F, atol=tol, rtol=0) | |||
|
|||
def test_return_expectation(self, tol, dev): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still not sure about the addition of return_obs
- do you have a future use-case in mind?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Easy enough to add in later if a use case arises, so I'll remove it for now.
Co-authored-by: Tom Bromley <49409390+trbromley@users.noreply.github.com>
@trbromley Good catch on realizing that tensorflow and torch would be catching the state even when not using adjoint diff! Now the state gets cached by default, but only if using the adjoint diff method. I've renamed the keyword to I've also removed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @albi3ro, awesome!
Co-authored-by: Tom Bromley <49409390+trbromley@users.noreply.github.com>
The current adjoint differentiation performs a forward pass of the circuit, even though PennyLane will often have performed one just before calling the adjoint jacobian function.
This PR intends to make it easier to reduce the number of circuit executions when using adjoint diff. Instead of using a "one-size-fits-all" caching method, this PR introduces several different ways users could reduce forward pass executions:
compute expectation values and the jacobian at the same time inside of
adjoint_jacobian
with thereturn_obs
keywordstart with the current device state if requested by the
use_device_state
keywordstart with the provided state if it is provided to the
starting_state
keywordA note that
starting_state
takes precedence overuse_device_state
.Reducing the forward pass can reduce execution time of the device method call itself by a full 25%.
I am also working on editing the interfaces to take advantage of this speed up.