Skip to content

Commit

Permalink
fix and enable pulse gradient with broadcasting on new device (#4620)
Browse files Browse the repository at this point in the history
**Context:**
Enabling this was missed in the DQ2 switch-over. We originally said that
`ParametrizedEvolution` should fail to apply to broadcasted states, but
`pulse_generator` and `stoch_pulse_grad` have custom handling for
certain cases, so we ought not be so decisive in `apply_operation`. That
said, if the inputted state is batched but the operator is not, then
we're in trouble.

**Description of the Change:**
- Only raise the batched-state-ParametrizedEvolution error when the op
is not itself batched
- put shots on tapes created by `pulse_generator` (missed in gradient
uses tape shots upgrade)
- change `test_jax` at the bottom of each gradient test file to use
`"backprop"` instead of `None` for the `gradient_fn`. This worked as-is
on `DefaultQubitJax` because it has a [passthru_interface
set](https://github.com/PennyLaneAI/pennylane/blob/e909e56a197cbdea772449e972a12da4931cf2b8/pennylane/interfaces/execution.py#L581),
but really it's using backprop. Maybe that boolean needs more work, but
this seemed like a niche-enough test case, so I fixed the test instead
of the code.

**Benefits:**
pulse gradients work with the new device and `use_broadcasting=True`!
  • Loading branch information
timmysilv authored and mudit2812 committed Sep 22, 2023
1 parent 240fbe2 commit 09baad0
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 26 deletions.
4 changes: 3 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
accessible by the short name `default.qubit.legacy`, or directly via `qml.devices.DefaultQubitLegacy`.
[(#4594)](https://github.com/PennyLaneAI/pennylane/pull/4594)
[(#4436)](https://github.com/PennyLaneAI/pennylane/pull/4436)
[(#4620)](https://github.com/PennyLaneAI/pennylane/pull/4620)

<h3>Improvements 🛠</h3>

Expand Down Expand Up @@ -141,9 +142,10 @@
`DefaultQubitJax` in the old API.
[(#4596)](https://github.com/PennyLaneAI/pennylane/pull/4596)

* DefaultQubit2 dispatches to a faster implementation for applying `ParameterizedEvolution` to a state
* DefaultQubit2 dispatches to a faster implementation for applying `ParametrizedEvolution` to a state
when it is more efficient to evolve the state than the operation matrix.
[(#4598)](https://github.com/PennyLaneAI/pennylane/pull/4598)
[(#4620)](https://github.com/PennyLaneAI/pennylane/pull/4620)

* `ShotAdaptiveOptimizer` has been updated to pass shots to QNode executions instead of overriding
device shots before execution. This makes it compatible with the new device API.
Expand Down
4 changes: 2 additions & 2 deletions pennylane/devices/qubit/apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,9 @@ def apply_parametrized_evolution(
):
"""Apply ParametrizedEvolution by evolving the state rather than the operator matrix
if we are operating on more than half of the subsystem"""
if is_state_batched:
if is_state_batched and op.batch_size is None:
raise RuntimeError(
"ParameterizedEvolution does not support batching, but received a batched state"
"ParameterizedEvolution does not support standard broadcasting, but received a batched state"
)

# shape(state) is static (not a tracer), we can use an if statement
Expand Down
3 changes: 2 additions & 1 deletion pennylane/gradients/pulse_generator_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ def _insert_op(tape, ops, op_idx):
ops_pre = tape.operations[:op_idx]
ops_post = tape.operations[op_idx:]
return [
qml.tape.QuantumScript(ops_pre + [insert] + ops_post, tape.measurements) for insert in ops
qml.tape.QuantumScript(ops_pre + [insert] + ops_post, tape.measurements, shots=tape.shots)
for insert in ops
]


Expand Down
2 changes: 1 addition & 1 deletion tests/devices/qubit/test_apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def test_batched_state_raises_an_error(self):
]
)

with pytest.raises(RuntimeError, match="does not support batching"):
with pytest.raises(RuntimeError, match="does not support standard broadcasting"):
_ = apply_operation(op, initial_state, is_state_batched=True)


Expand Down
17 changes: 9 additions & 8 deletions tests/gradients/core/test_pulse_generator_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,9 +965,8 @@ def test_all_zero_diff_methods_multiple_returns_tape(self):
assert np.allclose(res_pulse_gen[1][2], 0)


# TODO: add default.qubit once it supports PRNG key
@pytest.mark.jax
@pytest.mark.parametrize("dev_name", ["default.qubit.jax"])
@pytest.mark.parametrize("dev_name", ["default.qubit", "default.qubit.jax"])
class TestPulseGeneratorTape:
"""Test that differentiating tapes with ``pulse_generator`` works."""

Expand All @@ -979,7 +978,8 @@ def test_single_pulse_single_term(self, dev_name, shots, tol):
import jax.numpy as jnp

prng_key = jax.random.PRNGKey(8251)
dev = qml.device(dev_name, wires=1, shots=shots, prng_key=prng_key)
key = "prng_key" if dev_name == "default.qubit.jax" else "seed"
dev = qml.device(dev_name, wires=1, shots=shots, **{key: prng_key})

H = jnp.polyval * X(0)
x = jnp.array([0.4, 0.2, 0.1])
Expand Down Expand Up @@ -1015,7 +1015,8 @@ def test_single_pulse_multi_term(self, dev_name, shots, tol):

prng_key = jax.random.PRNGKey(8251)
dev = qml.device(dev_name, wires=1, shots=None)
dev_shots = qml.device(dev_name, wires=1, shots=shots, prng_key=prng_key)
key = "prng_key" if dev_name == "default.qubit.jax" else "seed"
dev_shots = qml.device(dev_name, wires=1, shots=shots, **{key: prng_key})

H = 0.1 * Z(0) + jnp.polyval * X(0) + qml.pulse.constant * Y(0)
x = jnp.array([0.4, 0.2, 0.1])
Expand Down Expand Up @@ -1095,7 +1096,8 @@ def test_multi_pulse(self, dev_name, shots, tol):

prng_key = jax.random.PRNGKey(8251)
dev = qml.device(dev_name, wires=1, shots=None)
dev_shots = qml.device(dev_name, wires=1, shots=shots, prng_key=prng_key)
key = "prng_key" if dev_name == "default.qubit.jax" else "seed"
dev_shots = qml.device(dev_name, wires=1, shots=shots, **{key: prng_key})

H0 = 0.1 * Z(0) + jnp.polyval * X(0)
H1 = 0.2 * Y(0) + qml.pulse.constant * Y(0) + jnp.polyval * Z(0)
Expand Down Expand Up @@ -1447,9 +1449,8 @@ def circuit(param1, param2):
assert tracker.totals["executions"] == 1 + 2 # one forward pass, two shifted tapes


# TODO: port ParametrizedEvolution to new default.qubit
@pytest.mark.jax
@pytest.mark.parametrize("dev_name", ["default.qubit.jax"])
@pytest.mark.parametrize("dev_name", ["default.qubit", "default.qubit.jax"])
class TestPulseGeneratorDiff:
"""Test that pulse_generator is differentiable, i.e. that computing
the derivative with pulse_generator is differentiable a second time,
Expand All @@ -1473,7 +1474,7 @@ def fun(params):
tape = qml.tape.QuantumScript([op], [qml.expval(Z(0))])
tape.trainable_params = [0]
_tapes, fn = pulse_generator(tape)
return fn(qml.execute(_tapes, dev, None))
return fn(qml.execute(_tapes, dev, "backprop"))

params = [jnp.array(0.4)]
p = params[0] * T
Expand Down
35 changes: 22 additions & 13 deletions tests/gradients/core/test_pulse_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,9 +1352,8 @@ def circuit(params):
assert tracker.totals["executions"] == 4 # two shifted tapes, two splitting times


# TODO: add default.qubit once it supports PRNG key
@pytest.mark.jax
@pytest.mark.parametrize("dev_name", ["default.qubit.jax"])
@pytest.mark.parametrize("dev_name", ["default.qubit", "default.qubit.jax"])
class TestStochPulseGradIntegration:
"""Test that stoch_pulse_grad integrates correctly with QNodes and ML interfaces."""

Expand All @@ -1367,7 +1366,8 @@ def test_simple_qnode_expval(self, dev_name, num_split_times, shots, tol):
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)
dev = qml.device(dev_name, wires=1, shots=shots, prng_key=jax.random.PRNGKey(74))
key = "prng_key" if dev_name == "default.qubit.jax" else "seed"
dev = qml.device(dev_name, wires=1, shots=shots, **{key: jax.random.PRNGKey(74)})
T = 0.2
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)

Expand All @@ -1394,7 +1394,8 @@ def test_simple_qnode_expval_two_evolves(self, dev_name, num_split_times, shots,
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)
dev = qml.device(dev_name, wires=1, shots=shots, prng_key=jax.random.PRNGKey(74))
key = "prng_key" if dev_name == "default.qubit.jax" else "seed"
dev = qml.device(dev_name, wires=1, shots=shots, **{key: jax.random.PRNGKey(74)})
T_x = 0.1
T_y = 0.2
ham_x = qml.pulse.constant * qml.PauliX(0)
Expand Down Expand Up @@ -1425,7 +1426,8 @@ def test_simple_qnode_probs(self, dev_name, num_split_times, shots, tol):
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)
dev = qml.device(dev_name, wires=1, shots=shots, prng_key=jax.random.PRNGKey(74))
key = "prng_key" if dev_name == "default.qubit.jax" else "seed"
dev = qml.device(dev_name, wires=1, shots=shots, **{key: jax.random.PRNGKey(74)})
T = 0.2
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)

Expand All @@ -1452,7 +1454,8 @@ def test_simple_qnode_probs_expval(self, dev_name, num_split_times, shots, tol):
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)
dev = qml.device(dev_name, wires=1, shots=shots, prng_key=jax.random.PRNGKey(74))
key = "prng_key" if dev_name == "default.qubit.jax" else "seed"
dev = qml.device(dev_name, wires=1, shots=shots, **{key: jax.random.PRNGKey(74)})
T = 0.2
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)

Expand Down Expand Up @@ -1535,8 +1538,13 @@ def ansatz(params):
)
qnode_backprop = qml.QNode(ansatz, dev, interface="jax")

grad_pulse_grad = jax.grad(qnode_pulse_grad)(params)
assert dev.num_executions == 1 + 2 * 3 * num_split_times
with qml.Tracker(dev) as tracker:
grad_pulse_grad = jax.grad(qnode_pulse_grad)(params)
assert (
tracker.totals["executions"] == (1 + 2 * 3 * num_split_times)
if dev_name == "default.qubit.jax"
else 1
)
grad_backprop = jax.grad(qnode_backprop)(params)

assert all(
Expand All @@ -1552,7 +1560,8 @@ def test_multi_return_broadcasting_multi_shots_raises(self, dev_name):

jax.config.update("jax_enable_x64", True)
shots = [100, 100]
dev = qml.device(dev_name, wires=1, shots=shots, prng_key=jax.random.PRNGKey(74))
key = "prng_key" if dev_name == "default.qubit.jax" else "seed"
dev = qml.device(dev_name, wires=1, shots=shots, **{key: jax.random.PRNGKey(74)})
T = 0.2
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)

Expand Down Expand Up @@ -1582,7 +1591,8 @@ def test_qnode_probs_expval_broadcasting(self, dev_name, num_split_times, shots,
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)
dev = qml.device(dev_name, wires=1, shots=shots, prng_key=jax.random.PRNGKey(74))
key = "prng_key" if dev_name == "default.qubit.jax" else "seed"
dev = qml.device(dev_name, wires=1, shots=shots, **{key: jax.random.PRNGKey(74)})
T = 0.2
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)

Expand Down Expand Up @@ -1757,9 +1767,8 @@ def ansatz(params):
jax.clear_caches()


# TODO: port ParametrizedEvolution to new default.qubit
@pytest.mark.jax
@pytest.mark.parametrize("dev_name", ["default.qubit.jax"])
@pytest.mark.parametrize("dev_name", ["default.qubit", "default.qubit.jax"])
class TestStochPulseGradDiff:
"""Test that stoch_pulse_grad is differentiable."""

Expand All @@ -1780,7 +1789,7 @@ def fun(params):
tape = qml.tape.QuantumScript([op], [qml.expval(qml.PauliZ(0))])
tape.trainable_params = [0]
tapes, fn = stoch_pulse_grad(tape)
return fn(qml.execute(tapes, dev, None))
return fn(qml.execute(tapes, dev, "backprop"))

params = [jnp.array(0.4)]
p = params[0] * T
Expand Down

0 comments on commit 09baad0

Please sign in to comment.