Skip to content

Commit

Permalink
support lead/lag on arrays (#2435)
Browse files Browse the repository at this point in the history
Signed-off-by: Bobby Wang <wbo4958@gmail.com>
  • Loading branch information
wbo4958 authored May 19, 2021
1 parent 51c0dc2 commit c7ebbaa
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 37 deletions.
20 changes: 10 additions & 10 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -880,9 +880,9 @@ Accelerator supports are described below.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, MAP, UDT)</em></td>
<td><b>NS</b></td>
</tr>
</table>
Expand Down Expand Up @@ -8692,7 +8692,7 @@ Accelerator support is described below.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested STRING, NULL, BINARY, CALENDAR, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -8734,7 +8734,7 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested STRING, BINARY, CALENDAR, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand All @@ -8755,7 +8755,7 @@ Accelerator support is described below.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested STRING, NULL, BINARY, CALENDAR, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -8897,7 +8897,7 @@ Accelerator support is described below.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested STRING, NULL, BINARY, CALENDAR, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -8939,7 +8939,7 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested STRING, BINARY, CALENDAR, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand All @@ -8960,7 +8960,7 @@ Accelerator support is described below.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested STRING, NULL, BINARY, CALENDAR, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -17378,7 +17378,7 @@ Accelerator support is described below.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -17420,7 +17420,7 @@ Accelerator support is described below.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down
35 changes: 35 additions & 0 deletions integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,41 @@ def do_it(spark):
.withColumn('row_num', f.row_number().over(baseWindowSpec))
assert_gpu_and_cpu_are_equal_collect(do_it, conf={'spark.rapids.sql.hasNans': 'false'})


lead_lag_array_data_gens =\
[ArrayGen(sub_gen, max_length=10) for sub_gen in lead_lag_data_gens] + \
[ArrayGen(ArrayGen(sub_gen, max_length=10), max_length=10) for sub_gen in lead_lag_data_gens] + \
[ArrayGen(ArrayGen(ArrayGen(sub_gen, max_length=10), max_length=10), max_length=10) \
for sub_gen in lead_lag_data_gens]

# lead and lag are supported for arrays, but the other window operations like min and max are not right now
# once they are all supported the tests should be combined.
@ignore_order(local=True)
@pytest.mark.parametrize('d_gen', lead_lag_array_data_gens, ids=meta_idfn('agg:'))
@pytest.mark.parametrize('c_gen', [long_gen], ids=meta_idfn('orderBy:'))
@pytest.mark.parametrize('b_gen', [long_gen], ids=meta_idfn('orderBy:'))
@pytest.mark.parametrize('a_gen', [long_gen], ids=meta_idfn('partBy:'))
def test_window_aggs_for_rows_lead_lag_on_arrays(a_gen, b_gen, c_gen, d_gen):
data_gen = [
('a', RepeatSeqGen(a_gen, length=20)),
('b', b_gen),
('c', c_gen),
('d', d_gen),
('d_default', d_gen)]

assert_gpu_and_cpu_are_equal_sql(
lambda spark: gen_df(spark, data_gen, length=2048),
"window_agg_table",
'''
SELECT
LEAD(d, 5) OVER (PARTITION by a ORDER BY b,c) lead_d_5,
LEAD(d, 2, d_default) OVER (PARTITION by a ORDER BY b,c) lead_d_2_default,
LAG(d, 5) OVER (PARTITION by a ORDER BY b,c) lag_d_5,
LAG(d, 2, d_default) OVER (PARTITION by a ORDER BY b,c) lag_d_2_default
FROM window_agg_table
''')


# lead and lag don't currently work for string columns, so redo the tests, but just for strings
# without lead and lag
@ignore_order
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,27 +201,29 @@ class Spark311Shims extends Spark301Shims {
// Spark 3.1.1-specific LEAD expression, using custom OffsetWindowFunctionMeta.
GpuOverrides.expr[Lead](
"Window function that returns N entries ahead of this one",
ExprChecks.windowOnly(TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP, TypeSig.all,
Seq(ParamCheck("input", TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP, TypeSig.all),
ExprChecks.windowOnly((TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.ARRAY).nested(), TypeSig.all,
Seq(ParamCheck("input", (TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.ARRAY).nested(), TypeSig.all),
ParamCheck("offset", TypeSig.INT, TypeSig.INT),
ParamCheck("default", TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.NULL, TypeSig.all))),
ParamCheck("default", (TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.NULL + TypeSig.ARRAY).nested(),
TypeSig.all))),
(lead, conf, p, r) => new OffsetWindowFunctionMeta[Lead](lead, conf, p, r) {
override def convertToGpu(): GpuExpression =
GpuLead(input.convertToGpu(), offset.convertToGpu(), default.convertToGpu())
}),
// Spark 3.1.1-specific LAG expression, using custom OffsetWindowFunctionMeta.
GpuOverrides.expr[Lag](
"Window function that returns N entries behind this one",
ExprChecks.windowOnly(TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP, TypeSig.all,
Seq(ParamCheck("input", TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP, TypeSig.all),
ExprChecks.windowOnly((TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.ARRAY).nested(), TypeSig.all,
Seq(ParamCheck("input", (TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.ARRAY).nested(), TypeSig.all),
ParamCheck("offset", TypeSig.INT, TypeSig.INT),
ParamCheck("default", TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.NULL, TypeSig.all))),
ParamCheck("default", (TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.NULL + TypeSig.ARRAY).nested(),
TypeSig.all))),
(lag, conf, p, r) => new OffsetWindowFunctionMeta[Lag](lag, conf, p, r) {
override def convertToGpu(): GpuExpression = {
GpuLag(input.convertToGpu(), offset.convertToGpu(), default.convertToGpu())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -840,11 +840,13 @@ object GpuOverrides {
"\"window\") of rows",
ExprChecks.windowOnly(
TypeSig.commonCudfTypes + TypeSig.DECIMAL +
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT),
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT +
TypeSig.ARRAY),
TypeSig.all,
Seq(ParamCheck("windowFunction",
TypeSig.commonCudfTypes + TypeSig.DECIMAL +
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT),
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT +
TypeSig.ARRAY),
TypeSig.all),
ParamCheck("windowSpec",
TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DECIMAL,
Expand Down Expand Up @@ -897,26 +899,27 @@ object GpuOverrides {
}),
expr[Lead](
"Window function that returns N entries ahead of this one",
ExprChecks.windowOnly(TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP, TypeSig.all,
Seq(ParamCheck("input", TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP, TypeSig.all),
ExprChecks.windowOnly((TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.ARRAY).nested(), TypeSig.all,
Seq(ParamCheck("input", (TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.ARRAY).nested(), TypeSig.all),
ParamCheck("offset", TypeSig.INT, TypeSig.INT),
ParamCheck("default", TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.NULL, TypeSig.all))),
ParamCheck("default", (TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.NULL + TypeSig.ARRAY).nested(),
TypeSig.all))),
(lead, conf, p, r) => new OffsetWindowFunctionMeta[Lead](lead, conf, p, r) {
override def convertToGpu(): GpuExpression =
GpuLead(input.convertToGpu(), offset.convertToGpu(), default.convertToGpu())
}),
expr[Lag](
"Window function that returns N entries behind this one",
ExprChecks.windowOnly(TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP, TypeSig.all,
Seq(ParamCheck("input", TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP, TypeSig.all),
ExprChecks.windowOnly((TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.ARRAY).nested(), TypeSig.all,
Seq(ParamCheck("input", (TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.ARRAY).nested(), TypeSig.all),
ParamCheck("offset", TypeSig.INT, TypeSig.INT),
ParamCheck("default", TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.NULL, TypeSig.all))),
ParamCheck("default", (TypeSig.numeric + TypeSig.BOOLEAN + TypeSig.DATE +
TypeSig.TIMESTAMP + TypeSig.NULL + TypeSig.ARRAY).nested(), TypeSig.all))),
(lag, conf, p, r) => new OffsetWindowFunctionMeta[Lag](lag, conf, p, r) {
override def convertToGpu(): GpuExpression =
GpuLag(input.convertToGpu(), offset.convertToGpu(), default.convertToGpu())
Expand Down Expand Up @@ -2937,7 +2940,8 @@ object GpuOverrides {
ExecChecks(
TypeSig.commonCudfTypes + TypeSig.DECIMAL +
TypeSig.STRUCT.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL) +
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT),
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT
+ TypeSig.ARRAY),
TypeSig.all),
(windowOp, conf, p, r) =>
new GpuWindowExecMeta(windowOp, conf, p, r)
Expand Down

0 comments on commit c7ebbaa

Please sign in to comment.