Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added in basic support for scalar structs to named_struct #2509

Merged
merged 2 commits into from
May 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -4110,9 +4110,9 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, STRUCT, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, UDT)</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand All @@ -4133,7 +4133,7 @@ Accelerator support is described below.
<td> </td>
<td> </td>
<td> </td>
<td><em>PS* (missing nested BINARY, CALENDAR, STRUCT, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, UDT)</em></td>
<td> </td>
</tr>
<tr>
Expand Down Expand Up @@ -9590,7 +9590,7 @@ Accelerator support is described below.
<td>S</td>
<td><em>PS* (missing nested BINARY, CALENDAR, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, UDT)</em></td>
revans2 marked this conversation as resolved.
Show resolved Hide resolved
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, UDT)</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down
3 changes: 2 additions & 1 deletion integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,8 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):

# Some struct gens, but not all because of nesting
struct_gens_sample = [all_basic_struct_gen,
StructGen([['child0', byte_gen]]),
StructGen([]),
StructGen([['child0', byte_gen], ['child1', all_basic_struct_gen]]),
StructGen([['child0', ArrayGen(short_gen)], ['child1', double_gen]])]

simple_string_to_string_map_gen = MapGen(StringGen(pattern='key_[0-9]', nullable=False),
Expand Down
12 changes: 10 additions & 2 deletions integration_tests/src/main/python/struct_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
from data_gen import *
from pyspark.sql.types import *

def test_struct_scalar_project():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.range(2).selectExpr(
"named_struct('1', 2, '3', 4) as i",
"named_struct('a', 'b', 'c', 'd', 'e', named_struct()) as s",
"named_struct('a', map('foo', 10, 'bar', 11), 'arr', array(1.0, 2.0, 3.0)) as st"
"id"))

@pytest.mark.parametrize('data_gen', [StructGen([["first", boolean_gen], ["second", byte_gen], ["third", float_gen]]),
StructGen([["first", short_gen], ["second", int_gen], ["third", long_gen]]),
StructGen([["first", double_gen], ["second", date_gen], ["third", timestamp_gen]]),
Expand All @@ -32,14 +40,14 @@ def test_struct_get_item(data_gen):
'a.third'))


@pytest.mark.parametrize('data_gen', all_basic_gens + [null_gen, decimal_gen_default, decimal_gen_scale_precision, simple_string_to_string_map_gen] + single_level_array_gens, ids=idfn)
@pytest.mark.parametrize('data_gen', all_basic_gens + [null_gen, decimal_gen_default, decimal_gen_scale_precision] + single_level_array_gens + struct_gens_sample + map_gens_sample, ids=idfn)
def test_make_struct(data_gen):
# Spark has no good way to create a map literal without the map function
# so we are inserting one.
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).selectExpr(
'struct(a, b)',
'named_struct("foo", b, "m", map("a", "b"), "n", null, "bar", 5, "end", a)'),
'named_struct("foo", b, "m", map("a", "b"), "n", null, "bar", 5, "other", named_struct("z", "z"),"end", a)'),
conf = allow_negative_scale_of_decimal_conf)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -774,8 +774,9 @@ object GpuOverrides {
"Holds a static value from the query",
ExprChecks.projectNotLambda(
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.CALENDAR
+ TypeSig.ARRAY + TypeSig.MAP).nested(TypeSig.commonCudfTypes + TypeSig.NULL
+ TypeSig.DECIMAL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP),
+ TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT)
.nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL +
TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT),
TypeSig.all),
(lit, conf, p, r) => new LiteralExprMeta(lit, conf, p, r)),
expr[Signum](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ object CreateNamedStructCheck extends ExprChecks {
val nameSig: TypeSig = TypeSig.lit(TypeEnum.STRING)
val sparkNameSig: TypeSig = TypeSig.lit(TypeEnum.STRING)
val valueSig: TypeSig = (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL +
TypeSig.ARRAY + TypeSig.MAP).nested()
TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT).nested()
val sparkValueSig: TypeSig = TypeSig.all
val resultSig: TypeSig = TypeSig.STRUCT.nested(valueSig)
val sparkResultSig: TypeSig = TypeSig.STRUCT.nested(sparkValueSig)
Expand Down
29 changes: 25 additions & 4 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.collection.JavaConverters._
import scala.reflect.runtime.universe.TypeTag

import ai.rapids.cudf.{ColumnVector, DType, HostColumnVector, Scalar}
import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingArray
import org.json4s.JsonAST.{JField, JNull, JString}

import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -209,10 +210,15 @@ object GpuScalar extends Arm with Logging {
*/
def from(v: Any, t: DataType): Scalar = t match {
case nullType if v == null => nullType match {
case ArrayType(elementType, _) => Scalar.listFromNull(resolveElementType(elementType))
case MapType(keyType, valueType, _) => Scalar.listFromNull(
resolveElementType(StructType(
Seq(StructField("key", keyType), StructField("value", valueType)))))
case ArrayType(elementType, _) =>
Scalar.listFromNull(resolveElementType(elementType))
case StructType(fields) =>
Scalar.structFromNull(
fields.map(f => resolveElementType(f.dataType)): _*)
case MapType(keyType, valueType, _) =>
Scalar.listFromNull(
resolveElementType(StructType(
Seq(StructField("key", keyType), StructField("value", valueType)))))
case _ => Scalar.fromNull(GpuColumnVector.getNonNestedRapidsType(nullType))
}
case decType: DecimalType =>
Expand Down Expand Up @@ -297,6 +303,19 @@ object GpuScalar extends Arm with Logging {
case _ => throw new IllegalArgumentException(s"'$v: ${v.getClass}' is not supported" +
s" for ArrayType, expecting ArrayData")
}
case StructType(fields) => v match {
case row: InternalRow =>
val cvs = fields.zipWithIndex.safeMap {
case (f, i) =>
val dt = f.dataType
columnVectorFromLiterals(Seq(row.get(i, dt)), dt)
}
withResource(cvs) { cvs =>
Scalar.structFromColumnViews(cvs: _*)
}
case _ => throw new IllegalArgumentException(s"'$v: ${v.getClass}' is not supported" +
s" for StructType, expecting InternalRow")
}
case MapType(keyType, valueType, _) => v match {
case map: MapData =>
val struct = withResource(columnVectorFromLiterals(map.keyArray().array, keyType)) { keys =>
Expand Down Expand Up @@ -410,6 +429,8 @@ class GpuScalar private(
throw new IllegalArgumentException("Value should not be Scalar")
}

override def toString: String = s"GPU_SCALAR $dataType $value $scalar"

/**
* Gets the internal cudf Scalar of this GpuScalar.
*
Expand Down