Skip to content

Commit

Permalink
add config spark.tispark.plan.allow_order_project_limit_pushdown
Browse files Browse the repository at this point in the history
  • Loading branch information
marsishandsome committed Oct 22, 2019
1 parent 8092e6e commit 93f3c2e
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 75 deletions.
2 changes: 2 additions & 0 deletions core/src/main/scala/com/pingcap/tispark/TiConfigConst.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ object TiConfigConst {
val CACHE_EXPIRE_AFTER_ACCESS: String = "spark.tispark.statistics.expire_after_access"
val SHOW_ROWID: String = "spark.tispark.show_rowid"
val DB_PREFIX: String = "spark.tispark.db_prefix"
val ALLOW_ORDER_PROJECT_LIMIT_PUSHDOWN: String =
"spark.tispark.plan.allow_order_project_limit_pushdown"

val SNAPSHOT_ISOLATION_LEVEL: String = "SI"
val READ_COMMITTED_ISOLATION_LEVEL: String = "RC"
Expand Down
208 changes: 133 additions & 75 deletions core/src/main/scala/org/apache/spark/sql/TiStrategy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import com.pingcap.tikv.expression._
import com.pingcap.tikv.expression.visitor.{ColumnMatcher, MetaResolver}
import com.pingcap.tikv.meta.{TiDAGRequest, TiTimestamp}
import com.pingcap.tikv.meta.TiDAGRequest.PushDownType
import com.pingcap.tikv.predicates.TiKVScanAnalyzer.TiKVScanPlan
import com.pingcap.tikv.predicates.{PredicateUtils, TiKVScanAnalyzer}
import com.pingcap.tikv.statistics.TableStatistics
import com.pingcap.tispark.utils.TiUtil._
Expand All @@ -32,7 +31,7 @@ import com.pingcap.tispark.utils.TiUtil
import com.pingcap.tispark.{BasicExpression, TiConfigConst, TiDBRelation}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, _}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, Expression, IntegerLiteral, NamedExpression, SortOrder, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression, IntegerLiteral, NamedExpression, SortOrder, SubqueryExpression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -87,6 +86,12 @@ case class TiStrategy(getOrCreateTiContext: SparkSession => TiContext)(sparkSess
new TypeBlacklist(blacklistString)
}

private def allowOrderProjectLimitPushdown(): Boolean =
sqlConf
.getConfString(TiConfigConst.ALLOW_ORDER_PROJECT_LIMIT_PUSHDOWN, "false")
.toLowerCase
.toBoolean

private def allowAggregationPushdown(): Boolean =
sqlConf.getConfString(TiConfigConst.ALLOW_AGG_PUSHDOWN, "true").toLowerCase.toBoolean

Expand Down Expand Up @@ -306,23 +311,10 @@ case class TiStrategy(getOrCreateTiContext: SparkSession => TiContext)(sparkSess
): SparkPlan = {
val request = new TiDAGRequest(pushDownType(), timeZoneOffset())
request.setLimit(limit)
//if (shouldAddSortOrder(projectList, sortOrder)) {
addSortOrder(request, sortOrder)
//}
pruneFilterProject(projectList, filterPredicates, source, request)
}

private def shouldAddSortOrder(expressions: Seq[NamedExpression],
orders: Seq[SortOrder]): Boolean = {
val exprList = expressions.map(_.toAttribute)
!orders.exists { order =>
exprList.exists { l =>
val r = order.child.asInstanceOf[AttributeReference].toAttribute
r.semanticEquals(l)
}
}
}

private def collectLimit(limit: Int, child: LogicalPlan): SparkPlan = child match {
case PhysicalOperation(projectList, filters, LogicalRelation(source: TiDBRelation, _, _, _))
if filters.forall(TiUtil.isSupportedFilter(_, source, blacklist)) =>
Expand Down Expand Up @@ -512,66 +504,132 @@ case class TiStrategy(getOrCreateTiContext: SparkSession => TiContext)(sparkSess
// We do through similar logic with original Spark as in SparkStrategies.scala
// Difference is we need to test if a sub-plan can be consumed all together by TiKV
// and then we don't return (don't planLater) and plan the remaining all at once
private def doPlan(source: TiDBRelation, plan: LogicalPlan): Seq[SparkPlan] =
private def doPlan(source: TiDBRelation, plan: LogicalPlan): Seq[SparkPlan] = {
// TODO: This test should be done once for all children
plan match {
case logical.ReturnAnswer(rootPlan) =>
rootPlan match {
/*case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
takeOrderedAndProject(limit, order, child, child.output) :: Nil*/
/*case logical.Limit(
IntegerLiteral(limit),
logical.Project(projectList, logical.Sort(order, true, child))
) =>
takeOrderedAndProject(limit, order, child, projectList) :: Nil*/
case logical.Limit(IntegerLiteral(limit), child) =>
execution.CollectLimitExec(limit, collectLimit(limit, child)) :: Nil
case other => planLater(other) :: Nil
}
/*case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
takeOrderedAndProject(limit, order, child, child.output) :: Nil*/
/* case logical.Limit(
IntegerLiteral(limit),
logical.Project(projectList, logical.Sort(order, true, child))
) =>
takeOrderedAndProject(limit, order, child, projectList) :: Nil*/
// Collapse filters and projections and push plan directly
case PhysicalOperation(
projectList,
filters,
LogicalRelation(source: TiDBRelation, _, _, _)
) =>
pruneFilterProject(projectList, filters, source) :: Nil

// Basic logic of original Spark's aggregation plan is:
// PhysicalAggregation extractor will rewrite original aggregation
// into aggregateExpressions and resultExpressions.
// resultExpressions contains only references [[AttributeReference]]
// to the result of aggregation. resultExpressions might contain projections
// like Add(sumResult, 1).
// For a aggregate like agg(expr) + 1, the rewrite process is: rewrite agg(expr) ->
// 1. pushdown: agg(expr) as agg1, if avg then sum(expr), count(expr)
// 2. residual expr (for Spark itself): agg(agg1) as finalAgg1 the parameter is a
// reference to pushed plan's corresponding aggregation
// 3. resultExpressions: finalAgg1 + 1, the finalAgg1 is the reference to final result
// of the aggregation
case TiAggregation(
groupingExpressions,
aggregateExpressions,
resultExpressions,
TiAggregationProjection(filters, _, `source`, projects)
) if isValidAggregates(groupingExpressions, aggregateExpressions, filters, source) =>
val projectSet = AttributeSet((projects ++ filters).flatMap { _.references })
val tiColumns = buildTiColumnRefFromColumnSeq(projectSet, source)
val dagReq: TiDAGRequest = filterToDAGRequest(tiColumns, filters, source)
groupAggregateProjection(
tiColumns,
groupingExpressions,
aggregateExpressions,
resultExpressions,
`source`,
dagReq
)
case _ => Nil
if (!allowOrderProjectLimitPushdown()) {
plan match {
case logical.ReturnAnswer(rootPlan) =>
rootPlan match {
/*case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
takeOrderedAndProject(limit, order, child, child.output) :: Nil*/
/*case logical.Limit(
IntegerLiteral(limit),
logical.Project(projectList, logical.Sort(order, true, child))
) =>
takeOrderedAndProject(limit, order, child, projectList) :: Nil*/
case logical.Limit(IntegerLiteral(limit), child) =>
execution.CollectLimitExec(limit, collectLimit(limit, child)) :: Nil
case other => planLater(other) :: Nil
}
/*case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
takeOrderedAndProject(limit, order, child, child.output) :: Nil*/
/* case logical.Limit(
IntegerLiteral(limit),
logical.Project(projectList, logical.Sort(order, true, child))
) =>
takeOrderedAndProject(limit, order, child, projectList) :: Nil*/
// Collapse filters and projections and push plan directly
case PhysicalOperation(
projectList,
filters,
LogicalRelation(source: TiDBRelation, _, _, _)
) =>
pruneFilterProject(projectList, filters, source) :: Nil

// Basic logic of original Spark's aggregation plan is:
// PhysicalAggregation extractor will rewrite original aggregation
// into aggregateExpressions and resultExpressions.
// resultExpressions contains only references [[AttributeReference]]
// to the result of aggregation. resultExpressions might contain projections
// like Add(sumResult, 1).
// For a aggregate like agg(expr) + 1, the rewrite process is: rewrite agg(expr) ->
// 1. pushdown: agg(expr) as agg1, if avg then sum(expr), count(expr)
// 2. residual expr (for Spark itself): agg(agg1) as finalAgg1 the parameter is a
// reference to pushed plan's corresponding aggregation
// 3. resultExpressions: finalAgg1 + 1, the finalAgg1 is the reference to final result
// of the aggregation
case TiAggregation(
groupingExpressions,
aggregateExpressions,
resultExpressions,
TiAggregationProjection(filters, _, `source`, projects)
) if isValidAggregates(groupingExpressions, aggregateExpressions, filters, source) =>
val projectSet = AttributeSet((projects ++ filters).flatMap { _.references })
val tiColumns = buildTiColumnRefFromColumnSeq(projectSet, source)
val dagReq: TiDAGRequest = filterToDAGRequest(tiColumns, filters, source)
groupAggregateProjection(
tiColumns,
groupingExpressions,
aggregateExpressions,
resultExpressions,
`source`,
dagReq
)
case _ => Nil
}
} else {
plan match {
case logical.ReturnAnswer(rootPlan) =>
rootPlan match {
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
takeOrderedAndProject(limit, order, child, child.output) :: Nil
case logical.Limit(
IntegerLiteral(limit),
logical.Project(projectList, logical.Sort(order, true, child))
) =>
takeOrderedAndProject(limit, order, child, projectList) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
execution.CollectLimitExec(limit, collectLimit(limit, child)) :: Nil
case other => planLater(other) :: Nil
}
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
takeOrderedAndProject(limit, order, child, child.output) :: Nil
case logical.Limit(
IntegerLiteral(limit),
logical.Project(projectList, logical.Sort(order, true, child))
) =>
takeOrderedAndProject(limit, order, child, projectList) :: Nil
// Collapse filters and projections and push plan directly
case PhysicalOperation(
projectList,
filters,
LogicalRelation(source: TiDBRelation, _, _, _)
) =>
pruneFilterProject(projectList, filters, source) :: Nil

// Basic logic of original Spark's aggregation plan is:
// PhysicalAggregation extractor will rewrite original aggregation
// into aggregateExpressions and resultExpressions.
// resultExpressions contains only references [[AttributeReference]]
// to the result of aggregation. resultExpressions might contain projections
// like Add(sumResult, 1).
// For a aggregate like agg(expr) + 1, the rewrite process is: rewrite agg(expr) ->
// 1. pushdown: agg(expr) as agg1, if avg then sum(expr), count(expr)
// 2. residual expr (for Spark itself): agg(agg1) as finalAgg1 the parameter is a
// reference to pushed plan's corresponding aggregation
// 3. resultExpressions: finalAgg1 + 1, the finalAgg1 is the reference to final result
// of the aggregation
case TiAggregation(
groupingExpressions,
aggregateExpressions,
resultExpressions,
TiAggregationProjection(filters, _, `source`, projects)
) if isValidAggregates(groupingExpressions, aggregateExpressions, filters, source) =>
val projectSet = AttributeSet((projects ++ filters).flatMap {
_.references
})
val tiColumns = buildTiColumnRefFromColumnSeq(projectSet, source)
val dagReq: TiDAGRequest = filterToDAGRequest(tiColumns, filters, source)
groupAggregateProjection(
tiColumns,
groupingExpressions,
aggregateExpressions,
resultExpressions,
`source`,
dagReq
)
case _ => Nil
}
}
}
}

0 comments on commit 93f3c2e

Please sign in to comment.