Skip to content

Commit

Permalink
Add SortMergeJoin conditional tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowe committed Jan 7, 2022
1 parent 8728ff4 commit 6a4233c
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,12 +581,61 @@ def do_join(spark):
return left.join(right, (left.a == right.r_a) & (left.b >= right.r_b), 'Inner')
assert_gpu_and_cpu_are_equal_collect(do_join, conf=_sortmerge_join_conf)

# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', join_ast_gen, ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter'], ids=idfn)
def test_sortmerge_join_with_condition_ast(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 500, 250)
return left.join(right, (left.a == right.r_a) & (left.b >= right.r_b), join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=_sortmerge_join_conf)

# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
@allow_non_gpu('GreaterThan', 'ShuffleExchangeExec', 'SortMergeJoinExec')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [long_gen], ids=idfn)
@pytest.mark.parametrize('join_type', ['LeftSemi', 'LeftAnti'], ids=idfn)
def test_sortmerge_join_with_condition_join_type_fallback(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 500, 250)
return left.join(right, (left.a == right.r_a) & (left.b >= right.r_b), join_type)
assert_gpu_fallback_collect(do_join, 'SortMergeJoinExec', conf=_sortmerge_join_conf)

# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
@allow_non_gpu('GreaterThan', 'Log', 'ShuffleExchangeExec', 'SortMergeJoinExec')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [long_gen], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter'], ids=idfn)
def test_sortmerge_join_with_condition_ast_op_fallback(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 500, 250)
# AST does not support cast or logarithm yet
return left.join(right, (left.a == right.r_a) & (left.b > f.log(right.r_b)), join_type)
assert_gpu_fallback_collect(do_join, 'SortMergeJoinExec', conf=_sortmerge_join_conf)

# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
@allow_non_gpu('GreaterThan', 'ShuffleExchangeExec', 'SortMergeJoinExec')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', join_no_ast_gen, ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_sortmerge_join_with_condition_ast_type_fallback(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 500, 250)
return left.join(right, (left.a == right.r_a) & (left.b > right.r_b), join_type)
assert_gpu_fallback_collect(do_join, 'SortMergeJoinExec', conf=_sortmerge_join_conf)


_mixed_df1_with_nulls = [('a', RepeatSeqGen(LongGen(nullable=(True, 20.0)), length= 10)),
('b', IntegerGen()), ('c', LongGen())]
_mixed_df2_with_nulls = [('a', RepeatSeqGen(LongGen(nullable=(True, 20.0)), length= 10)),
('b', StringGen()), ('c', BooleanGen())]


@ignore_order
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'FullOuter', 'Cross'], ids=idfn)
def test_broadcast_join_mixed(join_type):
Expand Down

0 comments on commit 6a4233c

Please sign in to comment.