Skip to content

Commit

Permalink
Experimental const-T-list
Browse files Browse the repository at this point in the history
  • Loading branch information
windiana42 committed Aug 19, 2023
1 parent a23e740 commit 612444a
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 2 deletions.
8 changes: 8 additions & 0 deletions src/pydiverse/transform/core/ops/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"Xor",
"RXor",
"Invert",
"Isin",
]


Expand Down Expand Up @@ -104,3 +105,10 @@ class Invert(ElementWise, Unary):
signatures = [
"bool -> bool",
]


class Isin(Comparison):
name = "isin"
signatures = [
"T, const-T-list -> bool",
]
2 changes: 2 additions & 0 deletions src/pydiverse/transform/core/ops/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ def parse_cstypes(cst: str):
for type_ in (*base_args_types, base_rtype):
if type_.startswith("const-"):
type_ = type_[6:] # only look at type after dash
if type_.endswith("-list"):
type_ = type_[:-5] # only look at type before dash
if not type_.isalnum():
raise ValueError(f"Invalid type '{type_}'. Types must be alphanumeric.")

Expand Down
17 changes: 17 additions & 0 deletions src/pydiverse/transform/core/table_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,23 @@ def const_func(*args, **kwargs):
return TypedValue(literal_func, "float", const_value=const_func)
if isinstance(expr, str):
return TypedValue(literal_func, "str", const_value=const_func)
if isinstance(expr, list):

def assert_func(*args, **kwargs):
raise ValueError(
"List type is currently only supported for constants, "
f"not: {expr}"
)

if all(isinstance(exp, bool) for exp in expr):
return TypedValue(assert_func, "bool-list", const_value=const_func)
if all(isinstance(exp, int) for exp in expr):
return TypedValue(assert_func, "int-list", const_value=const_func)
if all(isinstance(exp, float) for exp in expr):
return TypedValue(assert_func, "float-list", const_value=const_func)
if all(isinstance(exp, str) for exp in expr):
return TypedValue(assert_func, "str-list", const_value=const_func)
raise ValueError(f"Unknown expression type: {expr}")

class AlignedExpressionEvaluator(Generic[AlignedT], DelegatingTranslator[AlignedT]):
"""
Expand Down
7 changes: 7 additions & 0 deletions src/pydiverse/transform/eager/pandas_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,3 +685,10 @@ def _rank(x: pd.Series):
@op.auto(variant="transform")
def _rank(x: pd.Series):
return x.transform("rank", method="min")


with PandasTableImpl.op(ops.Isin()) as op:

@op.auto
def _isin(x: pd.Series, *args):
return x.isin(args)
7 changes: 7 additions & 0 deletions src/pydiverse/transform/lazy/sql_table/sql_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,3 +913,10 @@ def _rank(x):
# row_number() is like rank(method="first")
# rank() is like method="min"
return sa.func.RANK(), ImplicitArrange(sql.expression.ClauseList(x))


with SQLTableImpl.op(ops.Isin()) as op:

@op.auto
def _isin(x, *args):
return x.in_(args)
6 changes: 4 additions & 2 deletions tests/test_backend_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,8 +588,10 @@ def test_noop(self, df2_x, df2_y):

@tables(["df2"])
def test_simple_filter(self, df2_x, df2_y):
assert_result_equal(df2_x, df2_y, lambda t: t >> filter(t.col1 == 2))
assert_result_equal(df2_x, df2_y, lambda t: t >> filter(t.col1 != 2))
# assert_result_equal(df2_x, df2_y, lambda t: t >> filter(t.col1 == 2))
# assert_result_equal(df2_x, df2_y, lambda t: t >> filter(t.col1 != 2))
# assert_result_equal(df2_x, df2_y, lambda t: t >> filter(t.col1.isin(2)))
assert_result_equal(df2_x, df2_y, lambda t: t >> filter(t.col1.isin([2, 3])))

@tables(["df2"])
def test_chained_filters(self, df2_x, df2_y):
Expand Down

0 comments on commit 612444a

Please sign in to comment.