From 15fb4cebb4ce46fb06df889c14776efa90433a23 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Fri, 17 Sep 2021 05:29:48 +0800 Subject: [PATCH] resolve comments --- .../rapids/shims/v2/datetimeExpressions.scala | 32 ++++++++++++++----- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/sql-plugin/src/main/320/scala/org/apache/spark/sql/rapids/shims/v2/datetimeExpressions.scala b/sql-plugin/src/main/320/scala/org/apache/spark/sql/rapids/shims/v2/datetimeExpressions.scala index 997e1e4283c..e271319f01b 100644 --- a/sql-plugin/src/main/320/scala/org/apache/spark/sql/rapids/shims/v2/datetimeExpressions.scala +++ b/sql-plugin/src/main/320/scala/org/apache/spark/sql/rapids/shims/v2/datetimeExpressions.scala @@ -16,12 +16,14 @@ package org.apache.spark.sql.rapids.shims.v2 +import java.util.concurrent.TimeUnit + import ai.rapids.cudf.{ColumnVector, ColumnView, DType, Scalar} -import com.nvidia.spark.rapids.{GpuColumnVector, GpuScalar} +import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuScalar} import com.nvidia.spark.rapids.RapidsPluginImplicits._ +import com.nvidia.spark.rapids.shims.v2.ShimBinaryExpression -import org.apache.spark.sql.catalyst.expressions.{Expression, TimeZoneAwareExpression} -import org.apache.spark.sql.rapids.GpuTimeMath +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, TimeZoneAwareExpression} import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.unsafe.types.CalendarInterval @@ -29,16 +31,26 @@ import org.apache.spark.unsafe.types.CalendarInterval case class GpuTimeAdd(start: Expression, interval: Expression, timeZoneId: Option[String] = None) - extends GpuTimeMath(start, interval, timeZoneId) { + extends ShimBinaryExpression + with GpuExpression + with TimeZoneAwareExpression + with ExpectsInputTypes + with Serializable { + + override def left: Expression = start + override def right: Expression = interval + + override def toString: String = s"$left - $right" + override def sql: String = s"${left.sql} - ${right.sql}" + + override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess + + val microSecondsInOneDay: Long = TimeUnit.DAYS.toMicros(1) override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = { copy(timeZoneId = Option(timeZoneId)) } - override def intervalMath(us_s: Scalar, us: ColumnView): ColumnVector = { - us.add(us_s) - } - override def inputTypes: Seq[AbstractDataType] = Seq(AnyTimestampType, TypeCollection(CalendarIntervalType, DayTimeIntervalType)) @@ -87,4 +99,8 @@ case class GpuTimeAdd(start: Expression, } } } + + private def intervalMath(us_s: Scalar, us: ColumnView): ColumnVector = { + us.add(us_s) + } }