Skip to content

Commit

Permalink
[CORE] Code refactor: simplify transformer classes (#3426)
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you authored Oct 19, 2023
1 parent a713d5e commit 6090cea
Show file tree
Hide file tree
Showing 24 changed files with 54 additions and 457 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -335,5 +335,5 @@ class CHMetricsApi extends MetricsApi with Logging with LogLevelUtil {
Map.empty

override def genGenerateTransformerMetricsUpdater(
metrics: Map[String, SQLMetric]): MetricsUpdater = new NoopMetricsUpdater
metrics: Map[String, SQLMetric]): MetricsUpdater = NoopMetricsUpdater
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ import java.util
import scala.collection.JavaConverters._

case class CHFilterExecTransformer(condition: Expression, child: SparkPlan)
extends FilterExecTransformerBase(condition, child)
with TransformSupport {
extends FilterExecTransformerBase(condition, child) {

override protected def doValidateInternal(): ValidationResult = {
val leftCondition = getLeftCondition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ object MetricsUtil extends Logging {
case t: TransformSupport =>
MetricsUpdaterTree(t.metricsUpdater(), t.children.map(treeifyMetricsUpdaters))
case _ =>
MetricsUpdaterTree(new NoopMetricsUpdater, Seq())
MetricsUpdaterTree(NoopMetricsUpdater, Seq())
}
}

Expand Down Expand Up @@ -104,7 +104,7 @@ object MetricsUtil extends Logging {
s"Updating native metrics failed due to the wrong size of metrics data: " +
s"$numNativeMetrics")
()
} else if (mutNode.updater.isInstanceOf[NoopMetricsUpdater]) {
} else if (mutNode.updater == NoopMetricsUpdater) {
()
} else {
updateTransformerMetricsInternal(
Expand Down Expand Up @@ -155,7 +155,7 @@ object MetricsUtil extends Logging {

mutNode.children.foreach {
child =>
if (!child.updater.isInstanceOf[NoopMetricsUpdater]) {
if (child.updater != NoopMetricsUpdater) {
val result = updateTransformerMetricsInternal(
child,
relMap,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ import java.util
import scala.collection.JavaConverters._

case class FilterExecTransformer(condition: Expression, child: SparkPlan)
extends FilterExecTransformerBase(condition, child)
with TransformSupport {
extends FilterExecTransformerBase(condition, child) {

override protected def doValidateInternal(): ValidationResult = {
val leftCondition = getLeftCondition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/
package io.glutenproject.execution

import io.glutenproject.GlutenConfig
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.expression.{ConverterUtils, ExpressionConverter, ExpressionTransformer}
import io.glutenproject.extension.{GlutenPlan, ValidationResult}
Expand Down Expand Up @@ -44,8 +43,7 @@ import java.util
import scala.collection.JavaConverters._

abstract class FilterExecTransformerBase(val cond: Expression, val input: SparkPlan)
extends UnaryExecNode
with TransformSupport
extends UnaryTransformSupport
with PredicateHelper
with AliasAwareOutputPartitioning
with Logging {
Expand All @@ -65,34 +63,11 @@ abstract class FilterExecTransformerBase(val cond: Expression, val input: SparkP
// The columns that will filtered out by `IsNotNull` could be considered as not nullable.
private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId)

override def supportsColumnar: Boolean = GlutenConfig.getConf.enableColumnarIterator

override def isNullIntolerant(expr: Expression): Boolean = expr match {
case e: NullIntolerant => e.children.forall(isNullIntolerant)
case _ => false
}

override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = child match {
case c: TransformSupport =>
c.columnarInputRDDs
case _ =>
Seq(child.executeColumnar())
}

override def getBuildPlans: Seq[(SparkPlan, SparkPlan)] = child match {
case c: TransformSupport =>
c.getBuildPlans
case _ =>
Seq()
}

override def getStreamedLeafPlan: SparkPlan = child match {
case c: TransformSupport =>
c.getStreamedLeafPlan
case _ =>
this
}

override def metricsUpdater(): MetricsUpdater =
BackendsApiManager.getMetricsApiInstance.genFilterTransformerMetricsUpdater(metrics)

Expand Down Expand Up @@ -208,16 +183,10 @@ abstract class FilterExecTransformerBase(val cond: Expression, val input: SparkP
}
TransformContext(inputAttributes, output, currRel)
}

override protected def doExecute()
: org.apache.spark.rdd.RDD[org.apache.spark.sql.catalyst.InternalRow] = {
throw new UnsupportedOperationException(s"This operator doesn't support doExecute().")
}
}

case class ProjectExecTransformer private (projectList: Seq[NamedExpression], child: SparkPlan)
extends UnaryExecNode
with TransformSupport
extends UnaryTransformSupport
with PredicateHelper
with AliasAwareOutputPartitioning
with Logging {
Expand All @@ -228,8 +197,6 @@ case class ProjectExecTransformer private (projectList: Seq[NamedExpression], ch

val sparkConf: SparkConf = sparkContext.getConf

override def supportsColumnar: Boolean = GlutenConfig.getConf.enableColumnarIterator

override protected def doValidateInternal(): ValidationResult = {
val substraitContext = new SubstraitContext
// Firstly, need to check if the Substrait plan for this operator can be successfully generated.
Expand All @@ -245,27 +212,6 @@ case class ProjectExecTransformer private (projectList: Seq[NamedExpression], ch
case _ => false
}

override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = child match {
case c: TransformSupport =>
c.columnarInputRDDs
case _ =>
Seq(child.executeColumnar())
}

override def getBuildPlans: Seq[(SparkPlan, SparkPlan)] = child match {
case c: TransformSupport =>
c.getBuildPlans
case _ =>
Seq()
}

override def getStreamedLeafPlan: SparkPlan = child match {
case c: TransformSupport =>
c.getStreamedLeafPlan
case _ =>
this
}

override def metricsUpdater(): MetricsUpdater =
BackendsApiManager.getMetricsApiInstance.genProjectTransformerMetricsUpdater(metrics)

Expand Down Expand Up @@ -360,11 +306,6 @@ case class ProjectExecTransformer private (projectList: Seq[NamedExpression], ch

override protected def outputExpressions: Seq[NamedExpression] = projectList

override protected def doExecute()
: org.apache.spark.rdd.RDD[org.apache.spark.sql.catalyst.InternalRow] = {
throw new UnsupportedOperationException(s"This operator doesn't support doExecute().")
}

override protected def withNewChildInternal(newChild: SparkPlan): ProjectExecTransformer =
copy(child = newChild)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch

import com.google.common.collect.Lists

trait BasicScanExecTransformer extends TransformSupport with SupportFormat {
trait BasicScanExecTransformer extends LeafTransformSupport with SupportFormat {

// The key of merge schema option in Parquet reader.
protected val mergeSchemaOptionKey = "mergeschema"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/
package io.glutenproject.execution

import io.glutenproject.GlutenConfig
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.extension.ValidationResult
import io.glutenproject.metrics.MetricsUpdater
Expand All @@ -25,7 +24,6 @@ import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.connector.read.{InputPartition, Scan}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExecShim, FileScan}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -86,8 +84,6 @@ class BatchScanExecTransformer(
super.doValidateInternal()
}

override def supportsColumnar(): Boolean = GlutenConfig.getConf.enableColumnarIterator

override def doExecuteColumnar(): RDD[ColumnarBatch] = {
doExecuteColumnarInternal()
}
Expand All @@ -102,18 +98,6 @@ class BatchScanExecTransformer(

override def canEqual(other: Any): Boolean = other.isInstanceOf[BatchScanExecTransformer]

override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = {
Seq()
}

override def getBuildPlans: Seq[(SparkPlan, SparkPlan)] = {
Seq((this, null))
}

override def getStreamedLeafPlan: SparkPlan = {
this
}

override def metricsUpdater(): MetricsUpdater =
BackendsApiManager.getMetricsApiInstance.genBatchScanTransformerMetricsUpdater(metrics)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import io.glutenproject.substrait.extensions.ExtensionBuilder
import io.glutenproject.substrait.rel.{RelBuilder, RelNode}

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.execution._
Expand All @@ -44,7 +43,7 @@ case class ExpandExecTransformer(
output: Seq[Attribute],
child: SparkPlan)
extends UnaryExecNode
with TransformSupport {
with UnaryTransformSupport {

// Note: "metrics" is made transient to avoid sending driver-side metrics to tasks.
@transient override lazy val metrics =
Expand All @@ -59,26 +58,6 @@ case class ExpandExecTransformer(
// as UNKNOWN partitioning
override def outputPartitioning: Partitioning = UnknownPartitioning(0)

override def supportsColumnar: Boolean = true

override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = child match {
case c: TransformSupport =>
c.columnarInputRDDs
case _ =>
Seq(child.executeColumnar())
}

override def getBuildPlans: Seq[(SparkPlan, SparkPlan)] = {
throw new UnsupportedOperationException(s"This operator doesn't support getBuildPlans.")
}

override def getStreamedLeafPlan: SparkPlan = child match {
case c: TransformSupport =>
c.getStreamedLeafPlan
case _ =>
this
}

def getRelNode(
context: SubstraitContext,
projections: Seq[Seq[Expression]],
Expand Down Expand Up @@ -252,9 +231,6 @@ case class ExpandExecTransformer(
TransformContext(inputAttributes, output, currRel)
}

override protected def doExecute(): RDD[InternalRow] =
throw new UnsupportedOperationException("doExecute is not supported in ColumnarExpandExec.")

override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
throw new UnsupportedOperationException(s"This operator doesn't support doExecuteColumnar().")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/
package io.glutenproject.execution

import io.glutenproject.GlutenConfig
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.expression.ConverterUtils
import io.glutenproject.extension.ValidationResult
Expand All @@ -29,7 +28,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, BoundReference, DynamicPruningExpression, Expression, PlanExpression, Predicate}
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.execution.{FileSourceScanExecShim, InSubqueryExec, ScalarSubquery, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.{FileSourceScanExecShim, InSubqueryExec, ScalarSubquery, SQLExecution}
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDirectory}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -78,14 +77,6 @@ class FileSourceScanExecTransformer(
Map.empty[String, SQLMetric]
}

override lazy val supportsColumnar: Boolean = {
/*
relation.fileFormat
.supportBatch(relation.sparkSession, schema) && GlutenConfig.getConf.enableColumnarIterator
*/
GlutenConfig.getConf.enableColumnarIterator
}

override def filterExprs(): Seq[Expression] = dataFilters

override def outputAttributes(): Seq[Attribute] = output
Expand Down Expand Up @@ -117,18 +108,6 @@ class FileSourceScanExecTransformer(

override def hashCode(): Int = super.hashCode()

override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = {
Seq()
}

override def getBuildPlans: Seq[(SparkPlan, SparkPlan)] = {
Seq((this, null))
}

override def getStreamedLeafPlan: SparkPlan = {
this
}

override protected def doValidateInternal(): ValidationResult = {
// Bucketing table has `bucketId` in filename, should apply this in backends
// TODO Support bucketed scan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,8 @@ import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode}
import io.glutenproject.substrait.extensions.ExtensionBuilder
import io.glutenproject.substrait.rel.{RelBuilder, RelNode}

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.sql.execution.SparkPlan

import com.google.protobuf.Any

Expand All @@ -47,8 +44,7 @@ case class GenerateExecTransformer(
outer: Boolean,
generatorOutput: Seq[Attribute],
child: SparkPlan)
extends UnaryExecNode
with TransformSupport {
extends UnaryTransformSupport {

@transient
override lazy val metrics =
Expand All @@ -58,37 +54,9 @@ case class GenerateExecTransformer(

override def producedAttributes: AttributeSet = AttributeSet(generatorOutput)

override protected def doExecute(): RDD[InternalRow] = {
throw new UnsupportedOperationException(s"GenerateExecTransformer doesn't support doExecute")
}

override protected def withNewChildInternal(newChild: SparkPlan): GenerateExecTransformer =
copy(generator, requiredChildOutput, outer, generatorOutput, newChild)

override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = child match {
case c: TransformSupport =>
c.columnarInputRDDs
case _ =>
Seq(child.executeColumnar())
}

override def getBuildPlans: Seq[(SparkPlan, SparkPlan)] = child match {
case c: TransformSupport =>
val childPlans = c.getBuildPlans
childPlans :+ (this, null)
case _ =>
Seq((this, null))
}

override def getStreamedLeafPlan: SparkPlan = child match {
case c: TransformSupport =>
c.getStreamedLeafPlan
case _ =>
this
}

override def supportsColumnar: Boolean = true

override protected def doValidateInternal(): ValidationResult = {
val validationResult =
BackendsApiManager.getTransformerApiInstance.validateGenerator(generator, outer)
Expand Down
Loading

0 comments on commit 6090cea

Please sign in to comment.