diff --git a/examples/async_learner_actor.py b/examples/async_learner_actor.py index ea2ef19..fad227c 100644 --- a/examples/async_learner_actor.py +++ b/examples/async_learner_actor.py @@ -222,7 +222,7 @@ def stats_callback(type: str, payload: dict) -> dict: def main(_): - devices = jax.local_devices()[:1] + devices = jax.local_devices() num_devices = len(devices) sharding = jax.sharding.PositionalSharding(devices) assert FLAGS.batch_size % num_devices == 0