Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjoint differentation improvements #1341

Merged
merged 44 commits into from
Jun 9, 2021
Merged

Conversation

albi3ro
Copy link
Contributor

@albi3ro albi3ro commented May 19, 2021

The current adjoint differentiation performs a forward pass of the circuit, even though PennyLane will often have performed one just before calling the adjoint jacobian function.

This PR intends to make it easier to reduce the number of circuit executions when using adjoint diff. Instead of using a "one-size-fits-all" caching method, this PR introduces several different ways users could reduce forward pass executions:

  1. compute expectation values and the jacobian at the same time inside of adjoint_jacobian with the return_obs keyword

  2. start with the current device state if requested by the use_device_state keyword

  3. start with the provided state if it is provided to the starting_state keyword

A note that starting_state takes precedence over use_device_state.

Reducing the forward pass can reduce execution time of the device method call itself by a full 25%.

I am also working on editing the interfaces to take advantage of this speed up.

@github-actions
Copy link
Contributor

Hello. You may have forgotten to update the changelog!
Please edit .github/CHANGELOG.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

@codecov
Copy link

codecov bot commented May 19, 2021

Codecov Report

Merging #1341 (347a979) into master (0b5da11) will increase coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##           master    #1341   +/-   ##
=======================================
  Coverage   98.18%   98.19%           
=======================================
  Files         157      157           
  Lines       11687    11713   +26     
=======================================
+ Hits        11475    11501   +26     
  Misses        212      212           
Impacted Files Coverage Δ
pennylane/interfaces/jax.py 89.74% <ø> (ø)
pennylane/_qubit_device.py 98.58% <100.00%> (+0.03%) ⬆️
pennylane/interfaces/tf.py 97.43% <100.00%> (+0.29%) ⬆️
pennylane/interfaces/torch.py 100.00% <100.00%> (ø)
pennylane/qnode.py 97.85% <100.00%> (+0.02%) ⬆️
pennylane/tape/jacobian_tape.py 98.01% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 0b5da11...347a979. Read the comment docs.

@josh146
Copy link
Member

josh146 commented May 20, 2021

🎉

@albi3ro
Copy link
Contributor Author

albi3ro commented May 20, 2021

As of now, this branch is seeing decent time improvements. I changed the benchmark repo's GradientComputation_light benchmark to use the adjoint method.

The results for current master:

[  0.00%] · For pennylane commit 411c3d52 <master>:
[  0.00%] ·· Benchmarking virtualenv-py3.8-Qulacs-jax-jaxlib-networkx-qsimcirq-tensorflow-torch
[  0.05%] ··· Running (asv.core_suite.GradientComputation_light.time_gradient--).
[  0.09%] ··· asv.core_suite.GradientComputation_light.time_gradient                                                                                       4/16 failed
[  0.09%] ··· ========= ============== ============ ============ ========= ============== ============ ============= =========
              --                                                 n_layers / interface
              --------- ------------------------------------------------------------------------------------------------------
               n_wires   3 / autograd     3 / tf     3 / torch    3 / jax   6 / autograd     6 / tf      6 / torch    6 / jax
              ========= ============== ============ ============ ========= ============== ============ ============= =========
                  2       5.11±0.1ms     10.4±2ms    4.77±0.1ms    failed    7.45±0.3ms    14.2±0.2ms   7.14±0.05ms    failed
                  5       10.2±0.2ms    18.6±0.5ms   9.76±0.3ms    failed    17.2±0.3ms    31.7±0.3ms    16.6±0.6ms    failed
              ========= ============== ============ ============ ========= ============== ============ ============= =========

And for the most recent commit to this PR:

    [  0.00%] · For pennylane commit 7a26a6ce <adjoint_diff_improvements>:
[  0.00%] ·· Benchmarking virtualenv-py3.8-Qulacs-jax-jaxlib-networkx-qsimcirq-tensorflow-torch
[  0.05%] ··· Running (asv.core_suite.GradientComputation_light.time_gradient--).
[  0.09%] ··· asv.core_suite.GradientComputation_light.time_gradient                                                                                       3/16 failed
[  0.09%] ··· ========= ============== ============ ============ ========= ============== ============ ============ =========
              --                                                 n_layers / interface
              --------- -----------------------------------------------------------------------------------------------------
               n_wires   3 / autograd     3 / tf     3 / torch    3 / jax   6 / autograd     6 / tf     6 / torch    6 / jax
              ========= ============== ============ ============ ========= ============== ============ ============ =========
                  2      4.35±0.05ms    9.00±0.3ms    4.98±1ms    235±2ms    6.76±0.3ms     14.5±3ms    6.74±0.1ms    failed
                  5      8.63±0.03ms    17.0±0.6ms   9.82±0.2ms    failed    15.2±0.5ms    29.7±0.7ms   16.1±0.4ms    failed
              ========= ============== ============ ============ ========= ============== ============ ============ =========

@albi3ro albi3ro requested a review from trbromley June 1, 2021 16:37
Co-authored-by: Tom Bromley <49409390+trbromley@users.noreply.github.com>
@albi3ro
Copy link
Contributor Author

albi3ro commented Jun 4, 2021

@trbromley For the autograd and jax derivatives, we know that we just performed a forward pass, so the device is in the correct state.

For the tensorflow and torch interfaces, we can perform the forward pass, the go do other stuff, maybe perform a forward pass with different parameters, and then perform the backward pass where we perform dev.adjoint_jacobian(qtape). So we cannot use the current device state, since other things may have occurred on the device since the forward pass.

Caching the device state means keeping around a copy of the statevector, which may come with some memory overhead. At some point, we may switch to having this as default, but I don't want to do so just yet given the memory considerations.

Co-authored-by: Tom Bromley <49409390+trbromley@users.noreply.github.com>
Copy link
Contributor

@dwierichs dwierichs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool improvements, just left very few comments.
Just to make sure, as my experience with the interfaces is quite limited still: Is there absolutely no way one could break the caching with intermediate calls / changes of the device/qnode state in the sense of recycling the wrong state?

Comment on lines +847 to +848
if starting_state is not None:
ket = self._reshape(starting_state, [2] * self.num_wires)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also check whether the array is flat and do

if starting_state is not None:
    if starting_state.shape != (2,) * self.num_wires:
        ket = self._reshape(starting_state, [2] * self.num_wires)
    ket = starting_state

or is this unnecessary detail optimization?

if starting_state is not None:
ket = self._reshape(starting_state, [2] * self.num_wires)
elif use_device_state:
ket = self._reshape(self.state, [2] * self.num_wires)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we use self._state here and skip the reshaping?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Substituting for self._prerotated_state

else:
self.reset()
self.execute(tape)
ket = self._reshape(self.state, [2] * self.num_wires)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here as well.
One could "reduce code length" slightly by doing

if starting_state is not None:
    ...
else:
    if not use_device_state:
        self.reset()
        self.execute(tape)
    ket = self._reshape(self.state, [2] * self.num_wires)

Comment on lines 896 to 897
for kk, bra_ in enumerate(bras):
jac[kk, trainable_param_number] = 2 * dot_product_real(bra_, ket_temp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we do this with a single dot product across the 0th axis of bras?

jac[:, trainable_param_number] = 2 * np.real(np.tensordot(np.conj(bras), ket_temp, [2] * self.num_wires))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bras is a list of arrays, rather than an array with an extra dimension. I thought about making it an array with an additional dimension. Think I'll go ahead and proceed with trying to make that work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes, sorry I overlooked that! The multiplication should be saving more than the conversion to array costs, I'd hope :)

@@ -174,6 +179,9 @@ def _evaluate_grad_matrix(grad_matrix_fn):
if grad_matrix_fn in saved_grad_matrices:
return saved_grad_matrices[grad_matrix_fn]

if self.jacobian_options.get("cache_state", False):
self.jacobian_options["device_pd_options"] = {"starting_state": state}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this is fantasizing about very strange use cases, but could it make sense to also cache the output of self.jacobian_options.get("cache_state", False) in order to prevent a later change of the keyword argument for a new forward pass from making the device forget that caching was actually activated for an earlier execution and the state is lying around locally?

@albi3ro
Copy link
Contributor Author

albi3ro commented Jun 7, 2021

[  0.00%] · For pennylane commit d8b6011d <adjoint_diff_improvements>:
[  0.00%] ·· Benchmarking virtualenv-py3.8-Qulacs-jax-jaxlib-networkx-qsimcirq-tensorflow-torch
[  0.04%] ··· Running (asv.core_suite.GradientComputation_light.time_gradient--).
[  0.09%] ··· ...uite.GradientComputation_light.time_gradient        4/16 failed
[  0.09%] ··· ========= ========== =========== =============
               n_wires   n_layers   interface
              --------- ---------- ----------- -------------
                  2         3        autograd   4.49±0.05ms
                  2         3           tf       8.73±0.1ms
                  2         3         torch      4.42±0.3ms
                  2         3          jax         failed
                  2         6        autograd    6.48±0.1ms
                  2         6           tf       13.3±0.2ms
                  2         6         torch     6.25±0.05ms
                  2         6          jax         failed
                  5         3        autograd   8.90±0.08ms
                  5         3           tf       17.0±0.2ms
                  5         3         torch      8.39±0.1ms
                  5         3          jax         failed
                  5         6        autograd    15.3±0.2ms
                  5         6           tf       29.1±0.4ms
                  5         6         torch      14.3±0.2ms
                  5         6          jax         failed


[  0.00%] · For pennylane commit 0dcbed53 <master>:
[  0.00%] ·· Benchmarking virtualenv-py3.8-Qulacs-jax-jaxlib-networkx-qsimcirq-tensorflow-torch
[  0.04%] ··· Running (asv.core_suite.GradientComputation_light.time_gradient--).
[  0.09%] ··· ...uite.GradientComputation_light.time_gradient        3/16 failed
[  0.09%] ··· ========= ========== =========== =============
               n_wires   n_layers   interface
              --------- ---------- ----------- -------------
                  2         3        autograd   4.81±0.05ms
                  2         3           tf       9.14±0.1ms
                  2         3         torch     4.60±0.04ms
                  2         3          jax         failed
                  2         6        autograd    7.58±0.3ms
                  2         6           tf       13.9±0.2ms
                  2         6         torch      6.98±0.3ms
                  2         6          jax         failed
                  5         3        autograd   10.1±0.09ms
                  5         3           tf       18.1±0.3ms
                  5         3         torch      9.81±0.3ms
                  5         3          jax        682±8ms
                  5         6        autograd    17.6±0.3ms
                  5         6           tf       31.7±0.1ms
                  5         6         torch     16.4±0.05ms
                  5         6          jax         failed
              ========= ========== =========== =============

Specified cache_state=True in the benchmark. Looks like all three testable interfaces show improvement!

@albi3ro albi3ro requested a review from trbromley June 7, 2021 19:10
Copy link
Contributor

@trbromley trbromley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @albi3ro! The code is on point and am happy to give a ✔️, but I just wanted to discuss more the question of cache_state. Concretely, my suggestion would be to:

  • Make cache_state only cache when using the adjoint diff method
  • Have it be on by default
  • Rename to adjoint_cache or similar

Would be great to get your thoughts though!

use_device_state (bool): use current device state to initialize. A forward pass of the same
circuit should be the last thing the device has executed. If a ``starting_state`` is
provided, that takes precedence.
return_obs (bool): return the expectation values alongside the jacobian as a tuple (jac, obs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return_obs (bool): return the expectation values alongside the jacobian as a tuple (jac, obs)
return_obs (bool): return the expectation values alongside the jacobian as a tuple ``(jac, obs)``

Dimensions are ``(len(observables), len(trainable_params))``.
Union[array, tuple(array)]: the derivative of the tape with respect to trainable parameters.
Dimensions are ``(len(observables), len(trainable_params))``. If ``return_obs`` keyword is True,
then returns ``jacobian, expectation_values``
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
then returns ``jacobian, expectation_values``
then returns ``jacobian, expectation_values``.

Comment on lines 868 to 875
n_obs = len(tape.observables)
bras = np.empty([n_obs] + [2] * self.num_wires, dtype=np.complex128)
for kk in range(n_obs):
bras[kk, ...] = self._apply_operation(ket, tape.observables[kk])

lambdas = [self._apply_operation(phi, obs) for obs in tape.observables]
# this can probably be more optimized, but at least now it works...
if return_obs:
expectation = dot_product_real(bras, ket)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the changes from phi and lambdas to ket and bras for readability?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep

pennylane/_qubit_device.py Outdated Show resolved Hide resolved
Comment on lines 126 to 128
cache_state=False (bool): for tensorflow and torch interfaces and adjoint differentiation,
this indicates whether to save the device state after the forward pass. Doing so saves a
forward execution. Device state automatically reused with autograd and jax interfaces.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 I was on the fence about adding cache_state as a keyword argument, feeling that it might pollute the keyword arguments of the QNode for a fairly specific use case. However, can agree that it might be nice to have the option to turn off/on this feature due to memory limitations. One suggestion would be to rename to something more specific, e.g., adjoint_cache since perhaps cache_state might give the impression of something more general than it is.

I'd also suggest making cache_state=True by default. If it's off by default, I think most users will not notice and not benefit from the feature, so all the hard work in this PR will not be quite so beneficial. Any users experiencing memory issues can always turn off caching if they have a problem.

Comment on lines 126 to 128
cache_state=False (bool): for tensorflow and torch interfaces and adjoint differentiation,
this indicates whether to save the device state after the forward pass. Doing so saves a
forward execution. Device state automatically reused with autograd and jax interfaces.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, now I look again, when cache_state=True, we cache the state in the Torch and TF interfaces regardless of the diff_method? 🤔

Shouldn't we have it so that caching only occurs if diff_method="adjoint"? We can expand to further use cases going forward.

In other words,
Currently: off by default, but when on caching occurs in TF and Torch interfaces regardless of diff method
My suggestion: on by default, and caching occurs in TF and Torch interfaces only when diff method is adjoint

I understand the caution with off-by-default in the former, but I think it's worth following a smart defaults philosophy.

Comment on lines +793 to +794
Also makes sure executing a second circuit before backward pass does not interfere
with answer.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice to test this edge case!

tests/interfaces/test_qnode_torch.py Outdated Show resolved Hide resolved
@@ -192,6 +192,64 @@ def test_gradient_gate_with_multiple_parameters(self, tol, dev):
# the different methods agree
assert np.allclose(grad_D, grad_F, atol=tol, rtol=0)

def test_return_expectation(self, tol, dev):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not sure about the addition of return_obs - do you have a future use-case in mind?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Easy enough to add in later if a use case arises, so I'll remove it for now.

@albi3ro
Copy link
Contributor Author

albi3ro commented Jun 9, 2021

@trbromley Good catch on realizing that tensorflow and torch would be catching the state even when not using adjoint diff!

Now the state gets cached by default, but only if using the adjoint diff method.

I've renamed the keyword to adjoint_cache.

I've also removed return_obs. May be useful in the future at some point, but I can easily add it in when the use case arises.

@albi3ro albi3ro requested a review from trbromley June 9, 2021 14:23
Copy link
Contributor

@trbromley trbromley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @albi3ro, awesome!

.github/CHANGELOG.md Outdated Show resolved Hide resolved
.github/CHANGELOG.md Outdated Show resolved Hide resolved
pennylane/qnode.py Outdated Show resolved Hide resolved
tests/interfaces/test_qnode_torch.py Outdated Show resolved Hide resolved
tests/interfaces/test_qnode_tf.py Outdated Show resolved Hide resolved
.github/CHANGELOG.md Show resolved Hide resolved
pennylane/qnode.py Outdated Show resolved Hide resolved
@albi3ro albi3ro merged commit 9a52aa1 into master Jun 9, 2021
@albi3ro albi3ro deleted the adjoint_diff_improvements branch June 9, 2021 16:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
review-ready 👌 PRs which are ready for review by someone from the core team.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants