Skip to content

Commit

Permalink
standardize functions and objects to export in base.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Aug 8, 2024
1 parent eccf6d9 commit 04c8f26
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 8 deletions.
25 changes: 22 additions & 3 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
# helpers
'DimensionMismatchError',
'UnitMismatchError',
'is_dimensionless',
'is_unitless',
'display_in_unit',
'get_dim',
'get_unit',

Expand Down Expand Up @@ -698,7 +698,6 @@ def has_same_unit(obj1, obj2) -> bool:
return unit1 == unit2



def fail_for_dimension_mismatch(
obj1, obj2=None, error_message=None, **error_arrays
):
Expand Down Expand Up @@ -974,7 +973,7 @@ def array_with_unit(
return Quantity(floatval, dim=get_or_create_dimension(dim._dims), dtype=dtype)


def is_unitless(obj: Union['Quantity', 'Unit', jax.typing.ArrayLike]) -> bool:
def is_dimensionless(obj: Union['Quantity', 'Unit', 'Dimension', jax.typing.ArrayLike]) -> bool:
"""
Test if a value is dimensionless or not.
Expand All @@ -988,6 +987,26 @@ def is_unitless(obj: Union['Quantity', 'Unit', jax.typing.ArrayLike]) -> bool:
dimensionless : `bool`
``True`` if `obj` is dimensionless.
"""
if isinstance(obj, Dimension):
return obj.is_dimensionless
return _to_quantity(obj).dim.is_dimensionless


def is_unitless(obj: Union['Quantity', 'Unit', jax.typing.ArrayLike]) -> bool:
"""
Test if a value is unitless or not.
Parameters
----------
obj : `object`
The object to check.
Returns
-------
unitless : `bool`
``True`` if `obj` is unitless.
"""
assert not isinstance(obj, Dimension), f"Dimension objects are not unitless or not, but got {obj}"

return _to_quantity(obj).is_unitless

Expand Down
3 changes: 2 additions & 1 deletion brainunit/math/_fun_accept_unitless_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import brainunit as bu
import brainunit.math as bm
from brainunit import meter, DimensionMismatchError, assert_quantity
from brainunit import meter, DimensionMismatchError
from brainunit._base import assert_quantity

fun_accept_unitless_unary = [
'exp', 'exp2', 'expm1', 'log', 'log10', 'log1p', 'log2',
Expand Down
3 changes: 2 additions & 1 deletion brainunit/math/_fun_array_creation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import brainunit as bu
import brainunit.math as bm
from brainunit import second, meter, DimensionMismatchError, assert_quantity
from brainunit import second, meter, DimensionMismatchError
from brainunit._base import assert_quantity

fun_array_creation_given_shape = [
'empty', 'ones', 'zeros',
Expand Down
3 changes: 2 additions & 1 deletion brainunit/math/_fun_change_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import brainunit as bu
import brainunit.math as bm
from brainunit import meter, second, assert_quantity, volt, get_dim
from brainunit import meter, second, volt, get_dim
from brainunit._base import assert_quantity

fun_change_unit_unary = [
'reciprocal', 'var', 'nanvar', 'cbrt', 'square', 'sqrt',
Expand Down
3 changes: 2 additions & 1 deletion brainunit/math/_fun_keep_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import brainunit as bu
import brainunit.math as bm
from brainunit import assert_quantity, second, meter, ms
from brainunit import second, meter, ms
from brainunit._base import assert_quantity

fun_keep_unit_squence_inputs = [
'row_stack', 'concatenate', 'stack', 'vstack', 'hstack', 'dstack', 'column_stack', 'block', 'append',
Expand Down
3 changes: 2 additions & 1 deletion brainunit/math/_fun_remove_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

import brainunit as bu
import brainunit.math as bm
from brainunit import assert_quantity, DimensionMismatchError
from brainunit import DimensionMismatchError
from brainunit._base import assert_quantity

fun_remove_unit_unary = [
'signbit', 'sign',
Expand Down

0 comments on commit 04c8f26

Please sign in to comment.