Skip to content

Commit

Permalink
[ci] Fix bug in test_ndarray.py when using the latest version of JAX (
Browse files Browse the repository at this point in the history
#708)

* Update test_ndarray.py

* Update ndarray.py
  • Loading branch information
Routhleck authored Dec 16, 2024
1 parent c289edd commit a08ad48
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion brainpy/_src/math/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _check_tracer(self):
self_value = self.value
if hasattr(self_value, '_trace') and hasattr(self_value._trace.main, 'jaxpr_stack'):
if len(self_value._trace.main.jaxpr_stack) == 0:
raise RuntimeError('This Array is modified during the transformation. '
raise jax.errors.UnexpectedTracerError('This Array is modified during the transformation. '
'BrainPy only supports transformations for Variable. '
'Please declare it as a Variable.') from jax.core.escaped_tracer_error(self_value, None)
return self_value
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/math/tests/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _f(self, b):

def test_tracing(self):
print(self.f(1.))
with self.assertRaises(RuntimeError):
with self.assertRaises(jax.errors.UnexpectedTracerError):
print(self.f(bm.ones(10)))


Expand Down

0 comments on commit a08ad48

Please sign in to comment.