Skip to content

Commit

Permalink
Fix No Matching column bug (#1162) (#1168)
Browse files Browse the repository at this point in the history
  • Loading branch information
marsishandsome authored Oct 24, 2019
1 parent 82878fa commit 2d38787
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 60 deletions.
2 changes: 2 additions & 0 deletions core/src/main/scala/com/pingcap/tispark/TiConfigConst.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ object TiConfigConst {
val WRITE_ENABLE: String = "spark.tispark.write.enable"
val WRITE_WITHOUT_LOCK_TABLE: String = "spark.tispark.write.without_lock_table"
val TIKV_REGION_SPLIT_SIZE_IN_MB: String = "spark.tispark.tikv.region_split_size_in_mb"
val ALLOW_ORDER_PROJECT_LIMIT_PUSHDOWN: String =
"spark.tispark.plan.allow_order_project_limit_pushdown"

val SNAPSHOT_ISOLATION_LEVEL: String = "SI"
val READ_COMMITTED_ISOLATION_LEVEL: String = "RC"
Expand Down
178 changes: 118 additions & 60 deletions core/src/main/scala/org/apache/spark/sql/TiStrategy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ case class TiStrategy(getOrCreateTiContext: SparkSession => TiContext)(sparkSess
new TypeBlacklist(blacklistString)
}

private def allowOrderProjectLimitPushdown(): Boolean =
sqlConf
.getConfString(TiConfigConst.ALLOW_ORDER_PROJECT_LIMIT_PUSHDOWN, "false")
.toLowerCase
.toBoolean

private def allowAggregationPushdown(): Boolean =
sqlConf.getConfString(TiConfigConst.ALLOW_AGG_PUSHDOWN, "true").toLowerCase.toBoolean

Expand Down Expand Up @@ -501,66 +507,118 @@ case class TiStrategy(getOrCreateTiContext: SparkSession => TiContext)(sparkSess
// We do through similar logic with original Spark as in SparkStrategies.scala
// Difference is we need to test if a sub-plan can be consumed all together by TiKV
// and then we don't return (don't planLater) and plan the remaining all at once
private def doPlan(source: TiDBRelation, plan: LogicalPlan): Seq[SparkPlan] =
private def doPlan(source: TiDBRelation, plan: LogicalPlan): Seq[SparkPlan] = {
// TODO: This test should be done once for all children
plan match {
case logical.ReturnAnswer(rootPlan) =>
rootPlan match {
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
takeOrderedAndProject(limit, order, child, child.output) :: Nil
case logical.Limit(
IntegerLiteral(limit),
logical.Project(projectList, logical.Sort(order, true, child))
) =>
takeOrderedAndProject(limit, order, child, projectList) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
execution.CollectLimitExec(limit, collectLimit(limit, child)) :: Nil
case other => planLater(other) :: Nil
}
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
takeOrderedAndProject(limit, order, child, child.output) :: Nil
case logical.Limit(
IntegerLiteral(limit),
logical.Project(projectList, logical.Sort(order, true, child))
) =>
takeOrderedAndProject(limit, order, child, projectList) :: Nil
// Collapse filters and projections and push plan directly
case PhysicalOperation(
projectList,
filters,
LogicalRelation(source: TiDBRelation, _, _, _)
) =>
pruneFilterProject(projectList, filters, source) :: Nil

// Basic logic of original Spark's aggregation plan is:
// PhysicalAggregation extractor will rewrite original aggregation
// into aggregateExpressions and resultExpressions.
// resultExpressions contains only references [[AttributeReference]]
// to the result of aggregation. resultExpressions might contain projections
// like Add(sumResult, 1).
// For a aggregate like agg(expr) + 1, the rewrite process is: rewrite agg(expr) ->
// 1. pushdown: agg(expr) as agg1, if avg then sum(expr), count(expr)
// 2. residual expr (for Spark itself): agg(agg1) as finalAgg1 the parameter is a
// reference to pushed plan's corresponding aggregation
// 3. resultExpressions: finalAgg1 + 1, the finalAgg1 is the reference to final result
// of the aggregation
case TiAggregation(
groupingExpressions,
aggregateExpressions,
resultExpressions,
TiAggregationProjection(filters, _, `source`, projects)
) if isValidAggregates(groupingExpressions, aggregateExpressions, filters, source) =>
val projectSet = AttributeSet((projects ++ filters).flatMap { _.references })
val tiColumns = buildTiColumnRefFromColumnSeq(projectSet, source)
val dagReq: TiDAGRequest = filterToDAGRequest(tiColumns, filters, source)
groupAggregateProjection(
tiColumns,
groupingExpressions,
aggregateExpressions,
resultExpressions,
`source`,
dagReq
)
case _ => Nil
if (!allowOrderProjectLimitPushdown()) {
plan match {
case logical.ReturnAnswer(rootPlan) =>
rootPlan match {
case logical.Limit(IntegerLiteral(limit), child) =>
execution.CollectLimitExec(limit, collectLimit(limit, child)) :: Nil
case other => planLater(other) :: Nil
}
// Collapse filters and projections and push plan directly
case PhysicalOperation(
projectList,
filters,
LogicalRelation(source: TiDBRelation, _, _, _)
) =>
pruneFilterProject(projectList, filters, source) :: Nil

// Basic logic of original Spark's aggregation plan is:
// PhysicalAggregation extractor will rewrite original aggregation
// into aggregateExpressions and resultExpressions.
// resultExpressions contains only references [[AttributeReference]]
// to the result of aggregation. resultExpressions might contain projections
// like Add(sumResult, 1).
// For a aggregate like agg(expr) + 1, the rewrite process is: rewrite agg(expr) ->
// 1. pushdown: agg(expr) as agg1, if avg then sum(expr), count(expr)
// 2. residual expr (for Spark itself): agg(agg1) as finalAgg1 the parameter is a
// reference to pushed plan's corresponding aggregation
// 3. resultExpressions: finalAgg1 + 1, the finalAgg1 is the reference to final result
// of the aggregation
case TiAggregation(
groupingExpressions,
aggregateExpressions,
resultExpressions,
TiAggregationProjection(filters, _, `source`, projects)
) if isValidAggregates(groupingExpressions, aggregateExpressions, filters, source) =>
val projectSet = AttributeSet((projects ++ filters).flatMap { _.references })
val tiColumns = buildTiColumnRefFromColumnSeq(projectSet, source)
val dagReq: TiDAGRequest = filterToDAGRequest(tiColumns, filters, source)
groupAggregateProjection(
tiColumns,
groupingExpressions,
aggregateExpressions,
resultExpressions,
`source`,
dagReq
)
case _ => Nil
}
} else {
plan match {
case logical.ReturnAnswer(rootPlan) =>
rootPlan match {
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
takeOrderedAndProject(limit, order, child, child.output) :: Nil
case logical.Limit(
IntegerLiteral(limit),
logical.Project(projectList, logical.Sort(order, true, child))
) =>
takeOrderedAndProject(limit, order, child, projectList) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
execution.CollectLimitExec(limit, collectLimit(limit, child)) :: Nil
case other => planLater(other) :: Nil
}
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
takeOrderedAndProject(limit, order, child, child.output) :: Nil
case logical.Limit(
IntegerLiteral(limit),
logical.Project(projectList, logical.Sort(order, true, child))
) =>
takeOrderedAndProject(limit, order, child, projectList) :: Nil
// Collapse filters and projections and push plan directly
case PhysicalOperation(
projectList,
filters,
LogicalRelation(source: TiDBRelation, _, _, _)
) =>
pruneFilterProject(projectList, filters, source) :: Nil

// Basic logic of original Spark's aggregation plan is:
// PhysicalAggregation extractor will rewrite original aggregation
// into aggregateExpressions and resultExpressions.
// resultExpressions contains only references [[AttributeReference]]
// to the result of aggregation. resultExpressions might contain projections
// like Add(sumResult, 1).
// For a aggregate like agg(expr) + 1, the rewrite process is: rewrite agg(expr) ->
// 1. pushdown: agg(expr) as agg1, if avg then sum(expr), count(expr)
// 2. residual expr (for Spark itself): agg(agg1) as finalAgg1 the parameter is a
// reference to pushed plan's corresponding aggregation
// 3. resultExpressions: finalAgg1 + 1, the finalAgg1 is the reference to final result
// of the aggregation
case TiAggregation(
groupingExpressions,
aggregateExpressions,
resultExpressions,
TiAggregationProjection(filters, _, `source`, projects)
) if isValidAggregates(groupingExpressions, aggregateExpressions, filters, source) =>
val projectSet = AttributeSet((projects ++ filters).flatMap {
_.references
})
val tiColumns = buildTiColumnRefFromColumnSeq(projectSet, source)
val dagReq: TiDAGRequest = filterToDAGRequest(tiColumns, filters, source)
groupAggregateProjection(
tiColumns,
groupingExpressions,
aggregateExpressions,
resultExpressions,
`source`,
dagReq
)
case _ => Nil
}
}
}
}
26 changes: 26 additions & 0 deletions core/src/test/scala/org/apache/spark/sql/IssueTestSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,31 @@ import com.pingcap.tispark.TiConfigConst
import org.apache.spark.sql.functions.{col, sum}

class IssueTestSuite extends BaseTiSparkTest {
// https://github.com/pingcap/tispark/issues/1161
test("No Match Column") {
tidbStmt.execute("DROP TABLE IF EXISTS t_no_match_column")
tidbStmt.execute("""
|CREATE TABLE `t_no_match_column` (
| `json` text DEFAULT NULL
|)
""".stripMargin)

spark.sql("""
|select temp44.id from (
| SELECT get_json_object(t.json, '$.id') as id from t_no_match_column t
|) as temp44 order by temp44.id desc
""".stripMargin).explain(true)

spark.sql("""
|select temp44.id from (
| SELECT get_json_object(t.json, '$.id') as id from t_no_match_column t
|) as temp44 order by temp44.id desc limit 1
""".stripMargin).explain(true)

judge("select tp_int from full_data_type_table order by tp_int limit 20")
judge("select tp_int g from full_data_type_table order by g limit 20")
}

// https://github.com/pingcap/tispark/issues/1147
test("Fix Bit Type default value bug") {
def runBitTest(bitNumber: Int, bitString: String): Unit = {
Expand Down Expand Up @@ -390,6 +415,7 @@ class IssueTestSuite extends BaseTiSparkTest {
tidbStmt.execute("drop table if exists table_group_by_bigint")
tidbStmt.execute("drop table if exists catalog_delay")
tidbStmt.execute("drop table if exists t_origin_default_value")
tidbStmt.execute("drop table if exists t_no_match_column")
} finally {
super.afterAll()
}
Expand Down

0 comments on commit 2d38787

Please sign in to comment.