diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 9e96ab8a9b6ca..413d0af61a05c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -316,6 +316,8 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { (value, o.value) match { case (null, null) => true case (a: Array[Byte], b: Array[Byte]) => util.Arrays.equals(a, b) + case (a: ArrayBasedMapData, b: ArrayBasedMapData) => + a.keyArray == b.keyArray && a.valueArray == b.valueArray case (a, b) => a != null && a.equals(b) } case _ => false diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index cdb83d3580f0a..38e32ff2518f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -471,4 +472,18 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { CreateNamedStruct(Seq("a", "x", "b", 2.0)).genCode(ctx) assert(ctx.inlinedMutableStates.isEmpty) } + + test("SPARK-33338: semanticEquals should handle static GetMapValue correctly") { + val keys = new Array[UTF8String](1) + val values = new Array[UTF8String](1) + keys(0) = UTF8String.fromString("key") + values(0) = UTF8String.fromString("value") + + val d1 = new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values)) + val d2 = new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values)) + val m1 = GetMapValue(Literal.create(d1, MapType(StringType, StringType)), Literal("a")) + val m2 = GetMapValue(Literal.create(d2, MapType(StringType, StringType)), Literal("a")) + + assert(m1.semanticEquals(m2)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 0dd2a286772a5..cebbf9282f710 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3706,6 +3706,18 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } } + + test("SPARK-33338: GROUP BY using literal map should not fail") { + withTempDir { dir => + sql(s"CREATE TABLE t USING ORC LOCATION '${dir.toURI}' AS SELECT map('k1', 'v1') m, 'k1' k") + Seq( + "SELECT map('k1', 'v1')[k] FROM t GROUP BY 1", + "SELECT map('k1', 'v1')[k] FROM t GROUP BY map('k1', 'v1')[k]", + "SELECT map('k1', 'v1')[k] a FROM t GROUP BY a").foreach { statement => + checkAnswer(sql(statement), Row("v1")) + } + } + } } case class Foo(bar: Option[String])