Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix and enable pulse gradient with broadcasting on new device #4620

Merged
merged 4 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
accessible by the short name `default.qubit.legacy`, or directly via `qml.devices.DefaultQubitLegacy`.
[(#4594)](https://github.com/PennyLaneAI/pennylane/pull/4594)
[(#4436)](https://github.com/PennyLaneAI/pennylane/pull/4436)
[(#4620)](https://github.com/PennyLaneAI/pennylane/pull/4620)

<h3>Improvements 🛠</h3>

Expand Down Expand Up @@ -94,9 +95,10 @@
`DefaultQubitJax` in the old API.
[(#4596)](https://github.com/PennyLaneAI/pennylane/pull/4596)

* DefaultQubit2 dispatches to a faster implementation for applying `ParameterizedEvolution` to a state
* DefaultQubit2 dispatches to a faster implementation for applying `ParametrizedEvolution` to a state
when it is more efficient to evolve the state than the operation matrix.
[(#4598)](https://github.com/PennyLaneAI/pennylane/pull/4598)
[(#4620)](https://github.com/PennyLaneAI/pennylane/pull/4620)

* `ShotAdaptiveOptimizer` has been updated to pass shots to QNode executions instead of overriding
device shots before execution. This makes it compatible with the new device API.
Expand Down
4 changes: 2 additions & 2 deletions pennylane/devices/qubit/apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,9 @@ def apply_parametrized_evolution(
):
"""Apply ParametrizedEvolution by evolving the state rather than the operator matrix
if we are operating on more than half of the subsystem"""
if is_state_batched:
if is_state_batched and op.batch_size is None:
raise RuntimeError(
"ParameterizedEvolution does not support batching, but received a batched state"
"ParameterizedEvolution does not support standard broadcasting, but received a batched state"
)

# shape(state) is static (not a tracer), we can use an if statement
Expand Down
3 changes: 2 additions & 1 deletion pennylane/gradients/pulse_generator_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ def _insert_op(tape, ops, op_idx):
ops_pre = tape.operations[:op_idx]
ops_post = tape.operations[op_idx:]
return [
qml.tape.QuantumScript(ops_pre + [insert] + ops_post, tape.measurements) for insert in ops
qml.tape.QuantumScript(ops_pre + [insert] + ops_post, tape.measurements, shots=tape.shots)
for insert in ops
]


Expand Down
2 changes: 1 addition & 1 deletion tests/devices/qubit/test_apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def test_batched_state_raises_an_error(self):
]
)

with pytest.raises(RuntimeError, match="does not support batching"):
with pytest.raises(RuntimeError, match="does not support standard broadcasting"):
_ = apply_operation(op, initial_state, is_state_batched=True)


Expand Down
17 changes: 9 additions & 8 deletions tests/gradients/core/test_pulse_generator_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,9 +965,8 @@ def test_all_zero_diff_methods_multiple_returns_tape(self):
assert np.allclose(res_pulse_gen[1][2], 0)


# TODO: add default.qubit once it supports PRNG key
@pytest.mark.jax
@pytest.mark.parametrize("dev_name", ["default.qubit.jax"])
@pytest.mark.parametrize("dev_name", ["default.qubit", "default.qubit.jax"])
class TestPulseGeneratorTape:
"""Test that differentiating tapes with ``pulse_generator`` works."""

Expand All @@ -979,7 +978,8 @@ def test_single_pulse_single_term(self, dev_name, shots, tol):
import jax.numpy as jnp

prng_key = jax.random.PRNGKey(8251)
dev = qml.device(dev_name, wires=1, shots=shots, prng_key=prng_key)
key = "prng_key" if dev_name == "default.qubit.jax" else "seed"
dev = qml.device(dev_name, wires=1, shots=shots, **{key: prng_key})

H = jnp.polyval * X(0)
x = jnp.array([0.4, 0.2, 0.1])
Expand Down Expand Up @@ -1015,7 +1015,8 @@ def test_single_pulse_multi_term(self, dev_name, shots, tol):

prng_key = jax.random.PRNGKey(8251)
dev = qml.device(dev_name, wires=1, shots=None)
dev_shots = qml.device(dev_name, wires=1, shots=shots, prng_key=prng_key)
key = "prng_key" if dev_name == "default.qubit.jax" else "seed"
dev_shots = qml.device(dev_name, wires=1, shots=shots, **{key: prng_key})

H = 0.1 * Z(0) + jnp.polyval * X(0) + qml.pulse.constant * Y(0)
x = jnp.array([0.4, 0.2, 0.1])
Expand Down Expand Up @@ -1095,7 +1096,8 @@ def test_multi_pulse(self, dev_name, shots, tol):

prng_key = jax.random.PRNGKey(8251)
dev = qml.device(dev_name, wires=1, shots=None)
dev_shots = qml.device(dev_name, wires=1, shots=shots, prng_key=prng_key)
key = "prng_key" if dev_name == "default.qubit.jax" else "seed"
dev_shots = qml.device(dev_name, wires=1, shots=shots, **{key: prng_key})

H0 = 0.1 * Z(0) + jnp.polyval * X(0)
H1 = 0.2 * Y(0) + qml.pulse.constant * Y(0) + jnp.polyval * Z(0)
Expand Down Expand Up @@ -1447,9 +1449,8 @@ def circuit(param1, param2):
assert tracker.totals["executions"] == 1 + 2 # one forward pass, two shifted tapes


# TODO: port ParametrizedEvolution to new default.qubit
@pytest.mark.jax
@pytest.mark.parametrize("dev_name", ["default.qubit.jax"])
@pytest.mark.parametrize("dev_name", ["default.qubit", "default.qubit.jax"])
class TestPulseGeneratorDiff:
"""Test that pulse_generator is differentiable, i.e. that computing
the derivative with pulse_generator is differentiable a second time,
Expand All @@ -1473,7 +1474,7 @@ def fun(params):
tape = qml.tape.QuantumScript([op], [qml.expval(Z(0))])
tape.trainable_params = [0]
_tapes, fn = pulse_generator(tape)
return fn(qml.execute(_tapes, dev, None))
return fn(qml.execute(_tapes, dev, "backprop"))

params = [jnp.array(0.4)]
p = params[0] * T
Expand Down
35 changes: 22 additions & 13 deletions tests/gradients/core/test_pulse_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,9 +1352,8 @@ def circuit(params):
assert tracker.totals["executions"] == 4 # two shifted tapes, two splitting times


# TODO: add default.qubit once it supports PRNG key
@pytest.mark.jax
@pytest.mark.parametrize("dev_name", ["default.qubit.jax"])
@pytest.mark.parametrize("dev_name", ["default.qubit", "default.qubit.jax"])
class TestStochPulseGradIntegration:
"""Test that stoch_pulse_grad integrates correctly with QNodes and ML interfaces."""

Expand All @@ -1367,7 +1366,8 @@ def test_simple_qnode_expval(self, dev_name, num_split_times, shots, tol):
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)
dev = qml.device(dev_name, wires=1, shots=shots, prng_key=jax.random.PRNGKey(74))
key = "prng_key" if dev_name == "default.qubit.jax" else "seed"
dev = qml.device(dev_name, wires=1, shots=shots, **{key: jax.random.PRNGKey(74)})
T = 0.2
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)

Expand All @@ -1394,7 +1394,8 @@ def test_simple_qnode_expval_two_evolves(self, dev_name, num_split_times, shots,
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)
dev = qml.device(dev_name, wires=1, shots=shots, prng_key=jax.random.PRNGKey(74))
key = "prng_key" if dev_name == "default.qubit.jax" else "seed"
dev = qml.device(dev_name, wires=1, shots=shots, **{key: jax.random.PRNGKey(74)})
T_x = 0.1
T_y = 0.2
ham_x = qml.pulse.constant * qml.PauliX(0)
Expand Down Expand Up @@ -1425,7 +1426,8 @@ def test_simple_qnode_probs(self, dev_name, num_split_times, shots, tol):
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)
dev = qml.device(dev_name, wires=1, shots=shots, prng_key=jax.random.PRNGKey(74))
key = "prng_key" if dev_name == "default.qubit.jax" else "seed"
dev = qml.device(dev_name, wires=1, shots=shots, **{key: jax.random.PRNGKey(74)})
T = 0.2
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)

Expand All @@ -1452,7 +1454,8 @@ def test_simple_qnode_probs_expval(self, dev_name, num_split_times, shots, tol):
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)
dev = qml.device(dev_name, wires=1, shots=shots, prng_key=jax.random.PRNGKey(74))
key = "prng_key" if dev_name == "default.qubit.jax" else "seed"
dev = qml.device(dev_name, wires=1, shots=shots, **{key: jax.random.PRNGKey(74)})
T = 0.2
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)

Expand Down Expand Up @@ -1535,8 +1538,13 @@ def ansatz(params):
)
qnode_backprop = qml.QNode(ansatz, dev, interface="jax")

grad_pulse_grad = jax.grad(qnode_pulse_grad)(params)
assert dev.num_executions == 1 + 2 * 3 * num_split_times
with qml.Tracker(dev) as tracker:
grad_pulse_grad = jax.grad(qnode_pulse_grad)(params)
assert (
tracker.totals["executions"] == (1 + 2 * 3 * num_split_times)
if dev_name == "default.qubit.jax"
else 1
)
grad_backprop = jax.grad(qnode_backprop)(params)

assert all(
Expand All @@ -1552,7 +1560,8 @@ def test_multi_return_broadcasting_multi_shots_raises(self, dev_name):

jax.config.update("jax_enable_x64", True)
shots = [100, 100]
dev = qml.device(dev_name, wires=1, shots=shots, prng_key=jax.random.PRNGKey(74))
key = "prng_key" if dev_name == "default.qubit.jax" else "seed"
dev = qml.device(dev_name, wires=1, shots=shots, **{key: jax.random.PRNGKey(74)})
T = 0.2
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)

Expand Down Expand Up @@ -1582,7 +1591,8 @@ def test_qnode_probs_expval_broadcasting(self, dev_name, num_split_times, shots,
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)
dev = qml.device(dev_name, wires=1, shots=shots, prng_key=jax.random.PRNGKey(74))
key = "prng_key" if dev_name == "default.qubit.jax" else "seed"
dev = qml.device(dev_name, wires=1, shots=shots, **{key: jax.random.PRNGKey(74)})
T = 0.2
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)

Expand Down Expand Up @@ -1757,9 +1767,8 @@ def ansatz(params):
jax.clear_caches()


# TODO: port ParametrizedEvolution to new default.qubit
@pytest.mark.jax
@pytest.mark.parametrize("dev_name", ["default.qubit.jax"])
@pytest.mark.parametrize("dev_name", ["default.qubit", "default.qubit.jax"])
class TestStochPulseGradDiff:
"""Test that stoch_pulse_grad is differentiable."""

Expand All @@ -1780,7 +1789,7 @@ def fun(params):
tape = qml.tape.QuantumScript([op], [qml.expval(qml.PauliZ(0))])
tape.trainable_params = [0]
tapes, fn = stoch_pulse_grad(tape)
return fn(qml.execute(tapes, dev, None))
return fn(qml.execute(tapes, dev, "backprop"))

params = [jnp.array(0.4)]
p = params[0] * T
Expand Down
Loading