From b6d74607033165f705a1e8fc394d93ad0d3ac71e Mon Sep 17 00:00:00 2001 From: youliang Date: Tue, 31 Oct 2023 23:48:06 -0700 Subject: [PATCH] bad commit for jax devices --- examples/async_learner_actor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/async_learner_actor.py b/examples/async_learner_actor.py index ea2ef19..fad227c 100644 --- a/examples/async_learner_actor.py +++ b/examples/async_learner_actor.py @@ -222,7 +222,7 @@ def stats_callback(type: str, payload: dict) -> dict: def main(_): - devices = jax.local_devices()[:1] + devices = jax.local_devices() num_devices = len(devices) sharding = jax.sharding.PositionalSharding(devices) assert FLAGS.batch_size % num_devices == 0