Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable piece-wise-constant functions #3645

Merged
merged 76 commits into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from 61 commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
107b132
Initial draft of time dependent hamiltonian
lillian542 Jan 9, 2023
b0ab45d
Allow creation of TDHamiltonian by multiplication of fn and Observable
lillian542 Jan 9, 2023
ffd3d19
Import TDHamiltonian as qml.ops.TDHamiltonian
lillian542 Jan 9, 2023
1f27249
Remove top-level import due to circular imports
lillian542 Jan 9, 2023
7375525
Reorganize H_drift and H_ts
lillian542 Jan 9, 2023
f2bcafe
Add docstrings
lillian542 Jan 9, 2023
2f6aa9c
Fix addition bug for Tensor and Observable
lillian542 Jan 9, 2023
1d93837
Update docstring
lillian542 Jan 10, 2023
15bcc0c
Rename TDHamiltonian to ParametrizedHamiltonian
lillian542 Jan 10, 2023
ba483aa
Rename file parametrized_hamiltonian.py
lillian542 Jan 10, 2023
90253be
Calling H(params, t) returns Operator instead of matrix
lillian542 Jan 10, 2023
63e63e1
Remove inheritance from Observable
lillian542 Jan 10, 2023
51b620d
Change variable names and docstring comments to reflect switch from T…
lillian542 Jan 10, 2023
ab95f94
Docstring example
lillian542 Jan 10, 2023
cd94471
Move from qubit module to math_op module
lillian542 Jan 10, 2023
d20726b
Fix bug when calling ParametrizedHamiltonian if H_fixed is None
lillian542 Jan 10, 2023
daa4336
Update __add__ method
lillian542 Jan 10, 2023
36b6b29
Add tests
lillian542 Jan 10, 2023
faa7451
Update tests_passing_pylint
lillian542 Jan 10, 2023
2aa0b33
update tests for pylint
lillian542 Jan 10, 2023
87904ae
Merge branch 'master' into time_dependent_hamiltonian
AlbertMitjans Jan 11, 2023
0061dcc
Switch from isfunction to callable in Observable.__mul__
lillian542 Jan 11, 2023
e55d15e
Return 0 instead of None if _get_terms is empty
lillian542 Jan 11, 2023
5896c77
Apply docstring suggestions from code review
lillian542 Jan 11, 2023
50e1add
Clean up based on code review suggestions
lillian542 Jan 11, 2023
78e2c99
Remove assumption that ops are Observables
lillian542 Jan 11, 2023
38aca1c
Switch from pH.H_fixed to pH.H_fixed()
lillian542 Jan 11, 2023
f972765
Support addition as Operator+ParametrizedHamiltonian
lillian542 Jan 11, 2023
1f20492
Support creating ParametrizedHamiltonian via qml.ops.dot
lillian542 Jan 11, 2023
607240c
Test for qutrit ParametrizedHamiltonian
lillian542 Jan 11, 2023
7991645
Add wires argument to __call__
lillian542 Jan 11, 2023
e97381a
Merge branch 'master' into time_dependent_hamiltonian
AlbertMitjans Jan 12, 2023
9a6a74a
Incorporate code review suggestions
lillian542 Jan 12, 2023
aefdd42
Add pwc_from_array
lillian542 Jan 16, 2023
973e9ff
Add pwc_from_function
lillian542 Jan 16, 2023
4ce9e94
Merge branch 'time_dependent_hamiltonian' into pwc_functions
lillian542 Jan 16, 2023
467afed
Examples in docstrings
lillian542 Jan 16, 2023
b8625ae
Change call signature on pwc_from_array to (dt, index)
lillian542 Jan 17, 2023
ee2586b
Change call signature on pwc_from_smooth to (dt, num_bins)
lillian542 Jan 17, 2023
227af84
Deal with jax.numpy import
lillian542 Jan 17, 2023
cb647b1
Merge branch 'master' into pwc_functions
AlbertMitjans Jan 18, 2023
f00fa93
Merge branch 'master' into pwc_functions
lillian542 Jan 19, 2023
ca6bae4
Update pennylane/operation.py
lillian542 Jan 19, 2023
fc813ea
Remove unintentional edits
lillian542 Jan 20, 2023
c496006
Remove unintended changes
lillian542 Jan 20, 2023
eb02b90
Switch arg from dt to t
lillian542 Jan 20, 2023
3ff5f14
Merge branch 'master' into pwc_functions
lillian542 Jan 20, 2023
0defffb
Move from math to pulse module
lillian542 Jan 20, 2023
803d1b6
Return 0 outside of time interval
lillian542 Jan 20, 2023
39d4585
Fix bug where first bin is twice as large as subsequent bins and exte…
lillian542 Jan 20, 2023
ac80dc8
Add tests
lillian542 Jan 20, 2023
1376235
Merge branch 'master' into pwc_functions
AlbertMitjans Jan 23, 2023
09a8665
Clarifying change in docstring
lillian542 Jan 23, 2023
f25309e
Merge branch 'pwc_functions' of github.com:PennyLaneAI/pennylane into…
lillian542 Jan 23, 2023
1673aaf
Reorganize tests
lillian542 Jan 23, 2023
240a551
Update tests
lillian542 Jan 23, 2023
9c7a63f
Merge branch 'master' into pwc_functions
lillian542 Jan 23, 2023
50e6026
changelog
lillian542 Jan 24, 2023
c0217e9
Update tests
lillian542 Jan 24, 2023
c90850b
Rename t in pwc to timespan to differentiate from t in (params, t) fr…
lillian542 Jan 24, 2023
aaa813c
Rename t in pwc to timespan to differentiate from t in (params, t) fr…
lillian542 Jan 24, 2023
8a498ec
Reinstate import tests
lillian542 Jan 24, 2023
336b5be
Update docstring
lillian542 Jan 24, 2023
85b0989
Mark as xfail for now
lillian542 Jan 24, 2023
9225e80
Update tests
lillian542 Jan 25, 2023
1baeb7e
Merge branch 'master' into pwc_functions
lillian542 Jan 25, 2023
64c4dfc
test codecov
lillian542 Jan 25, 2023
61183ea
clean up
lillian542 Jan 25, 2023
3bdb11b
Apply suggestions from code review
lillian542 Jan 25, 2023
52277ab
Merge branch 'master' into pwc_functions
lillian542 Jan 25, 2023
e5c8144
Apply suggestions from code review
lillian542 Jan 25, 2023
37db408
Move files pt1
lillian542 Jan 25, 2023
6d6fe7a
Update example
lillian542 Jan 25, 2023
89f67ad
Fix import statement
lillian542 Jan 25, 2023
6a1db1f
Use jnp.concatenate instead of appending to list
lillian542 Jan 25, 2023
efaf832
Merge branch 'master' into pwc_functions
AlbertMitjans Jan 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,53 @@
>>> H = qml.ops.dot(coeffs, ops)
```

* Added `pwc` as a convenience function for defining a `ParametrizedHamiltonian`.
This function can be used to create a callable coefficient by setting
the timespan over which the function should be non-zero. The resulting callable
can be passed an array of parameters and a time.
[(#3645)](https://github.com/PennyLaneAI/pennylane/pull/3645)

```pycon
>>> timespan = (2, 4)
>>> f = pwc(timespan)
>>> f * qml.PauliX(0)
ParametrizedHamiltonian: terms=1
```
The `params` array will be used as bin values evenly distributed over the timespan,
and the parameter `t` will determine which of the bins is returned.

```pycon
>>> f(params=[1.2, 2.3, 3.4, 4.5], t=3.9)
DeviceArray(4.5, dtype=float32)
>>> f(params=[1.2, 2.3, 3.4, 4.5], t=6) # zero outside the range (2, 4)
DeviceArray(0., dtype=float32)
lillian542 marked this conversation as resolved.
Show resolved Hide resolved
```

* Added `pwc_from_function` as a decorator for defining a `ParametrizedHamiltonian`.
lillian542 marked this conversation as resolved.
Show resolved Hide resolved
This function can be used to decorate a function and create a piecewise constant
approximation of it.
[(#3645)](https://github.com/PennyLaneAI/pennylane/pull/3645)

```pycon
@pwc_from_function(t=(2, 4), num_bins=10)
def f1(p, t):
return p * t
lillian542 marked this conversation as resolved.
Show resolved Hide resolved
```
The resulting function approximates the same of `p**2 * t` on the interval `t=(2, 4)`
in 10 bins, and returns zero outside the interval.

```pycon
#t=2 and t=2.1 are within the same bin
lillian542 marked this conversation as resolved.
Show resolved Hide resolved
>>> f1(3, 2), f1(3, 2.1)
(DeviceArray(6., dtype=float32), DeviceArray(6., dtype=float32))
# next bin
>>> f1(3, 2.2)
DeviceArray(6.6666665, dtype=float32)
# outside the interval t=(2, 4)
>>>f1(3, 5)
lillian542 marked this conversation as resolved.
Show resolved Hide resolved
DeviceArray(0., dtype=float32)
```

<h3>Improvements</h3>

* Most channels in are now fully differentiable in all interfaces.
Expand Down
2 changes: 1 addition & 1 deletion pennylane/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
shadow_expval,
)
from pennylane.ops import *
from pennylane.ops import adjoint, ctrl, exp, op_sum, pow, prod, s_prod, evolve
from pennylane.ops import adjoint, ctrl, exp, op_sum, pow, prod, s_prod, evolve, pulse
AlbertMitjans marked this conversation as resolved.
Show resolved Hide resolved
from pennylane.templates import broadcast, layer
from pennylane.templates.embeddings import *
from pennylane.templates.layers import *
Expand Down
18 changes: 18 additions & 0 deletions pennylane/ops/pulse/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2018-2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains classes and functions used in pulse programming.
"""

from .convenience_functions import pwc, pwc_from_function
149 changes: 149 additions & 0 deletions pennylane/ops/pulse/convenience_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2018-2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This file contains convenience functions for pulse programming."""

import numpy as np

has_jax = True
try:
import jax.numpy as jnp
except ImportError:
has_jax = False


def pwc(timespan):
"""Create a function that is piecewise-constant in time, based on the params for a TDHamiltonian.

Args:
timespan(Union[float, tuple(float, float)]: The timespan defining the region where the function is non-zero.
If an integer is provided, the timespan is defined as ``(0, timespan)``.

Returns:
func: a function that contains two arguments, one for the trainable parameters(array) and
one for time(int). When called, the function uses the array of parameters to create evenly sized bins
within the ``timespan``, with each bin value set by an element of the array. It then selects the value
the parameter array corresponding to the specified time.

**Example**

>>> t1, t2 = 1, 3
>>> f1 = pwc((t1, t2))

The resulting function ``f1`` has the call signature ``f1(params, t)``. If passed an array of parameters and
a time, it will assign the array as the constants in the piecewise function, and select the constant corresponding
to the specified time, based on the time interval defined by ``timespan``.

>>> params = [np.linspace(10, 20, 10)]
>>> f1(params, 2)
AlbertMitjans marked this conversation as resolved.
Show resolved Hide resolved
tensor(15.55555556, requires_grad=True)

>>> f1(params, 2.1) # same bin
tensor(15.55555556, requires_grad=True)

>>> f1(params, 2.5) # next bin
tensor(17.77777778, requires_grad=True)
"""
if not has_jax:
raise ImportError(
"Module jax is required for any pulse-related convenience function. "
"You can install jax via: pip install jax"
)

if isinstance(timespan, tuple):
t1, t2 = timespan
else:
t1 = 0
t2 = timespan

def func(params, t):
num_bins = len(params)
# include 0 as an additional option for function output
params = jnp.array(list(params) + [0])

# get idx from timestamp, then set idx=0 if idx is out of bounds for the array
idx = num_bins / (t2 - t1) * (t - t1)
idx = jnp.where((idx >= 0) & (idx <= num_bins), jnp.array(idx, dtype=int), -1)

return params[idx]

return func


def pwc_from_function(timespan, num_bins):
"""
Decorator to turn a smooth function into a piecewise constant function.

Args:
timespan(Union[float, tuple(float)]): The timespan defining the region where the function is non-zero.
If an integer is provided, the timespan is defined as ``(0, timespan)``.
num_bins(int): number of bins for time-binning the function

Returns:
a function that takes some smooth function ``f(params, t)`` and converts it to a
piecewise constant function spanning time ``t`` in `num_bins` bins.

**Example**

>>> def f0(params, t): return params[0] * t + params[1]
>>> timespan = 10
>>> num_bins = 10
>>> f1 = pwc_from_function(timespan, num_bins)(f0)
>>> f1([2, 4], 3), f0([2, 4], 3)
(DeviceArray(10.666666, dtype=float32), DeviceArray(10, dtype=int32))

>>> f1([2, 4], 3.2), f0([2, 4], 3.2)
(DeviceArray(10.666666, dtype=float32), DeviceArray(10.4, dtype=float32))

>>> f1([2, 4], 4.5), f0([2, 4], 4.5)
(DeviceArray(12.888889, dtype=float32), DeviceArray(13., dtype=float32))

# ToDo: can we include images in the docs for the version rendered for the website? Would be clearest way to illustrate

The same effect can be achieved by decorating the smooth function:

>>> @pwc_from_function(timespan, num_bins)
... def fn(params, t):
... return params[0] * t + params[1]
>>> fn([2, 4], 3)
DeviceArray(10.666666, dtype=float32)

"""
if not has_jax:
raise ImportError(
"Module jax is required for any pulse-related convenience function. "
"You can install jax via: pip install jax"
)

if isinstance(timespan, tuple):
t1, t2 = timespan
else:
t1 = 0
t2 = timespan

def inner(fn):
time_bins = np.linspace(t1, t2, num_bins)

def wrapper(params, t):
constants = jnp.array(list(fn(params, time_bins)) + [0])

idx = num_bins / (t2 - t1) * (t - t1)
# check interval is within 0 to num_bins, then cast to int, to avoid casting outcomes between -1 and 0 as 0
idx = jnp.where((idx >= 0) & (idx <= num_bins), jnp.array(idx, dtype=int), -1)

return constants[idx]

return wrapper

return inner
Loading