Skip to content

Commit

Permalink
upgrade brainunit.math and fix all tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Aug 9, 2024
1 parent 04c8f26 commit 09bfd9f
Show file tree
Hide file tree
Showing 14 changed files with 342 additions and 402 deletions.
9 changes: 5 additions & 4 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1619,7 +1619,7 @@ def __imod__(self, other):

def __eq__(self, other) -> bool:
if isinstance(other, Unit):
return (other.dim == self.dim) and (other.scale == self.scale)
return (other.dim == self.dim) and (other.scale == self.scale) and (other.base == self.base)
else:
return False

Expand Down Expand Up @@ -1844,9 +1844,10 @@ def to_decimal_num(self, unit: Unit) -> jax.typing.ArrayLike:
The decimal number of the quantity based on the given unit.
"""
assert isinstance(unit, Unit), f"Expected a Unit, but got {unit}."
assert unit.has_same_dim(self.unit), (f"Cannot convert to the decimal number using a unit with different "
f"dimensions. The quantity has the unit {self.unit}, but the given "
f"unit is {unit}")
if not unit.has_same_dim(self.unit):
raise UnitMismatchError(f"Cannot convert to the decimal number using a unit with different "
f"dimensions. The quantity has the unit {self.unit}, but the given "
f"unit is {unit}")
if not unit.has_same_scale(self.unit):
return self._mantissa * (self.unit.value / unit.value)
else:
Expand Down
38 changes: 20 additions & 18 deletions brainunit/math/_einops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ._fun_array_creation import asarray, zeros_like
from ._fun_keep_unit import reshape, transpose, expand_dims, tile, where, squeeze
from ._misc import shape
from .._base import Quantity, DIMENSIONLESS
from .._base import Quantity, UNITLESS
from .._misc import set_module_as

T = TypeVar('T')
Expand Down Expand Up @@ -793,12 +793,12 @@ def _partition_list(bs: Sequence[bool], l: Sequence[T]) -> tuple[list[T], list[T
def _sum(x, axes, preferred_element_type):
if jax.dtypes.result_type(x, preferred_element_type) != x.dtype:
x = x.astype(preferred_element_type)
dim = None
unit = UNITLESS
if isinstance(x, Quantity):
dim = x.dim
x = x.value
unit = x.unit
x = x.mantissa
x = jax.lax.reduce(x, np.array(0, x.dtype), jax.lax.add if x.dtype != jnp.bool_ else jax.lax.bitwise_or, axes)
return Quantity(x, dim=dim) if dim is not None else x
return x if unit.is_unitless else Quantity(x, unit=unit)


def _sum_uniques(operand, names, uniques, preferred_element_type):
Expand All @@ -816,22 +816,24 @@ def _dot_general(
precision: jax.lax.PrecisionLike = None,
preferred_element_type: jax.typing.DTypeLike | None = None
) -> jax.Array | Quantity:
dim = DIMENSIONLESS
unit = UNITLESS
if isinstance(lhs, Quantity):
dim = dim * lhs.dim
lhs = lhs.value
unit = unit * lhs.unit
lhs = lhs.mantissa
if isinstance(rhs, Quantity):
dim = dim * rhs.dim
rhs = rhs.value
r = jax.lax.dot_general(lhs,
rhs,
dimension_numbers=dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type)
if dim == DIMENSIONLESS:
unit = unit * rhs.unit
rhs = rhs.mantissa
r = jax.lax.dot_general(
lhs,
rhs,
dimension_numbers=dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type
)
if unit.is_unitless:
return r
else:
return Quantity(r, dim=dim)
return Quantity(r, unit=unit)


def _delta(dtype: jax.typing.DTypeLike, shape_, axes: Sequence[int]) -> jax.Array:
Expand Down Expand Up @@ -979,7 +981,7 @@ def _einsum(
ret = operands[0].mantissa if isinstance(operands[0], Quantity) else operands[0]
ret = jax.lax.convert_element_type(ret, preferred_element_type)
if isinstance(operands[0], Quantity):
return Quantity(ret, dim=operands[0].dim)
return Quantity(ret, unit=operands[0].unit)
return ret


Expand Down
21 changes: 6 additions & 15 deletions brainunit/math/_einops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import brainunit as bu
from brainunit.math._einops import einrearrange, einreduce, einrepeat, _enumerate_directions
from brainunit.math._einops_parsing import EinopsError
from brainunit._base import assert_quantity

REDUCTIONS = ("min", "max", "sum", "mean", "prod")

Expand Down Expand Up @@ -347,16 +348,6 @@ def test_reduction_imperatives_booleans():
assert np.array_equal(expected_result_all, bu.math.as_numpy(res_all))


def assert_quantity(q, values, unit=None):
values = jnp.asarray(values)
if unit is None:
assert jnp.allclose(q, values), f"Values do not match: {q.mantissa} != {values}"
return
else:
assert bu.have_same_dim(q.dim, unit), f"Dimension mismatch: ({q}) ({unit})"
if not jnp.allclose(q.mantissa, values):
raise AssertionError(f"Values do not match: {q.mantissa} != {values}")


def test_einsum():
a = jnp.array([1, 2, 3])
Expand Down Expand Up @@ -411,30 +402,30 @@ def test_einsum():
q2 = np.random.rand(2, 3) * bu.second
q3 = np.random.rand(5) * bu.kilogram
result = bu.math.einsum('ab,ab,c->', q1, q2, q3)
expected = jnp.einsum('ab,ab,c->', q1.value, q2.value, q3.value)
expected = jnp.einsum('ab,ab,c->', q1.mantissa, q2.mantissa, q3.mantissa)
assert_quantity(result, expected, bu.meter * bu.second * bu.kilogram)

# Case 'ab,cd,ef->abcdef'
q1 = np.random.rand(2, 3) * bu.meter
q2 = np.random.rand(4, 5) * bu.second
q3 = np.random.rand(6, 7) * bu.kilogram
result = bu.math.einsum('ab,cd,ef->abcdef', q1, q2, q3)
expected = jnp.einsum('ab,cd,ef->abcdef', q1.value, q2.value, q3.value)
expected = jnp.einsum('ab,cd,ef->abcdef', q1.mantissa, q2.mantissa, q3.mantissa)
assert_quantity(result, expected, bu.meter * bu.second * bu.kilogram)

# Case 'eb,cb,fb->cef'
q1 = np.random.rand(8, 2) * bu.meter
q2 = np.random.rand(6, 2) * bu.second
q3 = np.random.rand(5, 2) * bu.kilogram
result = bu.math.einsum('eb,cb,fb->cef', q1, q2, q3)
expected = jnp.einsum('eb,cb,fb->cef', q1.value, q2.value, q3.value)
expected = jnp.einsum('eb,cb,fb->cef', q1.mantissa, q2.mantissa, q3.mantissa)
assert_quantity(result, expected, bu.meter * bu.second * bu.kilogram)

# Case 'ab,ab'
q1 = np.random.rand(2, 3) * bu.meter
q2 = np.random.rand(2, 3) * bu.second
result = bu.math.einsum('ab,ab', q1, q2)
expected = jnp.einsum('ab,ab', q1.value, q2.value)
expected = jnp.einsum('ab,ab', q1.mantissa, q2.mantissa)
assert_quantity(result, expected, bu.meter * bu.second)

# Case 'aab,fa,df,ecc->bde'
Expand All @@ -443,7 +434,7 @@ def test_einsum():
q3 = np.random.rand(4, 5) * bu.kilogram
q4 = np.random.rand(2, 3, 3) * bu.ampere
result = bu.math.einsum('aab,fa,df,ecc->bde', q1, q2, q3, q4)
expected = jnp.einsum('aab,fa,df,ecc->bde', q1.value, q2.value, q3.value, q4.value)
expected = jnp.einsum('aab,fa,df,ecc->bde', q1.mantissa, q2.mantissa, q3.mantissa, q4.mantissa)
print()
print(result)
print(expected)
Expand Down
53 changes: 14 additions & 39 deletions brainunit/math/_fun_accept_unitless.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
# limitations under the License.
# ==============================================================================
from __future__ import annotations

from typing import (Union, Optional, Tuple, Any, Callable)
from typing import Union, Optional, Tuple, Any, Callable

import jax
import jax.numpy as jnp

from .._base import (Quantity, Unit, fail_for_dimension_mismatch, fail_for_unit_mismatch)
from .._base import Quantity, Unit
from .._misc import set_module_as

__all__ = [
Expand Down Expand Up @@ -62,13 +61,7 @@ def _fun_accept_unitless_unary(
x = x.mantissa
return func(x, *args, **kwargs)
else:
fail_for_unit_mismatch(
x,
unit_to_scale,
error_message="Unit mismatch: {value} != {unit_to_scale}",
value=x,
unit_to_scale=unit_to_scale
)
assert isinstance(unit_to_scale, Unit), f'unit_to_scale should be a Unit instance. Got {unit_to_scale}'
return func(x.to_decimal_num(unit_to_scale), *args, **kwargs)
else:
assert unit_to_scale is None, f'Unit should be None for the function "{func}" when "x" is not a Quantity.'
Expand Down Expand Up @@ -765,17 +758,11 @@ def _fun_accept_unitless_return_keep_unit(
if unit_to_scale is None:
assert x.is_unitless, (f'Input should be unitless for the function "{func}" '
f'when scaling "unit_to_scale" is not provided.')
x = x.value
x = x.mantissa
return func(x, *args, **kwargs)
else:
fail_for_dimension_mismatch(
x,
unit_to_scale,
error_message="Unit mismatch: {value} != {unit_to_scale}",
value=x,
unit_to_scale=unit_to_scale
)
r = func(x.to_value(unit_to_scale), *args, **kwargs)
assert isinstance(unit_to_scale, Unit), f'unit_to_scale should be a Unit instance. Got {unit_to_scale}'
r = func(x.to_decimal_num(unit_to_scale), *args, **kwargs)
return jax.tree.map(lambda a: a * unit_to_scale, r)
else:
# assert unit_to_scale is None, f'Unit should be None for the function "{func}" when "x" is not a Quantity.'
Expand Down Expand Up @@ -1002,30 +989,18 @@ def _fun_accept_unitless_binary(
if unit_to_scale is None:
assert x.is_unitless, (f'Input should be unitless for the function "{func}" '
f'when scaling "unit_to_scale" is not provided.')
x = x.value
x = x.mantissa
else:
fail_for_dimension_mismatch(
x,
unit_to_scale,
error_message="Unit mismatch: {value} != {unit_to_scale}",
value=x,
unit_to_scale=unit_to_scale
)
x = x.to_value(unit_to_scale)
assert isinstance(unit_to_scale, Unit), f'unit_to_scale should be a Unit instance. Got {unit_to_scale}'
x = x.to_decimal_num(unit_to_scale)
if isinstance(y, Quantity):
if unit_to_scale is None:
assert y.is_unitless, (f'Input should be unitless for the function "{func}" '
f'when scaling "unit_to_scale" is not provided.')
y = y.value
y = y.mantissa
else:
fail_for_dimension_mismatch(
y,
unit_to_scale,
error_message="Unit mismatch: {value} != {unit_to_scale}",
value=y,
unit_to_scale=unit_to_scale
)
y = y.to_value(unit_to_scale)
assert isinstance(unit_to_scale, Unit), f'unit_to_scale should be a Unit instance. Got {unit_to_scale}'
y = y.to_decimal_num(unit_to_scale)
return func(x, y, *args, **kwargs)


Expand Down Expand Up @@ -1379,10 +1354,10 @@ def invert(
def _fun_unitless_binary(func, x, y, *args, **kwargs):
if isinstance(x, Quantity):
assert x.is_unitless, f'Expected unitless array, got {x}'
x = x.value
x = x.mantissa
if isinstance(y, Quantity):
assert y.is_unitless, f'Expected unitless array, got {y}'
y = y.value
y = y.mantissa
return func(x, y, *args, **kwargs)


Expand Down
14 changes: 7 additions & 7 deletions brainunit/math/_fun_accept_unitless_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import brainunit as bu
import brainunit.math as bm
from brainunit import meter, DimensionMismatchError
from brainunit import meter
from brainunit._base import assert_quantity

fun_accept_unitless_unary = [
Expand Down Expand Up @@ -61,13 +61,13 @@ def test_fun_accept_unitless_unary_1(self, value):
(bu.nA, bu.amp)]:
q = value * unit
result = fun(q, unit_to_scale=unit2scale)
expected = jnp_fun(q / unit2scale)
expected = jnp_fun(q.to_decimal_num(unit2scale))
assert_quantity(result, expected)

with pytest.raises(AssertionError):
result = fun(q)

with pytest.raises(DimensionMismatchError):
with pytest.raises(bu.UnitMismatchError):
result = fun(q, unit_to_scale=bu.nS)

@parameterized.product(
Expand Down Expand Up @@ -97,7 +97,7 @@ def test_fun_accept_unitless_binary_2_results(self, value):
with pytest.raises(AssertionError):
result1, result2 = bm_fun(q)

with pytest.raises(DimensionMismatchError):
with pytest.raises(bu.UnitMismatchError):
result1, result2 = bm_fun(q, unit_to_scale=bu.second)

@parameterized.product(
Expand All @@ -119,13 +119,13 @@ def test_func_accept_unitless_binary(self, value):
q1 = value1 * meter
q2 = value2 * meter
result = bm_fun(q1, q2, unit_to_scale=bu.dametre)
expected = jnp_fun(jnp.array(value1) / bu.dametre.mantissa, jnp.array(value2) / bu.dametre.mantissa)
expected = jnp_fun(q1.to_decimal_num(bu.dametre), q2.to_decimal_num(bu.dametre))
assert_quantity(result, expected)

with pytest.raises(AssertionError):
result = bm_fun(q1, q2)

with pytest.raises(DimensionMismatchError):
with pytest.raises(bu.UnitMismatchError):
result = bm_fun(q1, q2, unit_to_scale=bu.second)

@parameterized.product(
Expand Down Expand Up @@ -175,7 +175,7 @@ def test_fun_accept_unitless_unary_can_return_quantity(self, value):
with pytest.raises(AssertionError):
result = bm_fun(q)

with pytest.raises(DimensionMismatchError):
with pytest.raises(bu.UnitMismatchError):
result = bm_fun(q, unit_to_scale=bu.second)

@parameterized.product(
Expand Down
Loading

0 comments on commit 09bfd9f

Please sign in to comment.