Skip to content

Commit

Permalink
More lenient dtype support (#550)
Browse files Browse the repository at this point in the history
* Allow `bool` in `sum` and `prod`

* Make dtype checking more lenient in case of 'all'
  • Loading branch information
tomwhite authored Aug 13, 2024
1 parent 3dbbc9a commit 0e6e3f0
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 8 deletions.
10 changes: 8 additions & 2 deletions cubed/array_api/array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,12 +399,18 @@ def __int__(self, /):
# Utility methods

def _check_allowed_dtypes(self, other, dtype_category, op):
if self.dtype not in _dtype_categories[dtype_category]:
if (
dtype_category != "all"
and self.dtype not in _dtype_categories[dtype_category]
):
raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}")
if isinstance(other, (int, complex, float, bool)):
other = self._promote_scalar(other)
elif isinstance(other, CoreArray):
if other.dtype not in _dtype_categories[dtype_category]:
if (
dtype_category != "all"
and other.dtype not in _dtype_categories[dtype_category]
):
raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}")
else:
return NotImplemented
Expand Down
19 changes: 13 additions & 6 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math

from cubed.array_api.dtypes import (
_boolean_dtypes,
_numeric_dtypes,
_real_floating_dtypes,
_real_numeric_dtypes,
Expand Down Expand Up @@ -124,10 +125,13 @@ def min(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None)
def prod(
x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None
):
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in prod")
# boolean is allowed by numpy
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
raise TypeError("Only numeric or boolean dtypes are allowed in prod")
if dtype is None:
if x.dtype in _signed_integer_dtypes:
if x.dtype in _boolean_dtypes:
dtype = int64
elif x.dtype in _signed_integer_dtypes:
dtype = int64
elif x.dtype in _unsigned_integer_dtypes:
dtype = uint64
Expand All @@ -153,10 +157,13 @@ def prod(
def sum(
x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None
):
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in sum")
# boolean is allowed by numpy
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
raise TypeError("Only numeric or boolean dtypes are allowed in sum")
if dtype is None:
if x.dtype in _signed_integer_dtypes:
if x.dtype in _boolean_dtypes:
dtype = int64
elif x.dtype in _signed_integer_dtypes:
dtype = int64
elif x.dtype in _unsigned_integer_dtypes:
dtype = uint64
Expand Down
10 changes: 10 additions & 0 deletions cubed/tests/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from numpy.testing import assert_array_equal

import cubed.array_api as xp


# This is less strict than the spec, but is supported by implementations like NumPy
def test_prod_sum_bool():
a = xp.ones((2,), dtype=xp.bool)
assert_array_equal(xp.prod(a).compute(), xp.asarray([1], dtype=xp.int64))
assert_array_equal(xp.sum(a).compute(), xp.asarray([2], dtype=xp.int64))

0 comments on commit 0e6e3f0

Please sign in to comment.