diff --git a/integration_tests/src/main/python/array_test.py b/integration_tests/src/main/python/array_test.py index 58d2b7f1042..78bd9b3b35a 100644 --- a/integration_tests/src/main/python/array_test.py +++ b/integration_tests/src/main/python/array_test.py @@ -54,6 +54,7 @@ def test_make_array(data_gen): (s1, s2) = gen_scalars_for_sql(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen)) assert_gpu_and_cpu_are_equal_collect( lambda spark : binary_op_df(spark, data_gen).selectExpr( + 'array(null)', 'array(a, b)', 'array(b, a, null, {}, {})'.format(s1, s2))) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala index 10309bc4007..36eb4fc15c4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala @@ -176,6 +176,8 @@ object GpuScalar extends Arm with Logging { val colType = resolveElementType(elementType) val rows = seq.map(convertElementTo(_, elementType)) ColumnVector.fromStructs(colType, rows.asInstanceOf[Seq[HostColumnVector.StructData]]: _*) + case NullType => + GpuColumnVector.columnVectorFromNull(seq.size, NullType) case u => throw new IllegalArgumentException(s"Unsupported element type ($u) to create a" + s" ColumnVector.") diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ProjectExprSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ProjectExprSuite.scala index 17fc45c9429..ca9667de9a8 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ProjectExprSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ProjectExprSuite.scala @@ -61,6 +61,7 @@ class ProjectExprSuite extends SparkQueryCompareTestSuite { Array(StructField("id", IntegerType), StructField("name", StringType)))))), new Column(Literal.create(List(BigDecimal(123L, 2), BigDecimal(-1444L, 2)), ArrayType(DecimalType(10, 2))))) + .selectExpr("array(null)", "array(array(null))", "array()") } testSparkResultsAreEqual("project time", frameFromParquet("timestamp-date-test.parquet"),