Skip to content

Commit

Permalink
[SPARK-49608][SQL] Improve SQL function sort_array
Browse files Browse the repository at this point in the history
  • Loading branch information
panbingkun committed Sep 12, 2024
1 parent 8023504 commit 08141a6
Showing 1 changed file with 160 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import org.apache.spark.sql.util.SQLOpenHashSet
import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.{ByteArray, CalendarInterval, UTF8String}
import org.apache.spark.util.SparkClassUtils

/**
* Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit
Expand Down Expand Up @@ -1024,6 +1025,126 @@ case class MapSort(base: Expression)
: MapSort = copy(base = newChild)
}

/**
* This expression converts data of `ArrayData` to an array of java type.
*
* NOTE: When the data type of expression is `ArrayType`, and the expression is foldable,
* the `ConstantFolding` can do constant folding optimization automatically,
* (avoiding frequent calls to `ArrayData.to{XXX}Array()`).
*/
case class ToJavaArray(array: Expression, needBoxedPrimitive: Boolean = false)
extends UnaryExpression
with ImplicitCastInputTypes
with NullIntolerant
with QueryErrorsBase {

override def checkInputDataTypes(): TypeCheckResult = array.dataType match {
case ArrayType(_, _) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(0),
"requiredType" -> toSQLType(ArrayType),
"inputSql" -> toSQLExpr(array),
"inputType" -> toSQLType(array.dataType))
)
}

override def inputTypes: Seq[AbstractDataType] = Seq(array.dataType)
override def dataType: DataType = {
if (canPerformFast) {
elementType match {
case ByteType => ObjectType(classOf[Array[Byte]])
case ShortType => ObjectType(classOf[Array[Short]])
case IntegerType => ObjectType(classOf[Array[Int]])
case LongType => ObjectType(classOf[Array[Long]])
case FloatType => ObjectType(classOf[Array[Float]])
case DoubleType => ObjectType(classOf[Array[Double]])
}
} else if (isPrimitiveType) {
elementType match {
case BooleanType => ObjectType(classOf[Array[java.lang.Boolean]])
case ByteType => ObjectType(classOf[Array[java.lang.Byte]])
case ShortType => ObjectType(classOf[Array[java.lang.Short]])
case IntegerType => ObjectType(classOf[Array[java.lang.Integer]])
case LongType => ObjectType(classOf[Array[java.lang.Long]])
case FloatType => ObjectType(classOf[Array[java.lang.Float]])
case DoubleType => ObjectType(classOf[Array[java.lang.Double]])
}
} else {
ObjectType(classOf[Array[Object]])
}
}

override def child: Expression = array
override def prettyName: String = "to_java_array"

@transient lazy val elementType: DataType =
array.dataType.asInstanceOf[ArrayType].elementType
private def resultArrayElementNullable: Boolean =
array.dataType.asInstanceOf[ArrayType].containsNull
private def isPrimitiveType: Boolean = CodeGenerator.isPrimitiveType(elementType)
private def canPerformFast = isPrimitiveType && elementType != BooleanType &&
!resultArrayElementNullable && !needBoxedPrimitive

private def toJavaArray(array: Any): Any = {
val arrayData = array.asInstanceOf[ArrayData]
if (canPerformFast) {
elementType match {
case ByteType => arrayData.toByteArray()
case ShortType => arrayData.toShortArray()
case IntegerType => arrayData.toIntArray()
case LongType => arrayData.toLongArray()
case FloatType => arrayData.toFloatArray()
case DoubleType => arrayData.toDoubleArray()
}
} else if (isPrimitiveType) {
elementType match {
case BooleanType => arrayData.toArray[java.lang.Boolean](BooleanType)
case ByteType => arrayData.toArray[java.lang.Byte](ByteType)
case ShortType => arrayData.toArray[java.lang.Short](ShortType)
case IntegerType => arrayData.toArray[java.lang.Integer](IntegerType)
case LongType => arrayData.toArray[java.lang.Long](LongType)
case FloatType => arrayData.toArray[java.lang.Float](FloatType)
case DoubleType => arrayData.toArray[java.lang.Double](DoubleType)
}
} else {
arrayData.toObjectArray(elementType)
}
}

private def toJavaArrayCodegen(
ctx: CodegenContext,
ev: ExprCode,
array: String): String = {
val elementTypeTerm = ctx.addReferenceObj("elementTypeTerm", elementType)
if (canPerformFast) {
val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType)
s"""${ev.value} = $array.to${primitiveTypeName}Array();"""
} else if (isPrimitiveType) {
val boxedJavaType = CodeGenerator.boxedType(elementType)
val classTagTerm = ctx.addReferenceObj("classTagTerm",
ClassTag(SparkClassUtils.classForName(s"java.lang.$boxedJavaType")))
s"""${ev.value} = ($boxedJavaType[]) $array.toArray($elementTypeTerm, $classTagTerm);"""
} else {
s"""${ev.value} = $array.toObjectArray($elementTypeTerm);"""
}
}

override def nullSafeEval(array: Any): Any = {
toJavaArray(array)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, array => toJavaArrayCodegen(ctx, ev, array))
}

override protected def withNewChildInternal(newChild: Expression): Expression =
copy(array = newChild)
}

/**
* Sorts the input array in ascending / descending order according to the natural ordering of
* the array elements and returns it.
Expand All @@ -1047,7 +1168,11 @@ case class MapSort(base: Expression)
since = "1.5.0")
// scalastyle:on line.size.limit
case class SortArray(base: Expression, ascendingOrder: Expression)
extends BinaryExpression with ExpectsInputTypes with NullIntolerant with QueryErrorsBase {
extends BinaryExpression
with ExpectsInputTypes
with NullIntolerant
with RuntimeReplaceable
with QueryErrorsBase {

def this(e: Expression) = this(e, Literal(true))

Expand Down Expand Up @@ -1090,41 +1215,34 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
)
}

@transient private lazy val lt: Comparator[Any] = {
@transient private lazy val lt: Comparator[Any] = new Comparator[Any] with Serializable {
val ordering = base.dataType match {
case _ @ ArrayType(n, _) =>
PhysicalDataType.ordering(n)
}
(o1: Any, o2: Any) => {
if (o1 == null && o2 == null) {
0
} else if (o1 == null) {
-1
} else if (o2 == null) {
1
} else {
ordering.compare(o1, o2)

override def compare(o1: Any, o2: Any): Int =
(o1, o2) match {
case (null, null) => 0
case (null, _) => -1
case (_, null) => 1
case _ => ordering.compare(o1, o2)
}
}
}

@transient private lazy val gt: Comparator[Any] = {
@transient private lazy val gt: Comparator[Any] = new Comparator[Any] with Serializable {
val ordering = base.dataType match {
case _ @ ArrayType(n, _) =>
PhysicalDataType.ordering(n)
}

(o1: Any, o2: Any) => {
if (o1 == null && o2 == null) {
0
} else if (o1 == null) {
1
} else if (o2 == null) {
-1
} else {
ordering.compare(o2, o1)
override def compare(o1: Any, o2: Any): Int =
(o1, o2) match {
case (null, null) => 0
case (null, _) => 1
case (_, null) => -1
case _ => ordering.compare(o2, o1)
}
}
}

@transient lazy val elementType: DataType =
Expand All @@ -1133,88 +1251,28 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
private def resultArrayElementNullable: Boolean =
base.dataType.asInstanceOf[ArrayType].containsNull

private def sortEval(array: Any, ascending: Boolean): Any = {
val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
if (elementType != NullType) {
java.util.Arrays.sort(data, if (ascending) lt else gt)
}
new GenericArrayData(data.asInstanceOf[Array[Any]])
}
@transient private lazy val isPrimitiveType: Boolean = CodeGenerator.isPrimitiveType(elementType)
@transient private lazy val canPerformFastSort: Boolean = isPrimitiveType &&
elementType != BooleanType && !resultArrayElementNullable
@transient private lazy val ascending: Boolean = ascendingOrder.eval().asInstanceOf[Boolean]
@transient private lazy val comparatorObjectType = ObjectType(classOf[Comparator[Object]])

private def sortCodegen(
ctx: CodegenContext,
ev: ExprCode,
base: String,
order: String): String = {
val genericArrayData = classOf[GenericArrayData].getName
val unsafeArrayData = classOf[UnsafeArrayData].getName
val array = ctx.freshName("array")
val c = ctx.freshName("c")
if (elementType == NullType) {
s"${ev.value} = $base.copy();"
override def replacement: Expression = {
val toJavaArray = ToJavaArray(base, !ascending)
val (arguments, inputTypes) = if (canPerformFastSort && ascending) {
(Seq(toJavaArray), Seq(toJavaArray.dataType))
} else if (isPrimitiveType) {
(Seq(toJavaArray, Literal(ascending, BooleanType)), Seq(toJavaArray.dataType, BooleanType))
} else {
val elementTypeTerm = ctx.addReferenceObj("elementTypeTerm", elementType)
val sortOrder = ctx.freshName("sortOrder")
val o1 = ctx.freshName("o1")
val o2 = ctx.freshName("o2")
val jt = CodeGenerator.javaType(elementType)
val comp = if (CodeGenerator.isPrimitiveType(elementType)) {
val bt = CodeGenerator.boxedType(elementType)
val v1 = ctx.freshName("v1")
val v2 = ctx.freshName("v2")
s"""
|$jt $v1 = (($bt) $o1).${jt}Value();
|$jt $v2 = (($bt) $o2).${jt}Value();
|int $c = ${ctx.genComp(elementType, v1, v2)};
""".stripMargin
} else {
s"int $c = ${ctx.genComp(elementType, s"(($jt) $o1)", s"(($jt) $o2)")};"
}
val canPerformFastSort = CodeGenerator.isPrimitiveType(elementType) &&
elementType != BooleanType && !resultArrayElementNullable
val nonNullPrimitiveAscendingSort = if (canPerformFastSort) {
val javaType = CodeGenerator.javaType(elementType)
val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType)
s"""
|if ($order) {
| $javaType[] $array = $base.to${primitiveTypeName}Array();
| java.util.Arrays.sort($array);
| ${ev.value} = $unsafeArrayData.fromPrimitiveArray($array);
|} else
""".stripMargin
} else {
""
}
s"""
|$nonNullPrimitiveAscendingSort
|{
| Object[] $array = $base.toObjectArray($elementTypeTerm);
| final int $sortOrder = $order ? 1 : -1;
| java.util.Arrays.sort($array, new java.util.Comparator() {
| @Override public int compare(Object $o1, Object $o2) {
| if ($o1 == null && $o2 == null) {
| return 0;
| } else if ($o1 == null) {
| return -$sortOrder;
| } else if ($o2 == null) {
| return $sortOrder;
| }
| $comp
| return $sortOrder * $c;
| }
| });
| ${ev.value} = new $genericArrayData($array);
|}
""".stripMargin
(Seq(toJavaArray, Literal(if (ascending) lt else gt, comparatorObjectType)),
Seq(toJavaArray.dataType, comparatorObjectType))
}
}

override def nullSafeEval(array: Any, ascending: Any): Any = {
sortEval(array, ascending.asInstanceOf[Boolean])
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (b, order) => sortCodegen(ctx, ev, b, order))
StaticInvoke(
classOf[SortArrayExpressionUtils],
dataType,
"sort",
arguments,
inputTypes)
}

override def prettyName: String = "sort_array"
Expand Down

0 comments on commit 08141a6

Please sign in to comment.