Skip to content

Commit

Permalink
support creating list ColumnVector for Literal(ArrayType(NullType)) (N…
Browse files Browse the repository at this point in the history
…VIDIA#2448)

Signed-off-by: Bobby Wang <wbo4958@gmail.com>
  • Loading branch information
wbo4958 authored May 19, 2021
1 parent aa128ea commit b589e2d
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 0 deletions.
1 change: 1 addition & 0 deletions integration_tests/src/main/python/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_make_array(data_gen):
(s1, s2) = gen_scalars_for_sql(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen))
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).selectExpr(
'array(null)',
'array(a, b)',
'array(b, a, null, {}, {})'.format(s1, s2)))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ object GpuScalar extends Arm with Logging {
val colType = resolveElementType(elementType)
val rows = seq.map(convertElementTo(_, elementType))
ColumnVector.fromStructs(colType, rows.asInstanceOf[Seq[HostColumnVector.StructData]]: _*)
case NullType =>
GpuColumnVector.columnVectorFromNull(seq.size, NullType)
case u =>
throw new IllegalArgumentException(s"Unsupported element type ($u) to create a" +
s" ColumnVector.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class ProjectExprSuite extends SparkQueryCompareTestSuite {
Array(StructField("id", IntegerType), StructField("name", StringType)))))),
new Column(Literal.create(List(BigDecimal(123L, 2), BigDecimal(-1444L, 2)),
ArrayType(DecimalType(10, 2)))))
.selectExpr("array(null)", "array(array(null))", "array()")
}

testSparkResultsAreEqual("project time", frameFromParquet("timestamp-date-test.parquet"),
Expand Down

0 comments on commit b589e2d

Please sign in to comment.