Skip to content

Commit

Permalink
Improve naming
Browse files Browse the repository at this point in the history
  • Loading branch information
cbourjau committed Sep 26, 2024
1 parent 0d84633 commit a50616c
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 51 deletions.
16 changes: 5 additions & 11 deletions ndonnx/_logic_in_data/_typed_array/date_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..dtypes import TY_ARRAY, DType
from ..schema import DTypeInfo, flatten_components
from . import onnx, py_scalars
from .funcs import astypedarray, typed_where
from .funcs import astyarray, where
from .typed_array import TyArrayBase
from .utils import safe_cast

Expand Down Expand Up @@ -183,7 +183,7 @@ def __init__(self, is_nat: onnx.TyArrayBool, data: onnx.TyArrayInt64, unit: Unit
def __ndx_astype__(self, dtype: DType[TY_ARRAY]) -> TY_ARRAY | NotImplementedType:
res_type = dtype._tyarr_class
if isinstance(dtype, onnx.CoreIntegerDTypes):
data = typed_where(self.is_nat, _NAT_SENTINEL, self.data)
data = where(self.is_nat, _NAT_SENTINEL, self.data)
return data.astype(dtype)
if isinstance(dtype, TimeDelta):
powers = {
Expand All @@ -193,9 +193,7 @@ def __ndx_astype__(self, dtype: DType[TY_ARRAY]) -> TY_ARRAY | NotImplementedTyp
"ns": 9,
}
power = powers[dtype.unit] - powers[self.dtype.unit]
data = typed_where(
self.is_nat, astypedarray(np.iinfo(np.int64).min), self.data
)
data = where(self.is_nat, astyarray(np.iinfo(np.int64).min), self.data)

if power > 0:
data = data * np.pow(10, power)
Expand Down Expand Up @@ -243,9 +241,7 @@ def unwrap_numpy(self) -> np.ndarray:
def __ndx_astype__(self, dtype: DType[TY_ARRAY]) -> TY_ARRAY | NotImplementedType:
res_type = dtype._tyarr_class
if isinstance(dtype, onnx.CoreIntegerDTypes):
data = typed_where(
self.is_nat, astypedarray(np.iinfo(np.int64).min), self.data
)
data = where(self.is_nat, astyarray(np.iinfo(np.int64).min), self.data)
return data.astype(dtype)
if isinstance(dtype, DateTime):
powers = {
Expand All @@ -255,9 +251,7 @@ def __ndx_astype__(self, dtype: DType[TY_ARRAY]) -> TY_ARRAY | NotImplementedTyp
"ns": 9,
}
power = powers[dtype.unit] - powers[self.dtype.unit]
data = typed_where(
self.is_nat, astypedarray(np.iinfo(np.int64).min), self.data
)
data = where(self.is_nat, astyarray(np.iinfo(np.int64).min), self.data)

if power > 0:
data = data * np.pow(10, power)
Expand Down
52 changes: 23 additions & 29 deletions ndonnx/_logic_in_data/_typed_array/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,7 @@
from .utils import safe_cast


def typed_where(cond: TyArrayBase, x: TyArrayBase, y: TyArrayBase) -> TyArrayBase:
from . import TyArrayBool

# TODO: Masked condition
if not isinstance(cond, TyArrayBool):
raise TypeError("'cond' must be a boolean data type.")

ret = x.__ndx_where__(cond, y)
if ret is NotImplemented:
ret = y.__ndx_rwhere__(cond, x)
if ret is NotImplemented:
raise TypeError(
f"Unsupported operand data types for 'where': `{x.dtype}` and `{y.dtype}`"
)
return ret


def astypedarray(
def astyarray(
val: int | float | str | np.ndarray | TyArrayBase | Var,
dtype: None | DType = None,
use_py_scalars=False,
Expand Down Expand Up @@ -69,6 +52,28 @@ def astypedarray(
return arr


#########################################################################
# Free functions implemented via `__ndx_*__` methods on the typed array #
#########################################################################


def where(cond: TyArrayBase, x: TyArrayBase, y: TyArrayBase) -> TyArrayBase:
from . import TyArrayBool

# TODO: Masked condition
if not isinstance(cond, TyArrayBool):
raise TypeError("'cond' must be a boolean data type.")

ret = x.__ndx_where__(cond, y)
if ret is NotImplemented:
ret = y.__ndx_rwhere__(cond, x)
if ret is NotImplemented:
raise TypeError(
f"Unsupported operand data types for 'where': `{x.dtype}` and `{y.dtype}`"
)
return ret


def maximum(x1: TyArrayBase, x2: TyArrayBase, /) -> TyArrayBase:
res = x1.__ndx_maximum__(x2)
if res is NotImplemented:
Expand All @@ -78,14 +83,3 @@ def maximum(x1: TyArrayBase, x2: TyArrayBase, /) -> TyArrayBase:
f"Unsupported operand data types for 'max': `{x1.dtype}` and `{x2.dtype}`"
)
return res


def sum(
x: TyArrayBase,
/,
*,
axis: int | tuple[int, ...] | None = None,
dtype: DType | None = None,
keepdims: bool = False,
) -> TyArrayBase:
return x.sum(axis=axis, dtype=dtype, keepdims=keepdims)
4 changes: 2 additions & 2 deletions ndonnx/_logic_in_data/_typed_array/masked_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..dtypes import TY_ARRAY, DType
from ..schema import DTypeInfo, Schema, flatten_components
from . import onnx
from .funcs import astypedarray
from .funcs import astyarray
from .typed_array import TyArrayBase
from .utils import safe_cast

Expand Down Expand Up @@ -321,7 +321,7 @@ def __setitem__(self, index: SetitemIndex, value: Self) -> None:
if self.mask is None:
shape = self.dynamic_shape
self.mask = safe_cast(
onnx.TyArrayBool, astypedarray(False).broadcast_to(shape)
onnx.TyArrayBool, astyarray(False).broadcast_to(shape)
)
self.mask[index] = new_mask
else:
Expand Down
4 changes: 4 additions & 0 deletions ndonnx/_logic_in_data/_typed_array/typed_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def __repr__(self) -> str:
@abstractmethod
def __ndx_value_repr__(self) -> dict[str, str]:
"""A string representation of the fields to be used in ``Array.__repr__```."""
# Note: It is unfortunate that this part of the API relies on
# the rather useless `dict[str, str]` type hint. `TypedDict`
# is not a viable solution (?) since it does not play nicely
# with the subtyping.

@abstractmethod
def __getitem__(self, index: GetitemIndex) -> Self: ...
Expand Down
8 changes: 4 additions & 4 deletions ndonnx/_logic_in_data/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from spox import Var

from ._typed_array import TyArrayBase
from ._typed_array.funcs import astypedarray
from ._typed_array.funcs import astyarray
from .dtypes import DType

StrictShape = tuple[int, ...]
Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(self, shape=None, dtype=None, value=None, var=None):
if isinstance(value, np.ndarray):
raise NotImplementedError
if isinstance(value, int | float):
ty_arr = astypedarray(value, use_py_scalars=False, dtype=dtype)
ty_arr = astyarray(value, use_py_scalars=False, dtype=dtype)
self._data = ty_arr
return

Expand Down Expand Up @@ -278,7 +278,7 @@ def asarray(
return obj
if isinstance(obj, bool | int | float):
obj = np.array(obj)
data = astypedarray(obj)
data = astyarray(obj)
if dtype:
data = data.astype(dtype)
return Array._from_data(data)
Expand All @@ -289,7 +289,7 @@ def _as_array(
) -> Array:
if isinstance(val, Array):
return val
ty_arr = astypedarray(val, use_py_scalars=use_py_scalars)
ty_arr = astyarray(val, use_py_scalars=use_py_scalars)
return Array._from_data(ty_arr)


Expand Down
8 changes: 3 additions & 5 deletions ndonnx/_logic_in_data/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,13 @@ def sum(
dtype: DType | None = None,
keepdims: bool = False,
) -> Array:
from ._typed_array.funcs import sum

return Array._from_data(sum(x._data))
return Array._from_data(x._data.sum(axis=axis, dtype=dtype, keepdims=keepdims))


def where(cond: Array, a: Array, b: Array) -> Array:
from ._typed_array.funcs import typed_where
from ._typed_array import funcs as tyfuncs

data = typed_where(cond._data, a._data, b._data)
data = tyfuncs.where(cond._data, a._data, b._data)
return Array._from_data(data)


Expand Down

0 comments on commit a50616c

Please sign in to comment.