Skip to content

Commit

Permalink
[SPARK-48364][SQL] Add AbstractMapType type casting and fix RaiseErro…
Browse files Browse the repository at this point in the history
…r parameter map to work with collated strings

### What changes were proposed in this pull request?
Following up on the introduction of AbstractMapType (apache#46458) and changes that introduce collation awareness for RaiseError expression (apache#46461), this PR should add the appropriate type casting rules for AbstractMapType.

### Why are the changes needed?
Fix the CI failure for the `Support RaiseError misc expression with collation` test when ANSI is off.

### Does this PR introduce _any_ user-facing change?
Yes, type casting is now allowed for map types with collated strings.

### How was this patch tested?
Extended suite `CollationSQLExpressionsANSIOffSuite` with ANSI disabled.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#46661 from uros-db/fix-abstract-map.

Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
uros-db authored and cloud-fan committed May 22, 2024
1 parent e04d3d7 commit 6be3560
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveS
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType}

object CollationTypeCasts extends TypeCoercionRule {
override val transform: PartialFunction[Expression, Expression] = {
Expand Down Expand Up @@ -85,6 +85,11 @@ object CollationTypeCasts extends TypeCoercionRule {
private def extractStringType(dt: DataType): StringType = dt match {
case st: StringType => st
case ArrayType(et, _) => extractStringType(et)
case MapType(kt, vt, _) => if (hasStringType(kt)) {
extractStringType(kt)
} else {
extractStringType(vt)
}
}

/**
Expand All @@ -102,6 +107,14 @@ object CollationTypeCasts extends TypeCoercionRule {
case st: StringType if st.collationId != castType.collationId => castType
case ArrayType(arrType, nullable) =>
castStringType(arrType, castType).map(ArrayType(_, nullable)).orNull
case MapType(keyType, valueType, nullable) =>
val newKeyType = castStringType(keyType, castType).getOrElse(keyType)
val newValueType = castStringType(valueType, castType).getOrElse(valueType)
if (newKeyType != keyType || newValueType != valueType) {
MapType(newKeyType, newValueType, nullable)
} else {
null
}
case _ => null
}
Option(ret)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.trees.AlwaysProcess
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractStringType, StringTypeAnyCollation}
import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType, StringTypeAnyCollation}
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.UpCastRule.numericPrecedence

Expand Down Expand Up @@ -1048,6 +1048,15 @@ object TypeCoercion extends TypeCoercionBase {
}
}

case (MapType(fromKeyType, fromValueType, fn), AbstractMapType(toKeyType, toValueType)) =>
val newKeyType = implicitCast(fromKeyType, toKeyType).orNull
val newValueType = implicitCast(fromValueType, toValueType).orNull
if (newKeyType != null && newValueType != null) {
MapType(newKeyType, newValueType, fn)
} else {
null
}

case _ => null
}
Option(ret)
Expand Down Expand Up @@ -1080,10 +1089,10 @@ object TypeCoercion extends TypeCoercionBase {
/**
* Whether the data type contains StringType.
*/
@tailrec
def hasStringType(dt: DataType): Boolean = dt match {
case _: StringType => true
case ArrayType(et, _) => hasStringType(et)
case MapType(kt, vt, _) => hasStringType(kt) || hasStringType(vt)
// Add StructType if we support string promotion for struct fields in the future.
case _ => false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.util.{MapData, RandomUUIDGenerator}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeAnyCollation
import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeAnyCollation}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -85,7 +85,7 @@ case class RaiseError(errorClass: Expression, errorParms: Expression, dataType:
override def foldable: Boolean = false
override def nullable: Boolean = true
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, MapType(StringType, StringType))
Seq(StringTypeAnyCollation, AbstractMapType(StringTypeAnyCollation, StringTypeAnyCollation))

override def left: Expression = errorClass
override def right: Expression = errorParms
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import java.text.SimpleDateFormat

import scala.collection.immutable.Seq

import org.apache.spark.{SparkException, SparkIllegalArgumentException, SparkRuntimeException}
import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.{SparkConf, SparkException, SparkIllegalArgumentException, SparkRuntimeException}
import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -1636,3 +1636,9 @@ class CollationSQLExpressionsSuite

}
// scalastyle:on nonascii

class CollationSQLExpressionsANSIOffSuite extends CollationSQLExpressionsSuite {
override protected def sparkConf: SparkConf =
super.sparkConf.set(SQLConf.ANSI_ENABLED, false)

}
25 changes: 2 additions & 23 deletions sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeAnyCollation}
import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType}

class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
Expand Down Expand Up @@ -954,10 +955,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE",
parameters = Map(
"functionName" -> "`=`",
"dataType" -> toSQLType(MapType(
StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE")),
StringType
)),
"dataType" -> toSQLType(AbstractMapType(StringTypeAnyCollation, StringTypeAnyCollation)),
"sqlExpr" -> "\"(m = m)\""),
context = ExpectedContext(ctx, query.length - ctx.length, query.length - 1))
}
Expand Down Expand Up @@ -1010,25 +1008,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
|select map('a' collate utf8_binary_lcase, 1, 'b' collate utf8_binary_lcase, 2)
|['A' collate utf8_binary_lcase]
|""".stripMargin), Seq(Row(1)))
val ctx = "map('aaa' collate utf8_binary_lcase, 1, 'AAA' collate utf8_binary_lcase, 2)['AaA']"
val query = s"select $ctx"
checkError(
exception = intercept[AnalysisException](sql(query)),
errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
parameters = Map(
"sqlExpr" -> "\"map(collate(aaa), 1, collate(AAA), 2)[AaA]\"",
"paramIndex" -> "second",
"inputSql" -> "\"AaA\"",
"inputType" -> toSQLType(StringType),
"requiredType" -> toSQLType(StringType(
CollationFactory.collationNameToId("UTF8_BINARY_LCASE")))
),
context = ExpectedContext(
fragment = ctx,
start = query.length - ctx.length,
stop = query.length - 1
)
)
}

test("window aggregates should respect collation") {
Expand Down

0 comments on commit 6be3560

Please sign in to comment.