From be1dbb1c9c48bd858a000ec7f0a33ac9a8350e7b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 16 May 2024 14:52:06 -0700 Subject: [PATCH] Replace deprecated `jax.tree_*` functions with `jax.tree.*` The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25. PiperOrigin-RevId: 634537157 --- mctx/_src/policies.py | 2 +- mctx/_src/tests/tree_test.py | 2 +- requirements/requirements.txt | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mctx/_src/policies.py b/mctx/_src/policies.py index 624d92b..06f9670 100644 --- a/mctx/_src/policies.py +++ b/mctx/_src/policies.py @@ -477,7 +477,7 @@ def _broadcast_where(decision_leaf, chance_leaf): expanded_is_decision, decision_leaf, chance_leaf) - output = jax.tree_map(_broadcast_where, + output = jax.tree.map(_broadcast_where, output_if_decision_node, output_if_chance_node) return output, new_state diff --git a/mctx/_src/tests/tree_test.py b/mctx/_src/tests/tree_test.py index 370fae3..e4a5843 100644 --- a/mctx/_src/tests/tree_test.py +++ b/mctx/_src/tests/tree_test.py @@ -202,7 +202,7 @@ def run_policy(): invalid_actions=invalid_actions, **tree["algorithm_config"]) - policy_output = jax.jit(run_policy)() + policy_output = jax.jit(run_policy)() # pylint: disable=not-callable logging.info("Done search.") return tree_to_pytree(policy_output.search_tree) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 404ac71..49fce1f 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,3 +1,3 @@ chex>=0.0.8 -jax>=0.1.55 -jaxlib>=0.1.37 +jax>=0.4.25 +jaxlib>=0.4.25