Skip to content

Commit

Permalink
update branch with master (#3871)
Browse files Browse the repository at this point in the history
* Allow flexible parameter shapes and multi-dimensional tape return shape in custom JVP (#3766)

* custom jvp with any parameter and tape output shape

* tests

* lint test

* remove print

* changelog

* fix test, remove old todo

* tiny

* Apply suggestions from code review

Co-authored-by: Albert Mitjans <a.mitjanscoma@gmail.com>

* move commentary to docstring, rename num_l

* black

* review

* [skip ci]

* string format [skip ci]

* black

* merge fix

* review

---------

Co-authored-by: Albert Mitjans <a.mitjanscoma@gmail.com>
Co-authored-by: Romain Moyard <rmoyard@gmail.com>

* Daily rc sync to master $(date +'%Y-%m-%d') (#3845)

* Undeprecate `ApproxTimeEvolution` (#3797)

* readd ApproxTimeEvolution

* More

* Full deprecation

* Fix (#3801)

* Register `SpecialUnitary` operation with `DefaultMixed` and `NullQubit` (#3651)

* Add SpecialUnitary, its utilities and tests

* changelog

* whitespace

* docstrings

* empty

* device test

* black

* linting and test coverage

* registered in default qubit and added tests

* registered in default mixed and added tests

* black

* changelog

* register with null qubit, add test, lint null qubit test as bonus

* black

* update docstring

* revert some changes/continued merge

* remove old namespace dependencies

* black

* lint

* docstring typo

* lint

* remove inverse testing

* change reference

* revert addition

* add ParametrizedEvolution to mixed device (#3794)

* add ParametrizedEvolution to mixed device

* black test

* make test more broad

* black tests

* Update tests/pulse/test_parametrized_evolution.py

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>

* sparse implementation

* Revert "Merge branch 'evolvemixed' of https://github.com/PennyLaneAI/pennylane into evolvemixed"

This reverts commit e653b50, reversing
changes made to bb2d3d6.

* Revert "sparse implementation"

This reverts commit bb2d3d6.

* review comment

* trigger ci

---------

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>

* little touchups while reviewing changes in 0.29 (#3809)

* Update pulse module documentation examples (#3805)

* init

* spacing

* windows parsing in rect()

* revert jnp.array parameters, doc typo

* test docstring, remove interface warning

* doc fix?

* lint

---------

Co-authored-by: lillian542 <38584660+lillian542@users.noreply.github.com>

* Raise error when using `qml.probs` in the computational basis with non-commuting observables (#3811)

* 🐛 fix (MeasurementProcess): Change if a measurement is using the computational basis.

* 🎨 style (MeasurementProcess): Remove useless properties.

* 🧪 tests (qscript): Fix tests.

* 🧪 tests (measurements): Fix tests.

* 🧪 tests (return_types): Fix tests.

* 🧪 tests (tape): Fix tests.

* 🧪 tests (circuit_graph): Fix tests.

* 🧪 tests (return_types): Fix tests.

* 🧪 tests (return_types): Fix tests.

* ✏️ chore (changelog): Add changelog entry.

* 🔧 refactor (MeasurementProcess): Change method name.

* 🔧 refactor (MeasurementProcess): Change method name.

* 🔧 refactor (MeasurementProcess): Change method name.

* Update pennylane/tape/qscript.py

Co-authored-by: Romain Moyard <rmoyard@gmail.com>

---------

Co-authored-by: Romain Moyard <rmoyard@gmail.com>

* docs: consistent naming of tape instance in docstr (#3795)

as title states. nothing fancy here

* Documentation fixes (#3808)

* fix some type hints

* minor docstring fixes

* minor docstring changes

* fix rendering issue

* Update pennylane/pulse/parametrized_hamiltonian.py

---------

Co-authored-by: Matthew Silverman <matthews@xanadu.ai>

* Fix docs of new features (#3818)

* ✏️ chore (changelog): Fix changelog entry.

* 📝 docs (BasisRotation): Fix docstrings.

* 📝 docs (xyx_decompositions): Fix docstring.

* 📝 docs (metric_tensor): Fix docstring.

* Minor fixes for v0.29 release (#3815)

* Added fixes for auto documentation

* Added minor doc changes for hadamard gradient

* Added fix for queuing diagonalizing gates

* Small fix to IsingZZ doc

* Fixed doc for max_entropy

---------

Co-authored-by: Albert Mitjans <a.mitjanscoma@gmail.com>

* Copy one-electron integrals inside `_chemist_transform` (#3816)

* fix inplace one-int issue in `_chemist_transform`

* add test to assert no inplace manipulation happens

* happy black

* Fix equal for measurements with observables, measurement process repr, Hermitian casting fix (#3820)

* equal with measurement processes

* not change math.asarray for autograd

* not change math.asarray for autograd

* Update doc/releases/changelog-0.29.0.md

* add coverage for measurements of observables that are equal but have different wire orders

* break up docstrings, test hermitian case

* Update exp.py

fixing docs typo in RC branch

* Docstring fixes for unchanged functions (#3824)

* Added docstring fixes for old functions

* Updated more docs

* Changing functions

* Fixed more docstrings

* Added changes to more function docs

* Fix to build sphinx

* Addressing PR comments

* fix docker link; remove old inv comment (#3823)

* Device api doc typo fixes (#3819)

* tiny doc edits

* Update pennylane/devices/experimental/device_api.py

Co-authored-by: Matthew Silverman <matthews@xanadu.ai>

* seealso

* Update pennylane/devices/experimental/device_api.py

Co-authored-by: Christina Lee <christina@xanadu.ai>

---------

Co-authored-by: Matthew Silverman <matthews@xanadu.ai>
Co-authored-by: Christina Lee <christina@xanadu.ai>

* un-deprecate make_tape (#3807)

* Update changelog (#3719)

* Update changelog-dev.md

* Remove

* Move pulse

* Update title

* pulse edits

* minor

* Absorb new additions

* Reorder improvements

* Add new section

* Add commas to contributors section

* Remove order requirement

* Rearrange

* Simplify pulse section

* Improve

* Improve

* Improve

* Improve

* Add additional section

* Update doc/releases/changelog-dev.md

* brackets around PR numbers

Co-authored-by: Tom Bromley <49409390+trbromley@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Isaac De Vlugt <34751083+isaacdevlugt@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Isaac De Vlugt <34751083+isaacdevlugt@users.noreply.github.com>

* Add suggestions

Co-authored-by: Isaac De Vlugt <34751083+isaacdevlugt@users.noreply.github.com>

* Update doc/releases/changelog-dev.md

* qml.

* removed TODO markers

* deps and breaking changes

* deps and breaking changes

* Move addtion

* Apply suggestions from code review

* Update

* Update

* Update

* Update

* Update

* Update

* Apply suggestions from code review

* Update

* Update

* Update

* Update

* Add for Evolve

* Add

* Apply suggestions from code review

Co-authored-by: Korbinian Kottmann <43949391+Qottmann@users.noreply.github.com>

* Fix code

* minor

* minor

* minor

* minor

* minor

* minor

* minor

* minor

* minor

* minor

* Tom

* Add entry

* Update name

* Update

* Rearrange for auto

* Typo

* Remove ApproxTimeEvolution deprecation

* Typo

* minor

* minor

* Fix warning

* Maybe fix warning

* Add a few entries

* Add breaking change

* Maybe fix error

* Maybe fix warning

* minor

* minor

* add details to changelog for marginal_prob bug (#3821)

* Add subsection

* Update

* Update

* Add

---------

Co-authored-by: Isaac De Vlugt <isaacdevlugt@gmail.com>
Co-authored-by: Isaac De Vlugt <34751083+isaacdevlugt@users.noreply.github.com>
Co-authored-by: Korbinian Kottmann <43949391+Qottmann@users.noreply.github.com>
Co-authored-by: Matthew Silverman <matthews@xanadu.ai>

* Fix sign expand (#3814)

* Tape queue to script

* Fix doc

* QNode example

* Apply suggestions from code review

* Apply suggestions from code review

Co-authored-by: Matthew Silverman <matthews@xanadu.ai>

* use new pauli module

---------

Co-authored-by: Matthew Silverman <matthews@xanadu.ai>
Co-authored-by: Albert Mitjans <a.mitjanscoma@gmail.com>

* Fix (#3828)

* Remove warning. (#3829)

* Update deprecations page (#3830)

* Update grouping

* Update seed recipes

* Update QubitDevice.statistics

* Raise error if JAX or jaxlib are 0.4.4 (#3813)

* Raise error for 0.4.4 jaxlib and jax

* add tests

* black for tests

* add jax mark for test

* Remove jaxlib check

Co-authored-by: Filippo Vicentini <filippovicentini@gmail.com>

* More detailed error msg

Co-authored-by: Filippo Vicentini <filippovicentini@gmail.com>

* Update tests

* Tests and black

* Typo

* Change

---------

Co-authored-by: Filippo Vicentini <filippovicentini@gmail.com>
Co-authored-by: Romain Moyard <rmoyard@gmail.com>

* 0.29 changelog updates (#3835)

* Update codeblock type

* Update example

* Wording

* Add CTA

* Update wire number

* Remove wire

* Update wording

* Add

* Fix imports

* Add deprecation

* Remove tilde

* Merge data and other improvements

* Typo

* Typo

* Typo

* Emojis

* Remove whitespace

* remove duplicate entries

* Apply suggestions from code review

Co-authored-by: Matthew Silverman <matthews@xanadu.ai>

---------

Co-authored-by: Matthew Silverman <matthews@xanadu.ai>

* Docfix/circuits note gpu (#3836)

* Fix note on lightning.gpu plugin in intro/circuits.

* Update doc/releases/changelog-0.29.0.md.

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>

* Promote warning about vanilla version of NumPy to RST warning. (#3838)

* Promote warning about vanilla version of NumPy from bold format to RST warning. Make Autograd case uniform.

* Update changlog.

* fix strings that wrap around without a space (#3837)

* fix strings that wrap around without a space

* missed one

* always use trailing whitespace instead of leading

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>

* Use default.qubit.jax instead of default.qubit (#3839)

* Use default.qubit.jax instead of default.qubit

* Apply suggestions from code review

Taking the liberty to commit these changes here due to time sensitivity

* Revert "Apply suggestions from code review
"

This reverts commit d457ffb.

---------

Co-authored-by: Korbinian Kottmann <43949391+Qottmann@users.noreply.github.com>
Co-authored-by: Korbinian Kottmann <Korbinian.Kottmann@gmail.com>

* Update header images (#3827)

Co-authored-by: Tom Bromley <49409390+trbromley@users.noreply.github.com>

* exclude files from pr

---------

Co-authored-by: Romain Moyard <rmoyard@gmail.com>
Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
Co-authored-by: Korbinian Kottmann <43949391+Qottmann@users.noreply.github.com>
Co-authored-by: Matthew Silverman <matthews@xanadu.ai>
Co-authored-by: lillian542 <38584660+lillian542@users.noreply.github.com>
Co-authored-by: Albert Mitjans <a.mitjanscoma@gmail.com>
Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
Co-authored-by: Utkarsh <utkarshazad98@gmail.com>
Co-authored-by: Nathan Killoran <co9olguy@users.noreply.github.com>
Co-authored-by: Tom Bromley <49409390+trbromley@users.noreply.github.com>
Co-authored-by: Isaac De Vlugt <isaacdevlugt@gmail.com>
Co-authored-by: Isaac De Vlugt <34751083+isaacdevlugt@users.noreply.github.com>
Co-authored-by: Filippo Vicentini <filippovicentini@gmail.com>
Co-authored-by: Vincent Michaud-Rioux <vincentm@nanoacademic.com>
Co-authored-by: Korbinian Kottmann <Korbinian.Kottmann@gmail.com>
Co-authored-by: GitHub Actions Bot <>

* remove broken date from RC sync PR action (#3849)

I'm not sure what made this break, but it's broken. We can just get the date from when it was opened if we're curious!

Co-authored-by: Josh Izaac <josh146@gmail.com>

* ensure provided order of dataset parameters is preserved (#3856)

* ensure provided order of dataset parameters is preserved

* add another test to ensure order stability

* add another utest case without full

* changelog entry

* Fix autoray due to newly released version (#3865)

* autoray fixes

* Update pennylane/ops/qubit/parametric_ops.py

* changelog

* Update doc/releases/changelog-dev.md

* Update doc/releases/changelog-dev.md

Co-authored-by: Romain Moyard <rmoyard@gmail.com>

* use autograd, change test back to master version

---------

Co-authored-by: Romain Moyard <rmoyard@gmail.com>

* Update the definition of trainable params with Jax: right trainable parameters for JIT (#3697)

* First version

* Change params

* Changes

* ji_compute_jvp

* Working version for new return

* Change to old return system

* Update tests and jax-jit int

* Update tests

* Fix parameter shift hessian

* Update comment

* black

* Apply suggestions from code review

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>

* Pull and black

* Update pennylane/interfaces/jax_jit_tuple.py

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>

* Argnums in Jax transforms (#3750)

* First draft

* Fix merge

* Black and pylint

* Update

* Classical jacobians tests

* Rename argnum to argnum

* Finite diff and param shift with tests

* Typo

* Second derivative

* Update

* Test torch

* Trigger CI

* Hessian transforms

* Typo tests

* Too many branches

* Argnum renamed

* Update pennylane/fourier/qnode_spectrum.py

* Apply suggestions from code review

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>

* Apply suggestions from code review

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>

* From review

* Merge introduced typo

* Typo

---------

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>

* Tests new return type

* Argnums = [0] is default

* Typo

* Add tests hadamard

* Changelog

* Update doc/development/deprecations.rst

* Apply suggestions from code review

Co-authored-by: Matthew Silverman <matthews@xanadu.ai>

* black

* Update doc/releases/changelog-dev.md

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>

* Update doc/development/deprecations.rst

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>

* Update doc/development/deprecations.rst

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>

* Update from review

* Review change

---------

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
Co-authored-by: Matthew Silverman <matthews@xanadu.ai>

* fix jax version in upload workflow (#3843)

Co-authored-by: Romain Moyard <rmoyard@gmail.com>
Co-authored-by: Tom Bromley <49409390+trbromley@users.noreply.github.com>

---------

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
Co-authored-by: Albert Mitjans <a.mitjanscoma@gmail.com>
Co-authored-by: Romain Moyard <rmoyard@gmail.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Korbinian Kottmann <43949391+Qottmann@users.noreply.github.com>
Co-authored-by: Matthew Silverman <matthews@xanadu.ai>
Co-authored-by: lillian542 <38584660+lillian542@users.noreply.github.com>
Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
Co-authored-by: Utkarsh <utkarshazad98@gmail.com>
Co-authored-by: Nathan Killoran <co9olguy@users.noreply.github.com>
Co-authored-by: Tom Bromley <49409390+trbromley@users.noreply.github.com>
Co-authored-by: Isaac De Vlugt <isaacdevlugt@gmail.com>
Co-authored-by: Isaac De Vlugt <34751083+isaacdevlugt@users.noreply.github.com>
Co-authored-by: Filippo Vicentini <filippovicentini@gmail.com>
Co-authored-by: Vincent Michaud-Rioux <vincentm@nanoacademic.com>
Co-authored-by: Korbinian Kottmann <Korbinian.Kottmann@gmail.com>
Co-authored-by: Josh Izaac <josh146@gmail.com>
  • Loading branch information
19 people authored Mar 6, 2023
1 parent 7907be3 commit 16cb926
Show file tree
Hide file tree
Showing 45 changed files with 1,577 additions and 490 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/rc_sync.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ jobs:
source_branch: "${{ env.tmp_branch }}"
destination_branch: "master"
github_token: "${{ secrets.GITHUB_TOKEN }}"
pr_title: "Daily rc sync to master $(date +'%Y-%m-%d')"
pr_title: "Daily rc sync to master"
pr_body: "Automatic sync from the release candidate to master during a feature freeze."
pr_allow_empty: false
pr_draft: false
Expand Down
9 changes: 5 additions & 4 deletions .github/workflows/upload.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
types: [published]

env:
JAX_VERSION: 0.4.3
TF_VERSION: 2.10.0
TORCH_VERSION: 1.13.0

Expand Down Expand Up @@ -67,7 +68,7 @@ jobs:
# to the latest release. We can always fix a version later if it breaks.
- name: Conditionally install JAX
if: matrix.config.suite == 'jax'
run: pip3 install jax jaxlib
run: pip3 install jax==$JAX_VERSION jaxlib==$JAX_VERSION

- name: Install PennyLane
run: |
Expand Down Expand Up @@ -153,7 +154,7 @@ jobs:
run: pip3 install tensorflow~=$TF_VERSION keras~=$TF_VERSION

- name: Install JAX
run: pip3 install jax jaxlib
run: pip3 install jax==$JAX_VERSION jaxlib==$JAX_VERSION

- name: Install KaHyPar
run: pip3 install kahypar==1.1.7
Expand Down Expand Up @@ -290,7 +291,7 @@ jobs:
run: pip3 install tensorflow~=$TF_VERSION keras~=$TF_VERSION

- name: Install JAX
run: pip3 install jax jaxlib
run: pip3 install jax==$JAX_VERSION jaxlib==$JAX_VERSION

- name: Install PennyLane
run: |
Expand Down Expand Up @@ -344,7 +345,7 @@ jobs:

- name: Conditionally install Jax
if: contains(matrix.config.device, 'jax')
run: pip3 install jax jaxlib
run: pip3 install jax==$JAX_VERSION jaxlib==$JAX_VERSION

- name: Install PennyLane
run: |
Expand Down
Binary file modified doc/_static/header-dark-mode.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified doc/_static/header.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 6 additions & 0 deletions doc/development/deprecations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ Deprecations
Pending deprecations
--------------------

* The argument ``argnum`` for gradient transforms using the Jax interface is replaced by ``argnums``.

- ``argnum`` is automatically changed to ``argnums`` for gradient transforms using JAX and a warning is raised in v0.30
- ``argnums`` is the only option for gradient transforms using JAX in v0.31


* The ``get_operation`` tape method is updated to return the operation index as well, changing its signature.

- The new signature is available by changing the arg ``return_op_index`` to ``True`` in v0.29
Expand Down
6 changes: 3 additions & 3 deletions doc/releases/changelog-0.29.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@
return qml.expval(qml.PauliX(0) @ qml.PauliY(1))
```

Pulse-based circuits can be executed and differentiated on the `default.qubit` simulator using JAX
as an interface:
Pulse-based circuits can be executed and differentiated on the `default.qubit.jax` simulator using
JAX as an interface:

```pycon
>>> dev = qml.device("default.qubit", wires=2)
>>> dev = qml.device("default.qubit.jax", wires=2)
>>> qnode = qml.QNode(pulse_circuit, dev, interface="jax")
>>> params = (p1, p2)
>>> qnode(params, time=0.5)
Expand Down
25 changes: 25 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,31 @@

<h3>Improvements</h3>

* The custom JVP rules in PennyLane now also support non-scalar and mixed-shape tape parameters as
well as multi-dimensional tape return types, like broadcasted `qml.probs`, for example.
[(#3766)](https://github.com/PennyLaneAI/pennylane/pull/3766)

* The `qchem.jordan_wigner` function is extended to support more fermionic operator orders.
[(#3754)](https://github.com/PennyLaneAI/pennylane/pull/3754)
[(#3751)](https://github.com/PennyLaneAI/pennylane/pull/3751)

* `AdaptiveOptimizer` is updated to use non-default user-defined qnode arguments.
[(#3765)](https://github.com/PennyLaneAI/pennylane/pull/3765)

* When using Jax-jit with gradient transforms the trainable parameters are correctly set (instead of every parameter
to be set as trainable), and therefore the derivatives are computed more efficiently.
[(#3697)](https://github.com/PennyLaneAI/pennylane/pull/3697)

<h3>Breaking changes</h3>

* Trainable parameters for the Jax interface are the parameters that are `JVPTracer`, defined by setting
`argnums`. Previously, all JAX tracers, including those used for JIT compilation, were interpreted to be trainable.
[(#3697)](https://github.com/PennyLaneAI/pennylane/pull/3697)

* The keyword argument `argnums` is now used for gradient transform using Jax, instead of `argnum`.
`argnum` is automatically converted to `argnums` when using JAX, and will no longer be supported in v0.31.
[(#3697)](https://github.com/PennyLaneAI/pennylane/pull/3697)

<h3>Deprecations</h3>

<h3>Documentation</h3>
Expand All @@ -32,13 +48,22 @@

<h3>Bug fixes</h3>

* Registers `math.ndim` and `math.shape` for built-ins and autograd to accomodate Autoray 0.6.1.
[#3864](https://github.com/PennyLaneAI/pennylane/pull/3865)

* Ensure that `qml.data.load` returns datasets in a stable and expected order.
[(#3856)](https://github.com/PennyLaneAI/pennylane/pull/3856)

<h3>Contributors</h3>

This release contains contributions from (in alphabetical order):

Utkarsh Azad
Soran Jahangiri
Christina Lee
Vincent Michaud-Rioux
Romain Moyard
Mudit Pandey
Matthew Silverman
Jay Soni
David Wierichs
6 changes: 5 additions & 1 deletion pennylane/data/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,11 @@ def _generate_folders(node, folders):
"""

next_folders = folders[1:]
folders = set(node) if folders[0] == ["full"] else set(folders[0]).intersection(set(node))
if folders[0] == ["full"]:
folders = node
else:
values_for_this_node = set(folders[0]).intersection(set(node))
folders = [f for f in folders[0] if f in values_for_this_node]
return (
[
os.path.join(folder, child)
Expand Down
53 changes: 43 additions & 10 deletions pennylane/gradients/gradient_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,23 +273,44 @@ def __init__(
self.hybrid = hybrid
super().__init__(transform_fn, expand_fn=expand_fn, differentiable=differentiable)

def default_qnode_wrapper(self, qnode, targs, tkwargs):
def default_qnode_wrapper(self, qnode, targs, tkwargs): # pylint: disable=too-many-statements
# Here, we overwrite the QNode execution wrapper in order
# to take into account that classical processing may be present
# inside the QNode.
hybrid = tkwargs.pop("hybrid", self.hybrid)
argnums = tkwargs.pop("argnums", None)

_wrapper = super().default_qnode_wrapper(qnode, targs, tkwargs)

cjac_fn = qml.transforms.classical_jacobian(
qnode, argnum=argnums, expand_fn=expand_invalid_trainable
)

def jacobian_wrapper(
*args, **kwargs
): # pylint: disable=too-many-return-statements, too-many-branches
if not qml.math.get_trainable_indices(args):
): # pylint: disable=too-many-return-statements, too-many-branches, too-many-statements
argnum = tkwargs.get("argnum", None)
argnums = tkwargs.get("argnums", None)

interface = qml.math.get_interface(*args)
trainable_params = qml.math.get_trainable_indices(args)

if interface == "jax" and argnum:
warnings.warn(
"argnum is deprecated with the Jax interface. You should use argnums instead."
)
tkwargs.pop("argnum")
argnums = argnum

argnums_ = None
if interface == "jax" and not trainable_params:
if argnums is None:
argnums_ = [0]

else:
argnums_ = [argnums] if isinstance(argnums, int) else argnums

params = qml.math.jax_argnums_to_tape_trainable(
qnode, argnums_, self.expand_fn, args, kwargs
)
argnums_ = qml.math.get_trainable_indices(params)
kwargs["argnums"] = argnums_

if not trainable_params and argnums is None and argnums_ is None:
warnings.warn(
"Attempted to compute the gradient of a QNode with no trainable parameters. "
"If this is unintended, please add trainable parameters in accordance with "
Expand All @@ -303,7 +324,17 @@ def jacobian_wrapper(
return qjac

kwargs.pop("shots", False)
cjac = cjac_fn(*args, **kwargs)

# Special case where we apply a Jax transform (jacobian e.g.) on the gradient transform and argnums are
# defined on the outer transform and therefore on the args.
if interface == "jax":
argnum_cjac = trainable_params or argnums
else:
argnum_cjac = None

cjac = qml.transforms.classical_jacobian(
qnode, argnum=argnum_cjac, expand_fn=self.expand_fn
)(*args, **kwargs)

if qml.active_return():
if isinstance(cjac, tuple) and len(cjac) == 1:
Expand Down Expand Up @@ -373,6 +404,8 @@ def jacobian_wrapper(
jacs = tuple(
qml.math.tensordot(qjac, c, [[-1], [0]]) for c in cjac if c is not None
)
if len(jacs) == 1:
return jacs[0]
return jacs

is_square = cjac.ndim == 2 and cjac.shape[0] == cjac.shape[1]
Expand Down
27 changes: 20 additions & 7 deletions pennylane/gradients/hessian_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,17 +157,24 @@ def default_qnode_wrapper(self, qnode, targs, tkwargs):
# Here, we overwrite the QNode execution wrapper in order to take into account
# that classical processing may be present inside the QNode.
hybrid = tkwargs.pop("hybrid", self.hybrid)
argnums = tkwargs.get("argnums", None)

old_interface = qnode.interface

if old_interface == "auto":
qnode.interface = qml.math.get_interface(*targs, *list(tkwargs.values()))

_wrapper = super().default_qnode_wrapper(qnode, targs, tkwargs)
cjac_fn = qml.transforms.classical_jacobian(qnode, expand_fn=self.expand_fn)
cjac_fn = qml.transforms.classical_jacobian(qnode, argnum=argnums, expand_fn=self.expand_fn)

def hessian_wrapper(*args, **kwargs): # pylint: disable=too-many-branches
if argnums is not None:
argnums_ = [argnums] if isinstance(argnums, int) else argnums

def hessian_wrapper(*args, **kwargs):
if not qml.math.get_trainable_indices(args):
params = qml.math.jax_argnums_to_tape_trainable(
qnode, argnums_, self.expand_fn, args, kwargs
)
argnums_ = qml.math.get_trainable_indices(params)
kwargs["argnums"] = argnums_

if not qml.math.get_trainable_indices(args) and not argnums:
warnings.warn(
"Attempted to compute the hessian of a QNode with no trainable parameters. "
"If this is unintended, please add trainable parameters in accordance with "
Expand All @@ -187,7 +194,13 @@ def hessian_wrapper(*args, **kwargs):
qhess = (qhess,)

kwargs.pop("shots", False)
cjac = cjac_fn(*args, **kwargs)

if argnums is None and qml.math.get_interface(*args) == "jax":
cjac = qml.transforms.classical_jacobian(
qnode, argnum=qml.math.get_trainable_indices(args), expand_fn=self.expand_fn
)(*args, **kwargs)
else:
cjac = cjac_fn(*args, **kwargs)

has_single_arg = False
if not isinstance(cjac, tuple):
Expand Down
Loading

0 comments on commit 16cb926

Please sign in to comment.