Skip to content

Commit

Permalink
Use NotImplementedError for backend limitations.
Browse files Browse the repository at this point in the history
  • Loading branch information
windiana42 committed Jul 10, 2023
1 parent ebd464b commit a23e740
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
5 changes: 4 additions & 1 deletion src/pydiverse/transform/core/expressions/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def translate(self, expr, **kwargs) -> T:
try:
return bottom_up_replace(expr, lambda e: self._translate(e, **kwargs))
except Exception as e:
raise ValueError(
etype = ValueError
if isinstance(e, NotImplementedError):
etype = NotImplementedError
raise etype(
"An exception occured while trying to translate the expression"
f" '{expr}':\n{e}"
) from e
Expand Down
2 changes: 1 addition & 1 deletion src/pydiverse/transform/eager/pandas_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def _translate(self, expr, **kwargs):
return ((c1, c2),)
if expr.name == "__and__":
return tuple(itertools.chain(*expr.args))
raise Exception(
raise NotImplementedError(
f"Invalid ON clause element: {expr}. Only a conjunction of equalities"
" is supported by pandas (ands of equals)."
)
Expand Down
20 changes: 20 additions & 0 deletions tests/test_backend_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
collect,
filter,
group_by,
inner_join,
join,
left_join,
mutate,
Expand Down Expand Up @@ -559,6 +560,25 @@ def self_join_3(t):

assert_result_equal(df3_x, df3_y, self_join_3)

@tables(["df1", "df2"], expect_not_implemented=["pandas"])
def test_inequality_join(self, df1_x, df1_y, df2_x, df2_y):
assert_result_equal(
(df1_x, df2_x),
(df1_y, df2_y),
lambda t, u: t
>> select()
>> inner_join(u, (t.col1 >= u.col1) & (t.col1 <= u.col1))
>> full_sort(),
)

assert_result_equal(
(df1_x, df2_x),
(df1_y, df2_y),
lambda t, u: t
>> inner_join(u >> select(), ~((t.col1 < u.col1) | (t.col1 > u.col1)))
>> full_sort(),
)


class TestFilter:
@tables(["df2"])
Expand Down

0 comments on commit a23e740

Please sign in to comment.