Skip to content

Commit

Permalink
Merge pull request #19629 from jakevdp:key-reuse-pjit
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 604404276
  • Loading branch information
jax authors committed Feb 5, 2024
2 parents 206398a + f453442 commit 69a9f7f
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
7 changes: 5 additions & 2 deletions jax/experimental/key_reuse/_forwarding.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,11 @@ def _pjit_key_type_signature(eqn, args_consumed):
jaxpr = eqn.params['jaxpr']
forwarded_inputs = {i: eqn.invars.index(var) for i, var in enumerate(eqn.invars)
if var in eqn.invars[:i]}
return get_jaxpr_type_signature(jaxpr.jaxpr, consumed_inputs=args_consumed,
forwarded_inputs=forwarded_inputs)
sig = get_jaxpr_type_signature(jaxpr.jaxpr)
if args_consumed and any(np.any(args_consumed[s.idx] & s.mask) for s in sig.sinks):
# Double consumption detected: re-trace with context for better errors.
get_jaxpr_type_signature(jaxpr.jaxpr, args_consumed, forwarded_inputs)
return sig

key_reuse_signatures_dynamic[pjit.pjit_p] = _pjit_key_type_signature

Expand Down
6 changes: 5 additions & 1 deletion jax/experimental/key_reuse/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,11 @@ def _pjit_key_type_signature(eqn, args_consumed):
non_literal_invars = [v for v in eqn.invars if not isinstance(v, core.Literal)]
if len(set(non_literal_invars)) != len(non_literal_invars):
raise ValueError(f"pjit with duplicate inputs: {eqn.invars=}")
return get_jaxpr_type_signature(jaxpr.jaxpr, consumed_inputs=args_consumed)
sig = get_jaxpr_type_signature(jaxpr.jaxpr)
if args_consumed and any(np.any(args_consumed[s.idx] & s.mask) for s in sig.sinks):
# Double consumption detected: re-trace with context for better errors.
get_jaxpr_type_signature(jaxpr.jaxpr, args_consumed)
return sig

key_reuse_signatures_dynamic[pjit.pjit_p] = _pjit_key_type_signature

Expand Down
13 changes: 13 additions & 0 deletions tests/key_reuse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,7 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
random_bits_error = "In random_bits, key values .+ are already consumed.*"
random_split_error = "In random_split, key values .+ are already consumed.*"
generic_error = ".*key values .+ are already consumed.*"
pjit_error = "In pjit, key values a are already consumed."

def check_key_reuse(self, f, *args):
if self.use_forwarding:
Expand Down Expand Up @@ -782,6 +783,18 @@ def body_fun(i):
with self.assertRaisesRegex(KeyReuseError, "while_loop cond function leads to key reuse"):
self.check_key_reuse(f, 0)

def test_pjit_consumed_input(self):
@jax.jit
def g(key, x): # doesn't consume key
return x

def f(seed):
key = jax.random.key(seed)
x = jax.random.bits(key)
return g(key, x)

self.check_key_reuse(f, 0)


class KeyReuseIntegrationTestSimple(KeyReuseIntegrationTest):
use_forwarding = False
Expand Down

0 comments on commit 69a9f7f

Please sign in to comment.