From b8ce6ec2334c8fea4b9563ee2429b52f4eb0a7e1 Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Tue, 6 Aug 2024 20:31:38 +0100 Subject: [PATCH] Inlined extracted function. Inlined in order to not violate API boundaries. Trying to put this in a good place ends up leading to a circular import issue. --- cubed/array_api/statistical_functions.py | 35 +++++++++++++++++++++--- cubed/array_api/utility_functions.py | 30 +------------------- cubed/nan_functions.py | 24 ++++++++++++++-- 3 files changed, 53 insertions(+), 36 deletions(-) diff --git a/cubed/array_api/statistical_functions.py b/cubed/array_api/statistical_functions.py index eb33e2a8..c0a64c8e 100644 --- a/cubed/array_api/statistical_functions.py +++ b/cubed/array_api/statistical_functions.py @@ -5,9 +5,18 @@ _numeric_dtypes, _real_floating_dtypes, _real_numeric_dtypes, + _signed_integer_dtypes, + _unsigned_integer_dtypes, + complex64, + complex128, + float32, + float64, + int32, + uint32, + int64, + uint64, ) -from cubed.backend_array_api import namespace as nxp -from cubed.array_api.utility_functions import operator_default_dtype +from cubed.backend_array_api import namespace as nxp, PRECISION from cubed.core import reduction @@ -122,7 +131,16 @@ def prod( if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes: raise TypeError("Only numeric or boolean dtypes are allowed in prod") if dtype is None: - dtype = operator_default_dtype(x) + if x.dtype in _signed_integer_dtypes: + dtype = int64 if PRECISION == 64 else int32 + elif x.dtype in _unsigned_integer_dtypes: + dtype = uint64 if PRECISION == 64 else uint32 + elif x.dtype == float32 and PRECISION == 64: + dtype = float64 + elif x.dtype == complex64 and PRECISION == 64: + dtype = complex128 + else: + dtype = x.dtype extra_func_kwargs = dict(dtype=dtype) return reduction( x, @@ -143,7 +161,16 @@ def sum( if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes: raise TypeError("Only numeric or boolean dtypes are allowed in sum") if dtype is None: - dtype = operator_default_dtype(x) + if x.dtype in _signed_integer_dtypes: + dtype = int64 if PRECISION == 64 else int32 + elif x.dtype in _unsigned_integer_dtypes: + dtype = uint64 if PRECISION == 64 else uint32 + elif x.dtype == float32 and PRECISION == 64: + dtype = float64 + elif x.dtype == complex64 and PRECISION == 64: + dtype = complex128 + else: + dtype = x.dtype extra_func_kwargs = dict(dtype=dtype) return reduction( x, diff --git a/cubed/array_api/utility_functions.py b/cubed/array_api/utility_functions.py index 16ecb803..9825dd9b 100644 --- a/cubed/array_api/utility_functions.py +++ b/cubed/array_api/utility_functions.py @@ -1,17 +1,5 @@ from cubed.array_api.creation_functions import asarray -from cubed.array_api.dtypes import ( - _signed_integer_dtypes, - _unsigned_integer_dtypes, - int32, - uint32, - int64, - uint64, - float32, - float64, - complex64, - complex128, -) -from cubed.backend_array_api import namespace as nxp, namespace, PRECISION +from cubed.backend_array_api import namespace as nxp from cubed.core import reduction @@ -41,19 +29,3 @@ def any(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None) use_new_impl=use_new_impl, split_every=split_every, ) - - -def operator_default_dtype(x: namespace.ndarray) -> namespace.dtype: - """Derive the correct default data type for operators.""" - if x.dtype in _signed_integer_dtypes: - dtype = int64 if PRECISION == 64 else int32 - elif x.dtype in _unsigned_integer_dtypes: - dtype = uint64 if PRECISION == 64 else uint32 - elif x.dtype == float32 and PRECISION == 64: - dtype = float64 - elif x.dtype == complex64 and PRECISION == 64: - dtype = complex128 - else: - dtype = x.dtype - - return dtype diff --git a/cubed/nan_functions.py b/cubed/nan_functions.py index 402726a6..3aaf5b74 100644 --- a/cubed/nan_functions.py +++ b/cubed/nan_functions.py @@ -2,9 +2,18 @@ from cubed.array_api.dtypes import ( _numeric_dtypes, + _signed_integer_dtypes, + _unsigned_integer_dtypes, + complex64, + complex128, + float32, + float64, + int32, + uint32, + int64, + uint64, ) -from cubed.backend_array_api import namespace as nxp -from cubed.array_api.utility_functions import operator_default_dtype +from cubed.backend_array_api import namespace as nxp, PRECISION from cubed.core import reduction # TODO: refactor once nan functions are standardized: @@ -61,7 +70,16 @@ def nansum( if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in nansum") if dtype is None: - dtype = operator_default_dtype(x) + if x.dtype in _signed_integer_dtypes: + dtype = int64 if PRECISION == 64 else int32 + elif x.dtype in _unsigned_integer_dtypes: + dtype = uint64 if PRECISION == 64 else uint32 + elif x.dtype == float32 and PRECISION == 64: + dtype = float64 + elif x.dtype == complex64 and PRECISION == 64: + dtype = complex128 + else: + dtype = x.dtype return reduction( x, nxp.nansum,