Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement support for ArrayExists expression #4973

Merged
merged 13 commits into from
Mar 21, 2022
1 change: 1 addition & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.And"></a>spark.rapids.sql.expression.And|`and`|Logical AND|true|None|
<a name="sql.expression.AnsiCast"></a>spark.rapids.sql.expression.AnsiCast| |Convert a column of one type of data into another type|true|None|
<a name="sql.expression.ArrayContains"></a>spark.rapids.sql.expression.ArrayContains|`array_contains`|Returns a boolean if the array contains the passed in key|true|None|
<a name="sql.expression.ArrayExists"></a>spark.rapids.sql.expression.ArrayExists|`exists`|Return true if any element satisfies the predicate LambdaFunction|true|None|
<a name="sql.expression.ArrayMax"></a>spark.rapids.sql.expression.ArrayMax|`array_max`|Returns the maximum value in the array|true|None|
<a name="sql.expression.ArrayMin"></a>spark.rapids.sql.expression.ArrayMin|`array_min`|Returns the minimum value in the array|true|None|
<a name="sql.expression.ArrayTransform"></a>spark.rapids.sql.expression.ArrayTransform|`transform`|Transform elements in an array using the transform function. This is similar to a `map` in functional programming|true|None|
Expand Down
218 changes: 143 additions & 75 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -2006,12 +2006,12 @@ are limited.
<td> </td>
</tr>
<tr>
<td rowSpan="2">ArrayMax</td>
<td rowSpan="2">`array_max`</td>
<td rowSpan="2">Returns the maximum value in the array</td>
<td rowSpan="2">None</td>
<td rowSpan="2">project</td>
<td>input</td>
<td rowSpan="3">ArrayExists</td>
<td rowSpan="3">`exists`</td>
<td rowSpan="3">Return true if any element satisfies the predicate LambdaFunction</td>
<td rowSpan="3">None</td>
<td rowSpan="3">project</td>
<td>argument</td>
<td> </td>
<td> </td>
<td> </td>
Expand All @@ -2026,31 +2026,52 @@ are limited.
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
<td>function</td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td> </td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<th>Expression</th>
Expand Down Expand Up @@ -2079,6 +2100,53 @@ are limited.
<th>UDT</th>
</tr>
<tr>
<td rowSpan="2">ArrayMax</td>
<td rowSpan="2">`array_max`</td>
<td rowSpan="2">Returns the maximum value in the array</td>
<td rowSpan="2">None</td>
<td rowSpan="2">project</td>
<td>input</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT</em></td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td> </td>
<td><b>NS</b></td>
<td><b>NS</b></td>
</tr>
<tr>
<td rowSpan="2">ArrayMin</td>
<td rowSpan="2">`array_min`</td>
<td rowSpan="2">Returns the minimum value in the array</td>
Expand Down Expand Up @@ -2374,6 +2442,32 @@ are limited.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="2">AtLeastNNonNulls</td>
<td rowSpan="2"> </td>
<td rowSpan="2">Checks if number of non null/Nan values is greater than a given value</td>
Expand Down Expand Up @@ -2421,32 +2515,6 @@ are limited.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="4">Atan</td>
<td rowSpan="4">`atan`</td>
<td rowSpan="4">Inverse tangent</td>
Expand Down Expand Up @@ -2743,6 +2811,32 @@ are limited.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="2">BitLength</td>
<td rowSpan="2">`bit_length`</td>
<td rowSpan="2">The bit length of string data</td>
Expand Down Expand Up @@ -2790,32 +2884,6 @@ are limited.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="6">BitwiseAnd</td>
<td rowSpan="6">`&`</td>
<td rowSpan="6">Returns the bitwise AND of the operands</td>
Expand Down
25 changes: 25 additions & 0 deletions integration_tests/src/main/python/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,3 +331,28 @@ def test_get_array_struct_fields(data_gen):
max_length=6)
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, array_struct_gen).selectExpr('a.child0'))

@pytest.mark.parametrize('data_gen', [ArrayGen(string_gen), ArrayGen(int_gen)])
@pytest.mark.parametrize('threeVL', [
pytest.param(False, id='3VL:off'),
pytest.param(True, id='3VL:on'),
])
def test_array_exists(data_gen, threeVL):
def do_it(spark):
columns = ['a']
element_type = data_gen.data_type.elementType
if isinstance(element_type, IntegralType):
columns.extend([
'exists(a, item -> item % 2 = 0) as exists_even',
'exists(a, item -> item < 0) as exists_negative',
'exists(a, item -> item >= 0) as exists_non_negative'
])

if isinstance(element_type, StringType):
columns.extend(['exists(a, entry -> length(entry) > 5) as exists_longer_than_5'])

return unary_op_df(spark, data_gen).selectExpr(columns)

assert_gpu_and_cpu_are_equal_collect(do_it, conf= {
'spark.sql.legacy.followThreeValuedLogicInArrayExists' : threeVL,
})
Original file line number Diff line number Diff line change
Expand Up @@ -2857,6 +2857,25 @@ object GpuOverrides extends Logging {
GpuArrayTransform(childExprs.head.convertToGpu(), childExprs(1).convertToGpu())
}
}),
expr[ArrayExists](
"Return true if any element satisfies the predicate LambdaFunction",
ExprChecks.projectOnly(TypeSig.BOOLEAN, TypeSig.BOOLEAN,
Seq(
ParamCheck("argument",
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP),
TypeSig.ARRAY.nested(TypeSig.all)),
ParamCheck("function", TypeSig.BOOLEAN, TypeSig.BOOLEAN))),
(in, conf, p, r) => new ExprMeta[ArrayExists](in, conf, p, r) {
override def convertToGpu(): GpuExpression = {
GpuArrayExists(
childExprs.head.convertToGpu(),
childExprs(1).convertToGpu(),
SQLConf.get.getConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC)
)
}
}),

expr[TransformKeys](
"Transform keys in a map using a transform function",
ExprChecks.projectOnly(TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
Expand Down
Loading