Skip to content

Commit

Permalink
[SPARK-33338][SQL] GROUP BY using literal map should not fail
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR aims to fix `semanticEquals` works correctly on `GetMapValue` expressions having literal maps with `ArrayBasedMapData` and `GenericArrayData`.

### Why are the changes needed?

This is a regression from Apache Spark 1.6.x.
```scala
scala> sc.version
res1: String = 1.6.3

scala> sqlContext.sql("SELECT map('k1', 'v1')[k] FROM t GROUP BY map('k1', 'v1')[k]").show
+---+
|_c0|
+---+
| v1|
+---+
```

Apache Spark 2.x ~ 3.0.1 raise`RuntimeException` for the following queries.
```sql
CREATE TABLE t USING ORC AS SELECT map('k1', 'v1') m, 'k1' k
SELECT map('k1', 'v1')[k] FROM t GROUP BY 1
SELECT map('k1', 'v1')[k] FROM t GROUP BY map('k1', 'v1')[k]
SELECT map('k1', 'v1')[k] a FROM t GROUP BY a
```

**BEFORE**
```scala
Caused by: java.lang.RuntimeException: Couldn't find k#3 in [keys: [k1], values: [v1][k#3]#6]
	at scala.sys.package$.error(package.scala:27)
	at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1$$anonfun$applyOrElse$1.apply(BoundAttribute.scala:85)
	at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1$$anonfun$applyOrElse$1.apply(BoundAttribute.scala:79)
	at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52)
```

**AFTER**
```sql
spark-sql> SELECT map('k1', 'v1')[k] FROM t GROUP BY 1;
v1
Time taken: 1.278 seconds, Fetched 1 row(s)
spark-sql> SELECT map('k1', 'v1')[k] FROM t GROUP BY map('k1', 'v1')[k];
v1
Time taken: 0.313 seconds, Fetched 1 row(s)
spark-sql> SELECT map('k1', 'v1')[k] a FROM t GROUP BY a;
v1
Time taken: 0.265 seconds, Fetched 1 row(s)
```

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Pass the CIs with the newly added test case.

Closes #30246 from dongjoon-hyun/SPARK-33338.

Authored-by: Dongjoon Hyun <dhyun@apple.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
(cherry picked from commit 42c0b17)
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
dongjoon-hyun committed Nov 4, 2020
1 parent 6c5d008 commit 2bde026
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
(value, o.value) match {
case (null, null) => true
case (a: Array[Byte], b: Array[Byte]) => util.Arrays.equals(a, b)
case (a: ArrayBasedMapData, b: ArrayBasedMapData) =>
a.keyArray == b.keyArray && a.valueArray == b.valueArray
case (a, b) => a != null && a.equals(b)
}
case _ => false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -396,4 +397,18 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
CreateNamedStruct(Seq("a", "x", "b", 2.0)).genCode(ctx)
assert(ctx.inlinedMutableStates.isEmpty)
}

test("SPARK-33338: semanticEquals should handle static GetMapValue correctly") {
val keys = new Array[UTF8String](1)
val values = new Array[UTF8String](1)
keys(0) = UTF8String.fromString("key")
values(0) = UTF8String.fromString("value")

val d1 = new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values))
val d2 = new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values))
val m1 = GetMapValue(Literal.create(d1, MapType(StringType, StringType)), Literal("a"))
val m2 = GetMapValue(Literal.create(d2, MapType(StringType, StringType)), Literal("a"))

assert(m1.semanticEquals(m2))
}
}
12 changes: 12 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3102,6 +3102,18 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
"JOIN ( SELECT * FROM B JOIN C USING (id)) USING (id)"), Row(0))
}
}

test("SPARK-33338: GROUP BY using literal map should not fail") {
withTempDir { dir =>
sql(s"CREATE TABLE t USING ORC LOCATION '${dir.toURI}' AS SELECT map('k1', 'v1') m, 'k1' k")
Seq(
"SELECT map('k1', 'v1')[k] FROM t GROUP BY 1",
"SELECT map('k1', 'v1')[k] FROM t GROUP BY map('k1', 'v1')[k]",
"SELECT map('k1', 'v1')[k] a FROM t GROUP BY a").foreach { statement =>
checkAnswer(sql(statement), Row("v1"))
}
}
}
}

case class Foo(bar: Option[String])

0 comments on commit 2bde026

Please sign in to comment.