Skip to content

Commit

Permalink
Restore output of unsupported column types (NVIDIA#2321)
Browse files Browse the repository at this point in the history
* Restore output of unsupport column types

Fixes NVIDIA#2320

Signed-off-by: Gera Shegalov <gera@apache.org>

* attribute references grouped by type

* original messages

* expression id
  • Loading branch information
gerashegalov authored May 3, 2021
1 parent 32c4b73 commit bd8f761
Showing 1 changed file with 33 additions and 26 deletions.
59 changes: 33 additions & 26 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.time.ZoneId

import ai.rapids.cudf.DType

import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Expression, UnaryExpression, WindowSpecDefinition}
import org.apache.spark.sql.catalyst.expressions.{Attribute, CaseWhen, Expression, UnaryExpression, WindowSpecDefinition}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -486,6 +486,29 @@ abstract class TypeChecks[RET] {
def support(dataType: TypeEnum.Value): RET

val shown: Boolean = true

private def stringifyTypeAttributeMap(groupedByType: Map[DataType, Set[String]]): String = {
groupedByType.map { case (dataType, nameSet) =>
dataType + " " + nameSet.mkString("[", ", ", "]")
}.mkString(", ")
}

protected def tagUnsupportedTypes(
meta: RapidsMeta[_, _, _],
sig: TypeSig,
allowDecimal: Boolean,
fields: Seq[StructField],
msgFormat: String
): Unit = {
val unsupportedOutputTypes: Map[DataType, Set[String]] = fields
.filterNot(attr => sig.isSupportedByPlugin(attr.dataType, allowDecimal))
.groupBy(_.dataType)
.mapValues(_.map(_.name).toSet)

if (unsupportedOutputTypes.nonEmpty) {
meta.willNotWorkOnGpu(msgFormat.format(stringifyTypeAttributeMap(unsupportedOutputTypes)))
}
}
}

/**
Expand Down Expand Up @@ -571,15 +594,8 @@ class FileFormatChecks private (
fileType: FileFormatType,
op: FileFormatOp): Unit = {
val allowDecimal = meta.conf.decimalTypeEnabled

val unsupportedOutputTypes = schema.fields
.filterNot(attr => sig.isSupportedByPlugin(attr.dataType, allowDecimal))
.toSet

if (unsupportedOutputTypes.nonEmpty) {
meta.willNotWorkOnGpu("unsupported data types " +
unsupportedOutputTypes.mkString(", ") + s" in $op for $fileType")
}
tagUnsupportedTypes(meta, sig, allowDecimal, schema.fields,
s"unsupported data types %s in $op for $fileType")
}

override def support(dataType: TypeEnum.Value): SupportLevel =
Expand Down Expand Up @@ -631,23 +647,14 @@ class ExecChecks private(
val plan = meta.wrapped.asInstanceOf[SparkPlan]
val allowDecimal = meta.conf.decimalTypeEnabled

val unsupportedOutputTypes = plan.output
.filterNot(attr => check.isSupportedByPlugin(attr.dataType, allowDecimal))
.toSet

if (unsupportedOutputTypes.nonEmpty) {
meta.willNotWorkOnGpu("unsupported data types in output: " +
unsupportedOutputTypes.mkString(", "))
}
// expression.toString to capture ids in not-on-GPU tags
def toStructField(a: Attribute) = StructField(name = a.toString(), dataType = a.dataType)

val unsupportedInputTypes = plan.children.flatMap { childPlan =>
childPlan.output.filterNot(attr => check.isSupportedByPlugin(attr.dataType, allowDecimal))
}.toSet

if (unsupportedInputTypes.nonEmpty) {
meta.willNotWorkOnGpu("unsupported data types in input: " +
unsupportedInputTypes.mkString(", "))
}
tagUnsupportedTypes(meta, check, allowDecimal, plan.output.map(toStructField),
"unsupported data types in output: %s")
tagUnsupportedTypes(meta, check, allowDecimal,
plan.children.flatMap(_.output.map(toStructField)),
"unsupported data types in input: %s")
}

override def support(dataType: TypeEnum.Value): SupportLevel =
Expand Down

0 comments on commit bd8f761

Please sign in to comment.