Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable piece-wise-constant functions #3645

Merged
merged 76 commits into from
Jan 26, 2023
Merged

Enable piece-wise-constant functions #3645

merged 76 commits into from
Jan 26, 2023

Conversation

lillian542
Copy link
Contributor

@lillian542 lillian542 commented Jan 17, 2023

Context:
We want to add a feature that users can use to do something like

H = pwc * H1 + pwc * H2
..
qml.Evolve(H)(params, ..)

I.e. provide a pwc that is a callable and gets thus recognized by the dunder method of operators to create ParametrizedHamiltonian.

One of the problems is that it needs information about the duration and/or how many samples there are such that it can assign based on the float parameter t in the ODE integrator the corresponding function value.

Description of the Change:
Two functions added in pennylane.math.utils:

  • pwc: takes t and returns a pwc function with call signature f(params, t) that returns a value from params[index] based on t and the interval dt.
  • pwc_from_function : takes t and num_bins, and decorates a smooth function to return a piecewise constant function with call signature f(params, t).

lillian542 and others added 30 commits January 9, 2023 16:24
Co-authored-by: Albert Mitjans <a.mitjanscoma@gmail.com>
Co-authored-by: Korbinian Kottmann <43949391+Qottmann@users.noreply.github.com>
Copy link
Contributor

@AlbertMitjans AlbertMitjans left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to approve once these comments have been addressed!

doc/releases/changelog-dev.md Outdated Show resolved Hide resolved
doc/releases/changelog-dev.md Outdated Show resolved Hide resolved
doc/releases/changelog-dev.md Outdated Show resolved Hide resolved
pennylane/__init__.py Outdated Show resolved Hide resolved
pennylane/ops/pulse/convenience_functions.py Outdated Show resolved Hide resolved
Copy link
Contributor

@Qottmann Qottmann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, great work @lillian542 !!! I have one nitpicky performance concern:

I noticed a regression between the self-written wrapper and qml.pulse.pwc, I used the following benchmark (comparing changing pwc to pwc2 in the envelope). While the execution is the same, the compilation time increases significantly (up to factor 2). I think I already found the cause of the problem. Can you change it and re-run the benchmark? If jit-time and exec time are both more or less the same I am happy to approve.

I am not entirely sure the behavior of outputting 0 for times out of bound is good, because this can lead to the function "working" by outputting zero instead of raising an error when it is incorrectly used. Additionally the jnp.where business for the index might be another source for the regression. If a user wants pwc to only work in a certain window they should use rect. Happy to hear a third opinion @trbromley . Not a hill I would die on, but I'd like to open the discussion one more time.

import pennylane as qml
import pennylane.numpy as np
import jax.numpy as jnp
import jax

from jax.experimental.ode import odeint as jaxodeint
from functools import partial

import matplotlib.pyplot as plt

from datetime import datetime
data = qml.data.load("qchem", molname="H2O", basis="STO-3G", bondlength=1.9)[0]
H_obj = data.hamiltonian
n_wires = len(H_obj.wires)
omega = jnp.ones(n_wires)
g = jnp.ones(n_wires)

H_D = qml.ops.dot(omega, [qml.PauliZ(i) @ qml.PauliZ(i) for i in range(n_wires)])
H_D += qml.ops.dot(g, [qml.PauliY(i) @ qml.PauliY((i+1)%n_wires) + qml.PauliY((i+1)%n_wires) @ qml.PauliY(i) for i in range(n_wires)])
# TODO use official convenience functions once merged
def pwc2(duration):
    def wrapped(params, t):
        N = len(params)
        idx = jnp.array(N/(duration) * t, dtype=int) # corresponding sample
        return params[idx]

    return wrapped

from pennylane.pulse import pwc

def envelope(duration):
    # assuming p = (len(t_bins) + 1) for the frequency nu
    def wrapped(p, t):
        return 0.02*pwc(duration)(p[:-1], t) * jnp.cos(p[-1]*t)
    return wrapped

duration = 20.

fs = [envelope(duration) for i in range(n_wires)]
ops = [qml.PauliX(i) for i in range(n_wires)]

H_C = qml.ops.dot(fs, ops)

H_pulse = H_D + H_C
t_bins = 40 # number of time bins
theta = jnp.array([jnp.ones(t_bins + 1, dtype=float) for _ in range(n_wires)])

dev = qml.device("default.qubit", wires=range(n_wires))

ts = jnp.linspace(0., duration, t_bins)

@jax.jit
@qml.qnode(dev, interface="jax")
def qnode(theta, t=ts):
    qml.BasisState(data.hf_state, wires=H_obj.wires)
    qml.evolve(H_pulse)(params=theta, t=t)
    return qml.expval(H_obj)
value_and_grad = jax.jit(jax.value_and_grad(qnode, argnums=0))
time0 = datetime.now()
val, grad = jax.block_until_ready(value_and_grad(theta))
time1 = datetime.now()
print(f"grad and val compilation time: {time1 - time0}")
# run in notebook
%timeit jax.block_until_ready(value_and_grad(theta))

pennylane/pulse/convenience_functions.py Outdated Show resolved Hide resolved
Copy link
Contributor

@Qottmann Qottmann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All concerns resolved, thanks for the quick fix @lillian542 🚅

(For the record: Lillian made a good point about negative indexing, and I also realized that combining it with rect actually wouldnt work in the way I thought lambda p, t: pwc(duration)(p, t) * rect(duration2)(p, t))

Copy link
Contributor

@AlbertMitjans AlbertMitjans left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job!!

@lillian542 lillian542 merged commit 07b3e6e into master Jan 26, 2023
@lillian542 lillian542 deleted the pwc_functions branch January 26, 2023 15:08
mudit2812 pushed a commit that referenced this pull request Apr 13, 2023
* Initial draft of time dependent hamiltonian

* Allow creation of TDHamiltonian by multiplication of fn and Observable

* Import TDHamiltonian as qml.ops.TDHamiltonian

* Remove top-level import due to circular imports

* Reorganize H_drift and H_ts

* Add docstrings

* Fix addition bug for Tensor and Observable

* Update docstring

* Rename TDHamiltonian to ParametrizedHamiltonian

* Rename file parametrized_hamiltonian.py

* Calling H(params, t) returns Operator instead of matrix

* Remove inheritance from Observable

* Change variable names and docstring comments to reflect switch from Time-Dependent to Parametrized

* Docstring example

* Move from qubit module to math_op module

* Fix bug when calling ParametrizedHamiltonian if H_fixed is None

* Update __add__ method

* Add tests

* Update tests_passing_pylint

* update tests for pylint

* Switch from isfunction to callable in Observable.__mul__

* Return 0 instead of None if _get_terms is empty

* Apply docstring suggestions from code review

Co-authored-by: Albert Mitjans <a.mitjanscoma@gmail.com>
Co-authored-by: Korbinian Kottmann <43949391+Qottmann@users.noreply.github.com>

* Clean up based on code review suggestions

* Remove assumption that ops are Observables

* Switch from pH.H_fixed to pH.H_fixed()

* Support addition as Operator+ParametrizedHamiltonian

* Support creating ParametrizedHamiltonian via qml.ops.dot

* Test for qutrit ParametrizedHamiltonian

* Add wires argument to __call__

* Incorporate code review suggestions

* Add pwc_from_array

* Add pwc_from_function

* Examples in docstrings

* Change call signature on pwc_from_array to (dt, index)

* Change call signature on pwc_from_smooth to (dt, num_bins)

* Deal with jax.numpy import

* Update pennylane/operation.py

* Remove unintentional edits

* Remove unintended changes

* Switch arg from dt to t

* Move from math to pulse module

* Return 0 outside of time interval

* Fix bug where first bin is twice as large as subsequent bins and extends to before t1

* Add tests

* Clarifying change in docstring

* Reorganize tests

* Update tests

* changelog

* Update tests

* Rename t in pwc to timespan to differentiate from t in (params, t) from the resulting callable

* Rename t in pwc to timespan to differentiate from t in (params, t) from the resulting callable

* Reinstate import tests

* Update docstring

* Mark as xfail for now

* Update tests

* test codecov

* clean up

* Apply suggestions from code review

* Apply suggestions from code review

* Move files pt1

* Update example

* Fix import statement

* Use jnp.concatenate instead of appending to list

Co-authored-by: Albert Mitjans <a.mitjanscoma@gmail.com>
Co-authored-by: Korbinian Kottmann <43949391+Qottmann@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants