Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiplication on ANSI interval types #5105

Merged
merged 10 commits into from
Apr 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,3 +946,11 @@ def test_subtraction_overflow_with_ansi_enabled_day_time_interval(ansi_enabled):
def test_unary_positive_day_time_interval():
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, DayTimeIntervalGen()).selectExpr('+a'))

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('data_gen', _no_overflow_multiply_gens + [DoubleGen(min_exp=-3, max_exp=5, special_cases=[0.0])], ids=idfn)
def test_day_time_interval_multiply_number(data_gen):
gen_list = [('_c1', DayTimeIntervalGen(min_value=timedelta(seconds=-20 * 86400), max_value=timedelta(seconds=20 * 86400))),
('_c2', data_gen)]
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, gen_list).selectExpr("_c1 * _c2"))
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,9 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging {
ExprChecks.projectAndAst(
TypeSig.astTypes,
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.CALENDAR
+ TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT + TypeSig.DAYTIME)
.nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 +
TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT),
+ TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT + TypeSig.ansiIntervals)
.nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 +
TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT),
TypeSig.all),
(lit, conf, p, r) => new LiteralExprMeta(lit, conf, p, r)),
GpuOverrides.expr[TimeAdd](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ object GpuTypeShims {
*/
def supportToScalarForType(t: DataType): Boolean = {
t match {
case _: YearMonthIntervalType => true
case _: DayTimeIntervalType => true
case _ => false
}
Expand All @@ -143,8 +144,13 @@ object GpuTypeShims {
/**
* Convert the given value to Scalar
*/
def toScalarForType(t: DataType, v: Any) = {
def toScalarForType(t: DataType, v: Any): Scalar = {
t match {
case _: YearMonthIntervalType => v match {
case i: Int => Scalar.fromInt(i)
case _ => throw new IllegalArgumentException(s"'$v: ${v.getClass}' is not supported" +
s" for IntType, expecting int")
}
case _: DayTimeIntervalType => v match {
case l: Long => Scalar.fromLong(l)
case _ => throw new IllegalArgumentException(s"'$v: ${v.getClass}' is not supported" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids._
import org.apache.spark.sql.rapids.shims.GpuTimeAdd
import org.apache.spark.sql.rapids.shims.{GpuMultiplyDTInterval, GpuMultiplyYMInterval, GpuTimeAdd}
import org.apache.spark.sql.types.{CalendarIntervalType, DayTimeIntervalType, DecimalType, StructType}
import org.apache.spark.unsafe.types.CalendarInterval

Expand Down Expand Up @@ -181,7 +181,29 @@ trait Spark33XShims extends Spark321PlusShims with Spark320PlusNonDBShims {

// ANSI support for ABS was added in 3.2.0 SPARK-33275
override def convertToGpu(child: Expression): GpuExpression = GpuAbs(child, ansiEnabled)
})
}),
GpuOverrides.expr[MultiplyYMInterval](
"Year-month interval * number",
ExprChecks.binaryProject(
TypeSig.YEARMONTH,
TypeSig.YEARMONTH,
("lhs", TypeSig.YEARMONTH, TypeSig.YEARMONTH),
("rhs", TypeSig.gpuNumeric - TypeSig.DECIMAL_128, TypeSig.gpuNumeric)),
(a, conf, p, r) => new BinaryExprMeta[MultiplyYMInterval](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuMultiplyYMInterval(lhs, rhs)
}),
GpuOverrides.expr[MultiplyDTInterval](
"Day-time interval * number",
ExprChecks.binaryProject(
TypeSig.DAYTIME,
TypeSig.DAYTIME,
("lhs", TypeSig.DAYTIME, TypeSig.DAYTIME),
("rhs", TypeSig.gpuNumeric - TypeSig.DECIMAL_128, TypeSig.gpuNumeric)),
(a, conf, p, r) => new BinaryExprMeta[MultiplyDTInterval](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuMultiplyDTInterval(lhs, rhs)
})
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
super.getExprs ++ map
}
Expand Down
Loading