Skip to content

Commit

Permalink
[bug] Fixing compatibility issues with jax (#691)
Browse files Browse the repository at this point in the history
* Add code for processing and computing with progress bar

* fix

* Revert "Add code for processing and computing with progress bar"

This reverts commit 12977af.

* Add JAX configuration for disabling async dispatch

* Update controls.py

* Update brainpy-changelog.md

---------

Co-authored-by: Chaoming Wang <adaduo@outlook.com>
  • Loading branch information
Routhleck and chaoming0625 authored Oct 25, 2024
1 parent a51e3a7 commit 2a5adea
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 13 deletions.
74 changes: 74 additions & 0 deletions brainpy-changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,80 @@


## brainpy>2.3.x
### Version 2.6.1
#### Breaking Changes
- Fixing compatibility issues between `numpy` and `jax`

#### What's Changed
* [doc] Add Chinese version of `operator_custom_with_cupy.ipynb` and Rename it's title by @Routhleck in https://github.com/brainpy/BrainPy/pull/659
* Fix "amsgrad" is used before being defined when initializing the AdamW optimizer by @CloudyDory in https://github.com/brainpy/BrainPy/pull/660
* fix issue #661 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/662
* fix flax RNN interoperation, fix #663 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/665
* [fix] Replace jax.experimental.host_callback with jax.pure_callback by @Routhleck in https://github.com/brainpy/BrainPy/pull/670
* [math] Update `CustomOpByNumba` to support JAX version >= 0.4.24 by @Routhleck in https://github.com/brainpy/BrainPy/pull/669
* [math] Fix `CustomOpByNumba` on `multiple_results=True` by @Routhleck in https://github.com/brainpy/BrainPy/pull/671
* [math] Implementing event-driven sparse matrix @ matrix operators by @Routhleck in https://github.com/brainpy/BrainPy/pull/613
* [math] Add getting JIT connect matrix method for `brainpy.dnn.linear` by @Routhleck in https://github.com/brainpy/BrainPy/pull/672
* [math] Add get JIT weight matrix methods(Uniform & Normal) for `brainpy.dnn.linear` by @Routhleck in https://github.com/brainpy/BrainPy/pull/673
* support `Integrator.to_math_expr()` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/674
* [bug] Replace `collections.Iterable` with `collections.abc.Iterable` by @Routhleck in https://github.com/brainpy/BrainPy/pull/677
* Fix surrogate gradient function and numpy 2.0 compatibility by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/679
* :arrow_up: Bump docker/build-push-action from 5 to 6 by @dependabot in https://github.com/brainpy/BrainPy/pull/678
* fix the incorrect verbose of `clear_name_cache()` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/681
* [bug] Fix prograss bar is not displayed and updated as expected by @Routhleck in https://github.com/brainpy/BrainPy/pull/683
* Fix autograd by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/687


**Full Changelog**: https://github.com/brainpy/BrainPy/compare/V2.6.0...V2.6.1

### Version 2.6.0

#### New Features

This release provides several new features, including:

- ``MLIR`` registered operator customization interface in ``brainpy.math.XLACustomOp``.
- Operator customization with CuPy JIT interface.
- Bug fixes.



#### What's Changed
* [doc] Fix the wrong path of more examples of `operator customized with taichi.ipynb` by @Routhleck in https://github.com/brainpy/BrainPy/pull/612
* [docs] Add colab link for documentation notebooks by @Routhleck in https://github.com/brainpy/BrainPy/pull/614
* Update requirements-doc.txt to fix doc building temporally by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/617
* [math] Rebase operator customization using MLIR registration interface by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/618
* [docs] Add kaggle link for documentation notebooks by @Routhleck in https://github.com/brainpy/BrainPy/pull/619
* update requirements by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/620
* require `brainpylib>=0.2.6` for `jax>=0.4.24` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/622
* [tools] add `brainpy.tools.compose` and `brainpy.tools.pipe` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/624
* doc hierarchy update by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/630
* Standardizing and generalizing object-oriented transformations by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/628
* fix #626 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/631
* Fix delayvar not correct in concat mode by @CloudyDory in https://github.com/brainpy/BrainPy/pull/632
* [dependency] remove hard dependency of `taichi` and `numba` by @Routhleck in https://github.com/brainpy/BrainPy/pull/635
* `clear_buffer_memory()` support clearing `array`, `compilation`, and `names` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/639
* add `brainpy.math.surrogate..Surrogate` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/638
* Enable brainpy object as pytree so that it can be applied with ``jax.jit`` etc. directly by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/625
* Fix ci by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/640
* Clean taichi AOT caches by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/643
* [ci] Fix windows pytest fatal exception by @Routhleck in https://github.com/brainpy/BrainPy/pull/644
* [math] Support more than 8 parameters of taichi gpu custom operator definition by @Routhleck in https://github.com/brainpy/BrainPy/pull/642
* Doc for ``brainpylib>=0.3.0`` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/645
* Find back updates by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/646
* Update installation instruction by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/651
* Fix delay bug by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/650
* update doc by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/652
* [math] Add new customize operators with `cupy` by @Routhleck in https://github.com/brainpy/BrainPy/pull/653
* [math] Fix taichi custom operator on gpu backend by @Routhleck in https://github.com/brainpy/BrainPy/pull/655
* update cupy operator custom doc by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/656
* version 2.6.0 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/657
* Upgrade CI by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/658

## New Contributors
* @CloudyDory made their first contribution in https://github.com/brainpy/BrainPy/pull/632

**Full Changelog**: https://github.com/brainpy/BrainPy/compare/V2.5.0...V2.6.0


### Version 2.5.0
Expand Down
5 changes: 5 additions & 0 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,8 @@

del deprecation_getattr2

# jax config
import os
os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false'
import jax
jax.config.update('jax_cpu_enable_async_dispatch', False)
2 changes: 1 addition & 1 deletion brainpy/_src/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,7 @@ def fun2scan(carry, x):
dyn_vars[k]._value = dyn_vars_data[k]
carry, results = body_fun(carry, x)
if progress_bar:
jax.pure_callback(lambda *arg: bar.update(), ())
jax.debug.callback(lambda *arg: bar.update(), ())
carry = jax.tree.map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array))
return (dyn_vars.dict_data(), carry), results

Expand Down
24 changes: 12 additions & 12 deletions examples/dynamics_simulation/hh_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,16 @@ def __init__(self, size):
self.KNa.add_elem()


# hh = HH(1)
# I, length = bp.inputs.section_input(values=[0, 5, 0],
# durations=[100, 500, 100],
# return_length=True)
# runner = bp.DSRunner(
# hh,
# monitors=['V', 'INa.p', 'INa.q', 'IK.p'],
# inputs=[hh.input, I, 'iter'],
# )
# runner.run(length)
#
# bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)
hh = HH(1)
I, length = bp.inputs.section_input(values=[0, 5, 0],
durations=[100, 500, 100],
return_length=True)
runner = bp.DSRunner(
hh,
monitors=['V', 'INa.p', 'INa.q', 'IK.p'],
inputs=[hh.input, I, 'iter'],
)
runner.run(length)

bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)

0 comments on commit 2a5adea

Please sign in to comment.