From 08141a6e51c2df0e96d618ad8faf028044de385e Mon Sep 17 00:00:00 2001 From: panbingkun Date: Thu, 12 Sep 2024 17:10:10 +0800 Subject: [PATCH] [SPARK-49608][SQL] Improve SQL function `sort_array` --- .../expressions/collectionOperations.scala | 262 +++++++++++------- 1 file changed, 160 insertions(+), 102 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5d5aece35383e..150f5d120bbcb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.util.SQLOpenHashSet import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, CalendarInterval, UTF8String} +import org.apache.spark.util.SparkClassUtils /** * Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit @@ -1024,6 +1025,126 @@ case class MapSort(base: Expression) : MapSort = copy(base = newChild) } +/** + * This expression converts data of `ArrayData` to an array of java type. + * + * NOTE: When the data type of expression is `ArrayType`, and the expression is foldable, + * the `ConstantFolding` can do constant folding optimization automatically, + * (avoiding frequent calls to `ArrayData.to{XXX}Array()`). + */ +case class ToJavaArray(array: Expression, needBoxedPrimitive: Boolean = false) + extends UnaryExpression + with ImplicitCastInputTypes + with NullIntolerant + with QueryErrorsBase { + + override def checkInputDataTypes(): TypeCheckResult = array.dataType match { + case ArrayType(_, _) => + TypeCheckResult.TypeCheckSuccess + case _ => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> ordinalNumber(0), + "requiredType" -> toSQLType(ArrayType), + "inputSql" -> toSQLExpr(array), + "inputType" -> toSQLType(array.dataType)) + ) + } + + override def inputTypes: Seq[AbstractDataType] = Seq(array.dataType) + override def dataType: DataType = { + if (canPerformFast) { + elementType match { + case ByteType => ObjectType(classOf[Array[Byte]]) + case ShortType => ObjectType(classOf[Array[Short]]) + case IntegerType => ObjectType(classOf[Array[Int]]) + case LongType => ObjectType(classOf[Array[Long]]) + case FloatType => ObjectType(classOf[Array[Float]]) + case DoubleType => ObjectType(classOf[Array[Double]]) + } + } else if (isPrimitiveType) { + elementType match { + case BooleanType => ObjectType(classOf[Array[java.lang.Boolean]]) + case ByteType => ObjectType(classOf[Array[java.lang.Byte]]) + case ShortType => ObjectType(classOf[Array[java.lang.Short]]) + case IntegerType => ObjectType(classOf[Array[java.lang.Integer]]) + case LongType => ObjectType(classOf[Array[java.lang.Long]]) + case FloatType => ObjectType(classOf[Array[java.lang.Float]]) + case DoubleType => ObjectType(classOf[Array[java.lang.Double]]) + } + } else { + ObjectType(classOf[Array[Object]]) + } + } + + override def child: Expression = array + override def prettyName: String = "to_java_array" + + @transient lazy val elementType: DataType = + array.dataType.asInstanceOf[ArrayType].elementType + private def resultArrayElementNullable: Boolean = + array.dataType.asInstanceOf[ArrayType].containsNull + private def isPrimitiveType: Boolean = CodeGenerator.isPrimitiveType(elementType) + private def canPerformFast = isPrimitiveType && elementType != BooleanType && + !resultArrayElementNullable && !needBoxedPrimitive + + private def toJavaArray(array: Any): Any = { + val arrayData = array.asInstanceOf[ArrayData] + if (canPerformFast) { + elementType match { + case ByteType => arrayData.toByteArray() + case ShortType => arrayData.toShortArray() + case IntegerType => arrayData.toIntArray() + case LongType => arrayData.toLongArray() + case FloatType => arrayData.toFloatArray() + case DoubleType => arrayData.toDoubleArray() + } + } else if (isPrimitiveType) { + elementType match { + case BooleanType => arrayData.toArray[java.lang.Boolean](BooleanType) + case ByteType => arrayData.toArray[java.lang.Byte](ByteType) + case ShortType => arrayData.toArray[java.lang.Short](ShortType) + case IntegerType => arrayData.toArray[java.lang.Integer](IntegerType) + case LongType => arrayData.toArray[java.lang.Long](LongType) + case FloatType => arrayData.toArray[java.lang.Float](FloatType) + case DoubleType => arrayData.toArray[java.lang.Double](DoubleType) + } + } else { + arrayData.toObjectArray(elementType) + } + } + + private def toJavaArrayCodegen( + ctx: CodegenContext, + ev: ExprCode, + array: String): String = { + val elementTypeTerm = ctx.addReferenceObj("elementTypeTerm", elementType) + if (canPerformFast) { + val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType) + s"""${ev.value} = $array.to${primitiveTypeName}Array();""" + } else if (isPrimitiveType) { + val boxedJavaType = CodeGenerator.boxedType(elementType) + val classTagTerm = ctx.addReferenceObj("classTagTerm", + ClassTag(SparkClassUtils.classForName(s"java.lang.$boxedJavaType"))) + s"""${ev.value} = ($boxedJavaType[]) $array.toArray($elementTypeTerm, $classTagTerm);""" + } else { + s"""${ev.value} = $array.toObjectArray($elementTypeTerm);""" + } + } + + override def nullSafeEval(array: Any): Any = { + toJavaArray(array) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, array => toJavaArrayCodegen(ctx, ev, array)) + } + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(array = newChild) +} + /** * Sorts the input array in ascending / descending order according to the natural ordering of * the array elements and returns it. @@ -1047,7 +1168,11 @@ case class MapSort(base: Expression) since = "1.5.0") // scalastyle:on line.size.limit case class SortArray(base: Expression, ascendingOrder: Expression) - extends BinaryExpression with ExpectsInputTypes with NullIntolerant with QueryErrorsBase { + extends BinaryExpression + with ExpectsInputTypes + with NullIntolerant + with RuntimeReplaceable + with QueryErrorsBase { def this(e: Expression) = this(e, Literal(true)) @@ -1090,41 +1215,34 @@ case class SortArray(base: Expression, ascendingOrder: Expression) ) } - @transient private lazy val lt: Comparator[Any] = { + @transient private lazy val lt: Comparator[Any] = new Comparator[Any] with Serializable { val ordering = base.dataType match { case _ @ ArrayType(n, _) => PhysicalDataType.ordering(n) } - (o1: Any, o2: Any) => { - if (o1 == null && o2 == null) { - 0 - } else if (o1 == null) { - -1 - } else if (o2 == null) { - 1 - } else { - ordering.compare(o1, o2) + + override def compare(o1: Any, o2: Any): Int = + (o1, o2) match { + case (null, null) => 0 + case (null, _) => -1 + case (_, null) => 1 + case _ => ordering.compare(o1, o2) } - } } - @transient private lazy val gt: Comparator[Any] = { + @transient private lazy val gt: Comparator[Any] = new Comparator[Any] with Serializable { val ordering = base.dataType match { case _ @ ArrayType(n, _) => PhysicalDataType.ordering(n) } - (o1: Any, o2: Any) => { - if (o1 == null && o2 == null) { - 0 - } else if (o1 == null) { - 1 - } else if (o2 == null) { - -1 - } else { - ordering.compare(o2, o1) + override def compare(o1: Any, o2: Any): Int = + (o1, o2) match { + case (null, null) => 0 + case (null, _) => 1 + case (_, null) => -1 + case _ => ordering.compare(o2, o1) } - } } @transient lazy val elementType: DataType = @@ -1133,88 +1251,28 @@ case class SortArray(base: Expression, ascendingOrder: Expression) private def resultArrayElementNullable: Boolean = base.dataType.asInstanceOf[ArrayType].containsNull - private def sortEval(array: Any, ascending: Boolean): Any = { - val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) - if (elementType != NullType) { - java.util.Arrays.sort(data, if (ascending) lt else gt) - } - new GenericArrayData(data.asInstanceOf[Array[Any]]) - } + @transient private lazy val isPrimitiveType: Boolean = CodeGenerator.isPrimitiveType(elementType) + @transient private lazy val canPerformFastSort: Boolean = isPrimitiveType && + elementType != BooleanType && !resultArrayElementNullable + @transient private lazy val ascending: Boolean = ascendingOrder.eval().asInstanceOf[Boolean] + @transient private lazy val comparatorObjectType = ObjectType(classOf[Comparator[Object]]) - private def sortCodegen( - ctx: CodegenContext, - ev: ExprCode, - base: String, - order: String): String = { - val genericArrayData = classOf[GenericArrayData].getName - val unsafeArrayData = classOf[UnsafeArrayData].getName - val array = ctx.freshName("array") - val c = ctx.freshName("c") - if (elementType == NullType) { - s"${ev.value} = $base.copy();" + override def replacement: Expression = { + val toJavaArray = ToJavaArray(base, !ascending) + val (arguments, inputTypes) = if (canPerformFastSort && ascending) { + (Seq(toJavaArray), Seq(toJavaArray.dataType)) + } else if (isPrimitiveType) { + (Seq(toJavaArray, Literal(ascending, BooleanType)), Seq(toJavaArray.dataType, BooleanType)) } else { - val elementTypeTerm = ctx.addReferenceObj("elementTypeTerm", elementType) - val sortOrder = ctx.freshName("sortOrder") - val o1 = ctx.freshName("o1") - val o2 = ctx.freshName("o2") - val jt = CodeGenerator.javaType(elementType) - val comp = if (CodeGenerator.isPrimitiveType(elementType)) { - val bt = CodeGenerator.boxedType(elementType) - val v1 = ctx.freshName("v1") - val v2 = ctx.freshName("v2") - s""" - |$jt $v1 = (($bt) $o1).${jt}Value(); - |$jt $v2 = (($bt) $o2).${jt}Value(); - |int $c = ${ctx.genComp(elementType, v1, v2)}; - """.stripMargin - } else { - s"int $c = ${ctx.genComp(elementType, s"(($jt) $o1)", s"(($jt) $o2)")};" - } - val canPerformFastSort = CodeGenerator.isPrimitiveType(elementType) && - elementType != BooleanType && !resultArrayElementNullable - val nonNullPrimitiveAscendingSort = if (canPerformFastSort) { - val javaType = CodeGenerator.javaType(elementType) - val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType) - s""" - |if ($order) { - | $javaType[] $array = $base.to${primitiveTypeName}Array(); - | java.util.Arrays.sort($array); - | ${ev.value} = $unsafeArrayData.fromPrimitiveArray($array); - |} else - """.stripMargin - } else { - "" - } - s""" - |$nonNullPrimitiveAscendingSort - |{ - | Object[] $array = $base.toObjectArray($elementTypeTerm); - | final int $sortOrder = $order ? 1 : -1; - | java.util.Arrays.sort($array, new java.util.Comparator() { - | @Override public int compare(Object $o1, Object $o2) { - | if ($o1 == null && $o2 == null) { - | return 0; - | } else if ($o1 == null) { - | return -$sortOrder; - | } else if ($o2 == null) { - | return $sortOrder; - | } - | $comp - | return $sortOrder * $c; - | } - | }); - | ${ev.value} = new $genericArrayData($array); - |} - """.stripMargin + (Seq(toJavaArray, Literal(if (ascending) lt else gt, comparatorObjectType)), + Seq(toJavaArray.dataType, comparatorObjectType)) } - } - - override def nullSafeEval(array: Any, ascending: Any): Any = { - sortEval(array, ascending.asInstanceOf[Boolean]) - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (b, order) => sortCodegen(ctx, ev, b, order)) + StaticInvoke( + classOf[SortArrayExpressionUtils], + dataType, + "sort", + arguments, + inputTypes) } override def prettyName: String = "sort_array"