-
Notifications
You must be signed in to change notification settings - Fork 604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable piece-wise-constant functions #3645
Conversation
…ime-Dependent to Parametrized
Co-authored-by: Albert Mitjans <a.mitjanscoma@gmail.com> Co-authored-by: Korbinian Kottmann <43949391+Qottmann@users.noreply.github.com>
…om the resulting callable
…om the resulting callable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to approve once these comments have been addressed!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, great work @lillian542 !!! I have one nitpicky performance concern:
I noticed a regression between the self-written wrapper and qml.pulse.pwc
, I used the following benchmark (comparing changing pwc
to pwc2
in the envelope
). While the execution is the same, the compilation time increases significantly (up to factor 2). I think I already found the cause of the problem. Can you change it and re-run the benchmark? If jit-time and exec time are both more or less the same I am happy to approve.
I am not entirely sure the behavior of outputting 0
for times out of bound is good, because this can lead to the function "working" by outputting zero instead of raising an error when it is incorrectly used. Additionally the jnp.where
business for the index might be another source for the regression. If a user wants pwc
to only work in a certain window they should use rect
. Happy to hear a third opinion @trbromley . Not a hill I would die on, but I'd like to open the discussion one more time.
import pennylane as qml
import pennylane.numpy as np
import jax.numpy as jnp
import jax
from jax.experimental.ode import odeint as jaxodeint
from functools import partial
import matplotlib.pyplot as plt
from datetime import datetime
data = qml.data.load("qchem", molname="H2O", basis="STO-3G", bondlength=1.9)[0]
H_obj = data.hamiltonian
n_wires = len(H_obj.wires)
omega = jnp.ones(n_wires)
g = jnp.ones(n_wires)
H_D = qml.ops.dot(omega, [qml.PauliZ(i) @ qml.PauliZ(i) for i in range(n_wires)])
H_D += qml.ops.dot(g, [qml.PauliY(i) @ qml.PauliY((i+1)%n_wires) + qml.PauliY((i+1)%n_wires) @ qml.PauliY(i) for i in range(n_wires)])
# TODO use official convenience functions once merged
def pwc2(duration):
def wrapped(params, t):
N = len(params)
idx = jnp.array(N/(duration) * t, dtype=int) # corresponding sample
return params[idx]
return wrapped
from pennylane.pulse import pwc
def envelope(duration):
# assuming p = (len(t_bins) + 1) for the frequency nu
def wrapped(p, t):
return 0.02*pwc(duration)(p[:-1], t) * jnp.cos(p[-1]*t)
return wrapped
duration = 20.
fs = [envelope(duration) for i in range(n_wires)]
ops = [qml.PauliX(i) for i in range(n_wires)]
H_C = qml.ops.dot(fs, ops)
H_pulse = H_D + H_C
t_bins = 40 # number of time bins
theta = jnp.array([jnp.ones(t_bins + 1, dtype=float) for _ in range(n_wires)])
dev = qml.device("default.qubit", wires=range(n_wires))
ts = jnp.linspace(0., duration, t_bins)
@jax.jit
@qml.qnode(dev, interface="jax")
def qnode(theta, t=ts):
qml.BasisState(data.hf_state, wires=H_obj.wires)
qml.evolve(H_pulse)(params=theta, t=t)
return qml.expval(H_obj)
value_and_grad = jax.jit(jax.value_and_grad(qnode, argnums=0))
time0 = datetime.now()
val, grad = jax.block_until_ready(value_and_grad(theta))
time1 = datetime.now()
print(f"grad and val compilation time: {time1 - time0}")
# run in notebook
%timeit jax.block_until_ready(value_and_grad(theta))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All concerns resolved, thanks for the quick fix @lillian542 🚅
(For the record: Lillian made a good point about negative indexing, and I also realized that combining it with rect
actually wouldnt work in the way I thought lambda p, t: pwc(duration)(p, t) * rect(duration2)(p, t)
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good job!!
* Initial draft of time dependent hamiltonian * Allow creation of TDHamiltonian by multiplication of fn and Observable * Import TDHamiltonian as qml.ops.TDHamiltonian * Remove top-level import due to circular imports * Reorganize H_drift and H_ts * Add docstrings * Fix addition bug for Tensor and Observable * Update docstring * Rename TDHamiltonian to ParametrizedHamiltonian * Rename file parametrized_hamiltonian.py * Calling H(params, t) returns Operator instead of matrix * Remove inheritance from Observable * Change variable names and docstring comments to reflect switch from Time-Dependent to Parametrized * Docstring example * Move from qubit module to math_op module * Fix bug when calling ParametrizedHamiltonian if H_fixed is None * Update __add__ method * Add tests * Update tests_passing_pylint * update tests for pylint * Switch from isfunction to callable in Observable.__mul__ * Return 0 instead of None if _get_terms is empty * Apply docstring suggestions from code review Co-authored-by: Albert Mitjans <a.mitjanscoma@gmail.com> Co-authored-by: Korbinian Kottmann <43949391+Qottmann@users.noreply.github.com> * Clean up based on code review suggestions * Remove assumption that ops are Observables * Switch from pH.H_fixed to pH.H_fixed() * Support addition as Operator+ParametrizedHamiltonian * Support creating ParametrizedHamiltonian via qml.ops.dot * Test for qutrit ParametrizedHamiltonian * Add wires argument to __call__ * Incorporate code review suggestions * Add pwc_from_array * Add pwc_from_function * Examples in docstrings * Change call signature on pwc_from_array to (dt, index) * Change call signature on pwc_from_smooth to (dt, num_bins) * Deal with jax.numpy import * Update pennylane/operation.py * Remove unintentional edits * Remove unintended changes * Switch arg from dt to t * Move from math to pulse module * Return 0 outside of time interval * Fix bug where first bin is twice as large as subsequent bins and extends to before t1 * Add tests * Clarifying change in docstring * Reorganize tests * Update tests * changelog * Update tests * Rename t in pwc to timespan to differentiate from t in (params, t) from the resulting callable * Rename t in pwc to timespan to differentiate from t in (params, t) from the resulting callable * Reinstate import tests * Update docstring * Mark as xfail for now * Update tests * test codecov * clean up * Apply suggestions from code review * Apply suggestions from code review * Move files pt1 * Update example * Fix import statement * Use jnp.concatenate instead of appending to list Co-authored-by: Albert Mitjans <a.mitjanscoma@gmail.com> Co-authored-by: Korbinian Kottmann <43949391+Qottmann@users.noreply.github.com>
Context:
We want to add a feature that users can use to do something like
I.e. provide a pwc that is a callable and gets thus recognized by the dunder method of operators to create
ParametrizedHamiltonian
.One of the problems is that it needs information about the duration and/or how many samples there are such that it can assign based on the float parameter
t
in the ODE integrator the corresponding function value.Description of the Change:
Two functions added in
pennylane.math.utils
:pwc
: takest
and returns a pwc function with call signaturef(params, t)
that returns a value fromparams[index]
based ont
and the intervaldt
.pwc_from_function
: takest
andnum_bins
, and decorates a smooth function to return a piecewise constant function with call signaturef(params, t)
.