Skip to content

Commit

Permalink
Improve the unit display (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 authored Aug 30, 2024
1 parent 2d00f9c commit 4071263
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 46 deletions.
159 changes: 113 additions & 46 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,35 +1143,75 @@ def _assert_same_base(u1, u2):
f"But we got {u1.base} != {u1.base}.")


def _create_name(dim: Dimension, base, scale) -> str:
# def _create_name(dim: Dimension, base, scale) -> Tuple[str, bool]:
# if dim == DIMENSIONLESS:
# return f"Unit({base}^{scale})", True
# else:
# name = f"{dim}"
# # if _is_tracer(scale) or scale == 0.:
# # pass
# # else:
# # name = f"{base}^{scale} * {dim}"
# return name, False


def _find_base_unit(dim: Dimension, base, scale) -> Tuple[Optional[str], bool]:
if dim == DIMENSIONLESS:
return f"Unit({base}^{scale})"
else:
if _is_tracer(scale) or scale == 0.:
name = f"{dim}"
else:
name = f"{base}^{scale} * {dim}"
return name
u_name = f"Unit({base}^{scale})"
return u_name, False
if isinstance(base, (int, float)):
key = (dim, 0, base)
if key in _standard_units:
u_name = _standard_units[key].name
if _is_tracer(scale):
return u_name, False

return None, False


def _find_standard_unit(dim: Dimension, base, scale) -> Tuple[Optional[str], bool, bool]:
"""
:param dim:
:param base:
:param scale:
:return: Name, is full name, is dimensionless.
"""
if dim == DIMENSIONLESS:
return None, False, True
if isinstance(base, (int, float)):
if isinstance(scale, (int, float)):
key = (dim, scale, base)
if key in _standard_units:
u_name = _standard_units[key].name
return u_name, True, False
key = (dim, 0, base)
if key in _standard_units:
u_name = _standard_units[key].name
return u_name, False, False
return None, False, False


def _find_a_name(dim: Dimension, base, scale) -> Optional[str]:
def _find_a_name(dim: Dimension, base, scale) -> Tuple[Optional[str], bool]:
if dim == DIMENSIONLESS:
name = f"Unit({base}^{scale})"
return name
u_name = f"Unit({base}^{scale})"
return u_name, False

if isinstance(base, (int, float)):
if isinstance(scale, (int, float)):
key = (dim, scale, base)
if key in _standard_units:
name = _standard_units[key].name
return name
u_name = _standard_units[key].name
return u_name, True
key = (dim, 0, base)
if key in _standard_units:
name = _standard_units[key].name
u_name = _standard_units[key].name
if _is_tracer(scale):
return name
return u_name, False
else:
return f"{base}^{scale} * {name}"
return f"{base}^{scale} * {u_name}", False
return None, True


_standard_units: Dict[Tuple, 'Unit'] = {}
Expand Down Expand Up @@ -1280,7 +1320,7 @@ class Unit:
"""

__module__ = "brainunit"
__slots__ = ["_dim", "_base", "_scale", "_dispname", "_name", "iscompound"]
__slots__ = ["_dim", "_base", "_scale", "_dispname", "_name", "iscompound", "is_fullname"]
__array_priority__ = 1000

def __init__(
Expand All @@ -1291,6 +1331,7 @@ def __init__(
name: str = None,
dispname: str = None,
iscompound: bool = False,
is_fullname: bool = True
):
# The base for this unit (as the base of the exponent), i.e.
# a base of 10 means 10^3, for a "k" prefix.
Expand All @@ -1308,7 +1349,13 @@ def __init__(

# The name of this unit
if name is None:
name = _create_name(dim, base, scale)
if dim == DIMENSIONLESS:
name = f"Unit({base}^{scale})"
is_fullname = False
else:
name = dim.__repr__()
dispname = dim.__str__()
is_fullname = False
self._name = name

# The display name of this unit
Expand All @@ -1317,6 +1364,9 @@ def __init__(
# Whether this unit is a combination of other units
self.iscompound = iscompound

# whether the name is the full name
self.is_fullname = is_fullname

@property
def base(self) -> jax.typing.ArrayLike:
return self._base
Expand Down Expand Up @@ -1346,9 +1396,7 @@ def value(self) -> jax.typing.ArrayLike:

@value.setter
def value(self, value):
raise NotImplementedError(
"Cannot set the value of a Unit object."
)
raise NotImplementedError("Cannot set the value of a Unit object.")

@property
def dim(self) -> Dimension:
Expand Down Expand Up @@ -1400,6 +1448,7 @@ def copy(self):
name=self.name,
dispname=self.dispname,
iscompound=self.iscompound,
is_fullname=self.is_fullname,
)

def __deepcopy__(self, memodict):
Expand All @@ -1410,6 +1459,7 @@ def __deepcopy__(self, memodict):
name=self.name,
dispname=self.dispname,
iscompound=self.iscompound,
is_fullname=self.is_fullname,
)

def has_same_scale(self, other: 'Unit') -> bool:
Expand Down Expand Up @@ -1498,6 +1548,7 @@ def create(
base=base,
name=name,
dispname=dispname,
is_fullname=True,
)
add_standard_unit(u)
return u
Expand Down Expand Up @@ -1529,31 +1580,46 @@ def create_scaled_unit(baseunit: 'Unit', scalefactor: str) -> 'Unit':
dispname=dispname,
scale=scale,
base=baseunit.base,
is_fullname=True,
)
add_standard_unit(u)
return u

def __repr__(self):
return self.name
def __repr__(self) -> str:
if self.is_fullname:
return self.name
if self.dim.is_dimensionless:
return f'Unit({self.base}^{self.scale})'
else:
# return f'{self.base}^{self.scale} * {self.dim}'
return f'{self.base}^{self.scale} * {self.name}'

def __str__(self):
return self.dispname
def __str__(self) -> str:
if self.is_fullname:
return self.dispname
if self.dim.is_dimensionless:
return f'Unit({self.base}^{self.scale})'
else:
# return f'{self.base}^{self.scale} * {self.dim}'
return f'{self.base}^{self.scale} * {self.dispname}'

def __mul__(self, other) -> 'Unit' | Quantity:
# self * other
if isinstance(other, Unit):
_assert_same_base(self, other)
scale = self.scale + other.scale
dim = self.dim * other.dim
name = _find_a_name(dim, self.base, scale)
name, is_fullname, dimless = _find_standard_unit(dim, self.base, scale)
dispname = name
if name is None:
iscompound = False
if name is None and not dimless and not is_fullname and self.is_fullname and other.is_fullname:
name = f"{self.name} * {other.name}"
dispname = f"{self.dispname} * {other.dispname}"
iscompound = True
else:
iscompound = False
return Unit(dim, scale=scale, base=self.base, name=name, dispname=dispname, iscompound=iscompound)
is_fullname = True
return Unit(dim, scale=scale, base=self.base, name=name,
dispname=dispname, iscompound=iscompound,
is_fullname=is_fullname)

elif isinstance(other, Quantity):
return Quantity(other.mantissa, unit=(self * other.unit))
Expand Down Expand Up @@ -1582,9 +1648,10 @@ def __div__(self, other) -> 'Unit':
scale = self.scale - other.scale
dim = self.dim / other.dim
_assert_same_base(self, other)
name = _find_a_name(dim, self.base, scale)
name, is_fullname, dimless = _find_standard_unit(dim, self.base, scale)
dispname = name
if name is None:
iscompound = False
if name is None and not dimless and not is_fullname and self.is_fullname and other.is_fullname:
if self.iscompound:
dispname = f"({self.dispname})"
name = f"({self.name})"
Expand All @@ -1600,9 +1667,9 @@ def __div__(self, other) -> 'Unit':
dispname += other.dispname
name += other.name
iscompound = True
else:
iscompound = False
return Unit(dim, base=self.base, scale=scale, name=name, dispname=dispname, iscompound=iscompound)
is_fullname = True
return Unit(dim, base=self.base, scale=scale, name=name, dispname=dispname,
iscompound=iscompound, is_fullname=is_fullname)
else:
raise TypeError(f"unit {self} cannot divide by a non-unit {other}")

Expand All @@ -1611,19 +1678,19 @@ def __rdiv__(self, other) -> 'Unit' | Quantity:
if is_scalar_type(other) and other == 1:
dim = self.dim ** -1
scale = -self.scale
name = _find_a_name(dim, self.base, scale)
name, is_fullname, dimless = _find_standard_unit(dim, self.base, scale)
dispname = name
if name is None:
iscompound = False
if name is None and not dimless and not is_fullname and self.is_fullname:
if self.iscompound:
dispname = f"({self.dispname})"
name = f"({self.name})"
else:
dispname = self.dispname
name = self.name
iscompound = True
else:
iscompound = False
return Unit(dim, base=self.base, scale=scale, name=name, dispname=dispname, iscompound=iscompound)
return Unit(dim, base=self.base, scale=scale, name=name, dispname=dispname,
iscompound=iscompound, is_fullname=is_fullname)

elif isinstance(other, Unit):
return other.__div__(self)
Expand Down Expand Up @@ -1662,9 +1729,10 @@ def __pow__(self, other):
if is_scalar_type(other):
dim = self.dim ** other
scale = self.scale * other
name = _find_a_name(dim, self.base, scale)
name, is_fullname, dimless = _find_standard_unit(dim, self.base, scale)
dispname = name
if name is None:
iscompound = False
if name is None and not dimless and not is_fullname and self.is_fullname:
if self.iscompound:
dispname = f"({self.dispname})"
name = f"({self.name})"
Expand All @@ -1674,9 +1742,8 @@ def __pow__(self, other):
dispname += f"^{str(other)}"
name += f" ** {repr(other)}"
iscompound = True
else:
iscompound = False
return Unit(dim, base=self.base, scale=scale, name=name, dispname=dispname, iscompound=iscompound)
return Unit(dim, base=self.base, scale=scale, name=name, dispname=dispname,
iscompound=iscompound, is_fullname=is_fullname)
else:
raise TypeError(f"unit cannot perform an exponentiation (unit ** other) with a non-scalar, "
f"since one unit cannot contain multiple units. \n"
Expand Down Expand Up @@ -1738,7 +1805,7 @@ def __neq__(self, other) -> bool:
return not self.__eq__(other)

def __reduce__(self):
return _to_unit, (self.dim, self.scale, self.base, self.name, self.dispname, self.iscompound)
return _to_unit, (self.dim, self.scale, self.base, self.name, self.dispname, self.iscompound, self.is_fullname)


def _to_unit(*args):
Expand Down
4 changes: 4 additions & 0 deletions brainunit/_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,10 @@ def test_display(self):
assert_equal(display_in_unit(3.0 * bu.kmeter / 130.51 * bu.meter * bu.cm ** -1), '0.02298675 * 10.0^5 * meter')
assert_equal(display_in_unit(1. * bu.joule / bu.kelvin), '1. * joule / kelvin')

assert_equal(str(1. * bu.metre / ((3.0 * bu.ms) / (1. * bu.second))), '0.33333334 * kmeter')
assert_equal(str(1. * bu.metre / ((3.0 * bu.ms) / 1. * bu.second)), '0.33333334 * 10.0^3 * metre * second ** -2')
assert_equal(str((3.0 * bu.ms) / 1. * bu.second), '3. * 10.0^-3 * second2')

# def test_display2(self):
#
# @jax.jit
Expand Down

0 comments on commit 4071263

Please sign in to comment.