Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jun 11, 2024
1 parent fc2b978 commit 017eb6f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 29 deletions.
19 changes: 10 additions & 9 deletions brainunit/math/_compat_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike],
@set_module_as('brainunit.math')
def array_split(ary: Union[Quantity, bst.typing.ArrayLike],
indices_or_sections: Union[int, bst.typing.ArrayLike],
axis: Optional[int] = 0) -> list[Quantity] | list[Array]:
axis: Optional[int] = 0) -> Union[list[Quantity], list[Array]]:
'''
Split an array into multiple sub-arrays.
Expand Down Expand Up @@ -1678,7 +1678,7 @@ def f(x, *args, **kwargs):
def prod(x: Union[Quantity, bst.typing.ArrayLike],
axis: Optional[int] = None,
dtype: Optional[bst.typing.DTypeLike] = None,
out: Optional[...] = None,
out: None = None,
keepdims: Optional[bool] = False,
initial: Union[Quantity, bst.typing.ArrayLike] = None,
where: Union[Quantity, bst.typing.ArrayLike] = None,
Expand Down Expand Up @@ -1711,8 +1711,8 @@ def prod(x: Union[Quantity, bst.typing.ArrayLike],
def nanprod(x: Union[Quantity, bst.typing.ArrayLike],
axis: Optional[int] = None,
dtype: Optional[bst.typing.DTypeLike] = None,
out: Optional[...] = None,
keepdims: Optional[...] = False,
out: None = None,
keepdims: None = False,
initial: Union[Quantity, bst.typing.ArrayLike] = None,
where: Union[Quantity, bst.typing.ArrayLike] = None):
'''
Expand Down Expand Up @@ -1743,7 +1743,7 @@ def nanprod(x: Union[Quantity, bst.typing.ArrayLike],
def cumprod(x: Union[Quantity, bst.typing.ArrayLike],
axis: Optional[int] = None,
dtype: Optional[bst.typing.DTypeLike] = None,
out: Optional[...] = None) -> Union[Quantity, bst.typing.ArrayLike]:
out: None = None) -> Union[Quantity, bst.typing.ArrayLike]:
'''
Return the cumulative product of elements along a given axis.
Expand All @@ -1766,7 +1766,7 @@ def cumprod(x: Union[Quantity, bst.typing.ArrayLike],
def nancumprod(x: Union[Quantity, bst.typing.ArrayLike],
axis: Optional[int] = None,
dtype: Optional[bst.typing.DTypeLike] = None,
out: Optional[...] = None) -> Union[Quantity, bst.typing.ArrayLike]:
out: None = None) -> Union[Quantity, bst.typing.ArrayLike]:
'''
Return the cumulative product of elements along a given axis treating Not a Numbers (NaNs) as one.
Expand Down Expand Up @@ -2029,9 +2029,10 @@ def float_power(x: Union[Quantity, bst.typing.ArrayLike],
Returns:
out: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array.
'''
assert isscalar(y), f'{jnp.float_power.__name__} only supports scalar exponent'
if isinstance(y, Quantity):
assert isscalar(y), f'{jnp.float_power.__name__} only supports scalar exponent'
if isinstance(x, Quantity):
return _return_check_unitless(Quantity(jnp.float_power(x.value, y), unit=x.unit ** y.unit))
return _return_check_unitless(Quantity(jnp.float_power(x.value, y), unit=x.unit ** y))
elif isinstance(x, (jax.Array, np.ndarray)):
return jnp.float_power(x, y)
else:
Expand Down Expand Up @@ -3228,7 +3229,7 @@ def wrap_elementwise_bit_operation_binary(func):
def f(x, y, *args, **kwargs):
if isinstance(x, Quantity) or isinstance(y, Quantity):
raise ValueError(f'Expected integers, got {x} and {y}')
elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)):
elif isinstance(x, bst.typing.ArrayLike) and isinstance(y, bst.typing.ArrayLike):
return func(x, y, *args, **kwargs)
else:
raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}')
Expand Down
40 changes: 20 additions & 20 deletions brainunit/math/_compat_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_full_like(self):
self.assertTrue(jnp.all(result == 4))

q = [1, 2, 3] * bu.second
result_q = bu.math.full_like(q, 4 * bu.second)
result_q = bu.math.full_like(q, 4, unit=bu.second)
assert_quantity(result_q, jnp.full_like(jnp.array([1, 2, 3]), 4), bu.second)

def test_diag(self):
Expand All @@ -97,7 +97,7 @@ def test_diag(self):
self.assertTrue(jnp.all(result == jnp.diag(array)))

q = [1, 2, 3] * bu.second
result_q = bu.math.diag(q)
result_q = bu.math.diag(q, unit=bu.second)
assert_quantity(result_q, jnp.diag(jnp.array([1, 2, 3])), bu.second)

def test_tril(self):
Expand All @@ -107,7 +107,7 @@ def test_tril(self):
self.assertTrue(jnp.all(result == jnp.tril(array)))

q = jnp.ones((3, 3)) * bu.second
result_q = bu.math.tril(q)
result_q = bu.math.tril(q, unit=bu.second)
assert_quantity(result_q, jnp.tril(jnp.ones((3, 3))), bu.second)

def test_triu(self):
Expand All @@ -117,7 +117,7 @@ def test_triu(self):
self.assertTrue(jnp.all(result == jnp.triu(array)))

q = jnp.ones((3, 3)) * bu.second
result_q = bu.math.triu(q)
result_q = bu.math.triu(q, unit=bu.second)
assert_quantity(result_q, jnp.triu(jnp.ones((3, 3))), bu.second)

def test_empty_like(self):
Expand Down Expand Up @@ -1810,22 +1810,6 @@ def test_invert(self):
q = [0b1100] * bu.second
result_q = bu.math.invert(q)

def test_left_shift(self):
result = bu.math.left_shift(jnp.array([0b0100]), 2)
self.assertTrue(jnp.all(result == jnp.left_shift(jnp.array([0b0100]), 2)))

with pytest.raises(ValueError):
q = [0b0100] * bu.second
result_q = bu.math.left_shift(q, 2)

def test_right_shift(self):
result = bu.math.right_shift(jnp.array([0b0100]), 2)
self.assertTrue(jnp.all(result == jnp.right_shift(jnp.array([0b0100]), 2)))

with pytest.raises(ValueError):
q = [0b0100] * bu.second
result_q = bu.math.right_shift(q, 2)


class TestElementwiseBitOperationsBinary(unittest.TestCase):

Expand Down Expand Up @@ -1856,6 +1840,22 @@ def test_bitwise_xor(self):
q2 = [0b1010] * bu.second
result_q = bu.math.bitwise_xor(q1, q2)

def test_left_shift(self):
result = bu.math.left_shift(jnp.array([0b1100]), 2)
self.assertTrue(jnp.all(result == jnp.left_shift(jnp.array([0b1100]), 2)))

with pytest.raises(ValueError):
q = [0b1100] * bu.second
result_q = bu.math.left_shift(q, 2)

def test_right_shift(self):
result = bu.math.right_shift(jnp.array([0b1100]), 2)
self.assertTrue(jnp.all(result == jnp.right_shift(jnp.array([0b1100]), 2)))

with pytest.raises(ValueError):
q = [0b1100] * bu.second
result_q = bu.math.right_shift(q, 2)


class TestLogicFuncsUnary(unittest.TestCase):
def test_all(self):
Expand Down

0 comments on commit 017eb6f

Please sign in to comment.