Skip to content

Commit

Permalink
Update docs to better describe support for floating point aggregation…
Browse files Browse the repository at this point in the history
… and NaNs (#3467)

Signed-off-by: Jason Lowe <jlowe@nvidia.com>
  • Loading branch information
jlowe authored Sep 14, 2021
1 parent 126c7d3 commit f5e3d0d
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 26 deletions.
36 changes: 18 additions & 18 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -1924,8 +1924,8 @@ are limited.
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>NaN literals are not supported. Columnar input must not contain NaNs and spark.rapids.sql.hasNans must be false.</em></td>
<td><em>PS<br/>NaN literals are not supported. Columnar input must not contain NaNs and spark.rapids.sql.hasNans must be false.</em></td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
Expand Down Expand Up @@ -14593,8 +14593,8 @@ are limited.
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>Input must not contain NaNs and spark.rapids.sql.hasNans must be false.</em></td>
<td><em>PS<br/>Input must not contain NaNs and spark.rapids.sql.hasNans must be false.</em></td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
Expand Down Expand Up @@ -14636,8 +14636,8 @@ are limited.
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>Input must not contain NaNs and spark.rapids.sql.hasNans must be false.</em></td>
<td><em>PS<br/>Input must not contain NaNs and spark.rapids.sql.hasNans must be false.</em></td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
Expand Down Expand Up @@ -14679,8 +14679,8 @@ are limited.
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>Input must not contain NaNs and spark.rapids.sql.hasNans must be false.</em></td>
<td><em>PS<br/>Input must not contain NaNs and spark.rapids.sql.hasNans must be false.</em></td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
Expand Down Expand Up @@ -14726,8 +14726,8 @@ are limited.
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>Input must not contain NaNs and spark.rapids.sql.hasNans must be false.</em></td>
<td><em>PS<br/>Input must not contain NaNs and spark.rapids.sql.hasNans must be false.</em></td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
Expand Down Expand Up @@ -14769,8 +14769,8 @@ are limited.
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>Input must not contain NaNs and spark.rapids.sql.hasNans must be false.</em></td>
<td><em>PS<br/>Input must not contain NaNs and spark.rapids.sql.hasNans must be false.</em></td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
Expand Down Expand Up @@ -14812,8 +14812,8 @@ are limited.
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>Input must not contain NaNs and spark.rapids.sql.hasNans must be false.</em></td>
<td><em>PS<br/>Input must not contain NaNs and spark.rapids.sql.hasNans must be false.</em></td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
Expand Down Expand Up @@ -14859,8 +14859,8 @@ are limited.
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>Input must not contain NaNs and spark.rapids.sql.hasNans must be false.</em></td>
<td><em>PS<br/>Input must not contain NaNs and spark.rapids.sql.hasNans must be false.</em></td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
Expand Down Expand Up @@ -14923,8 +14923,8 @@ are limited.
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>Input must not contain NaNs and spark.rapids.sql.hasNans must be false.</em></td>
<td><em>PS<br/>Input must not contain NaNs and spark.rapids.sql.hasNans must be false.</em></td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,9 @@ object GpuOverrides extends Logging {
}
}

private val nanAggPsNote = "Input must not contain NaNs and" +
s" ${RapidsConf.HAS_NANS} must be false."

def expr[INPUT <: Expression](
desc: String,
pluginChecks: ExprChecks,
Expand Down Expand Up @@ -2036,7 +2039,9 @@ object GpuOverrides extends Logging {
TypeSig.all,
Seq(ParamCheck(
"pivotColumn",
TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_64,
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_64)
.withPsNote(TypeEnum.DOUBLE, nanAggPsNote)
.withPsNote(TypeEnum.FLOAT, nanAggPsNote),
TypeSig.all),
ParamCheck("valueColumn",
TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_64,
Expand Down Expand Up @@ -2075,13 +2080,19 @@ object GpuOverrides extends Logging {
ExprChecks.fullAgg(
TypeSig.commonCudfTypes + TypeSig.NULL, TypeSig.orderable,
Seq(ParamCheck("input",
TypeSig.commonCudfTypes + TypeSig.NULL, TypeSig.orderable))
(TypeSig.commonCudfTypes + TypeSig.NULL)
.withPsNote(TypeEnum.DOUBLE, nanAggPsNote)
.withPsNote(TypeEnum.FLOAT, nanAggPsNote),
TypeSig.orderable))
).asInstanceOf[ExprChecksImpl].contexts
++
ExprChecks.windowOnly(
TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL, TypeSig.orderable,
Seq(ParamCheck("input",
TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL, TypeSig.orderable))
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL)
.withPsNote(TypeEnum.DOUBLE, nanAggPsNote)
.withPsNote(TypeEnum.FLOAT, nanAggPsNote),
TypeSig.orderable))
).asInstanceOf[ExprChecksImpl].contexts
),
(max, conf, p, r) => new AggExprMeta[Max](max, conf, p, r) {
Expand All @@ -2098,13 +2109,19 @@ object GpuOverrides extends Logging {
ExprChecks.fullAgg(
TypeSig.commonCudfTypes + TypeSig.NULL, TypeSig.orderable,
Seq(ParamCheck("input",
TypeSig.commonCudfTypes + TypeSig.NULL, TypeSig.orderable))
(TypeSig.commonCudfTypes + TypeSig.NULL)
.withPsNote(TypeEnum.DOUBLE, nanAggPsNote)
.withPsNote(TypeEnum.FLOAT, nanAggPsNote),
TypeSig.orderable))
).asInstanceOf[ExprChecksImpl].contexts
++
ExprChecks.windowOnly(
TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL, TypeSig.orderable,
Seq(ParamCheck("input",
TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL, TypeSig.orderable))
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL)
.withPsNote(TypeEnum.DOUBLE, nanAggPsNote)
.withPsNote(TypeEnum.FLOAT, nanAggPsNote),
TypeSig.orderable))
).asInstanceOf[ExprChecksImpl].contexts
),
(a, conf, p, r) => new AggExprMeta[Min](a, conf, p, r) {
Expand Down Expand Up @@ -2503,7 +2520,9 @@ object GpuOverrides extends Logging {
TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL - TypeSig.STRING,
TypeSig.orderable,
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL -
TypeSig.STRING),
TypeSig.STRING)
.withPsNote(TypeEnum.DOUBLE, nanAggPsNote)
.withPsNote(TypeEnum.FLOAT, nanAggPsNote),
TypeSig.ARRAY.nested(TypeSig.orderable)),
(in, conf, p, r) => new UnaryExprMeta[ArrayMin](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
Expand All @@ -2521,7 +2540,9 @@ object GpuOverrides extends Logging {
TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL - TypeSig.STRING,
TypeSig.orderable,
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL
- TypeSig.STRING),
- TypeSig.STRING)
.withPsNote(TypeEnum.DOUBLE, nanAggPsNote)
.withPsNote(TypeEnum.FLOAT, nanAggPsNote),
TypeSig.ARRAY.nested(TypeSig.orderable)),
(in, conf, p, r) => new UnaryExprMeta[ArrayMax](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
Expand All @@ -2545,7 +2566,12 @@ object GpuOverrides extends Logging {
TypeSig.BOOLEAN,
("array", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.NULL),
TypeSig.ARRAY.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypes, TypeSig.all)),
("key", TypeSig.commonCudfTypes
.withPsNote(TypeEnum.DOUBLE, "NaN literals are not supported. Columnar input" +
s" must not contain NaNs and ${RapidsConf.HAS_NANS} must be false.")
.withPsNote(TypeEnum.FLOAT, "NaN literals are not supported. Columnar input" +
s" must not contain NaNs and ${RapidsConf.HAS_NANS} must be false."),
TypeSig.all)),
(in, conf, p, r) => new BinaryExprMeta[ArrayContains](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
// do not support literal arrays as LHS
Expand Down Expand Up @@ -3444,6 +3470,7 @@ object GpuOverrides extends Logging {
val postColToRowProjection = TreeNodeTag[Seq[NamedExpression]](
"rapids.gpu.postColToRowProcessing")
}

/** Tag the initial plan when AQE is enabled */
case class GpuQueryStagePrepOverrides() extends Rule[SparkPlan] with Logging {
override def apply(plan: SparkPlan) :SparkPlan = {
Expand Down

0 comments on commit f5e3d0d

Please sign in to comment.