Skip to content

Commit

Permalink
Enable piece-wise-constant functions (#3645)
Browse files Browse the repository at this point in the history
* Initial draft of time dependent hamiltonian

* Allow creation of TDHamiltonian by multiplication of fn and Observable

* Import TDHamiltonian as qml.ops.TDHamiltonian

* Remove top-level import due to circular imports

* Reorganize H_drift and H_ts

* Add docstrings

* Fix addition bug for Tensor and Observable

* Update docstring

* Rename TDHamiltonian to ParametrizedHamiltonian

* Rename file parametrized_hamiltonian.py

* Calling H(params, t) returns Operator instead of matrix

* Remove inheritance from Observable

* Change variable names and docstring comments to reflect switch from Time-Dependent to Parametrized

* Docstring example

* Move from qubit module to math_op module

* Fix bug when calling ParametrizedHamiltonian if H_fixed is None

* Update __add__ method

* Add tests

* Update tests_passing_pylint

* update tests for pylint

* Switch from isfunction to callable in Observable.__mul__

* Return 0 instead of None if _get_terms is empty

* Apply docstring suggestions from code review

Co-authored-by: Albert Mitjans <a.mitjanscoma@gmail.com>
Co-authored-by: Korbinian Kottmann <43949391+Qottmann@users.noreply.github.com>

* Clean up based on code review suggestions

* Remove assumption that ops are Observables

* Switch from pH.H_fixed to pH.H_fixed()

* Support addition as Operator+ParametrizedHamiltonian

* Support creating ParametrizedHamiltonian via qml.ops.dot

* Test for qutrit ParametrizedHamiltonian

* Add wires argument to __call__

* Incorporate code review suggestions

* Add pwc_from_array

* Add pwc_from_function

* Examples in docstrings

* Change call signature on pwc_from_array to (dt, index)

* Change call signature on pwc_from_smooth to (dt, num_bins)

* Deal with jax.numpy import

* Update pennylane/operation.py

* Remove unintentional edits

* Remove unintended changes

* Switch arg from dt to t

* Move from math to pulse module

* Return 0 outside of time interval

* Fix bug where first bin is twice as large as subsequent bins and extends to before t1

* Add tests

* Clarifying change in docstring

* Reorganize tests

* Update tests

* changelog

* Update tests

* Rename t in pwc to timespan to differentiate from t in (params, t) from the resulting callable

* Rename t in pwc to timespan to differentiate from t in (params, t) from the resulting callable

* Reinstate import tests

* Update docstring

* Mark as xfail for now

* Update tests

* test codecov

* clean up

* Apply suggestions from code review

* Apply suggestions from code review

* Move files pt1

* Update example

* Fix import statement

* Use jnp.concatenate instead of appending to list

Co-authored-by: Albert Mitjans <a.mitjanscoma@gmail.com>
Co-authored-by: Korbinian Kottmann <43949391+Qottmann@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 26, 2023
1 parent d5fc880 commit 07b3e6e
Show file tree
Hide file tree
Showing 6 changed files with 649 additions and 6 deletions.
49 changes: 49 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,55 @@
4: ─────────────────╰RY(0.81)─╰SWAP─────────────────╰RY(0.06)─╰SWAP─────────────────┤ State
```


* Added `pwc` as a convenience function for defining a `ParametrizedHamiltonian`.
This function can be used to create a callable coefficient by setting
the timespan over which the function should be non-zero. The resulting callable
can be passed an array of parameters and a time.
[(#3645)](https://github.com/PennyLaneAI/pennylane/pull/3645)

```pycon
>>> timespan = (2, 4)
>>> f = pwc(timespan)
>>> f * qml.PauliX(0)
ParametrizedHamiltonian: terms=1
```
The `params` array will be used as bin values evenly distributed over the timespan,
and the parameter `t` will determine which of the bins is returned.

```pycon
>>> f(params=[1.2, 2.3, 3.4, 4.5], t=3.9)
DeviceArray(4.5, dtype=float32)
>>> f(params=[1.2, 2.3, 3.4, 4.5], t=6) # zero outside the range (2, 4)
DeviceArray(0., dtype=float32)
```

* Added `pwc_from_function` as a decorator for defining a `ParametrizedHamiltonian`.
This function can be used to decorate a function and create a piecewise constant
approximation of it.
[(#3645)](https://github.com/PennyLaneAI/pennylane/pull/3645)

```pycon
>>> @pwc_from_function(t=(2, 4), num_bins=10)
... def f1(p, t):
... return p * t
```
The resulting function approximates the same of `p**2 * t` on the interval `t=(2, 4)`
in 10 bins, and returns zero outside the interval.

```pycon
# t=2 and t=2.1 are within the same bin
>>> f1(3, 2), f1(3, 2.1)
(DeviceArray(6., dtype=float32), DeviceArray(6., dtype=float32))
# next bin
>>> f1(3, 2.2)
DeviceArray(6.6666665, dtype=float32)
# outside the interval t=(2, 4)
>>> f1(3, 5)
DeviceArray(0., dtype=float32)
```


<h3>Improvements</h3>

* `qml.purity` is added as a measurement process for purity
Expand Down
2 changes: 1 addition & 1 deletion pennylane/pulse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
This module contains classes and functions used in pulse programming.
"""

from .convenience_functions import constant, rect
from .convenience_functions import constant, rect, pwc, pwc_from_function
from .parametrized_evolution import ParametrizedEvolution
from .parametrized_hamiltonian import ParametrizedHamiltonian
130 changes: 130 additions & 0 deletions pennylane/pulse/convenience_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""This file contains convenience functions for pulse programming."""
from typing import Callable, List, Tuple, Union
import numpy as np

has_jax = True
try:
Expand Down Expand Up @@ -115,3 +116,132 @@ def f(p, t):
return _f(p, t)

return f


def pwc(timespan):
"""Creates a function that is piecewise-constant in time.
Args:
timespan(Union[float, tuple(float, float)]: The timespan defining the region where the function is non-zero.
If an integer is provided, the timespan is defined as ``(0, timespan)``.
Returns:
func: a function that takes two arguments, an array of trainable parameters and a `float` defining the
time at which the function is evaluated. When called, the function uses the array of parameters to
create evenly sized bins within the ``timespan``, with each bin value set by an element of the array.
It then selects the value of the parameter array corresponding to the specified time, based on the
assigned binning.
**Example**
>>> timespan = (1, 3)
>>> f1 = pwc(timespan)
The resulting function ``f1`` has the call signature ``f1(params, t)``. If passed an array of parameters and
a time, it will assign the array as the constants in the piecewise function, and select the constant corresponding
to the specified time, based on the time interval defined by ``timespan``.
>>> params = [10, 11, 12, 13, 14]
>>> f1(params, 2)
Array(12, dtype=int32)
>>> f1(params, 2.1) # same bin
Array(12, dtype=int32)
>>> f1(params, 2.5) # next bin
Array(13, dtype=int32)
"""
if not has_jax:
raise ImportError(
"Module jax is required for any pulse-related convenience function. "
"You can install jax via: pip install jax"
)

if isinstance(timespan, tuple):
t1, t2 = timespan
else:
t1 = 0
t2 = timespan

def func(params, t):
num_bins = len(params)
params = jnp.concatenate([jnp.array(params), jnp.zeros(1)])
# get idx from timestamp, then set idx=0 if idx is out of bounds for the array
idx = num_bins / (t2 - t1) * (t - t1)
idx = jnp.where((idx >= 0) & (idx <= num_bins), jnp.array(idx, dtype=int), -1)

return params[idx]

return func


def pwc_from_function(timespan, num_bins):
"""
Decorator to turn a smooth function into a piecewise constant function.
Args:
timespan(Union[float, tuple(float)]): The timespan defining the region where the function is non-zero.
If an integer is provided, the timespan is defined as ``(0, timespan)``.
num_bins(int): number of bins for time-binning the function
Returns:
a function that takes some smooth function ``f(params, t)`` and converts it to a
piecewise constant function spanning time ``t`` in `num_bins` bins.
**Example**
.. code-block:: python3
def smooth_function(params, t):
return params[0] * t + params[1]
timespan = 10
num_bins = 10
binned_function = pwc_from_function(timespan, num_bins)(f0)
>>> binned_function([2, 4], 3), smooth_function([2, 4], 3) # t = 3
(DeviceArray(10.666666, dtype=float32), DeviceArray(10, dtype=int32))
>>> binned_function([2, 4], 3.2), smooth_function([2, 4], 3.2) # t = 3.2
(DeviceArray(10.666666, dtype=float32), DeviceArray(10.4, dtype=float32))
>>> binned_function([2, 4], 4.5), smooth_function([2, 4], 4.5) # t = 4.5
(DeviceArray(12.888889, dtype=float32), DeviceArray(13., dtype=float32))
The same effect can be achieved by decorating the smooth function:
>>> @pwc_from_function(timespan, num_bins)
... def fn(params, t):
... return params[0] * t + params[1]
>>> fn([2, 4], 3)
DeviceArray(10.666666, dtype=float32)
"""
if not has_jax:
raise ImportError(
"Module jax is required for any pulse-related convenience function. "
"You can install jax via: pip install jax"
)

if isinstance(timespan, tuple):
t1, t2 = timespan
else:
t1 = 0
t2 = timespan

def inner(fn):
time_bins = np.linspace(t1, t2, num_bins)

def wrapper(params, t):
constants = jnp.array(list(fn(params, time_bins)) + [0])

idx = num_bins / (t2 - t1) * (t - t1)
# check interval is within 0 to num_bins, then cast to int, to avoid casting outcomes between -1 and 0 as 0
idx = jnp.where((idx >= 0) & (idx <= num_bins), jnp.array(idx, dtype=int), -1)

return constants[idx]

return wrapper

return inner
Loading

0 comments on commit 07b3e6e

Please sign in to comment.