Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Aug 23, 2024
1 parent a0ea781 commit e5633d8
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 64 deletions.
64 changes: 2 additions & 62 deletions brainunit/math/_fun_accept_unitless_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,15 @@
'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan',
'tanh',
]
fun_accept_unitless_unary_2_results = [
'modf', 'frexp',
]

fun_accept_unitless_binary = [
'hypot', 'arctan2', 'logaddexp', 'logaddexp2',
'corrcoef', 'correlate', 'cov',
]
fun_accept_unitless_binary_ldexp = [
'ldexp',
]
fun_accept_unitless_unary_can_return_quantity = [
'round', 'around', 'round_', 'rint',
'floor', 'ceil', 'trunc', 'fix',
]

fun_elementwise_bit_operation_unary = [
'bitwise_not', 'invert',
]
Expand Down Expand Up @@ -70,36 +65,6 @@ def test_fun_accept_unitless_unary_1(self, value):
with pytest.raises(bu.UnitMismatchError):
result = fun(q, unit_to_scale=bu.nS)

@parameterized.product(
value=[(1.123, 2.567, 3.891), (1.23, 2.34, 3.45)]
)
def test_fun_accept_unitless_binary_2_results(self, value):
bm_fun_list = [getattr(bm, fun) for fun in fun_accept_unitless_unary_2_results]
jnp_fun_list = [getattr(jnp, fun) for fun in fun_accept_unitless_unary_2_results]

for bm_fun, jnp_fun in zip(bm_fun_list, jnp_fun_list):
print(f'fun: {bm_fun.__name__}')
result1, result2 = bm_fun(jnp.array(value))
expected1, expected2 = jnp_fun(jnp.array(value))
assert_quantity(result1, expected1)
assert_quantity(result2, expected2)

q = value * meter
result1, result2 = bm_fun(q, unit_to_scale=meter)
expected1, expected2 = jnp_fun(jnp.array(value))
if bm_fun.__name__ == 'modf':
assert_quantity(result1, expected1, meter)
assert_quantity(result2, expected2, meter)
else:
assert_quantity(result1, expected1)
assert_quantity(result2, expected2)

with pytest.raises(AssertionError):
result1, result2 = bm_fun(q)

with pytest.raises(bu.UnitMismatchError):
result1, result2 = bm_fun(q, unit_to_scale=bu.second)

@parameterized.product(
value=[[(1.0, 2.0), (3.0, 4.0), ],
[(1.23, 2.34, 3.45), (4.56, 5.67, 6.78)]]
Expand Down Expand Up @@ -153,31 +118,6 @@ def test_func_accept_unitless_binary_ldexp(self, value):
with pytest.raises(AssertionError):
result = bm_fun(q1, q2)

@parameterized.product(
value=[(1.123, 2.567, 3.891), (1.23, 2.34, 3.45)]
)
def test_fun_accept_unitless_unary_can_return_quantity(self, value):
bm_fun_list = [getattr(bm, fun) for fun in fun_accept_unitless_unary_can_return_quantity]
jnp_fun_list = [getattr(jnp, fun) for fun in fun_accept_unitless_unary_can_return_quantity]

for bm_fun, jnp_fun in zip(bm_fun_list, jnp_fun_list):
print(f'fun: {bm_fun.__name__}')

result = bm_fun(jnp.array(value))
expected = jnp_fun(jnp.array(value))
assert_quantity(result, expected)

q = value * meter
result = bm_fun(q, unit_to_scale=meter)
expected = jnp_fun(jnp.array(value))
assert_quantity(result, expected, meter)

with pytest.raises(AssertionError):
result = bm_fun(q)

with pytest.raises(bu.UnitMismatchError):
result = bm_fun(q, unit_to_scale=bu.second)

@parameterized.product(
value=[(1, 2), (1, 2, 3)]
)
Expand Down
4 changes: 3 additions & 1 deletion brainunit/math/_fun_keep_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3500,4 +3500,6 @@ def modf(
-------
The fractional and integral parts of the input, both with the same dimension.
"""
return _fun_keep_unit_unary(jnp.modf, x)
if isinstance(x, Quantity):
return jax.tree.map(lambda y: Quantity(y, unit=x.unit), jnp.modf(x.mantissa))
return jnp.modf(x)
52 changes: 51 additions & 1 deletion brainunit/math/_fun_keep_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
'nanmin', 'nanmax', 'ptp', 'average', 'mean', 'std',
'nanmedian', 'nanmean', 'nanstd', 'diff', 'nan_to_num',
]
fun_accept_unitless_unary_can_return_quantity = [
'round', 'around', 'round_', 'rint',
'floor', 'ceil', 'trunc', 'fix',
]
fun_keep_unit_math_binary = [
'fmod', 'mod', 'remainder',
'maximum', 'minimum', 'fmax', 'fmin',
Expand All @@ -49,6 +53,9 @@
fun_keep_unit_math_unary_misc = [
'trace', 'lcm', 'gcd', 'copysign', 'rot90', 'intersect1d',
]
fun_accept_unitless_unary_2_results = [
'modf',
]


class TestFunKeepUnitSquenceInputs(parameterized.TestCase):
Expand Down Expand Up @@ -503,7 +510,7 @@ def test_extract(self):
a = array * bu.second
result_q = bu.math.extract(q > 1, a)
expected_q = jnp.extract(q > 1, jnp.array([1, 2, 3])) * bu.second
assert bu.math.allclose(result_q , expected_q)
assert bu.math.allclose(result_q, expected_q)

def test_take(self):
array = jnp.array([4, 3, 5, 7, 6, 8])
Expand Down Expand Up @@ -682,6 +689,49 @@ def test_fun_keep_unit_quantile(self, value, q, unit):
expected = jnp_fun(jnp.array(value), q)
assert_quantity(result, expected, unit=unit)

@parameterized.product(
value=[(1.123, 2.567, 3.891), (1.23, 2.34, 3.45)]
)
def test_fun_accept_unitless_binary_2_results(self, value):
bm_fun_list = [getattr(bm, fun) for fun in fun_accept_unitless_unary_2_results]
jnp_fun_list = [getattr(jnp, fun) for fun in fun_accept_unitless_unary_2_results]

for fun in fun_accept_unitless_unary_2_results:
bm_fun = getattr(bm, fun)
jnp_fun = getattr(jnp, fun)

print(f'fun: {bm_fun.__name__}')
result1, result2 = bm_fun(jnp.array(value))
expected1, expected2 = jnp_fun(jnp.array(value))
assert_quantity(result1, expected1)
assert_quantity(result2, expected2)

for unit in [meter, ms]:
q = value * unit
result1, result2 = bm_fun(q)
expected1, expected2 = jnp_fun(jnp.array(value))
assert_quantity(result1, expected1, unit)
assert_quantity(result2, expected2, unit)

@parameterized.product(
value=[(1.123, 2.567, 3.891), (1.23, 2.34, 3.45)]
)
def test_fun_accept_unitless_unary_can_return_quantity(self, value):
for fun in fun_accept_unitless_unary_can_return_quantity:
bm_fun = getattr(bm, fun)
jnp_fun = getattr(jnp, fun)

print(f'fun: {bm_fun.__name__}')
result = bm_fun(jnp.array(value))
expected = jnp_fun(jnp.array(value))
assert_quantity(result, expected)

for unit in [meter, ms]:
q = value * unit
result = bm_fun(q)
expected = jnp_fun(jnp.array(value))
assert_quantity(result, expected, unit)


class TestFunKeepUnitMathFunMisc(parameterized.TestCase):
def test_trace(self):
Expand Down

0 comments on commit e5633d8

Please sign in to comment.