From 612444a0ee4776b3663eaacea0ea31db14f4c3ee Mon Sep 17 00:00:00 2001 From: windiana42 Date: Sat, 19 Aug 2023 07:41:58 +0200 Subject: [PATCH] Experimental const-T-list --- src/pydiverse/transform/core/ops/logical.py | 8 ++++++++ src/pydiverse/transform/core/ops/registry.py | 2 ++ src/pydiverse/transform/core/table_impl.py | 17 +++++++++++++++++ src/pydiverse/transform/eager/pandas_table.py | 7 +++++++ .../transform/lazy/sql_table/sql_table.py | 7 +++++++ tests/test_backend_equivalence.py | 6 ++++-- 6 files changed, 45 insertions(+), 2 deletions(-) diff --git a/src/pydiverse/transform/core/ops/logical.py b/src/pydiverse/transform/core/ops/logical.py index f76b5e5..560d0b9 100644 --- a/src/pydiverse/transform/core/ops/logical.py +++ b/src/pydiverse/transform/core/ops/logical.py @@ -16,6 +16,7 @@ "Xor", "RXor", "Invert", + "Isin", ] @@ -104,3 +105,10 @@ class Invert(ElementWise, Unary): signatures = [ "bool -> bool", ] + + +class Isin(Comparison): + name = "isin" + signatures = [ + "T, const-T-list -> bool", + ] diff --git a/src/pydiverse/transform/core/ops/registry.py b/src/pydiverse/transform/core/ops/registry.py index 59ca7a1..c14726d 100644 --- a/src/pydiverse/transform/core/ops/registry.py +++ b/src/pydiverse/transform/core/ops/registry.py @@ -323,6 +323,8 @@ def parse_cstypes(cst: str): for type_ in (*base_args_types, base_rtype): if type_.startswith("const-"): type_ = type_[6:] # only look at type after dash + if type_.endswith("-list"): + type_ = type_[:-5] # only look at type before dash if not type_.isalnum(): raise ValueError(f"Invalid type '{type_}'. Types must be alphanumeric.") diff --git a/src/pydiverse/transform/core/table_impl.py b/src/pydiverse/transform/core/table_impl.py index f721be4..fe38f57 100644 --- a/src/pydiverse/transform/core/table_impl.py +++ b/src/pydiverse/transform/core/table_impl.py @@ -249,6 +249,23 @@ def const_func(*args, **kwargs): return TypedValue(literal_func, "float", const_value=const_func) if isinstance(expr, str): return TypedValue(literal_func, "str", const_value=const_func) + if isinstance(expr, list): + + def assert_func(*args, **kwargs): + raise ValueError( + "List type is currently only supported for constants, " + f"not: {expr}" + ) + + if all(isinstance(exp, bool) for exp in expr): + return TypedValue(assert_func, "bool-list", const_value=const_func) + if all(isinstance(exp, int) for exp in expr): + return TypedValue(assert_func, "int-list", const_value=const_func) + if all(isinstance(exp, float) for exp in expr): + return TypedValue(assert_func, "float-list", const_value=const_func) + if all(isinstance(exp, str) for exp in expr): + return TypedValue(assert_func, "str-list", const_value=const_func) + raise ValueError(f"Unknown expression type: {expr}") class AlignedExpressionEvaluator(Generic[AlignedT], DelegatingTranslator[AlignedT]): """ diff --git a/src/pydiverse/transform/eager/pandas_table.py b/src/pydiverse/transform/eager/pandas_table.py index 4a190b2..60fa25f 100644 --- a/src/pydiverse/transform/eager/pandas_table.py +++ b/src/pydiverse/transform/eager/pandas_table.py @@ -685,3 +685,10 @@ def _rank(x: pd.Series): @op.auto(variant="transform") def _rank(x: pd.Series): return x.transform("rank", method="min") + + +with PandasTableImpl.op(ops.Isin()) as op: + + @op.auto + def _isin(x: pd.Series, *args): + return x.isin(args) diff --git a/src/pydiverse/transform/lazy/sql_table/sql_table.py b/src/pydiverse/transform/lazy/sql_table/sql_table.py index db5505e..be10dd9 100644 --- a/src/pydiverse/transform/lazy/sql_table/sql_table.py +++ b/src/pydiverse/transform/lazy/sql_table/sql_table.py @@ -913,3 +913,10 @@ def _rank(x): # row_number() is like rank(method="first") # rank() is like method="min" return sa.func.RANK(), ImplicitArrange(sql.expression.ClauseList(x)) + + +with SQLTableImpl.op(ops.Isin()) as op: + + @op.auto + def _isin(x, *args): + return x.in_(args) diff --git a/tests/test_backend_equivalence.py b/tests/test_backend_equivalence.py index 760f74a..573478f 100644 --- a/tests/test_backend_equivalence.py +++ b/tests/test_backend_equivalence.py @@ -588,8 +588,10 @@ def test_noop(self, df2_x, df2_y): @tables(["df2"]) def test_simple_filter(self, df2_x, df2_y): - assert_result_equal(df2_x, df2_y, lambda t: t >> filter(t.col1 == 2)) - assert_result_equal(df2_x, df2_y, lambda t: t >> filter(t.col1 != 2)) + # assert_result_equal(df2_x, df2_y, lambda t: t >> filter(t.col1 == 2)) + # assert_result_equal(df2_x, df2_y, lambda t: t >> filter(t.col1 != 2)) + # assert_result_equal(df2_x, df2_y, lambda t: t >> filter(t.col1.isin(2))) + assert_result_equal(df2_x, df2_y, lambda t: t >> filter(t.col1.isin([2, 3]))) @tables(["df2"]) def test_chained_filters(self, df2_x, df2_y):