Skip to content

Commit

Permalink
Shape hints
Browse files Browse the repository at this point in the history
Signed-off-by: Iaroslav Igoshev <iaroslav.igoshev@intel.com>
  • Loading branch information
YarShev committed Mar 12, 2024
1 parent fe3a229 commit 67d4447
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 44 deletions.
92 changes: 61 additions & 31 deletions modin/core/storage_formats/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1797,42 +1797,70 @@ def isin_func(df, values):

# String map partitions operations

str_capitalize = Map.register(_str_map("capitalize"), dtypes="copy")
str_center = Map.register(_str_map("center"), dtypes="copy")
str_contains = Map.register(_str_map("contains"), dtypes=np.bool_)
str_count = Map.register(_str_map("count"), dtypes=int)
str_endswith = Map.register(_str_map("endswith"), dtypes=np.bool_)
str_find = Map.register(_str_map("find"), dtypes=np.int64)
str_findall = Map.register(_str_map("findall"), dtypes="copy")
str_get = Map.register(_str_map("get"), dtypes="copy")
str_index = Map.register(_str_map("index"), dtypes=np.int64)
str_isalnum = Map.register(_str_map("isalnum"), dtypes=np.bool_)
str_isalpha = Map.register(_str_map("isalpha"), dtypes=np.bool_)
str_isdecimal = Map.register(_str_map("isdecimal"), dtypes=np.bool_)
str_isdigit = Map.register(_str_map("isdigit"), dtypes=np.bool_)
str_islower = Map.register(_str_map("islower"), dtypes=np.bool_)
str_isnumeric = Map.register(_str_map("isnumeric"), dtypes=np.bool_)
str_isspace = Map.register(_str_map("isspace"), dtypes=np.bool_)
str_istitle = Map.register(_str_map("istitle"), dtypes=np.bool_)
str_isupper = Map.register(_str_map("isupper"), dtypes=np.bool_)
str_join = Map.register(_str_map("join"), dtypes="copy")
str_len = Map.register(_str_map("len"), dtypes=int)
str_ljust = Map.register(_str_map("ljust"), dtypes="copy")
str_lower = Map.register(_str_map("lower"), dtypes="copy")
str_lstrip = Map.register(_str_map("lstrip"), dtypes="copy")
str_match = Map.register(_str_map("match"), dtypes="copy")
str_normalize = Map.register(_str_map("normalize"), dtypes="copy")
str_pad = Map.register(_str_map("pad"), dtypes="copy")
_str_partition = Map.register(_str_map("partition"), dtypes="copy")
str_capitalize = Map.register(
_str_map("capitalize"), dtypes="copy", shape_hint="column"
)
str_center = Map.register(_str_map("center"), dtypes="copy", shape_hint="column")
str_contains = Map.register(
_str_map("contains"), dtypes=np.bool_, shape_hint="column"
)
str_count = Map.register(_str_map("count"), dtypes=int, shape_hint="column")
str_endswith = Map.register(
_str_map("endswith"), dtypes=np.bool_, shape_hint="column"
)
str_find = Map.register(_str_map("find"), dtypes=np.int64, shape_hint="column")
str_findall = Map.register(_str_map("findall"), dtypes="copy", shape_hint="column")
str_get = Map.register(_str_map("get"), dtypes="copy", shape_hint="column")
str_index = Map.register(_str_map("index"), dtypes=np.int64, shape_hint="column")
str_isalnum = Map.register(
_str_map("isalnum"), dtypes=np.bool_, shape_hint="column"
)
str_isalpha = Map.register(
_str_map("isalpha"), dtypes=np.bool_, shape_hint="column"
)
str_isdecimal = Map.register(
_str_map("isdecimal"), dtypes=np.bool_, shape_hint="column"
)
str_isdigit = Map.register(
_str_map("isdigit"), dtypes=np.bool_, shape_hint="column"
)
str_islower = Map.register(
_str_map("islower"), dtypes=np.bool_, shape_hint="column"
)
str_isnumeric = Map.register(
_str_map("isnumeric"), dtypes=np.bool_, shape_hint="column"
)
str_isspace = Map.register(
_str_map("isspace"), dtypes=np.bool_, shape_hint="column"
)
str_istitle = Map.register(
_str_map("istitle"), dtypes=np.bool_, shape_hint="column"
)
str_isupper = Map.register(
_str_map("isupper"), dtypes=np.bool_, shape_hint="column"
)
str_join = Map.register(_str_map("join"), dtypes="copy", shape_hint="column")
str_len = Map.register(_str_map("len"), dtypes=int, shape_hint="column")
str_ljust = Map.register(_str_map("ljust"), dtypes="copy", shape_hint="column")
str_lower = Map.register(_str_map("lower"), dtypes="copy", shape_hint="column")
str_lstrip = Map.register(_str_map("lstrip"), dtypes="copy", shape_hint="column")
str_match = Map.register(_str_map("match"), dtypes="copy", shape_hint="column")
str_normalize = Map.register(
_str_map("normalize"), dtypes="copy", shape_hint="column"
)
str_pad = Map.register(_str_map("pad"), dtypes="copy", shape_hint="column")
_str_partition = Map.register(
_str_map("partition"), dtypes="copy", shape_hint="column"
)

def str_partition(self, sep=" ", expand=True):
# For `expand`, need an operator that can create more columns than before
if expand:
return super().str_partition(sep=sep, expand=expand)
return self._str_partition(sep=sep, expand=False)

str_repeat = Map.register(_str_map("repeat"), dtypes="copy")
_str_extract = Map.register(_str_map("extract"), dtypes="copy")
str_repeat = Map.register(_str_map("repeat"), dtypes="copy", shape_hint="column")
_str_extract = Map.register(_str_map("extract"), dtypes="copy", shape_hint="column")

def str_extract(self, pat, flags, expand):
regex = re.compile(pat, flags=flags)
Expand Down Expand Up @@ -1970,12 +1998,14 @@ def searchsorted(df):

# END Dt map partitions operations

def astype(self, col_dtypes, errors: str = "raise"):
def astype(self, col_dtypes, errors: str = "raise", shape_hint=None):
# `errors` parameter needs to be part of the function signature because
# other query compilers may not take care of error handling at the API
# layer. This query compiler assumes there won't be any errors due to
# invalid type keys.
return self.__constructor__(self._modin_frame.astype(col_dtypes, errors=errors))
return self.__constructor__(
self._modin_frame.astype(col_dtypes, errors=errors), shape_hint=shape_hint
)

def infer_objects(self):
return self.__constructor__(self._modin_frame.infer_objects())
Expand Down
34 changes: 22 additions & 12 deletions modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,12 +588,12 @@ def __constructor__(self):
"""
return type(self)

def abs(self): # noqa: RT01, D200
def abs(self, **kwargs): # noqa: RT01, D200
"""
Return a `BasePandasDataset` with absolute numeric value of each element.
"""
self._validate_dtypes(numeric_only=True)
return self.__constructor__(query_compiler=self._query_compiler.abs())
return self.__constructor__(query_compiler=self._query_compiler.abs(**kwargs))

def _set_index(self, new_index):
"""
Expand Down Expand Up @@ -995,7 +995,9 @@ def asof(self, where, subset=None): # noqa: PR01, RT01, D200
result = result.squeeze()
return result

def astype(self, dtype, copy=None, errors="raise"): # noqa: PR01, RT01, D200
def astype(
self, dtype, copy=None, errors="raise", **kwargs
): # noqa: PR01, RT01, D200
"""
Cast a Modin object to a specified dtype `dtype`.
"""
Expand Down Expand Up @@ -1043,7 +1045,9 @@ def astype(self, dtype, copy=None, errors="raise"): # noqa: PR01, RT01, D200
copy = True

if copy:
new_query_compiler = self._query_compiler.astype(col_dtypes, errors=errors)
new_query_compiler = self._query_compiler.astype(
col_dtypes, errors=errors, **kwargs
)
return self._create_or_update_from_compiler(new_query_compiler)
return self

Expand Down Expand Up @@ -1937,11 +1941,11 @@ def isin(self, values, **kwargs): # noqa: PR01, RT01, D200
)
)

def isna(self): # noqa: RT01, D200
def isna(self, **kwargs): # noqa: RT01, D200
"""
Detect missing values.
"""
return self.__constructor__(query_compiler=self._query_compiler.isna())
return self.__constructor__(query_compiler=self._query_compiler.isna(**kwargs))

isnull = isna

Expand Down Expand Up @@ -2201,11 +2205,11 @@ def ne(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200
"""
return self._binary_op("ne", other, axis=axis, level=level, dtypes=np.bool_)

def notna(self): # noqa: RT01, D200
def notna(self, **kwargs): # noqa: RT01, D200
"""
Detect existing (non-missing) values.
"""
return self.__constructor__(query_compiler=self._query_compiler.notna())
return self.__constructor__(query_compiler=self._query_compiler.notna(**kwargs))

notnull = notna

Expand Down Expand Up @@ -3700,6 +3704,7 @@ def value_counts(
sort: bool = True,
ascending: bool = False,
dropna: bool = True,
**kwargs,
):
if subset is None:
subset = self._query_compiler.columns
Expand All @@ -3724,6 +3729,7 @@ def value_counts(
# )
# https://pandas.pydata.org/pandas-docs/version/2.0/whatsnew/v2.0.0.html#value-counts-sets-the-resulting-name-to-count
counted_values.name = "proportion" if normalize else "count"
counted_values._query_compiler._shape_hint = kwargs.get("shape_hint", None)
return counted_values

def var(
Expand Down Expand Up @@ -4008,7 +4014,7 @@ def _getitem_slice(self, key: slice):
def __gt__(self, right):
return self.gt(right)

def __invert__(self):
def __invert__(self, **kwargs):
"""
Apply bitwise inverse to each element of the `BasePandasDataset`.
Expand All @@ -4027,7 +4033,9 @@ def __invert__(self):
)
)
)
return self.__constructor__(query_compiler=self._query_compiler.invert())
return self.__constructor__(
query_compiler=self._query_compiler.invert(**kwargs)
)

@_doc_binary_op(
operation="less than or equal comparison",
Expand Down Expand Up @@ -4081,7 +4089,7 @@ def __matmul__(self, other):
def __ne__(self, other):
return self.ne(other)

def __neg__(self):
def __neg__(self, **kwargs):
"""
Change the sign for every value of self.
Expand All @@ -4090,7 +4098,9 @@ def __neg__(self):
BasePandasDataset
"""
self._validate_dtypes(numeric_only=True)
return self.__constructor__(query_compiler=self._query_compiler.negative())
return self.__constructor__(
query_compiler=self._query_compiler.negative(**kwargs)
)

def __nonzero__(self):
"""
Expand Down
52 changes: 51 additions & 1 deletion modin/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,27 @@ def __iter__(self):
"""
return self._to_pandas().__iter__()

def __invert__(self):
"""
Apply bitwise inverse to each element of the `Series`.
Returns
-------
Series
New Series containing bitwise inverse to each value.
"""
return super(Series, self).__invert__(shape_hint="column")

def __neg__(self):
"""
Change the sign for every value of self.
Returns
-------
Series
"""
return super(Series, self).__neg__(shape_hint="column")

@_doc_binary_op(operation="modulo", bin_op="mod")
def __mod__(self, right):
return self.mod(right)
Expand Down Expand Up @@ -512,6 +533,12 @@ def values(self): # noqa: RT01, D200
data = pd.Categorical(data, dtype=self.dtype)
return data

def abs(self): # noqa: RT01, D200
"""
Return a `Series` with absolute numeric value of each element.
"""
return super(Series, self).abs(shape_hint="column")

def add(self, other, level=None, fill_value=None, axis=0): # noqa: PR01, RT01, D200
"""
Return Addition of series and other, element-wise (binary operator add).
Expand Down Expand Up @@ -1180,6 +1207,14 @@ def info(
show_counts=show_counts,
)

def astype(self, dtype, copy=None, errors="raise"):
"""
Whether elements in `Series` are contained in `values`.
"""
return super(Series, self).astype(
dtype, copy=copy, errors=errors, shape_hint="column"
)

def isin(self, values): # noqa: PR01, RT01, D200
"""
Whether elements in `Series` are contained in `values`.
Expand All @@ -1194,7 +1229,7 @@ def isna(self):
-------
The result of detecting missing values.
"""
return super(Series, self).isna()
return super(Series, self).isna(shape_hint=False)

def isnull(self):
"""
Expand Down Expand Up @@ -1365,6 +1400,12 @@ def ne(self, other, level=None, fill_value=None, axis=0): # noqa: PR01, RT01, D
new_self, new_other = self._prepare_inter_op(other)
return super(Series, new_self).ne(new_other, level=level, axis=axis)

def notna(self): # noqa: RT01, D200
"""
Detect existing (non-missing) values.
"""
return super(Series, self).notna(shape_hint="column")

def nlargest(self, n=5, keep="first"): # noqa: PR01, RT01, D200
"""
Return the largest `n` elements.
Expand Down Expand Up @@ -1528,6 +1569,14 @@ def ravel(self, order="C"): # noqa: PR01, RT01, D200

return data

def round(self, decimals=0, *args, **kwargs): # noqa: PR01, RT01, D200
"""
Round a `Series` to a variable number of decimal places.
"""
return super(Series, self).round(
decimals=decimals, *args, shape_hint="column", **kwargs
)

@_inherit_docstrings(pandas.Series.reindex, apilink="pandas.Series.reindex")
def reindex(
self,
Expand Down Expand Up @@ -2085,6 +2134,7 @@ def value_counts(
sort=sort,
ascending=ascending,
dropna=dropna,
shape_hint="column",
)
return counted_values

Expand Down

0 comments on commit 67d4447

Please sign in to comment.