Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable piece-wise-constant functions #3645

Merged
merged 76 commits into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
107b132
Initial draft of time dependent hamiltonian
lillian542 Jan 9, 2023
b0ab45d
Allow creation of TDHamiltonian by multiplication of fn and Observable
lillian542 Jan 9, 2023
ffd3d19
Import TDHamiltonian as qml.ops.TDHamiltonian
lillian542 Jan 9, 2023
1f27249
Remove top-level import due to circular imports
lillian542 Jan 9, 2023
7375525
Reorganize H_drift and H_ts
lillian542 Jan 9, 2023
f2bcafe
Add docstrings
lillian542 Jan 9, 2023
2f6aa9c
Fix addition bug for Tensor and Observable
lillian542 Jan 9, 2023
1d93837
Update docstring
lillian542 Jan 10, 2023
15bcc0c
Rename TDHamiltonian to ParametrizedHamiltonian
lillian542 Jan 10, 2023
ba483aa
Rename file parametrized_hamiltonian.py
lillian542 Jan 10, 2023
90253be
Calling H(params, t) returns Operator instead of matrix
lillian542 Jan 10, 2023
63e63e1
Remove inheritance from Observable
lillian542 Jan 10, 2023
51b620d
Change variable names and docstring comments to reflect switch from T…
lillian542 Jan 10, 2023
ab95f94
Docstring example
lillian542 Jan 10, 2023
cd94471
Move from qubit module to math_op module
lillian542 Jan 10, 2023
d20726b
Fix bug when calling ParametrizedHamiltonian if H_fixed is None
lillian542 Jan 10, 2023
daa4336
Update __add__ method
lillian542 Jan 10, 2023
36b6b29
Add tests
lillian542 Jan 10, 2023
faa7451
Update tests_passing_pylint
lillian542 Jan 10, 2023
2aa0b33
update tests for pylint
lillian542 Jan 10, 2023
87904ae
Merge branch 'master' into time_dependent_hamiltonian
AlbertMitjans Jan 11, 2023
0061dcc
Switch from isfunction to callable in Observable.__mul__
lillian542 Jan 11, 2023
e55d15e
Return 0 instead of None if _get_terms is empty
lillian542 Jan 11, 2023
5896c77
Apply docstring suggestions from code review
lillian542 Jan 11, 2023
50e1add
Clean up based on code review suggestions
lillian542 Jan 11, 2023
78e2c99
Remove assumption that ops are Observables
lillian542 Jan 11, 2023
38aca1c
Switch from pH.H_fixed to pH.H_fixed()
lillian542 Jan 11, 2023
f972765
Support addition as Operator+ParametrizedHamiltonian
lillian542 Jan 11, 2023
1f20492
Support creating ParametrizedHamiltonian via qml.ops.dot
lillian542 Jan 11, 2023
607240c
Test for qutrit ParametrizedHamiltonian
lillian542 Jan 11, 2023
7991645
Add wires argument to __call__
lillian542 Jan 11, 2023
e97381a
Merge branch 'master' into time_dependent_hamiltonian
AlbertMitjans Jan 12, 2023
9a6a74a
Incorporate code review suggestions
lillian542 Jan 12, 2023
aefdd42
Add pwc_from_array
lillian542 Jan 16, 2023
973e9ff
Add pwc_from_function
lillian542 Jan 16, 2023
4ce9e94
Merge branch 'time_dependent_hamiltonian' into pwc_functions
lillian542 Jan 16, 2023
467afed
Examples in docstrings
lillian542 Jan 16, 2023
b8625ae
Change call signature on pwc_from_array to (dt, index)
lillian542 Jan 17, 2023
ee2586b
Change call signature on pwc_from_smooth to (dt, num_bins)
lillian542 Jan 17, 2023
227af84
Deal with jax.numpy import
lillian542 Jan 17, 2023
cb647b1
Merge branch 'master' into pwc_functions
AlbertMitjans Jan 18, 2023
f00fa93
Merge branch 'master' into pwc_functions
lillian542 Jan 19, 2023
ca6bae4
Update pennylane/operation.py
lillian542 Jan 19, 2023
fc813ea
Remove unintentional edits
lillian542 Jan 20, 2023
c496006
Remove unintended changes
lillian542 Jan 20, 2023
eb02b90
Switch arg from dt to t
lillian542 Jan 20, 2023
3ff5f14
Merge branch 'master' into pwc_functions
lillian542 Jan 20, 2023
0defffb
Move from math to pulse module
lillian542 Jan 20, 2023
803d1b6
Return 0 outside of time interval
lillian542 Jan 20, 2023
39d4585
Fix bug where first bin is twice as large as subsequent bins and exte…
lillian542 Jan 20, 2023
ac80dc8
Add tests
lillian542 Jan 20, 2023
1376235
Merge branch 'master' into pwc_functions
AlbertMitjans Jan 23, 2023
09a8665
Clarifying change in docstring
lillian542 Jan 23, 2023
f25309e
Merge branch 'pwc_functions' of github.com:PennyLaneAI/pennylane into…
lillian542 Jan 23, 2023
1673aaf
Reorganize tests
lillian542 Jan 23, 2023
240a551
Update tests
lillian542 Jan 23, 2023
9c7a63f
Merge branch 'master' into pwc_functions
lillian542 Jan 23, 2023
50e6026
changelog
lillian542 Jan 24, 2023
c0217e9
Update tests
lillian542 Jan 24, 2023
c90850b
Rename t in pwc to timespan to differentiate from t in (params, t) fr…
lillian542 Jan 24, 2023
aaa813c
Rename t in pwc to timespan to differentiate from t in (params, t) fr…
lillian542 Jan 24, 2023
8a498ec
Reinstate import tests
lillian542 Jan 24, 2023
336b5be
Update docstring
lillian542 Jan 24, 2023
85b0989
Mark as xfail for now
lillian542 Jan 24, 2023
9225e80
Update tests
lillian542 Jan 25, 2023
1baeb7e
Merge branch 'master' into pwc_functions
lillian542 Jan 25, 2023
64c4dfc
test codecov
lillian542 Jan 25, 2023
61183ea
clean up
lillian542 Jan 25, 2023
3bdb11b
Apply suggestions from code review
lillian542 Jan 25, 2023
52277ab
Merge branch 'master' into pwc_functions
lillian542 Jan 25, 2023
e5c8144
Apply suggestions from code review
lillian542 Jan 25, 2023
37db408
Move files pt1
lillian542 Jan 25, 2023
6d6fe7a
Update example
lillian542 Jan 25, 2023
89f67ad
Fix import statement
lillian542 Jan 25, 2023
6a1db1f
Use jnp.concatenate instead of appending to list
lillian542 Jan 25, 2023
efaf832
Merge branch 'master' into pwc_functions
AlbertMitjans Jan 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,55 @@
4: ─────────────────╰RY(0.81)─╰SWAP─────────────────╰RY(0.06)─╰SWAP─────────────────┤ State
```


* Added `pwc` as a convenience function for defining a `ParametrizedHamiltonian`.
This function can be used to create a callable coefficient by setting
the timespan over which the function should be non-zero. The resulting callable
can be passed an array of parameters and a time.
[(#3645)](https://github.com/PennyLaneAI/pennylane/pull/3645)

```pycon
>>> timespan = (2, 4)
>>> f = pwc(timespan)
>>> f * qml.PauliX(0)
ParametrizedHamiltonian: terms=1
```
The `params` array will be used as bin values evenly distributed over the timespan,
and the parameter `t` will determine which of the bins is returned.

```pycon
>>> f(params=[1.2, 2.3, 3.4, 4.5], t=3.9)
DeviceArray(4.5, dtype=float32)
>>> f(params=[1.2, 2.3, 3.4, 4.5], t=6) # zero outside the range (2, 4)
DeviceArray(0., dtype=float32)
lillian542 marked this conversation as resolved.
Show resolved Hide resolved
```

* Added `pwc_from_function` as a decorator for defining a `ParametrizedHamiltonian`.
lillian542 marked this conversation as resolved.
Show resolved Hide resolved
This function can be used to decorate a function and create a piecewise constant
approximation of it.
[(#3645)](https://github.com/PennyLaneAI/pennylane/pull/3645)

```pycon
>>> @pwc_from_function(t=(2, 4), num_bins=10)
... def f1(p, t):
... return p * t
```
The resulting function approximates the same of `p**2 * t` on the interval `t=(2, 4)`
in 10 bins, and returns zero outside the interval.

```pycon
# t=2 and t=2.1 are within the same bin
>>> f1(3, 2), f1(3, 2.1)
(DeviceArray(6., dtype=float32), DeviceArray(6., dtype=float32))
# next bin
>>> f1(3, 2.2)
DeviceArray(6.6666665, dtype=float32)
# outside the interval t=(2, 4)
>>> f1(3, 5)
DeviceArray(0., dtype=float32)
```


<h3>Improvements</h3>

* `qml.purity` is added as a measurement process for purity
Expand Down
2 changes: 1 addition & 1 deletion pennylane/pulse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
This module contains classes and functions used in pulse programming.
"""

from .convenience_functions import constant, rect
from .convenience_functions import constant, rect, pwc, pwc_from_function
from .parametrized_evolution import ParametrizedEvolution
from .parametrized_hamiltonian import ParametrizedHamiltonian
130 changes: 130 additions & 0 deletions pennylane/pulse/convenience_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""This file contains convenience functions for pulse programming."""
from typing import Callable, List, Tuple, Union
import numpy as np

has_jax = True
try:
Expand Down Expand Up @@ -115,3 +116,132 @@ def f(p, t):
return _f(p, t)

return f


def pwc(timespan):
"""Creates a function that is piecewise-constant in time.

Args:
timespan(Union[float, tuple(float, float)]: The timespan defining the region where the function is non-zero.
If an integer is provided, the timespan is defined as ``(0, timespan)``.

Returns:
func: a function that takes two arguments, an array of trainable parameters and a `float` defining the
time at which the function is evaluated. When called, the function uses the array of parameters to
create evenly sized bins within the ``timespan``, with each bin value set by an element of the array.
It then selects the value of the parameter array corresponding to the specified time, based on the
assigned binning.

**Example**

>>> timespan = (1, 3)
>>> f1 = pwc(timespan)

The resulting function ``f1`` has the call signature ``f1(params, t)``. If passed an array of parameters and
a time, it will assign the array as the constants in the piecewise function, and select the constant corresponding
to the specified time, based on the time interval defined by ``timespan``.

>>> params = [10, 11, 12, 13, 14]
>>> f1(params, 2)
Array(12, dtype=int32)

>>> f1(params, 2.1) # same bin
Array(12, dtype=int32)

>>> f1(params, 2.5) # next bin
Array(13, dtype=int32)
"""
if not has_jax:
raise ImportError(
"Module jax is required for any pulse-related convenience function. "
"You can install jax via: pip install jax"
)

if isinstance(timespan, tuple):
t1, t2 = timespan
else:
t1 = 0
t2 = timespan

def func(params, t):
num_bins = len(params)
params = jnp.concatenate([jnp.array(params), jnp.zeros(1)])
# get idx from timestamp, then set idx=0 if idx is out of bounds for the array
idx = num_bins / (t2 - t1) * (t - t1)
idx = jnp.where((idx >= 0) & (idx <= num_bins), jnp.array(idx, dtype=int), -1)

return params[idx]

return func


def pwc_from_function(timespan, num_bins):
"""
Decorator to turn a smooth function into a piecewise constant function.

Args:
timespan(Union[float, tuple(float)]): The timespan defining the region where the function is non-zero.
If an integer is provided, the timespan is defined as ``(0, timespan)``.
num_bins(int): number of bins for time-binning the function

Returns:
a function that takes some smooth function ``f(params, t)`` and converts it to a
piecewise constant function spanning time ``t`` in `num_bins` bins.

**Example**

.. code-block:: python3

def smooth_function(params, t):
return params[0] * t + params[1]

timespan = 10
num_bins = 10

binned_function = pwc_from_function(timespan, num_bins)(f0)

>>> binned_function([2, 4], 3), smooth_function([2, 4], 3) # t = 3
(DeviceArray(10.666666, dtype=float32), DeviceArray(10, dtype=int32))

>>> binned_function([2, 4], 3.2), smooth_function([2, 4], 3.2) # t = 3.2
(DeviceArray(10.666666, dtype=float32), DeviceArray(10.4, dtype=float32))

>>> binned_function([2, 4], 4.5), smooth_function([2, 4], 4.5) # t = 4.5
(DeviceArray(12.888889, dtype=float32), DeviceArray(13., dtype=float32))

The same effect can be achieved by decorating the smooth function:

>>> @pwc_from_function(timespan, num_bins)
... def fn(params, t):
... return params[0] * t + params[1]
>>> fn([2, 4], 3)
DeviceArray(10.666666, dtype=float32)

"""
if not has_jax:
raise ImportError(
"Module jax is required for any pulse-related convenience function. "
"You can install jax via: pip install jax"
)

if isinstance(timespan, tuple):
t1, t2 = timespan
else:
t1 = 0
t2 = timespan

def inner(fn):
time_bins = np.linspace(t1, t2, num_bins)

def wrapper(params, t):
constants = jnp.array(list(fn(params, time_bins)) + [0])

idx = num_bins / (t2 - t1) * (t - t1)
# check interval is within 0 to num_bins, then cast to int, to avoid casting outcomes between -1 and 0 as 0
idx = jnp.where((idx >= 0) & (idx <= num_bins), jnp.array(idx, dtype=int), -1)

return constants[idx]

return wrapper

return inner
Loading