diff --git a/brainunit/_base.py b/brainunit/_base.py index f1f6247..f64aace 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -37,8 +37,8 @@ # helpers 'DimensionMismatchError', 'UnitMismatchError', + 'is_dimensionless', 'is_unitless', - 'display_in_unit', 'get_dim', 'get_unit', @@ -698,7 +698,6 @@ def has_same_unit(obj1, obj2) -> bool: return unit1 == unit2 - def fail_for_dimension_mismatch( obj1, obj2=None, error_message=None, **error_arrays ): @@ -974,7 +973,7 @@ def array_with_unit( return Quantity(floatval, dim=get_or_create_dimension(dim._dims), dtype=dtype) -def is_unitless(obj: Union['Quantity', 'Unit', jax.typing.ArrayLike]) -> bool: +def is_dimensionless(obj: Union['Quantity', 'Unit', 'Dimension', jax.typing.ArrayLike]) -> bool: """ Test if a value is dimensionless or not. @@ -988,6 +987,26 @@ def is_unitless(obj: Union['Quantity', 'Unit', jax.typing.ArrayLike]) -> bool: dimensionless : `bool` ``True`` if `obj` is dimensionless. """ + if isinstance(obj, Dimension): + return obj.is_dimensionless + return _to_quantity(obj).dim.is_dimensionless + + +def is_unitless(obj: Union['Quantity', 'Unit', jax.typing.ArrayLike]) -> bool: + """ + Test if a value is unitless or not. + + Parameters + ---------- + obj : `object` + The object to check. + + Returns + ------- + unitless : `bool` + ``True`` if `obj` is unitless. + """ + assert not isinstance(obj, Dimension), f"Dimension objects are not unitless or not, but got {obj}" return _to_quantity(obj).is_unitless diff --git a/brainunit/math/_fun_accept_unitless_test.py b/brainunit/math/_fun_accept_unitless_test.py index f1f2848..bd12871 100644 --- a/brainunit/math/_fun_accept_unitless_test.py +++ b/brainunit/math/_fun_accept_unitless_test.py @@ -4,7 +4,8 @@ import brainunit as bu import brainunit.math as bm -from brainunit import meter, DimensionMismatchError, assert_quantity +from brainunit import meter, DimensionMismatchError +from brainunit._base import assert_quantity fun_accept_unitless_unary = [ 'exp', 'exp2', 'expm1', 'log', 'log10', 'log1p', 'log2', diff --git a/brainunit/math/_fun_array_creation_test.py b/brainunit/math/_fun_array_creation_test.py index 8d595b3..34894cd 100644 --- a/brainunit/math/_fun_array_creation_test.py +++ b/brainunit/math/_fun_array_creation_test.py @@ -4,7 +4,8 @@ import brainunit as bu import brainunit.math as bm -from brainunit import second, meter, DimensionMismatchError, assert_quantity +from brainunit import second, meter, DimensionMismatchError +from brainunit._base import assert_quantity fun_array_creation_given_shape = [ 'empty', 'ones', 'zeros', diff --git a/brainunit/math/_fun_change_unit_test.py b/brainunit/math/_fun_change_unit_test.py index 9dd58c4..29dc243 100644 --- a/brainunit/math/_fun_change_unit_test.py +++ b/brainunit/math/_fun_change_unit_test.py @@ -4,7 +4,8 @@ import brainunit as bu import brainunit.math as bm -from brainunit import meter, second, assert_quantity, volt, get_dim +from brainunit import meter, second, volt, get_dim +from brainunit._base import assert_quantity fun_change_unit_unary = [ 'reciprocal', 'var', 'nanvar', 'cbrt', 'square', 'sqrt', diff --git a/brainunit/math/_fun_keep_unit_test.py b/brainunit/math/_fun_keep_unit_test.py index 5e37c67..10f872f 100644 --- a/brainunit/math/_fun_keep_unit_test.py +++ b/brainunit/math/_fun_keep_unit_test.py @@ -4,7 +4,8 @@ import brainunit as bu import brainunit.math as bm -from brainunit import assert_quantity, second, meter, ms +from brainunit import second, meter, ms +from brainunit._base import assert_quantity fun_keep_unit_squence_inputs = [ 'row_stack', 'concatenate', 'stack', 'vstack', 'hstack', 'dstack', 'column_stack', 'block', 'append', diff --git a/brainunit/math/_fun_remove_unit_test.py b/brainunit/math/_fun_remove_unit_test.py index f791f14..76fecc0 100644 --- a/brainunit/math/_fun_remove_unit_test.py +++ b/brainunit/math/_fun_remove_unit_test.py @@ -7,7 +7,8 @@ import brainunit as bu import brainunit.math as bm -from brainunit import assert_quantity, DimensionMismatchError +from brainunit import DimensionMismatchError +from brainunit._base import assert_quantity fun_remove_unit_unary = [ 'signbit', 'sign',