Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GpuIf and GpuCoalesce support array and struct types #2839

Merged
merged 15 commits into from
Jul 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -3391,9 +3391,9 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, MAP, UDT)</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand All @@ -3412,9 +3412,9 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, MAP, UDT)</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down Expand Up @@ -7594,7 +7594,7 @@ Accelerator support is described below.
<td rowSpan="8">None</td>
<td rowSpan="4">project</td>
<td>predicate</td>
<td><em>PS (literal values are not supported)</em></td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
Expand Down Expand Up @@ -7629,9 +7629,9 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, MAP, UDT)</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand All @@ -7650,9 +7650,9 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, MAP, UDT)</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand All @@ -7671,9 +7671,9 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, MAP, UDT)</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down
30 changes: 26 additions & 4 deletions integration_tests/src/main/python/conditionals_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,24 @@

all_gens = all_gen + [NullGen()]
all_nested_gens = array_gens_sample + struct_gens_sample
all_nested_gens_nonempty_struct = array_gens_sample + nonempty_struct_gens_sample

@pytest.mark.parametrize('data_gen', all_gens, ids=idfn)
# Create dedicated data gens of nested type for 'if' tests here with two exclusions:
# 1) Excludes the nested 'NullGen' because it seems to be impossible to convert the
# 'NullType' to a SQL type string. But the top level NullGen is handled specially
# in 'gen_scalars_for_sql'.
# 2) Excludes the empty struct gen 'Struct()' because it leads to an error as below
# in both cpu and gpu runs.
# E: java.lang.AssertionError: assertion failed: each serializer expression should contain\
# at least one `BoundReference`
if_array_gens_sample = [ArrayGen(sub_gen) for sub_gen in all_gen] + nested_array_gens_sample
if_struct_gen = StructGen([['child'+str(ind), sub_gen] for ind, sub_gen in enumerate(all_gen)])
if_struct_gens_sample = [if_struct_gen,
StructGen([['child0', byte_gen], ['child1', if_struct_gen]]),
StructGen([['child0', ArrayGen(short_gen)], ['child1', double_gen]])]
if_nested_gens = if_array_gens_sample + if_struct_gens_sample

@pytest.mark.parametrize('data_gen', all_gens + if_nested_gens, ids=idfn)
def test_if_else(data_gen):
(s1, s2) = gen_scalars_for_sql(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen))
null_lit = get_null_lit_string(data_gen.data_type)
Expand All @@ -37,7 +53,8 @@ def test_if_else(data_gen):
'IF(a, b, {})'.format(s2),
'IF(a, {}, {})'.format(s1, s2),
'IF(a, b, {})'.format(null_lit),
'IF(a, {}, c)'.format(null_lit)))
'IF(a, {}, c)'.format(null_lit)),
conf = allow_negative_scale_of_decimal_conf)

@pytest.mark.parametrize('data_gen', all_gens + all_nested_gens, ids=idfn)
def test_case_when(data_gen):
Expand Down Expand Up @@ -94,7 +111,11 @@ def test_nvl(data_gen):
'nvl(a, {})'.format(null_lit)))

#nvl is translated into a 2 param version of coalesce
@pytest.mark.parametrize('data_gen', all_gens, ids=idfn)
# Exclude the empty struct gen 'Struct()' because it leads to an error as below
# in both cpu and gpu runs.
# E: java.lang.AssertionError: assertion failed: each serializer expression should contain\
# at least one `BoundReference`
@pytest.mark.parametrize('data_gen', all_gens + all_nested_gens_nonempty_struct, ids=idfn)
def test_coalesce(data_gen):
num_cols = 20
s1 = gen_scalar(data_gen, force_no_nulls=not isinstance(data_gen, NullGen))
Expand All @@ -106,7 +127,8 @@ def test_coalesce(data_gen):
data_type = data_gen.data_type
assert_gpu_and_cpu_are_equal_collect(
lambda spark : gen_df(spark, gen).select(
f.coalesce(*command_args)))
f.coalesce(*command_args)),
conf = allow_negative_scale_of_decimal_conf)

def test_coalesce_constant_output():
# Coalesce can allow a constant value as output. Technically Spark should mark this
Expand Down
35 changes: 28 additions & 7 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,11 @@ def to_cast_string(spark_type):
return 'STRING'
elif isinstance(spark_type, DecimalType):
return 'DECIMAL({}, {})'.format(spark_type.precision, spark_type.scale)
elif isinstance(spark_type, ArrayType):
return 'ARRAY<{}>'.format(to_cast_string(spark_type.elementType))
elif isinstance(spark_type, StructType):
children = [fd.name + ':' + to_cast_string(fd.dataType) for fd in spark_type.fields]
return 'STRUCT<{}>'.format(','.join(children))
else:
raise RuntimeError('CAST TO TYPE {} NOT SUPPORTED YET'.format(spark_type))

Expand All @@ -773,26 +778,41 @@ def get_null_lit_string(spark_type):
string_type = to_cast_string(spark_type)
return 'CAST(null as {})'.format(string_type)

def _convert_to_sql(t, data):
def _convert_to_sql(spark_type, data):
if isinstance(data, str):
d = "'" + data.replace("'", "\\'") + "'"
elif isinstance(data, datetime):
d = "'" + data.strftime('%Y-%m-%d T%H:%M:%S.%f').zfill(26) + "'"
elif isinstance(data, date):
d = "'" + data.strftime('%Y-%m-%d').zfill(10) + "'"
elif isinstance(data, list):
assert isinstance(spark_type, ArrayType)
d = "array({})".format(",".join([_convert_to_sql(spark_type.elementType, x) for x in data]))
elif isinstance(data, tuple):
assert isinstance(spark_type, StructType) and len(data) == len(spark_type.fields)
# Format of each child: 'name',data
children = ["'{}'".format(fd.name) + ',' + _convert_to_sql(fd.dataType, x)
for fd, x in zip(spark_type.fields, data)]
d = "named_struct({})".format(','.join(children))
elif not data:
# data is None
d = "null"
else:
d = str(data)
d = "'{}'".format(str(data))

return 'CAST({} as {})'.format(d, t)
if isinstance(spark_type, NullType):
return d
else:
return 'CAST({} as {})'.format(d, to_cast_string(spark_type))

def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):
"""Generate scalar values, but strings that can be used in selectExpr or SQL"""
src = _gen_scalars_common(data_gen, count, seed=seed)
if isinstance(data_gen, NullGen):
assert not force_no_nulls
return ('null' for i in range(0, count))
string_type = to_cast_string(data_gen.data_type)
return (_convert_to_sql(string_type, src.gen(force_no_nulls=force_no_nulls)) for i in range(0, count))
spark_type = data_gen.data_type
return (_convert_to_sql(spark_type, src.gen(force_no_nulls=force_no_nulls)) for i in range(0, count))

byte_gen = ByteGen()
short_gen = ShortGen()
Expand Down Expand Up @@ -865,11 +885,12 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):
all_basic_struct_gen = StructGen([['child'+str(ind), sub_gen] for ind, sub_gen in enumerate(all_basic_gens)])

# Some struct gens, but not all because of nesting
struct_gens_sample = [all_basic_struct_gen,
StructGen([]),
nonempty_struct_gens_sample = [all_basic_struct_gen,
StructGen([['child0', byte_gen], ['child1', all_basic_struct_gen]]),
StructGen([['child0', ArrayGen(short_gen)], ['child1', double_gen]])]

struct_gens_sample = nonempty_struct_gens_sample + [StructGen([])]

simple_string_to_string_map_gen = MapGen(StringGen(pattern='key_[0-9]', nullable=False),
StringGen(), max_length=10)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1202,9 +1202,10 @@ object GpuOverrides {
expr[Coalesce] (
"Returns the first non-null argument if exists. Otherwise, null",
ExprChecks.projectNotLambda(
TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL, TypeSig.all,
(_commonTypes + TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all,
repeatingParamCheck = Some(RepeatingParamCheck("param",
TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL,
(_commonTypes + TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all))),
(a, conf, p, r) => new ExprMeta[Coalesce](a, conf, p, r) {
override def convertToGpu(): GpuExpression = GpuCoalesce(childExprs.map(_.convertToGpu()))
Expand Down Expand Up @@ -1738,13 +1739,15 @@ object GpuOverrides {
}),
expr[If](
"IF expression",
ExprChecks.projectNotLambda(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL,
ExprChecks.projectNotLambda(
(_commonTypes + TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all,
Seq(ParamCheck("predicate", TypeSig.psNote(TypeEnum.BOOLEAN,
"literal values are not supported"), TypeSig.BOOLEAN),
ParamCheck("trueValue", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL,
Seq(ParamCheck("predicate", TypeSig.BOOLEAN, TypeSig.BOOLEAN),
ParamCheck("trueValue",
(_commonTypes + TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all),
ParamCheck("falseValue", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL,
ParamCheck("falseValue",
(_commonTypes + TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all))),
(a, conf, p, r) => new ExprMeta[If](a, conf, p, r) {
override def convertToGpu(): GpuExpression = {
Expand Down