[bug] Fix prograss bar is not displayed and updated as expected #683
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
jax.pure_callback()
may not work as expected for certain use cases, such as updating a progress bar, whereasjax.debug.callback()
can be used effectively.jax.pure_callback()
is intended for pure functions without side effects. If JAX determines that the callback's result isn't used in the computation, it might optimize the callback call out, which means updates to a progress bar (a side effect) wouldn't occur.jax.debug.callback()
, on the other hand, does not assume the function is pure and always executes the callback as part of the computation. This consistency ensures that progress bar updates and other side effects happen reliably, even under JAX transformations like jit and vmap.For more detail, see https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html