From a23e740aeb9ad918a8c6241f8f666b312da4d918 Mon Sep 17 00:00:00 2001 From: windiana42 Date: Mon, 10 Jul 2023 12:25:53 +0200 Subject: [PATCH] Use NotImplementedError for backend limitations. --- .../transform/core/expressions/translator.py | 5 ++++- src/pydiverse/transform/eager/pandas_table.py | 2 +- tests/test_backend_equivalence.py | 20 +++++++++++++++++++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/pydiverse/transform/core/expressions/translator.py b/src/pydiverse/transform/core/expressions/translator.py index 2f2fae3..c1c1df4 100644 --- a/src/pydiverse/transform/core/expressions/translator.py +++ b/src/pydiverse/transform/core/expressions/translator.py @@ -35,7 +35,10 @@ def translate(self, expr, **kwargs) -> T: try: return bottom_up_replace(expr, lambda e: self._translate(e, **kwargs)) except Exception as e: - raise ValueError( + etype = ValueError + if isinstance(e, NotImplementedError): + etype = NotImplementedError + raise etype( "An exception occured while trying to translate the expression" f" '{expr}':\n{e}" ) from e diff --git a/src/pydiverse/transform/eager/pandas_table.py b/src/pydiverse/transform/eager/pandas_table.py index 953bcc4..4a190b2 100644 --- a/src/pydiverse/transform/eager/pandas_table.py +++ b/src/pydiverse/transform/eager/pandas_table.py @@ -444,7 +444,7 @@ def _translate(self, expr, **kwargs): return ((c1, c2),) if expr.name == "__and__": return tuple(itertools.chain(*expr.args)) - raise Exception( + raise NotImplementedError( f"Invalid ON clause element: {expr}. Only a conjunction of equalities" " is supported by pandas (ands of equals)." ) diff --git a/tests/test_backend_equivalence.py b/tests/test_backend_equivalence.py index c1fd1a7..760f74a 100644 --- a/tests/test_backend_equivalence.py +++ b/tests/test_backend_equivalence.py @@ -23,6 +23,7 @@ collect, filter, group_by, + inner_join, join, left_join, mutate, @@ -559,6 +560,25 @@ def self_join_3(t): assert_result_equal(df3_x, df3_y, self_join_3) + @tables(["df1", "df2"], expect_not_implemented=["pandas"]) + def test_inequality_join(self, df1_x, df1_y, df2_x, df2_y): + assert_result_equal( + (df1_x, df2_x), + (df1_y, df2_y), + lambda t, u: t + >> select() + >> inner_join(u, (t.col1 >= u.col1) & (t.col1 <= u.col1)) + >> full_sort(), + ) + + assert_result_equal( + (df1_x, df2_x), + (df1_y, df2_y), + lambda t, u: t + >> inner_join(u >> select(), ~((t.col1 < u.col1) | (t.col1 > u.col1))) + >> full_sort(), + ) + class TestFilter: @tables(["df2"])