Skip to content

Commit

Permalink
Inlined extracted function.
Browse files Browse the repository at this point in the history
Inlined in order to not violate API boundaries. Trying to put this in a good place ends up leading to a circular import issue.
  • Loading branch information
alxmrs committed Sep 21, 2024
1 parent dbefff4 commit b8ce6ec
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 36 deletions.
35 changes: 31 additions & 4 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,18 @@
_numeric_dtypes,
_real_floating_dtypes,
_real_numeric_dtypes,
_signed_integer_dtypes,
_unsigned_integer_dtypes,
complex64,
complex128,
float32,
float64,
int32,
uint32,
int64,
uint64,
)
from cubed.backend_array_api import namespace as nxp
from cubed.array_api.utility_functions import operator_default_dtype
from cubed.backend_array_api import namespace as nxp, PRECISION
from cubed.core import reduction


Expand Down Expand Up @@ -122,7 +131,16 @@ def prod(
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
raise TypeError("Only numeric or boolean dtypes are allowed in prod")
if dtype is None:
dtype = operator_default_dtype(x)
if x.dtype in _signed_integer_dtypes:
dtype = int64 if PRECISION == 64 else int32
elif x.dtype in _unsigned_integer_dtypes:
dtype = uint64 if PRECISION == 64 else uint32
elif x.dtype == float32 and PRECISION == 64:
dtype = float64
elif x.dtype == complex64 and PRECISION == 64:
dtype = complex128
else:
dtype = x.dtype
extra_func_kwargs = dict(dtype=dtype)
return reduction(
x,
Expand All @@ -143,7 +161,16 @@ def sum(
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
raise TypeError("Only numeric or boolean dtypes are allowed in sum")
if dtype is None:
dtype = operator_default_dtype(x)
if x.dtype in _signed_integer_dtypes:
dtype = int64 if PRECISION == 64 else int32
elif x.dtype in _unsigned_integer_dtypes:
dtype = uint64 if PRECISION == 64 else uint32
elif x.dtype == float32 and PRECISION == 64:
dtype = float64
elif x.dtype == complex64 and PRECISION == 64:
dtype = complex128
else:
dtype = x.dtype
extra_func_kwargs = dict(dtype=dtype)
return reduction(
x,
Expand Down
30 changes: 1 addition & 29 deletions cubed/array_api/utility_functions.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,5 @@
from cubed.array_api.creation_functions import asarray
from cubed.array_api.dtypes import (
_signed_integer_dtypes,
_unsigned_integer_dtypes,
int32,
uint32,
int64,
uint64,
float32,
float64,
complex64,
complex128,
)
from cubed.backend_array_api import namespace as nxp, namespace, PRECISION
from cubed.backend_array_api import namespace as nxp
from cubed.core import reduction


Expand Down Expand Up @@ -41,19 +29,3 @@ def any(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None)
use_new_impl=use_new_impl,
split_every=split_every,
)


def operator_default_dtype(x: namespace.ndarray) -> namespace.dtype:
"""Derive the correct default data type for operators."""
if x.dtype in _signed_integer_dtypes:
dtype = int64 if PRECISION == 64 else int32
elif x.dtype in _unsigned_integer_dtypes:
dtype = uint64 if PRECISION == 64 else uint32
elif x.dtype == float32 and PRECISION == 64:
dtype = float64
elif x.dtype == complex64 and PRECISION == 64:
dtype = complex128
else:
dtype = x.dtype

return dtype
24 changes: 21 additions & 3 deletions cubed/nan_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,18 @@

from cubed.array_api.dtypes import (
_numeric_dtypes,
_signed_integer_dtypes,
_unsigned_integer_dtypes,
complex64,
complex128,
float32,
float64,
int32,
uint32,
int64,
uint64,
)
from cubed.backend_array_api import namespace as nxp
from cubed.array_api.utility_functions import operator_default_dtype
from cubed.backend_array_api import namespace as nxp, PRECISION
from cubed.core import reduction

# TODO: refactor once nan functions are standardized:
Expand Down Expand Up @@ -61,7 +70,16 @@ def nansum(
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in nansum")
if dtype is None:
dtype = operator_default_dtype(x)
if x.dtype in _signed_integer_dtypes:
dtype = int64 if PRECISION == 64 else int32
elif x.dtype in _unsigned_integer_dtypes:
dtype = uint64 if PRECISION == 64 else uint32
elif x.dtype == float32 and PRECISION == 64:
dtype = float64
elif x.dtype == complex64 and PRECISION == 64:
dtype = complex128
else:
dtype = x.dtype
return reduction(
x,
nxp.nansum,
Expand Down

0 comments on commit b8ce6ec

Please sign in to comment.