Skip to content

Commit

Permalink
Capturing jaxpr.consts and passing them as positional arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
PietropaoloFrisoni committed Jul 29, 2024
1 parent 64ebe82 commit 37862ee
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 11 deletions.
42 changes: 31 additions & 11 deletions pennylane/ops/op_math/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,28 +479,38 @@ def _get_cond_qfunc_prim():
cond_prim.multiple_results = True

@cond_prim.def_impl
def _(conditions, *args, jaxpr_branches):
def _(conditions, *args_and_consts, jaxpr_branches, n_consts_per_branch, n_args):

for pred, jaxpr in zip(conditions, jaxpr_branches):
args = args_and_consts[:n_args]
consts_flat = args_and_consts[n_args:]

consts_per_branch = []
start = 0
for n in n_consts_per_branch:
consts_per_branch.append(consts_flat[start : start + n])
start += n

for pred, jaxpr, consts in zip(conditions, jaxpr_branches, consts_per_branch):
if pred and jaxpr is not None:
return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
return jax.core.eval_jaxpr(jaxpr.jaxpr, consts, *args)

return ()

@cond_prim.def_abstract_eval
def _(*_, jaxpr_branches):
def _(*_, jaxpr_branches, **__):

# Index 0 corresponds to the true branch
outvals_true = jaxpr_branches[0].out_avals

for idx, jaxpr_branch in enumerate(jaxpr_branches):
if idx == 0:
continue

if outvals_true and jaxpr_branch is None:
raise ValueError(
"The false branch must be provided if the true branch returns any variables"
)
if jaxpr_branch is None:
if outvals_true:
raise ValueError(
"The false branch must be provided if the true branch returns any variables"
)
continue

Check warning on line 513 in pennylane/ops/op_math/condition.py

View check run for this annotation

Codecov / codecov/patch

pennylane/ops/op_math/condition.py#L513

Added line #L513 was not covered by tests

outvals_branch = jaxpr_branch.out_avals
branch_type = "elif" if idx < len(jaxpr_branches) - 1 else "false"
Expand Down Expand Up @@ -531,21 +541,31 @@ def new_wrapper(*args, **kwargs):
)

# We extract each condition (or predicate) from the elifs argument list
# since these are traced by JAX and are passed as positional arguments to the cond primitive
# since these are traced by JAX and are passed as positional arguments to the primitive
elifs_conditions = []
jaxpr_elifs = []

for pred, elif_fn in elifs:
elifs_conditions.append(pred)
jaxpr_elifs.append(jax.make_jaxpr(functools.partial(elif_fn, **kwargs))(*args))

jaxpr_branches = [jaxpr_true, *jaxpr_elifs, jaxpr_false]
conditions = jax.numpy.array([condition, *elifs_conditions, True])

jaxpr_branches = [jaxpr_true, *jaxpr_elifs, jaxpr_false]
jaxpr_consts = [jaxpr.consts if jaxpr is not None else () for jaxpr in jaxpr_branches]

# We need to flatten the constants since JAX does not allow
# to pass lists as positional arguments
consts_flat = [const for sublist in jaxpr_consts for const in sublist]
n_consts_per_branch = [len(consts) for consts in jaxpr_consts]

return cond_prim.bind(
conditions,
*args,
*consts_flat,
jaxpr_branches=jaxpr_branches,
n_consts_per_branch=n_consts_per_branch,
n_args=len(args),
)

return new_wrapper
42 changes: 42 additions & 0 deletions tests/capture/test_capture_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,35 @@ def false_fn_2(arg):
return qml.expval(qml.Z(0))


@qml.qnode(dev)
def circuit_with_consts(pred, arg):
"""Quantum circuit with jaxpr constants."""

# these are captured as consts
arg1 = arg
arg2 = arg + 0.2
arg3 = arg + 0.3
arg4 = arg + 0.4
arg5 = arg + 0.5
arg6 = arg + 0.6

def true_fn():
qml.RX(arg1, 0)

def false_fn():
qml.RX(arg2, 0)
qml.RX(arg3, 0)

def elif_fn1():
qml.RX(arg4, 0)
qml.RX(arg5, 0)
qml.RX(arg6, 0)

qml.cond(pred > 0, true_fn, false_fn, elifs=((pred == 0, elif_fn1),))()

return qml.expval(qml.Z(0))


class TestCondCircuits:
"""Tests for conditional quantum circuits."""

Expand Down Expand Up @@ -399,3 +428,16 @@ def test_circuit_multiple_cond(self, tmp_pred, tmp_arg, expected):
"""Test circuit with returned operators in the branches."""
result = circuit_multiple_cond(tmp_pred, tmp_arg)
assert np.allclose(result, expected), f"Expected {expected}, but got {result}"

@pytest.mark.parametrize(
"pred, arg, expected",
[
(1, 0.5, 0.87758256), # RX(0.5)
(-1, 0.5, 0.0707372), # RX(0.7) -> RX(0.8)
(0, 0.5, -0.9899925), # RX(0.9) -> RX(1.0) -> RX(1.1)
],
)
def test_circuit_consts(self, pred, arg, expected):
"""Test circuit with jaxpr constants."""
result = circuit_with_consts(pred, arg)
assert np.allclose(result, expected), f"Expected {expected}, but got {result}"

0 comments on commit 37862ee

Please sign in to comment.