Skip to content

Commit

Permalink
[SPARK-50333][SQL][FOLLOWUP] Codegen Support for CsvToStructs(`from…
Browse files Browse the repository at this point in the history
…_csv`) - remove Invoke

### What changes were proposed in this pull request?
The pr aims to implement the codegen of `CsvToStructs`(`from_csv`) in the way of `manually`, rather than in the way of `Invoke`.

### Why are the changes needed?
Based on cloud-fan's double-check, apache#48509 (comment)
I believe that restore to manual implementation will not result in regression.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Update existed UT.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#48983 from panbingkun/SPARK-50333_FOLLOWUP.

Authored-by: panbingkun <panbingkun@apache.org>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
panbingkun authored and cloud-fan committed Nov 29, 2024
1 parent 0c16e93 commit 376bd4a
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.{DataType, NullType, StructType}
import org.apache.spark.unsafe.types.UTF8String

/**
* The expression `CsvToStructs` will utilize the `Invoke` to call it, support codegen.
* The expression `CsvToStructs` will utilize it to support codegen.
*/
case class CsvToStructsEvaluator(
options: Map[String, String],
Expand Down Expand Up @@ -86,6 +86,7 @@ case class CsvToStructsEvaluator(
}

final def evaluate(csv: UTF8String): InternalRow = {
if (csv == null) return null
converter(parser.parse(csv.toString))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.csv._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.csv.{CsvToStructsEvaluator, SchemaOfCsvEvaluator}
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, TreePattern}
import org.apache.spark.sql.catalyst.util.TypeUtils._
import org.apache.spark.sql.errors.QueryErrorsBase
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -57,17 +57,12 @@ case class CsvToStructs(
timeZoneId: Option[String] = None,
requiredSchema: Option[StructType] = None)
extends UnaryExpression
with RuntimeReplaceable
with ExpectsInputTypes
with TimeZoneAwareExpression {
with TimeZoneAwareExpression
with ExpectsInputTypes {

override def nullable: Boolean = child.nullable

override def nodePatternsInternal(): Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE)

// The CSV input data might be missing certain fields. We force the nullability
// of the user-provided schema to avoid data corruptions.
private val nullableSchema: StructType = schema.asNullable
override def nullIntolerant: Boolean = true

// Used in `FunctionRegistry`
def this(child: Expression, schema: Expression, options: Map[String, String]) =
Expand All @@ -86,8 +81,6 @@ case class CsvToStructs(
child = child,
timeZoneId = None)

private val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD)

override def dataType: DataType = requiredSchema.getOrElse(schema).asNullable

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
Expand All @@ -98,16 +91,37 @@ case class CsvToStructs(

override def prettyName: String = "from_csv"

// The CSV input data might be missing certain fields. We force the nullability
// of the user-provided schema to avoid data corruptions.
private val nullableSchema: StructType = schema.asNullable

@transient
private val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD)

@transient
private lazy val evaluator: CsvToStructsEvaluator = CsvToStructsEvaluator(
options, nullableSchema, nameOfCorruptRecord, timeZoneId, requiredSchema)

override def replacement: Expression = Invoke(
Literal.create(evaluator, ObjectType(classOf[CsvToStructsEvaluator])),
"evaluate",
dataType,
Seq(child),
Seq(child.dataType))
override def nullSafeEval(input: Any): Any = {
evaluator.evaluate(input.asInstanceOf[UTF8String])
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val refEvaluator = ctx.addReferenceObj("evaluator", evaluator)
val eval = child.genCode(ctx)
val resultType = CodeGenerator.boxedType(dataType)
val resultTerm = ctx.freshName("result")
ev.copy(code =
code"""
|${eval.code}
|$resultType $resultTerm = ($resultType) $refEvaluator.evaluate(${eval.value});
|boolean ${ev.isNull} = $resultTerm == null;
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${ev.isNull}) {
| ${ev.value} = $resultTerm;
|}
|""".stripMargin)
}

override protected def withNewChildInternal(newChild: Expression): CsvToStructs =
copy(child = newChild)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [invoke(CsvToStructsEvaluator(Map(mode -> FAILFAST),StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true)),_corrupt_record,Some(America/Los_Angeles),None).evaluate(g#0)) AS from_csv(g)#0]
Project [from_csv(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), (mode,FAILFAST), g#0, Some(America/Los_Angeles), None) AS from_csv(g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]

0 comments on commit 376bd4a

Please sign in to comment.