Skip to content

Commit

Permalink
upgrade the display of dimensionless
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Aug 23, 2024
1 parent 5759073 commit 25ec0d5
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 10 deletions.
42 changes: 32 additions & 10 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ def assert_quantity(
}


def _is_tracer(x):
return isinstance(x, (jax.ShapeDtypeStruct, jax.core.ShapedArray, DynamicJaxprTracer))


class Dimension:
"""
Stores the indices of the 7 basic SI unit dimension (length, mass, etc.).
Expand Down Expand Up @@ -409,7 +413,7 @@ def __ipow__(self, value):
def __eq__(self, value: 'Dimension'):
try:
return np.allclose(self._dims, value._dims)
except AttributeError:
except (AttributeError, jax.errors.TracerArrayConversionError):
# Only compare equal to another Dimensions object
return False

Expand Down Expand Up @@ -1144,26 +1148,44 @@ def _create_name(dim: Dimension, base, scale) -> str:
if dim == DIMENSIONLESS:
name = f"Unit({base ** scale})"
else:
if scale == 0.:
if _is_tracer(scale) or scale == 0.:
name = f"{dim}"
else:
name = f"{base}^{scale} * {dim}"
return name


def _find_name(dim: Dimension, base, scale) -> Tuple[str, str]:
if isinstance(base, (int, float)) and isinstance(scale, (int, float)):
key = (dim, scale, base)
if key in _standard_units:
name = _standard_units[key].name
return name, name
if isinstance(base, (int, float)):
if isinstance(scale, (int, float)):
if dim == DIMENSIONLESS:
name = f"Unit({base ** scale})"
return name, name

key = (dim, scale, base)
if key in _standard_units:
name = _standard_units[key].name
return name, name
else:
key = (dim, 0, base)
if key in _standard_units:
name = _standard_units[key].name
return name, name
name = _create_name(dim, base, scale)
return name, name


_standard_units: Dict[Tuple, 'Unit'] = {}


def add_standard_unit(u: 'Unit'):
if isinstance(u.base, (int, float)) and isinstance(u.scale, (int, float)):
key = (u.dim, u.scale, u.base)
# if key in _standard_units:
# raise ValueError(f"Unit {u} already exists: {_standard_units[key]}")
_standard_units[key] = u


class Unit:
r"""
A physical unit.
Expand Down Expand Up @@ -1478,7 +1500,7 @@ def create(
name=name,
dispname=dispname,
)
_standard_units[(dim, scale, base)] = u
add_standard_unit(u)
return u

@staticmethod
Expand Down Expand Up @@ -1509,7 +1531,7 @@ def create_scaled_unit(baseunit: 'Unit', scalefactor: str) -> 'Unit':
scale=scale,
base=baseunit.base,
)
_standard_units[(baseunit.dim, scale, baseunit.base)] = u
add_standard_unit(u)
return u

def __repr__(self):
Expand Down Expand Up @@ -2102,7 +2124,7 @@ def repr_in_unit(
'25.123 mV'
"""
value = jnp.asarray(self._mantissa)
if isinstance(value, (jax.ShapeDtypeStruct, jax.core.ShapedArray, DynamicJaxprTracer)):
if _is_tracer(value):
# in the JIT mode
s = str(value)
else:
Expand Down
13 changes: 13 additions & 0 deletions brainunit/_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,19 @@ def test_display(self):
display_in_unit(10 * nS, ohm)
assert_equal(display_in_unit(10.0, Unit(scale=1)), "1. * Unit(10.0)")


def test_display2(self):

@jax.jit
def f(s):
a = bu.ms ** s
print(a)
return a

with self.assertRaises(jax.errors.TracerBoolConversionError):
f(2)


def test_unary_operations(self):
q = Quantity(5, unit=mV)
assert_quantity(-q, -5, mV)
Expand Down
1 change: 1 addition & 0 deletions brainunit/_unit_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8146,3 +8146,4 @@
]

del base_units, scaled_units, powered_units, additional_units

0 comments on commit 25ec0d5

Please sign in to comment.