Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 634537157
  • Loading branch information
Jake VanderPlas authored and MctxDev committed May 16, 2024
1 parent 4629b83 commit be1dbb1
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion mctx/_src/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def _broadcast_where(decision_leaf, chance_leaf):
expanded_is_decision,
decision_leaf, chance_leaf)

output = jax.tree_map(_broadcast_where,
output = jax.tree.map(_broadcast_where,
output_if_decision_node,
output_if_chance_node)
return output, new_state
Expand Down
2 changes: 1 addition & 1 deletion mctx/_src/tests/tree_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def run_policy():
invalid_actions=invalid_actions,
**tree["algorithm_config"])

policy_output = jax.jit(run_policy)()
policy_output = jax.jit(run_policy)() # pylint: disable=not-callable
logging.info("Done search.")

return tree_to_pytree(policy_output.search_tree)
Expand Down
4 changes: 2 additions & 2 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
chex>=0.0.8
jax>=0.1.55
jaxlib>=0.1.37
jax>=0.4.25
jaxlib>=0.4.25

0 comments on commit be1dbb1

Please sign in to comment.