Skip to content

Commit

Permalink
Implement rank and dense_rank operations
Browse files Browse the repository at this point in the history
  • Loading branch information
NMAC427 committed Jul 20, 2023
1 parent f15a247 commit 4ebd970
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 32 deletions.
5 changes: 5 additions & 0 deletions src/pydiverse/transform/core/expressions/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def _translate(self, expr, accept_literal_col=True, **kwargs):

if isinstance(expr, expressions.FunctionCall):
operator = self.operator_registry.get_operator(expr.name)

# Mutate function call arguments using operator
expr_args, expr_kwargs = operator.mutate_args(expr.args, expr.kwargs)
expr = expressions.FunctionCall(expr.name, *expr_args, **expr_kwargs)

op_args, op_kwargs, context_kwargs = self.__translate_function_arguments(
expr, operator, **kwargs
)
Expand Down
6 changes: 6 additions & 0 deletions src/pydiverse/transform/core/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ def __hash__(self):
def validate_signature(self, signature: registry.OperatorSignature) -> bool:
pass

def mutate_args(self, args, kwargs):
"""
Allows the operator to modify the arguments passed to it before translation
"""
return args, kwargs


class OperatorExtension:
"""
Expand Down
38 changes: 37 additions & 1 deletion src/pydiverse/transform/core/ops/window.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,39 @@
from __future__ import annotations

from .core import Nullary, Window
from .core import Nullary, Unary, Window

__all__ = [
"Shift",
"RowNumber",
"Rank",
"DenseRank",
]


class WindowImplicitArrange(Window):
"""
Like a window function, except that the expression on which this op
gets called, is used for arranging.
Converts a call like this ``tbl.col1.nulls_first().rank()``, into a call like this
``(tbl.col1).rank(arrange=[tbl.col1.nulls_first()])
"""

def mutate_args(self, args, kwargs):
if len(args) == 0:
return args, kwargs

from pydiverse.transform.core.util.util import ordering_peeler

arrange = args[0]
peeled_first_arg, _, _ = ordering_peeler(arrange)
args = (peeled_first_arg, *args[1:])

assert "arrange" not in kwargs
kwargs = {**kwargs, "arrange": [arrange]}
return args, kwargs


class Shift(Window):
name = "shift"
signatures = [
Expand All @@ -21,3 +47,13 @@ class RowNumber(Window, Nullary):
signatures = [
"-> int",
]


class Rank(WindowImplicitArrange, Unary):
name = "rank"
signatures = ["T -> int"]


class DenseRank(WindowImplicitArrange, Unary):
name = "dense_rank"
signatures = ["T -> int"]
10 changes: 4 additions & 6 deletions src/pydiverse/transform/core/table_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,17 +382,15 @@ def as_column(self, name, table: AbstractTableImpl):
with AbstractTableImpl.op(ops.NullsFirst()) as op:

@op.auto
def _nulls_first(x):
# it is just a marker not doing anything to the input
return x
def _nulls_first(_):
raise RuntimeError("This is just a marker that never should get called")


with AbstractTableImpl.op(ops.NullsLast()) as op:

@op.auto
def _nulls_last(x):
# it is just a marker not doing anything to the input
return x
def _nulls_last(_):
raise RuntimeError("This is just a marker that never should get called")


#### ARITHMETIC OPERATORS ######################################################
Expand Down
48 changes: 31 additions & 17 deletions src/pydiverse/transform/core/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,29 +29,44 @@ def traverse(obj: T, callback: typing.Callable) -> T:
return callback(obj)


def sign_peeler(expr):
"""Remove unary - and + prefix and return the sign
:return: `True` for `+` and `False` for `-`
"""
num_neg = 0
def peel_markers(expr, markers):
found_markers = []
while isinstance(expr, expressions.FunctionCall):
if expr.name == "__neg__":
num_neg += 1
expr = expr.args[0]
elif expr.name == "__pos__":
if expr.name in markers:
found_markers.append(expr.name)
assert len(expr.args) == 1
expr = expr.args[0]
else:
break
return expr, found_markers


def sign_peeler(expr):
"""
Remove unary - and + prefix and return the sign
:return: `True` for `+` and `False` for `-`
"""

expr, markers = peel_markers(expr, {"__neg__", "__pos__"})
num_neg = markers.count("__neg__")
return expr, num_neg % 2 == 0


def extract_nulls_first(expr):
if expr.name == "nulls_first":
return True, expr.args[0]
if expr.name == "nulls_last":
return False, expr.args[0]
def ordering_peeler(expr):
expr, markers = peel_markers(
expr, {"__neg__", "__pos__", "nulls_first", "nulls_last"}
)

ascending = markers.count("__neg__") % 2 == 0
nulls_first = False
for marker in markers:
if marker == "nulls_first":
nulls_first = True
break
if marker == "nulls_last":
break

return False, expr
return expr, ascending, nulls_first


####
Expand All @@ -69,8 +84,7 @@ class OrderingDescriptor:
def translate_ordering(tbl, order_list) -> list[OrderingDescriptor]:
ordering = []
for arg in order_list:
col, ascending = sign_peeler(arg)
nulls_first, col = extract_nulls_first(col)
col, ascending, nulls_first = ordering_peeler(arg)

if not isinstance(col, (column.Column, column.LambdaColumn)):
raise ValueError(
Expand Down
54 changes: 54 additions & 0 deletions src/pydiverse/transform/eager/pandas_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,14 @@ def value(df, **kw):
def _translate_function(
self, expr, implementation, op_args, context_kwargs, *, verb=None, **kwargs
):
internal_kwargs = {}

def value(df, *, grouper=None, **kw):
args = [arg.value(df, grouper=grouper, **kw) for arg in op_args]
kwargs = {
"_tbl": self.backend,
"_df": df,
**internal_kwargs,
}

# Element wise operator
Expand Down Expand Up @@ -325,6 +328,11 @@ def bound_impl(x):
raise ValueError(
"Window function are only allowed inside a mutate."
)

if arrange := context_kwargs.get("arrange"):
ordering = translate_ordering(self.backend, arrange)
internal_kwargs["_ordering"] = ordering

value = self.arranged_window(value, operator, context_kwargs)

override_ftype = (
Expand Down Expand Up @@ -636,3 +644,49 @@ def _row_number(idx: pd.Series):
@op.auto(variant="transform")
def _row_number(idx: pd.Series):
return idx.cumcount() + 1


with PandasTableImpl.op(ops.Rank()) as op:

@op.auto
def _rank(x: pd.Series, *, _ordering: list[OrderingDescriptor]):
assert len(_ordering) == 1
ordering = _ordering[0]
na_option = "top" if ordering.nulls_first else "bottom"

return x.rank(method="min", ascending=ordering.asc, na_option=na_option)

# transform variant currently disabled due to bug in pandas:
# https://github.com/pandas-dev/pandas/issues/54206

# @op.auto(variant="transform")
# def _rank(x: pd.Series, *, _ordering: list[OrderingDescriptor]):
# assert len(_ordering) == 1
# ordering = _ordering[0]
# na_option = "top" if ordering.nulls_first else "bottom"
# return x.transform(
# "rank", method="min", ascending=ordering.asc, na_option=na_option
# )


with PandasTableImpl.op(ops.DenseRank()) as op:

@op.auto
def _dense_rank(x: pd.Series, *, _ordering: list[OrderingDescriptor]):
assert len(_ordering) == 1
ordering = _ordering[0]
na_option = "top" if ordering.nulls_first else "bottom"

return x.rank(method="dense", ascending=ordering.asc, na_option=na_option)

# transform variant currently disabled due to bug in pandas:
# https://github.com/pandas-dev/pandas/issues/54206

# @op.auto(variant="transform")
# def _dense_rank(x: pd.Series, *, _ordering: list[OrderingDescriptor]):
# assert len(_ordering) == 1
# ordering = _ordering[0]
# na_option = "top" if ordering.nulls_first else "bottom"
# return x.transform(
# "rank", method="dense", ascending=ordering.asc, na_option=na_option
# )
14 changes: 14 additions & 0 deletions src/pydiverse/transform/lazy/sql_table/sql_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,3 +854,17 @@ def _shift(x, by, empty_value=None):
@op.auto
def _row_number():
return sa.func.ROW_NUMBER()


with SQLTableImpl.op(ops.Rank()) as op:

@op.auto
def _rank(_):
return sa.func.rank()


with SQLTableImpl.op(ops.DenseRank()) as op:

@op.auto
def _dense_rank(_):
return sa.func.dense_rank()
50 changes: 42 additions & 8 deletions tests/test_backend_equivalence/test_window_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,11 @@ def test_complex(df3_x, df3_y):
# Test specific operations


@tables("df3")
def test_op_shift(df3_x, df3_y):
@tables("df4")
def test_op_shift(df4_x, df4_y):
assert_result_equal(
df3_x,
df3_y,
df4_x,
df4_y,
lambda t: t
>> group_by(t.col1)
>> mutate(
Expand All @@ -256,15 +256,49 @@ def test_op_shift(df3_x, df3_y):
)


@tables("df3")
def test_op_row_number(df3_x, df3_y):
@tables("df4")
def test_op_row_number(df4_x, df4_y):
assert_result_equal(
df3_x,
df3_y,
df4_x,
df4_y,
lambda t: t
>> group_by(t.col1)
>> mutate(
row_number1=f.row_number(arrange=[λ.col4]),
row_number2=f.row_number(arrange=[λ.col2, λ.col3]),
),
)


@tables("df4")
def test_op_rank(df4_x, df4_y):
assert_result_equal(
df4_x,
df4_y,
lambda t: t
>> group_by(t.col1)
>> mutate(
rank1=t.col1.rank(),
rank2=t.col2.rank(),
rank3=t.col2.nulls_last().rank(),
rank4=t.col5.nulls_first().rank(),
rank5=(-t.col5.nulls_first()).rank(),
),
)


@tables("df4")
def test_op_dense_rank(df4_x, df4_y):
assert_result_equal(
df4_x,
df4_y,
lambda t: t
>> group_by(t.col1)
>> mutate(
rank1=t.col1.dense_rank(),
rank2=t.col2.dense_rank(),
rank3=t.col2.nulls_last().dense_rank(),
rank4=t.col5.nulls_first().dense_rank(),
rank5=(-t.col5.nulls_first()).dense_rank(),
),
)

0 comments on commit 4ebd970

Please sign in to comment.