diff --git a/brainunit/_base.py b/brainunit/_base.py index 40b34e2..71cf3f8 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -28,6 +28,8 @@ from jax.interpreters.partial_eval import DynamicJaxprTracer from jax.tree_util import register_pytree_node_class +from ._misc import set_module_as + __all__ = [ # three base objects 'Dimension', @@ -116,6 +118,7 @@ def get_dim_for_display(d): return str(get_dim(d)) +@set_module_as('brainunit') def assert_quantity( q: Union['Quantity', jax.typing.ArrayLike], mantissa: jax.typing.ArrayLike, @@ -424,6 +427,7 @@ def __deepcopy__(self, memodict): return self +@set_module_as('brainunit') def get_or_create_dimension(*args, **kwds) -> Dimension: """ Create a new Dimension object or get a reference to an existing one. @@ -600,6 +604,7 @@ def __str__(self): return s +@set_module_as('brainunit') def get_dim(obj) -> Dimension: """ Return the dimension of any object that has them. @@ -630,6 +635,7 @@ def get_dim(obj) -> Dimension: raise TypeError(f"Object of type {type(obj)} does not have a dim") +@set_module_as('brainunit') def get_unit(obj) -> Unit: """ Return the unit of any object that has them. @@ -655,6 +661,7 @@ def get_unit(obj) -> Unit: raise TypeError(f"Object of type {type(obj)} does not have a unit") +@set_module_as('brainunit') def have_same_dim(obj1, obj2) -> bool: """Test if two values have the same dimensions. @@ -679,6 +686,7 @@ def have_same_dim(obj1, obj2) -> bool: return (dim1 is dim2) or (dim1 == dim2) +@set_module_as('brainunit') def has_same_unit(obj1, obj2) -> bool: """ Check whether two objects have the same unit. @@ -698,6 +706,7 @@ def has_same_unit(obj1, obj2) -> bool: return unit1 == unit2 +@set_module_as('brainunit') def fail_for_dimension_mismatch( obj1, obj2=None, error_message=None, **error_arrays ): @@ -779,6 +788,7 @@ def fail_for_dimension_mismatch( return dim1, dim2 +@set_module_as('brainunit') def fail_for_unit_mismatch( obj1, obj2=None, error_message=None, **error_arrays ) -> Tuple['Unit', 'Unit']: @@ -844,6 +854,7 @@ def fail_for_unit_mismatch( raise UnitMismatchError(error_message, unit1, unit2) +@set_module_as('brainunit') def display_in_unit( x: jax.typing.ArrayLike | 'Quantity', u: 'Unit', @@ -894,6 +905,7 @@ def display_in_unit( return x.in_unit(u).repr_in_unit(precision=precision) +@set_module_as('brainunit') def unit_scale_align_to_first(*args) -> List['Quantity']: """ Align the unit units of all arguments to the first one. @@ -937,6 +949,7 @@ def unit_scale_align_to_first(*args) -> List['Quantity']: return args +@set_module_as('brainunit') def array_with_unit( floatval, dim: Dimension, @@ -973,6 +986,7 @@ def array_with_unit( return Quantity(floatval, dim=get_or_create_dimension(dim._dims), dtype=dtype) +@set_module_as('brainunit') def is_dimensionless(obj: Union['Quantity', 'Unit', 'Dimension', jax.typing.ArrayLike]) -> bool: """ Test if a value is dimensionless or not. @@ -992,6 +1006,7 @@ def is_dimensionless(obj: Union['Quantity', 'Unit', 'Dimension', jax.typing.Arra return _to_quantity(obj).dim.is_dimensionless +@set_module_as('brainunit') def is_unitless(obj: Union['Quantity', 'Unit', jax.typing.ArrayLike]) -> bool: """ Test if a value is unitless or not. @@ -1007,10 +1022,10 @@ def is_unitless(obj: Union['Quantity', 'Unit', jax.typing.ArrayLike]) -> bool: ``True`` if `obj` is unitless. """ assert not isinstance(obj, Dimension), f"Dimension objects are not unitless or not, but got {obj}" - return _to_quantity(obj).is_unitless +@set_module_as('brainunit') def is_scalar_type(obj) -> bool: """ Tells you if the object is a 1d number type. @@ -3116,6 +3131,7 @@ def top_replace(s): return list(top_replace(seq)) +@set_module_as('brainunit') def check_dims(**au): """ Decorator to check dimensions of arguments passed to a function @@ -3359,6 +3375,7 @@ def new_f(*args, **kwds): return do_check_units +@set_module_as('brainunit') def check_units(**au): """ Decorator to check units of arguments passed to a function diff --git a/brainunit/_unit_constants.py b/brainunit/_unit_constants.py index 5d4aed4..4755ceb 100644 --- a/brainunit/_unit_constants.py +++ b/brainunit/_unit_constants.py @@ -19,7 +19,7 @@ have to be imported explicitly. You can use ``import ... as ...`` to import them with shorter names, e.g.:: - from brainpy.math.units import faraday_constant as F + from brainunit import faraday_constant as F The available constants are: