From 4ebd97097a5f09396f074b8f41e7f5156198f927 Mon Sep 17 00:00:00 2001 From: Nicolas Camenisch Date: Thu, 20 Jul 2023 14:46:15 +0200 Subject: [PATCH] Implement rank and dense_rank operations --- .../transform/core/expressions/translator.py | 5 ++ src/pydiverse/transform/core/ops/core.py | 6 +++ src/pydiverse/transform/core/ops/window.py | 38 ++++++++++++- src/pydiverse/transform/core/table_impl.py | 10 ++-- src/pydiverse/transform/core/util/util.py | 48 +++++++++++------ src/pydiverse/transform/eager/pandas_table.py | 54 +++++++++++++++++++ .../transform/lazy/sql_table/sql_table.py | 14 +++++ .../test_window_function.py | 50 ++++++++++++++--- 8 files changed, 193 insertions(+), 32 deletions(-) diff --git a/src/pydiverse/transform/core/expressions/translator.py b/src/pydiverse/transform/core/expressions/translator.py index 48d4b69..9f50691 100644 --- a/src/pydiverse/transform/core/expressions/translator.py +++ b/src/pydiverse/transform/core/expressions/translator.py @@ -70,6 +70,11 @@ def _translate(self, expr, accept_literal_col=True, **kwargs): if isinstance(expr, expressions.FunctionCall): operator = self.operator_registry.get_operator(expr.name) + + # Mutate function call arguments using operator + expr_args, expr_kwargs = operator.mutate_args(expr.args, expr.kwargs) + expr = expressions.FunctionCall(expr.name, *expr_args, **expr_kwargs) + op_args, op_kwargs, context_kwargs = self.__translate_function_arguments( expr, operator, **kwargs ) diff --git a/src/pydiverse/transform/core/ops/core.py b/src/pydiverse/transform/core/ops/core.py index ebae1e8..8aeafcb 100644 --- a/src/pydiverse/transform/core/ops/core.py +++ b/src/pydiverse/transform/core/ops/core.py @@ -82,6 +82,12 @@ def __hash__(self): def validate_signature(self, signature: registry.OperatorSignature) -> bool: pass + def mutate_args(self, args, kwargs): + """ + Allows the operator to modify the arguments passed to it before translation + """ + return args, kwargs + class OperatorExtension: """ diff --git a/src/pydiverse/transform/core/ops/window.py b/src/pydiverse/transform/core/ops/window.py index 41024cd..51ea74e 100644 --- a/src/pydiverse/transform/core/ops/window.py +++ b/src/pydiverse/transform/core/ops/window.py @@ -1,13 +1,39 @@ from __future__ import annotations -from .core import Nullary, Window +from .core import Nullary, Unary, Window __all__ = [ "Shift", "RowNumber", + "Rank", + "DenseRank", ] +class WindowImplicitArrange(Window): + """ + Like a window function, except that the expression on which this op + gets called, is used for arranging. + + Converts a call like this ``tbl.col1.nulls_first().rank()``, into a call like this + ``(tbl.col1).rank(arrange=[tbl.col1.nulls_first()]) + """ + + def mutate_args(self, args, kwargs): + if len(args) == 0: + return args, kwargs + + from pydiverse.transform.core.util.util import ordering_peeler + + arrange = args[0] + peeled_first_arg, _, _ = ordering_peeler(arrange) + args = (peeled_first_arg, *args[1:]) + + assert "arrange" not in kwargs + kwargs = {**kwargs, "arrange": [arrange]} + return args, kwargs + + class Shift(Window): name = "shift" signatures = [ @@ -21,3 +47,13 @@ class RowNumber(Window, Nullary): signatures = [ "-> int", ] + + +class Rank(WindowImplicitArrange, Unary): + name = "rank" + signatures = ["T -> int"] + + +class DenseRank(WindowImplicitArrange, Unary): + name = "dense_rank" + signatures = ["T -> int"] diff --git a/src/pydiverse/transform/core/table_impl.py b/src/pydiverse/transform/core/table_impl.py index 77dedd3..1eca987 100644 --- a/src/pydiverse/transform/core/table_impl.py +++ b/src/pydiverse/transform/core/table_impl.py @@ -382,17 +382,15 @@ def as_column(self, name, table: AbstractTableImpl): with AbstractTableImpl.op(ops.NullsFirst()) as op: @op.auto - def _nulls_first(x): - # it is just a marker not doing anything to the input - return x + def _nulls_first(_): + raise RuntimeError("This is just a marker that never should get called") with AbstractTableImpl.op(ops.NullsLast()) as op: @op.auto - def _nulls_last(x): - # it is just a marker not doing anything to the input - return x + def _nulls_last(_): + raise RuntimeError("This is just a marker that never should get called") #### ARITHMETIC OPERATORS ###################################################### diff --git a/src/pydiverse/transform/core/util/util.py b/src/pydiverse/transform/core/util/util.py index df0e562..2c22f89 100644 --- a/src/pydiverse/transform/core/util/util.py +++ b/src/pydiverse/transform/core/util/util.py @@ -29,29 +29,44 @@ def traverse(obj: T, callback: typing.Callable) -> T: return callback(obj) -def sign_peeler(expr): - """Remove unary - and + prefix and return the sign - :return: `True` for `+` and `False` for `-` - """ - num_neg = 0 +def peel_markers(expr, markers): + found_markers = [] while isinstance(expr, expressions.FunctionCall): - if expr.name == "__neg__": - num_neg += 1 - expr = expr.args[0] - elif expr.name == "__pos__": + if expr.name in markers: + found_markers.append(expr.name) + assert len(expr.args) == 1 expr = expr.args[0] else: break + return expr, found_markers + + +def sign_peeler(expr): + """ + Remove unary - and + prefix and return the sign + :return: `True` for `+` and `False` for `-` + """ + + expr, markers = peel_markers(expr, {"__neg__", "__pos__"}) + num_neg = markers.count("__neg__") return expr, num_neg % 2 == 0 -def extract_nulls_first(expr): - if expr.name == "nulls_first": - return True, expr.args[0] - if expr.name == "nulls_last": - return False, expr.args[0] +def ordering_peeler(expr): + expr, markers = peel_markers( + expr, {"__neg__", "__pos__", "nulls_first", "nulls_last"} + ) + + ascending = markers.count("__neg__") % 2 == 0 + nulls_first = False + for marker in markers: + if marker == "nulls_first": + nulls_first = True + break + if marker == "nulls_last": + break - return False, expr + return expr, ascending, nulls_first #### @@ -69,8 +84,7 @@ class OrderingDescriptor: def translate_ordering(tbl, order_list) -> list[OrderingDescriptor]: ordering = [] for arg in order_list: - col, ascending = sign_peeler(arg) - nulls_first, col = extract_nulls_first(col) + col, ascending, nulls_first = ordering_peeler(arg) if not isinstance(col, (column.Column, column.LambdaColumn)): raise ValueError( diff --git a/src/pydiverse/transform/eager/pandas_table.py b/src/pydiverse/transform/eager/pandas_table.py index f24adb9..932544d 100644 --- a/src/pydiverse/transform/eager/pandas_table.py +++ b/src/pydiverse/transform/eager/pandas_table.py @@ -274,11 +274,14 @@ def value(df, **kw): def _translate_function( self, expr, implementation, op_args, context_kwargs, *, verb=None, **kwargs ): + internal_kwargs = {} + def value(df, *, grouper=None, **kw): args = [arg.value(df, grouper=grouper, **kw) for arg in op_args] kwargs = { "_tbl": self.backend, "_df": df, + **internal_kwargs, } # Element wise operator @@ -325,6 +328,11 @@ def bound_impl(x): raise ValueError( "Window function are only allowed inside a mutate." ) + + if arrange := context_kwargs.get("arrange"): + ordering = translate_ordering(self.backend, arrange) + internal_kwargs["_ordering"] = ordering + value = self.arranged_window(value, operator, context_kwargs) override_ftype = ( @@ -636,3 +644,49 @@ def _row_number(idx: pd.Series): @op.auto(variant="transform") def _row_number(idx: pd.Series): return idx.cumcount() + 1 + + +with PandasTableImpl.op(ops.Rank()) as op: + + @op.auto + def _rank(x: pd.Series, *, _ordering: list[OrderingDescriptor]): + assert len(_ordering) == 1 + ordering = _ordering[0] + na_option = "top" if ordering.nulls_first else "bottom" + + return x.rank(method="min", ascending=ordering.asc, na_option=na_option) + + # transform variant currently disabled due to bug in pandas: + # https://github.com/pandas-dev/pandas/issues/54206 + + # @op.auto(variant="transform") + # def _rank(x: pd.Series, *, _ordering: list[OrderingDescriptor]): + # assert len(_ordering) == 1 + # ordering = _ordering[0] + # na_option = "top" if ordering.nulls_first else "bottom" + # return x.transform( + # "rank", method="min", ascending=ordering.asc, na_option=na_option + # ) + + +with PandasTableImpl.op(ops.DenseRank()) as op: + + @op.auto + def _dense_rank(x: pd.Series, *, _ordering: list[OrderingDescriptor]): + assert len(_ordering) == 1 + ordering = _ordering[0] + na_option = "top" if ordering.nulls_first else "bottom" + + return x.rank(method="dense", ascending=ordering.asc, na_option=na_option) + + # transform variant currently disabled due to bug in pandas: + # https://github.com/pandas-dev/pandas/issues/54206 + + # @op.auto(variant="transform") + # def _dense_rank(x: pd.Series, *, _ordering: list[OrderingDescriptor]): + # assert len(_ordering) == 1 + # ordering = _ordering[0] + # na_option = "top" if ordering.nulls_first else "bottom" + # return x.transform( + # "rank", method="dense", ascending=ordering.asc, na_option=na_option + # ) diff --git a/src/pydiverse/transform/lazy/sql_table/sql_table.py b/src/pydiverse/transform/lazy/sql_table/sql_table.py index 4a80b79..bca5382 100644 --- a/src/pydiverse/transform/lazy/sql_table/sql_table.py +++ b/src/pydiverse/transform/lazy/sql_table/sql_table.py @@ -854,3 +854,17 @@ def _shift(x, by, empty_value=None): @op.auto def _row_number(): return sa.func.ROW_NUMBER() + + +with SQLTableImpl.op(ops.Rank()) as op: + + @op.auto + def _rank(_): + return sa.func.rank() + + +with SQLTableImpl.op(ops.DenseRank()) as op: + + @op.auto + def _dense_rank(_): + return sa.func.dense_rank() diff --git a/tests/test_backend_equivalence/test_window_function.py b/tests/test_backend_equivalence/test_window_function.py index 01e68a1..2d3a040 100644 --- a/tests/test_backend_equivalence/test_window_function.py +++ b/tests/test_backend_equivalence/test_window_function.py @@ -242,11 +242,11 @@ def test_complex(df3_x, df3_y): # Test specific operations -@tables("df3") -def test_op_shift(df3_x, df3_y): +@tables("df4") +def test_op_shift(df4_x, df4_y): assert_result_equal( - df3_x, - df3_y, + df4_x, + df4_y, lambda t: t >> group_by(t.col1) >> mutate( @@ -256,11 +256,11 @@ def test_op_shift(df3_x, df3_y): ) -@tables("df3") -def test_op_row_number(df3_x, df3_y): +@tables("df4") +def test_op_row_number(df4_x, df4_y): assert_result_equal( - df3_x, - df3_y, + df4_x, + df4_y, lambda t: t >> group_by(t.col1) >> mutate( @@ -268,3 +268,37 @@ def test_op_row_number(df3_x, df3_y): row_number2=f.row_number(arrange=[λ.col2, λ.col3]), ), ) + + +@tables("df4") +def test_op_rank(df4_x, df4_y): + assert_result_equal( + df4_x, + df4_y, + lambda t: t + >> group_by(t.col1) + >> mutate( + rank1=t.col1.rank(), + rank2=t.col2.rank(), + rank3=t.col2.nulls_last().rank(), + rank4=t.col5.nulls_first().rank(), + rank5=(-t.col5.nulls_first()).rank(), + ), + ) + + +@tables("df4") +def test_op_dense_rank(df4_x, df4_y): + assert_result_equal( + df4_x, + df4_y, + lambda t: t + >> group_by(t.col1) + >> mutate( + rank1=t.col1.dense_rank(), + rank2=t.col2.dense_rank(), + rank3=t.col2.nulls_last().dense_rank(), + rank4=t.col5.nulls_first().dense_rank(), + rank5=(-t.col5.nulls_first()).dense_rank(), + ), + )