From 67d444788455328cf1803660a631dbd2a0afbfd8 Mon Sep 17 00:00:00 2001 From: Iaroslav Igoshev Date: Tue, 12 Mar 2024 10:25:35 +0100 Subject: [PATCH] Shape hints Signed-off-by: Iaroslav Igoshev --- .../storage_formats/pandas/query_compiler.py | 92 ++++++++++++------- modin/pandas/base.py | 34 ++++--- modin/pandas/series.py | 52 ++++++++++- 3 files changed, 134 insertions(+), 44 deletions(-) diff --git a/modin/core/storage_formats/pandas/query_compiler.py b/modin/core/storage_formats/pandas/query_compiler.py index 7beeca258dc..16cc383dcbc 100644 --- a/modin/core/storage_formats/pandas/query_compiler.py +++ b/modin/core/storage_formats/pandas/query_compiler.py @@ -1797,33 +1797,61 @@ def isin_func(df, values): # String map partitions operations - str_capitalize = Map.register(_str_map("capitalize"), dtypes="copy") - str_center = Map.register(_str_map("center"), dtypes="copy") - str_contains = Map.register(_str_map("contains"), dtypes=np.bool_) - str_count = Map.register(_str_map("count"), dtypes=int) - str_endswith = Map.register(_str_map("endswith"), dtypes=np.bool_) - str_find = Map.register(_str_map("find"), dtypes=np.int64) - str_findall = Map.register(_str_map("findall"), dtypes="copy") - str_get = Map.register(_str_map("get"), dtypes="copy") - str_index = Map.register(_str_map("index"), dtypes=np.int64) - str_isalnum = Map.register(_str_map("isalnum"), dtypes=np.bool_) - str_isalpha = Map.register(_str_map("isalpha"), dtypes=np.bool_) - str_isdecimal = Map.register(_str_map("isdecimal"), dtypes=np.bool_) - str_isdigit = Map.register(_str_map("isdigit"), dtypes=np.bool_) - str_islower = Map.register(_str_map("islower"), dtypes=np.bool_) - str_isnumeric = Map.register(_str_map("isnumeric"), dtypes=np.bool_) - str_isspace = Map.register(_str_map("isspace"), dtypes=np.bool_) - str_istitle = Map.register(_str_map("istitle"), dtypes=np.bool_) - str_isupper = Map.register(_str_map("isupper"), dtypes=np.bool_) - str_join = Map.register(_str_map("join"), dtypes="copy") - str_len = Map.register(_str_map("len"), dtypes=int) - str_ljust = Map.register(_str_map("ljust"), dtypes="copy") - str_lower = Map.register(_str_map("lower"), dtypes="copy") - str_lstrip = Map.register(_str_map("lstrip"), dtypes="copy") - str_match = Map.register(_str_map("match"), dtypes="copy") - str_normalize = Map.register(_str_map("normalize"), dtypes="copy") - str_pad = Map.register(_str_map("pad"), dtypes="copy") - _str_partition = Map.register(_str_map("partition"), dtypes="copy") + str_capitalize = Map.register( + _str_map("capitalize"), dtypes="copy", shape_hint="column" + ) + str_center = Map.register(_str_map("center"), dtypes="copy", shape_hint="column") + str_contains = Map.register( + _str_map("contains"), dtypes=np.bool_, shape_hint="column" + ) + str_count = Map.register(_str_map("count"), dtypes=int, shape_hint="column") + str_endswith = Map.register( + _str_map("endswith"), dtypes=np.bool_, shape_hint="column" + ) + str_find = Map.register(_str_map("find"), dtypes=np.int64, shape_hint="column") + str_findall = Map.register(_str_map("findall"), dtypes="copy", shape_hint="column") + str_get = Map.register(_str_map("get"), dtypes="copy", shape_hint="column") + str_index = Map.register(_str_map("index"), dtypes=np.int64, shape_hint="column") + str_isalnum = Map.register( + _str_map("isalnum"), dtypes=np.bool_, shape_hint="column" + ) + str_isalpha = Map.register( + _str_map("isalpha"), dtypes=np.bool_, shape_hint="column" + ) + str_isdecimal = Map.register( + _str_map("isdecimal"), dtypes=np.bool_, shape_hint="column" + ) + str_isdigit = Map.register( + _str_map("isdigit"), dtypes=np.bool_, shape_hint="column" + ) + str_islower = Map.register( + _str_map("islower"), dtypes=np.bool_, shape_hint="column" + ) + str_isnumeric = Map.register( + _str_map("isnumeric"), dtypes=np.bool_, shape_hint="column" + ) + str_isspace = Map.register( + _str_map("isspace"), dtypes=np.bool_, shape_hint="column" + ) + str_istitle = Map.register( + _str_map("istitle"), dtypes=np.bool_, shape_hint="column" + ) + str_isupper = Map.register( + _str_map("isupper"), dtypes=np.bool_, shape_hint="column" + ) + str_join = Map.register(_str_map("join"), dtypes="copy", shape_hint="column") + str_len = Map.register(_str_map("len"), dtypes=int, shape_hint="column") + str_ljust = Map.register(_str_map("ljust"), dtypes="copy", shape_hint="column") + str_lower = Map.register(_str_map("lower"), dtypes="copy", shape_hint="column") + str_lstrip = Map.register(_str_map("lstrip"), dtypes="copy", shape_hint="column") + str_match = Map.register(_str_map("match"), dtypes="copy", shape_hint="column") + str_normalize = Map.register( + _str_map("normalize"), dtypes="copy", shape_hint="column" + ) + str_pad = Map.register(_str_map("pad"), dtypes="copy", shape_hint="column") + _str_partition = Map.register( + _str_map("partition"), dtypes="copy", shape_hint="column" + ) def str_partition(self, sep=" ", expand=True): # For `expand`, need an operator that can create more columns than before @@ -1831,8 +1859,8 @@ def str_partition(self, sep=" ", expand=True): return super().str_partition(sep=sep, expand=expand) return self._str_partition(sep=sep, expand=False) - str_repeat = Map.register(_str_map("repeat"), dtypes="copy") - _str_extract = Map.register(_str_map("extract"), dtypes="copy") + str_repeat = Map.register(_str_map("repeat"), dtypes="copy", shape_hint="column") + _str_extract = Map.register(_str_map("extract"), dtypes="copy", shape_hint="column") def str_extract(self, pat, flags, expand): regex = re.compile(pat, flags=flags) @@ -1970,12 +1998,14 @@ def searchsorted(df): # END Dt map partitions operations - def astype(self, col_dtypes, errors: str = "raise"): + def astype(self, col_dtypes, errors: str = "raise", shape_hint=None): # `errors` parameter needs to be part of the function signature because # other query compilers may not take care of error handling at the API # layer. This query compiler assumes there won't be any errors due to # invalid type keys. - return self.__constructor__(self._modin_frame.astype(col_dtypes, errors=errors)) + return self.__constructor__( + self._modin_frame.astype(col_dtypes, errors=errors), shape_hint=shape_hint + ) def infer_objects(self): return self.__constructor__(self._modin_frame.infer_objects()) diff --git a/modin/pandas/base.py b/modin/pandas/base.py index 6779c47dc9b..0edf325186b 100644 --- a/modin/pandas/base.py +++ b/modin/pandas/base.py @@ -588,12 +588,12 @@ def __constructor__(self): """ return type(self) - def abs(self): # noqa: RT01, D200 + def abs(self, **kwargs): # noqa: RT01, D200 """ Return a `BasePandasDataset` with absolute numeric value of each element. """ self._validate_dtypes(numeric_only=True) - return self.__constructor__(query_compiler=self._query_compiler.abs()) + return self.__constructor__(query_compiler=self._query_compiler.abs(**kwargs)) def _set_index(self, new_index): """ @@ -995,7 +995,9 @@ def asof(self, where, subset=None): # noqa: PR01, RT01, D200 result = result.squeeze() return result - def astype(self, dtype, copy=None, errors="raise"): # noqa: PR01, RT01, D200 + def astype( + self, dtype, copy=None, errors="raise", **kwargs + ): # noqa: PR01, RT01, D200 """ Cast a Modin object to a specified dtype `dtype`. """ @@ -1043,7 +1045,9 @@ def astype(self, dtype, copy=None, errors="raise"): # noqa: PR01, RT01, D200 copy = True if copy: - new_query_compiler = self._query_compiler.astype(col_dtypes, errors=errors) + new_query_compiler = self._query_compiler.astype( + col_dtypes, errors=errors, **kwargs + ) return self._create_or_update_from_compiler(new_query_compiler) return self @@ -1937,11 +1941,11 @@ def isin(self, values, **kwargs): # noqa: PR01, RT01, D200 ) ) - def isna(self): # noqa: RT01, D200 + def isna(self, **kwargs): # noqa: RT01, D200 """ Detect missing values. """ - return self.__constructor__(query_compiler=self._query_compiler.isna()) + return self.__constructor__(query_compiler=self._query_compiler.isna(**kwargs)) isnull = isna @@ -2201,11 +2205,11 @@ def ne(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 """ return self._binary_op("ne", other, axis=axis, level=level, dtypes=np.bool_) - def notna(self): # noqa: RT01, D200 + def notna(self, **kwargs): # noqa: RT01, D200 """ Detect existing (non-missing) values. """ - return self.__constructor__(query_compiler=self._query_compiler.notna()) + return self.__constructor__(query_compiler=self._query_compiler.notna(**kwargs)) notnull = notna @@ -3700,6 +3704,7 @@ def value_counts( sort: bool = True, ascending: bool = False, dropna: bool = True, + **kwargs, ): if subset is None: subset = self._query_compiler.columns @@ -3724,6 +3729,7 @@ def value_counts( # ) # https://pandas.pydata.org/pandas-docs/version/2.0/whatsnew/v2.0.0.html#value-counts-sets-the-resulting-name-to-count counted_values.name = "proportion" if normalize else "count" + counted_values._query_compiler._shape_hint = kwargs.get("shape_hint", None) return counted_values def var( @@ -4008,7 +4014,7 @@ def _getitem_slice(self, key: slice): def __gt__(self, right): return self.gt(right) - def __invert__(self): + def __invert__(self, **kwargs): """ Apply bitwise inverse to each element of the `BasePandasDataset`. @@ -4027,7 +4033,9 @@ def __invert__(self): ) ) ) - return self.__constructor__(query_compiler=self._query_compiler.invert()) + return self.__constructor__( + query_compiler=self._query_compiler.invert(**kwargs) + ) @_doc_binary_op( operation="less than or equal comparison", @@ -4081,7 +4089,7 @@ def __matmul__(self, other): def __ne__(self, other): return self.ne(other) - def __neg__(self): + def __neg__(self, **kwargs): """ Change the sign for every value of self. @@ -4090,7 +4098,9 @@ def __neg__(self): BasePandasDataset """ self._validate_dtypes(numeric_only=True) - return self.__constructor__(query_compiler=self._query_compiler.negative()) + return self.__constructor__( + query_compiler=self._query_compiler.negative(**kwargs) + ) def __nonzero__(self): """ diff --git a/modin/pandas/series.py b/modin/pandas/series.py index d4eb8f6faef..82bd8fa0619 100644 --- a/modin/pandas/series.py +++ b/modin/pandas/series.py @@ -339,6 +339,27 @@ def __iter__(self): """ return self._to_pandas().__iter__() + def __invert__(self): + """ + Apply bitwise inverse to each element of the `Series`. + + Returns + ------- + Series + New Series containing bitwise inverse to each value. + """ + return super(Series, self).__invert__(shape_hint="column") + + def __neg__(self): + """ + Change the sign for every value of self. + + Returns + ------- + Series + """ + return super(Series, self).__neg__(shape_hint="column") + @_doc_binary_op(operation="modulo", bin_op="mod") def __mod__(self, right): return self.mod(right) @@ -512,6 +533,12 @@ def values(self): # noqa: RT01, D200 data = pd.Categorical(data, dtype=self.dtype) return data + def abs(self): # noqa: RT01, D200 + """ + Return a `Series` with absolute numeric value of each element. + """ + return super(Series, self).abs(shape_hint="column") + def add(self, other, level=None, fill_value=None, axis=0): # noqa: PR01, RT01, D200 """ Return Addition of series and other, element-wise (binary operator add). @@ -1180,6 +1207,14 @@ def info( show_counts=show_counts, ) + def astype(self, dtype, copy=None, errors="raise"): + """ + Whether elements in `Series` are contained in `values`. + """ + return super(Series, self).astype( + dtype, copy=copy, errors=errors, shape_hint="column" + ) + def isin(self, values): # noqa: PR01, RT01, D200 """ Whether elements in `Series` are contained in `values`. @@ -1194,7 +1229,7 @@ def isna(self): ------- The result of detecting missing values. """ - return super(Series, self).isna() + return super(Series, self).isna(shape_hint=False) def isnull(self): """ @@ -1365,6 +1400,12 @@ def ne(self, other, level=None, fill_value=None, axis=0): # noqa: PR01, RT01, D new_self, new_other = self._prepare_inter_op(other) return super(Series, new_self).ne(new_other, level=level, axis=axis) + def notna(self): # noqa: RT01, D200 + """ + Detect existing (non-missing) values. + """ + return super(Series, self).notna(shape_hint="column") + def nlargest(self, n=5, keep="first"): # noqa: PR01, RT01, D200 """ Return the largest `n` elements. @@ -1528,6 +1569,14 @@ def ravel(self, order="C"): # noqa: PR01, RT01, D200 return data + def round(self, decimals=0, *args, **kwargs): # noqa: PR01, RT01, D200 + """ + Round a `Series` to a variable number of decimal places. + """ + return super(Series, self).round( + decimals=decimals, *args, shape_hint="column", **kwargs + ) + @_inherit_docstrings(pandas.Series.reindex, apilink="pandas.Series.reindex") def reindex( self, @@ -2085,6 +2134,7 @@ def value_counts( sort=sort, ascending=ascending, dropna=dropna, + shape_hint="column", ) return counted_values