Skip to content

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jul 13, 2024
1 parent efe0164 commit f95d5f9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
15 changes: 8 additions & 7 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def get_unit(obj) -> Unit:
return obj.unit
except AttributeError:
if isinstance(obj, (numbers.Number, jax.Array, np.number, np.ndarray)):
return Unit(1, name='1', dispname='1')
return Unit(1)
try:
return Quantity(obj).unit
except TypeError:
Expand Down Expand Up @@ -1064,7 +1064,7 @@ def __init__(
# array value
if isinstance(value, Quantity):
self._dim = value.dim
self._unit = Unit(1, name='1', dispname='1') if unit is None else unit
self._unit = Unit(1, dim=value.dim, name=repr(value.dim), dispname=str(value.dim)) if unit is None else unit
self._value = jnp.array(value.value, dtype=dtype)
return

Expand All @@ -1088,7 +1088,7 @@ def __init__(

# unit
if unit is None:
self._unit = Unit(1, name='1', dispname='1')
self._unit = Unit(1, dim=dim, name=repr(dim), dispname=str(dim))
else:
self._unit = unit

Expand Down Expand Up @@ -1263,7 +1263,7 @@ def get_best_unit(self, *regs) -> 'Quantity':
The best unit for this `Array`.
"""
if self.is_unitless:
return Unit(1, name='1', dispname='1')
return Unit(1)
if len(regs):
for r in regs:
try:
Expand Down Expand Up @@ -1622,9 +1622,10 @@ def _binary_operation(
other_unit = get_unit(other)

new_dim = unit_operation(self.dim, other_dim)
new_unit = unit_operation(self.unit, other_unit)
new_unit = unit_operation(self.unit, other_unit) if self.unit is not None else other_unit
result = value_operation(self.value, other.value)
r = Quantity(result, dim=new_dim, unit=new_unit)
r = Quantity(result, dim=new_dim, unit=Unit(1, dim=new_unit.dim, name=new_unit.name, dispname=new_unit.dispname))

if inplace:
self.update_value(r.value)
return self
Expand Down Expand Up @@ -2890,7 +2891,7 @@ def __rdiv__(self, other):
u = Unit(
self.value,
dim=self.dim ** -1,
name=f"1 / {name}",
name=f"Unit(1) / {name}",
dispname=f"1 / {dispname}",
scale=-self.scale,
iscompound=True,
Expand Down
8 changes: 8 additions & 0 deletions brainunit/_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,6 +1497,8 @@ def test_set_default_magnitude_1():
assert q1.to_value() == 3e-3
assert q2.to_value() == 3e-6

bu.set_default_magnitude(0)


def test_set_default_magnitude_2():
bu.set_default_magnitude(-3)
Expand All @@ -1523,6 +1525,8 @@ def test_set_default_magnitude_2():
assert q1.to_value() == 3e-6
assert q2.to_value() == 3e-6

bu.set_default_magnitude(0)


def test_set_default_magnitude_3():
bu.set_default_magnitude({'s': 'm'})
Expand All @@ -1547,6 +1551,8 @@ def test_set_default_magnitude_3():
assert q1.to_value() == 3e-3
assert q2.to_value() == 3e-3

bu.set_default_magnitude(0)


def test_set_default_magnitude_4():
# volt: m=2, kg=1, s=-3, A=-1
Expand All @@ -1572,6 +1578,8 @@ def test_set_default_magnitude_4():
assert q1.to_value() == 3e3
assert q2.to_value() == 3

bu.set_default_magnitude(0)

# def test_set_default_magnitude_5():
# # volt: m=2, kg=1, s=-3, A=-1 --> (if set to milivolt) _default_magnitude = {'m': -1.5, 'kg': -3, 's': -1, 'A': -3} (WRONG!)
# bu.set_default_magnitude(-3, unit=bu.volt)
Expand Down

0 comments on commit f95d5f9

Please sign in to comment.