Skip to content

Commit

Permalink
[math] fix brainpy.math.ifelse bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Dec 7, 2023
1 parent 8107041 commit a3263fd
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 17 deletions.
5 changes: 4 additions & 1 deletion brainpy/_src/checkpoints/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@
import numpy as np
from jax import monitoring
from jax import process_index
from jax.experimental.array_serialization import get_tensorstore_spec, GlobalAsyncCheckpointManager # noqa
from jax.experimental.multihost_utils import sync_global_devices
try:
from jax.experimental.array_serialization import get_tensorstore_spec, GlobalAsyncCheckpointManager # noqa
except:
get_tensorstore_spec = GlobalAsyncCheckpointManager = None

try:
import msgpack
Expand Down
3 changes: 1 addition & 2 deletions brainpy/_src/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,8 +678,7 @@ def ifelse(
raise TypeError(msg)
cache_stack(tuple(branches), dyn_vars)
if current_transform_number():
return _if_else_return2(conditions, rets)

return rets[0]
branches = [_cond_transform_fun(fun, dyn_vars) for fun in branches]

code_scope = {'conditions': conditions, 'branches': branches}
Expand Down
25 changes: 25 additions & 0 deletions brainpy/_src/math/object_transform/tests/test_controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,28 @@ def f2():

self.assertTrue(f2().size == 200)

def test_grad1(self):
def F2(x):
return bm.ifelse(conditions=(x >= 10,),
branches=[lambda x: x,
lambda x: x ** 2, ],
operands=x)

self.assertTrue(bm.grad(F2)(9.0) == 18.)
self.assertTrue(bm.grad(F2)(11.0) == 1.)


def test_grad2(self):
def F3(x):
return bm.ifelse(conditions=(x >= 10, x >= 0),
branches=[lambda x: x,
lambda x: x ** 2,
lambda x: x ** 4, ],
operands=x)

self.assertTrue(bm.grad(F3)(9.0) == 18.)
self.assertTrue(bm.grad(F3)(11.0) == 1.)


class TestWhile(unittest.TestCase):
def test1(self):
Expand Down Expand Up @@ -481,3 +503,6 @@ def body(a):
file.seek(0)
out6 = file.read().strip()
self.assertTrue(out5 == out6)



7 changes: 0 additions & 7 deletions brainpy/_src/tools/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,3 @@
For more detail installation instructions, please see https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#dependency-2-jax
'''


brainpylib_install = '''
'''


7 changes: 0 additions & 7 deletions brainpy/_src/tools/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,13 @@


__all__ = [
'import_numba',
'numba_jit',
'numba_seed',
'numba_range',
'SUPPORT_NUMBA',
]


def import_numba():
if numba is None:
raise ModuleNotFoundError('Numba is needed. Please install numba through:\n\n'
'> pip install numba')
return numba


SUPPORT_NUMBA = numba is not None

Expand Down

0 comments on commit a3263fd

Please sign in to comment.