Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Sep 16, 2021
1 parent b0096ce commit 15fb4ce
Showing 1 changed file with 24 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,41 @@

package org.apache.spark.sql.rapids.shims.v2

import java.util.concurrent.TimeUnit

import ai.rapids.cudf.{ColumnVector, ColumnView, DType, Scalar}
import com.nvidia.spark.rapids.{GpuColumnVector, GpuScalar}
import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuScalar}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.shims.v2.ShimBinaryExpression

import org.apache.spark.sql.catalyst.expressions.{Expression, TimeZoneAwareExpression}
import org.apache.spark.sql.rapids.GpuTimeMath
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, TimeZoneAwareExpression}
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.types.CalendarInterval

case class GpuTimeAdd(start: Expression,
interval: Expression,
timeZoneId: Option[String] = None)
extends GpuTimeMath(start, interval, timeZoneId) {
extends ShimBinaryExpression
with GpuExpression
with TimeZoneAwareExpression
with ExpectsInputTypes
with Serializable {

override def left: Expression = start
override def right: Expression = interval

override def toString: String = s"$left - $right"
override def sql: String = s"${left.sql} - ${right.sql}"

override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess

val microSecondsInOneDay: Long = TimeUnit.DAYS.toMicros(1)

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
copy(timeZoneId = Option(timeZoneId))
}

override def intervalMath(us_s: Scalar, us: ColumnView): ColumnVector = {
us.add(us_s)
}

override def inputTypes: Seq[AbstractDataType] =
Seq(AnyTimestampType, TypeCollection(CalendarIntervalType, DayTimeIntervalType))

Expand Down Expand Up @@ -87,4 +99,8 @@ case class GpuTimeAdd(start: Expression,
}
}
}

private def intervalMath(us_s: Scalar, us: ColumnView): ColumnVector = {
us.add(us_s)
}
}

0 comments on commit 15fb4ce

Please sign in to comment.