Skip to content

Commit

Permalink
Replace deprecated jax.experimental.host_callback with `jax.experim…
Browse files Browse the repository at this point in the history
…ental.io_callback`

Fixes #14.
  • Loading branch information
schmrlng committed Oct 10, 2024
1 parent 324ab73 commit 16b8d15
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions hj_reachability/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from flax import struct
import jax
import jax.experimental.host_callback
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -108,16 +107,30 @@ class TqdmWrapper:

def __init__(self, tqdm, reference_time, total, *args, **kwargs):
self.reference_time = reference_time
jax.experimental.host_callback.id_tap(lambda total, __: self._create_tqdm(tqdm, total, *args, **kwargs), total)
jax.experimental.io_callback(
lambda total: self._create_tqdm(tqdm, float(total), *args, **kwargs),
None,
total,
ordered=True,
)

def _create_tqdm(self, tqdm, total, *args, **kwargs):
self._tqdm = tqdm.tqdm(total=total, *args, **kwargs)

def update_to(self, n):
return jax.experimental.host_callback.id_tap(lambda n, __: self._tqdm.update(n - self._tqdm.n), n)
jax.experimental.io_callback(
lambda n: self._tqdm.update(float(n) - self._tqdm.n) and None,
None,
n,
ordered=True,
)

def close(self):
return jax.experimental.host_callback.id_tap(lambda _, __: self._tqdm.close(), None)
jax.experimental.io_callback(
lambda: self._tqdm.close(),
None,
ordered=True,
)

def __enter__(self):
return self
Expand Down

0 comments on commit 16b8d15

Please sign in to comment.