diff --git a/integration_tests/src/main/python/arithmetic_ops_test.py b/integration_tests/src/main/python/arithmetic_ops_test.py index d529bc19c75..55295c67a8b 100644 --- a/integration_tests/src/main/python/arithmetic_ops_test.py +++ b/integration_tests/src/main/python/arithmetic_ops_test.py @@ -954,3 +954,59 @@ def test_day_time_interval_multiply_number(data_gen): ('_c2', data_gen)] assert_gpu_and_cpu_are_equal_collect( lambda spark: gen_df(spark, gen_list).selectExpr("_c1 * _c2")) + + +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') +@pytest.mark.parametrize('data_gen', _no_overflow_multiply_gens + [DoubleGen(min_exp=0, max_exp=5, special_cases=[])], ids=idfn) +def test_day_time_interval_division_number_no_overflow1(data_gen): + gen_list = [('_c1', DayTimeIntervalGen(min_value=timedelta(seconds=-5000 * 365 * 86400), max_value=timedelta(seconds=5000 * 365 * 86400))), + ('_c2', data_gen)] + assert_gpu_and_cpu_are_equal_collect( + # avoid dividing by 0 + lambda spark: gen_df(spark, gen_list).selectExpr("_c1 / case when _c2 = 0 then cast(1 as {}) else _c2 end".format(to_cast_string(data_gen.data_type)))) + +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') +@pytest.mark.parametrize('data_gen', _no_overflow_multiply_gens + [DoubleGen(min_exp=-5, max_exp=0, special_cases=[])], ids=idfn) +def test_day_time_interval_division_number_no_overflow2(data_gen): + gen_list = [('_c1', DayTimeIntervalGen(min_value=timedelta(seconds=-20 * 86400), max_value=timedelta(seconds=20 * 86400))), + ('_c2', data_gen)] + assert_gpu_and_cpu_are_equal_collect( + # avoid dividing by 0 + lambda spark: gen_df(spark, gen_list).selectExpr("_c1 / case when _c2 = 0 then cast(1 as {}) else _c2 end".format(to_cast_string(data_gen.data_type)))) + +def _get_overflow_df_2cols(spark, data_types, values, expr): + return spark.createDataFrame( + SparkContext.getOrCreate().parallelize([values]), + StructType([ + StructField('a', data_types[0]), + StructField('b', data_types[1]) + ]) + ).selectExpr(expr) + +# test interval division overflow, such as interval / 0, Long.MinValue / -1 ... +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') +@pytest.mark.parametrize('data_type,value_pair', [ + (ByteType(), [timedelta(seconds=1), 0]), + (ShortType(), [timedelta(seconds=1), 0]), + (IntegerType(), [timedelta(seconds=1), 0]), + (LongType(), [timedelta(seconds=1), 0]), + (FloatType(), [timedelta(seconds=1), 0.0]), + (FloatType(), [timedelta(seconds=1), -0.0]), + (DoubleType(), [timedelta(seconds=1), 0.0]), + (DoubleType(), [timedelta(seconds=1), -0.0]), + (FloatType(), [timedelta(seconds=1), float('NaN')]), + (DoubleType(), [timedelta(seconds=1), float('NaN')]), + (FloatType(), [MAX_DAY_TIME_INTERVAL, 0.1]), + (DoubleType(), [MAX_DAY_TIME_INTERVAL, 0.1]), + (FloatType(), [MIN_DAY_TIME_INTERVAL, 0.1]), + (DoubleType(), [MIN_DAY_TIME_INTERVAL, 0.1]), + (LongType(), [MIN_DAY_TIME_INTERVAL, -1]), + (FloatType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN + (DoubleType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN + (IntegerType(), [timedelta(microseconds=LONG_MIN), -1]) +], ids=idfn) +def test_day_time_interval_division_overflow(data_type, value_pair): + assert_gpu_and_cpu_error( + df_fun=lambda spark: _get_overflow_df_2cols(spark, [DayTimeIntervalType(), data_type], value_pair, 'a / b').collect(), + conf={}, + error_message='ArithmeticException') diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark33XShims.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark33XShims.scala index e5daa65acfb..34a9e6e66d2 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark33XShims.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark33XShims.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids._ -import org.apache.spark.sql.rapids.shims.{GpuMultiplyDTInterval, GpuMultiplyYMInterval, GpuTimeAdd} +import org.apache.spark.sql.rapids.shims.{GpuDivideDTInterval, GpuDivideYMInterval, GpuMultiplyDTInterval, GpuMultiplyYMInterval, GpuTimeAdd} import org.apache.spark.sql.types.{CalendarIntervalType, DayTimeIntervalType, DecimalType, StructType} import org.apache.spark.unsafe.types.CalendarInterval @@ -203,6 +203,28 @@ trait Spark33XShims extends Spark321PlusShims with Spark320PlusNonDBShims { (a, conf, p, r) => new BinaryExprMeta[MultiplyDTInterval](a, conf, p, r) { override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = GpuMultiplyDTInterval(lhs, rhs) + }), + GpuOverrides.expr[DivideYMInterval]( + "Year-month interval * operator", + ExprChecks.binaryProject( + TypeSig.YEARMONTH, + TypeSig.YEARMONTH, + ("lhs", TypeSig.YEARMONTH, TypeSig.YEARMONTH), + ("rhs", TypeSig.gpuNumeric - TypeSig.DECIMAL_128, TypeSig.gpuNumeric)), + (a, conf, p, r) => new BinaryExprMeta[DivideYMInterval](a, conf, p, r) { + override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = + GpuDivideYMInterval(lhs, rhs) + }), + GpuOverrides.expr[DivideDTInterval]( + "Day-time interval * operator", + ExprChecks.binaryProject( + TypeSig.DAYTIME, + TypeSig.DAYTIME, + ("lhs", TypeSig.DAYTIME, TypeSig.DAYTIME), + ("rhs", TypeSig.gpuNumeric - TypeSig.DECIMAL_128, TypeSig.gpuNumeric)), + (a, conf, p, r) => new BinaryExprMeta[DivideDTInterval](a, conf, p, r) { + override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = + GpuDivideDTInterval(lhs, rhs) }) ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap super.getExprs ++ map diff --git a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala index 41160f568be..2c923ca0791 100644 --- a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala +++ b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala @@ -16,6 +16,8 @@ package org.apache.spark.sql.rapids.shims +import java.math.BigInteger + import ai.rapids.cudf.{BinaryOperable, ColumnVector, DType, RoundMode, Scalar} import com.nvidia.spark.rapids.{Arm, BoolUtils, GpuBinaryExpression, GpuColumnVector, GpuScalar} @@ -184,6 +186,113 @@ object IntervalUtils extends Arm { z.castTo(DType.INT64) } } + + def getDouble(s: Scalar): Double = { + s.getType match { + case DType.INT8 => s.getByte.toDouble + case DType.INT16 => s.getShort.toDouble + case DType.INT32 => s.getInt.toDouble + case DType.INT64 => s.getLong.toDouble + case DType.FLOAT32 => s.getFloat.toDouble + case DType.FLOAT64 => s.getDouble + case t => throw new IllegalArgumentException(s"Unexpected type $t") + } + } + + def getLong(s: Scalar): Long = { + s.getType match { + case DType.INT8 => s.getByte.toLong + case DType.INT16 => s.getShort.toLong + case DType.INT32 => s.getInt.toLong + case DType.INT64 => s.getLong + case t => throw new IllegalArgumentException(s"Unexpected type $t") + } + } + + def hasZero(cv: BinaryOperable): Boolean = { + cv match { + case s: Scalar => getDouble(s) == 0.0d + case cv: ColumnVector => + withResource(Scalar.fromInt(0)) { zero => + withResource(cv.equalTo(zero)) { isZero => + BoolUtils.isAnyValidTrue(isZero) + } + } + } + } + + /** + * Divide p by q rounding with the HALF_UP mode. + * Logic is round(p.asDecimal / q, HALF_UP). + * + * It's equivalent to + * com.google.common.math.LongMath.divide(p, q, RoundingMode.HALF_UP): + * long div = p / q; // throws if q == 0 + * long rem = p - q * div; // equals p % q + * // signum is 1 if p and q are both non-negative or both negative, and -1 otherwise. + * int signum = 1 | (int) ((p xor q) >> (Long.SIZE - 1)); + * long absRem = abs(rem); + * long cmpRemToHalfDivisor = absRem - (abs(q) - absRem); + * increment = cmpRemToHalfDivisor >= 0 + * return increment ? div + signum : div; + * + * @param p integer(int/long) cv or scalar + * @param q integer(byte/short/int/long) cv or Scalar, if p is scala, q will not be scala + * @return decimal 128 cv of p / q + */ + def divWithHalfUpModeWithOverflowCheck(p: BinaryOperable, q: BinaryOperable): ColumnVector = { + // 1. overflow check q is 0 + if (IntervalUtils.hasZero(q)) { + throw new ArithmeticException("overflow: interval / zero") + } + + // 2. overflow check (p == min(int min or long min) && q == -1) + val min = p.getType match { + case DType.INT32 => Int.MinValue.toLong + case DType.INT64 => Long.MinValue + case t => throw new IllegalArgumentException(s"Unexpected type $t") + } + withResource(Scalar.fromLong(min)) { minScalar => + withResource(Scalar.fromLong(-1L)) { negOneScalar => + (p, q) match { + case (lCv: ColumnVector, rCv: ColumnVector) => + withResource(lCv.equalTo(minScalar)) { isMin => + withResource(rCv.equalTo(negOneScalar)) { isNegOne => + withResource(isMin.and(isNegOne)) { invalid => + if (BoolUtils.isAnyValidTrue(invalid)) { + throw new ArithmeticException("overflow occurs") + } + } + } + } + case (lCv: ColumnVector, rS: Scalar) => + withResource(lCv.equalTo(minScalar)) { isMin => + if (getLong(rS) == -1L && BoolUtils.isAnyValidTrue(isMin)) { + throw new ArithmeticException("overflow occurs") + } + } + case (lS: Scalar, rCv: ColumnVector) => + withResource(rCv.equalTo(negOneScalar)) { isNegOne => + if (getLong(lS) == min && BoolUtils.isAnyValidTrue(isNegOne)) { + throw new ArithmeticException("overflow occurs") + } + } + case (lS: Scalar, rS: Scalar) => + getLong(lS) == min && getLong(rS) == -1L + } + } + } + + // 3. round(p.asDecimal / q) + val dT = DType.create(DType.DTypeEnum.DECIMAL128, -1) + val leftDecimal = p match { + case pCv: ColumnVector => pCv.castTo(dT) + case pS: Scalar => Scalar.fromDecimal(-1, new BigInteger((getLong(pS) * 10L).toString)) + } + withResource(leftDecimal.div(q, dT)) { t => + t.round(RoundMode.HALF_UP) + } + } } /** @@ -321,3 +430,139 @@ case class GpuMultiplyDTInterval( override def dataType: DataType = DayTimeIntervalType() } + +/** + * Divide a day-time interval by a numeric with a HALF_UP mode: + * e.g.: 3 / 2 = 1.5 will be rounded to 2; -3 / 2 = -1.5 will be rounded to -2 + * year-month interval / number(byte, short, int, long, float, double) + * Note not support year-month interval / decimal + * Year-month interval's internal type is int, the value of int is 12 * year + month + * left expression is interval, right expression is number + * Rewrite from Spark code: + * https://github.com/apache/spark/blob/v3.2.1/sql/catalyst/src/main/scala/ + * org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala#L615 + * + */ +case class GpuDivideYMInterval( + interval: Expression, + num: Expression) extends GpuBinaryExpression with ImplicitCastInputTypes with NullIntolerant { + + override def left: Expression = interval + + override def right: Expression = num + + override def doColumnar(interval: GpuColumnVector, numScalar: GpuScalar): ColumnVector = { + doColumnar(interval.getBase, numScalar.getBase, num.dataType) + } + + override def doColumnar(interval: GpuColumnVector, num: GpuColumnVector): ColumnVector = { + doColumnar(interval.getBase, num.getBase, num.dataType) + } + + override def doColumnar(intervalScalar: GpuScalar, num: GpuColumnVector): ColumnVector = { + doColumnar(intervalScalar.getBase, num.getBase, num.dataType) + } + + override def doColumnar(numRows: Int, intervalScalar: GpuScalar, + numScalar: GpuScalar): ColumnVector = { + withResource(GpuColumnVector.from(intervalScalar, numRows, interval.dataType)) { expandedLhs => + doColumnar(expandedLhs, numScalar) + } + } + + private def doColumnar(interval: BinaryOperable, numOperable: BinaryOperable, + numType: DataType): ColumnVector = { + + numType match { + case ByteType | ShortType | IntegerType | LongType => + // interval is long; num is byte, short, int or long + // For overflow check: num is 0; interval == Long.Min && num == -1 + withResource(IntervalUtils.divWithHalfUpModeWithOverflowCheck(interval, numOperable)) { + // overflow already checked, directly cast without overflow check + decimalRet => decimalRet.castTo(DType.INT32) + } + + case FloatType | DoubleType => // num is float or double + withResource(interval.div(numOperable, DType.FLOAT64)) { double => + // check overflow, then round to int + IntervalUtils.roundDoubleToIntWithOverflowCheck(double) + } + case _ => throw new IllegalArgumentException( + s"Not support num type $numType in GpuDivideYMInterval") + } + } + + override def toString: String = s"$interval / $num" + + override def inputTypes: Seq[AbstractDataType] = Seq(YearMonthIntervalType, NumericType) + + override def dataType: DataType = YearMonthIntervalType() +} + +/** + * Divide a day-time interval by a numeric with a HALF_UP mode: + * e.g.: 3 / 2 = 1.5 will be rounded to 2; -3 / 2 = -1.5 will be rounded to -2 + * day-time interval / number(byte, short, int, long, float, double) + * Note not support day-time interval / decimal + * Day-time interval's interval type is long, the value of long is the total microseconds + * Rewrite from Spark code: + * https://github.com/apache/spark/blob/v3.2.1/sql/catalyst/src/main/scala/ + * org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala#L693 + */ +case class GpuDivideDTInterval( + interval: Expression, + num: Expression) + extends GpuBinaryExpression with ImplicitCastInputTypes with NullIntolerant { + + override def left: Expression = interval + + override def right: Expression = num + + override def doColumnar(interval: GpuColumnVector, numScalar: GpuScalar): ColumnVector = { + doColumnar(interval.getBase, numScalar.getBase, num.dataType) + } + + override def doColumnar(interval: GpuColumnVector, num: GpuColumnVector): ColumnVector = { + doColumnar(interval.getBase, num.getBase, num.dataType) + } + + override def doColumnar(intervalScalar: GpuScalar, num: GpuColumnVector): ColumnVector = { + doColumnar(intervalScalar.getBase, num.getBase, num.dataType) + } + + override def doColumnar(numRows: Int, intervalScalar: GpuScalar, + numScalar: GpuScalar): ColumnVector = { + withResource(GpuColumnVector.from(intervalScalar, numRows, interval.dataType)) { expandedLhs => + doColumnar(expandedLhs, numScalar) + } + } + + private def doColumnar(interval: BinaryOperable, numOperable: BinaryOperable, + numType: DataType): ColumnVector = { + numType match { + case ByteType | ShortType | IntegerType | LongType => + // interval is long; num is byte, short, int or long + // For overflow check: num is 0; interval == Long.Min && num == -1 + withResource(IntervalUtils.divWithHalfUpModeWithOverflowCheck(interval, numOperable)) { + // overflow already checked, directly cast without overflow check + decimalRet => decimalRet.castTo(DType.INT64) + } + + case FloatType | DoubleType => + // interval is long; num is float or double + withResource(interval.div(numOperable, DType.FLOAT64)) { double => + // check overflow, then round to long + IntervalUtils.roundDoubleToLongWithOverflowCheck(double) + } + case _ => throw new IllegalArgumentException( + s"Not support num type $numType in GpuDivideDTInterval") + } + } + + override def toString: String = s"$interval / $num" + + override def inputTypes: Seq[AbstractDataType] = Seq(DayTimeIntervalType, NumericType) + + override def dataType: DataType = DayTimeIntervalType() + +} diff --git a/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala b/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala new file mode 100644 index 00000000000..b6426da11a0 --- /dev/null +++ b/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala @@ -0,0 +1,236 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import java.time.Period + +import scala.util.Random + +import org.apache.spark.SparkException +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +// Currently Pyspark does not support YearMonthIntervalType, +// TODO move this to the integration test module, see issue +// https://github.com/NVIDIA/spark-rapids/issues/5212 +class IntervalDivisionSuite extends SparkQueryCompareTestSuite { + + testSparkResultsAreEqual( + "test year-month interval / num, normal case", + spark => { + val data = Seq( + Row(Period.ofMonths(5), 3.toByte, 2.toShort, 1, 4L, 2.5f, 2.6d), + Row(Period.ofMonths(5), null, null, null, null, null, null), + Row(null, 3.toByte, 3.toShort, 3, 3L, 1.1f, 2.7d), + Row(Period.ofMonths(0), 3.toByte, 3.toShort, 3, 3L, 3.1f, 3.4d), + Row(Period.ofMonths(0), Byte.MinValue, Short.MinValue, Int.MinValue, Long.MinValue, + Float.MinValue, Double.MinValue), + Row(Period.ofMonths(0), + Byte.MaxValue, Short.MaxValue, Int.MaxValue, Long.MaxValue, + Float.MaxValue, Double.MaxValue), + Row(Period.ofMonths(7), 4.toByte, 3.toShort, 2, 1L, + Float.NegativeInfinity, Double.NegativeInfinity), + Row(Period.ofMonths(7), 6.toByte, 5.toShort, 4, 3L, + Float.PositiveInfinity, Double.PositiveInfinity) + ) + val schema = StructType(Seq( + StructField("c_ym", YearMonthIntervalType()), + StructField("c_b", ByteType), + StructField("c_s", ShortType), + StructField("c_i", IntegerType), + StructField("c_l", LongType), + StructField("c_f", FloatType), + StructField("c_d", DoubleType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + }) { + df => + df.selectExpr( + "c_ym / c_b", "c_ym / c_s", "c_ym / c_i", "c_ym / c_l", "c_ym / c_f", "c_ym / c_d", + "c_ym / cast('15' as byte)", "c_ym / cast('15' as short)", "c_ym / 15", + "c_ym / 15L", "c_ym / 15.1f", "c_ym / 15.1d", + "interval '55' month / c_b", + "interval '15' month / c_s", + "interval '25' month / c_i", + "interval '15' month / c_l", + "interval '25' month / c_f", + "interval '35' month / c_d") + } + + testSparkResultsAreEqual("test year-month interval / float num, normal case", + spark => { + val numRows = 1024 + val r = new Random(0) + + def getSign: Int = if (r.nextBoolean()) 1 else -1 + + val data = (0 until numRows).map(i => { + Row(Period.ofMonths(getSign * r.nextInt(Int.MaxValue)), + getSign * (r.nextFloat() + 10f), // add 10 to overflow + getSign * (r.nextDouble() + 10d)) + }) + + val schema = StructType(Seq( + StructField("c_ym", YearMonthIntervalType()), + StructField("c_f", FloatType), + StructField("c_d", DoubleType))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => df.selectExpr("c_ym / c_f", "c_ym / c_d") + } + + testSparkResultsAreEqual("test year-month interval / int num, normal case", + spark => { + val numRows = 1024 + val r = new Random(0) + + def getSign: Int = if (r.nextBoolean()) 1 else -1 + + val data = (0 until numRows).map(i => { + Row(Period.ofMonths(getSign * r.nextInt(Int.MaxValue)), + getSign * r.nextInt(1024 * 1024) + 1, // add 1 to avoid dividing 0 + (getSign * r.nextInt(1024 * 1024) + 1).toLong) + }) + + val schema = StructType(Seq( + StructField("c_ym", YearMonthIntervalType()), + StructField("c_i", IntegerType), + StructField("c_l", LongType))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => df.selectExpr("c_ym / c_i", "c_ym / c_l") + } + + // both gpu and cpu will throw ArithmeticException + def testDivideYMOverflow(testCaseName: String, months: Int, num: Any): Unit = { + testBothCpuGpuExpectedException[SparkException](testCaseName + ", cv / scalar", + e => e.getMessage.contains("ArithmeticException"), + spark => { + val data = Seq(Row(Period.ofMonths(months))) + val schema = StructType(Seq( + StructField("c1", YearMonthIntervalType()))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => + num match { + case b: Byte => df.selectExpr(s"c1 / cast('$b' as Byte)") + case s: Short => df.selectExpr(s"c1 / cast('$s' as Short)") + case i: Int => df.selectExpr(s"c1 / $i") + case l: Long => df.selectExpr(s"c1 / ${l}L") + case f: Float if f.isNaN => df.selectExpr("c1 / cast('NaN' as float)") + case f: Float if f.equals(Float.PositiveInfinity) => + df.selectExpr("c1 / cast('Infinity' as float)") + case f: Float if f.equals(Float.NegativeInfinity) => + df.selectExpr("c1 / cast('-Infinity' as float)") + case f: Float => df.selectExpr(s"c1 / ${f}f") + case d: Double if d.isNaN => df.selectExpr("c1 / cast('NaN' as double)") + case d: Double if d.equals(Double.PositiveInfinity) => + df.selectExpr("c1 / cast('Infinity' as double)") + case d: Double if d.equals(Double.NegativeInfinity) => + df.selectExpr("c1 / cast('-Infinity' as double)") + case d: Double => df.selectExpr(s"c1 / ${d}d") + case _ => fail(s"Unsupported num type ${num.getClass}") + } + } + + testBothCpuGpuExpectedException[SparkException](testCaseName + ", cv / cv", + e => e.getMessage.contains("ArithmeticException"), + spark => { + // Period is the external type of year-month type + val (data, schema) = num match { + case b: Byte => (Seq(Row(Period.ofMonths(months), b)), StructType(Seq( + StructField("c1", YearMonthIntervalType()), + StructField("c2", ByteType) + ))) + case s: Short => (Seq(Row(Period.ofMonths(months), s)), StructType(Seq( + StructField("c1", YearMonthIntervalType()), + StructField("c2", ShortType) + ))) + case i: Int => (Seq(Row(Period.ofMonths(months), i)), StructType(Seq( + StructField("c1", YearMonthIntervalType()), + StructField("c2", IntegerType) + ))) + case l: Long => + (Seq(Row(Period.ofMonths(months), l)), StructType(Seq( + StructField("c1", YearMonthIntervalType()), + StructField("c2", LongType) + ))) + case f: Float => (Seq(Row(Period.ofMonths(months), f)), StructType(Seq( + StructField("c1", YearMonthIntervalType()), + StructField("c2", FloatType) + ))) + case d: Double => (Seq(Row(Period.ofMonths(months), d)), StructType(Seq( + StructField("c1", YearMonthIntervalType()), + StructField("c2", DoubleType) + ))) + case _ => + fail(s"Unsupported num type ${num.getClass}") + } + + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => df.selectExpr(s"c1 / c2") + } + + testBothCpuGpuExpectedException[SparkException](testCaseName + ", scalar / cv", + e => e.getMessage.contains("ArithmeticException"), + spark => { + val (data, schema) = num match { + case b: Byte => (Seq(Row(b)), StructType(Seq(StructField("c1", ByteType)))) + case s: Short => (Seq(Row(s)), StructType(Seq(StructField("c1", ShortType)))) + case i: Int => (Seq(Row(i)), StructType(Seq(StructField("c1", IntegerType)))) + case l: Long => (Seq(Row(l)), StructType(Seq(StructField("c1", LongType)))) + case f: Float => (Seq(Row(f)), StructType(Seq(StructField("c1", FloatType)))) + case d: Double => (Seq(Row(d)), StructType(Seq(StructField("c1", DoubleType)))) + case _ => fail(s"Unsupported num type ${num.getClass}") + } + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => df.selectExpr(s"interval '$months' month / c1") + } + } + + // divide by 0 + testDivideYMOverflow("year-month divide 0.toByte", 1, 0.toByte) + testDivideYMOverflow("year-month divide 0.toShort", 1, 0.toShort) + testDivideYMOverflow("year-month divide 0", 1, 0) + testDivideYMOverflow("year-month divide 0L", 1, 0.toLong) + testDivideYMOverflow("year-month divide 0.0f", 1, 0.0f) + testDivideYMOverflow("year-month divide 0.0d", 1, 0.0d) + testDivideYMOverflow("year-month divide -0.0f", 1, -0.0f) + testDivideYMOverflow("year-month divide -0.0d", 1, -0.0d) + + // NaN + testDivideYMOverflow("year-month divide Float.NaN", 1, Float.NaN) + testDivideYMOverflow("year-month divide Double.NaN", 1, Double.NaN) + // 0.0 / 0.0 = NaN + testDivideYMOverflow("year-month 0 divide 0.0f", 0, 0.0f) + testDivideYMOverflow("year-month 0 divide 0.0d", 0, 0.0d) + + // divide float/double overflow + testDivideYMOverflow("year-month divide overflow 1", Int.MaxValue, 0.1f) + testDivideYMOverflow("year-month divide overflow 2", Int.MaxValue, 0.1d) + testDivideYMOverflow("year-month divide overflow 3", Int.MinValue, 0.1f) + testDivideYMOverflow("year-month divide overflow 4", Int.MinValue, 0.1d) + + testDivideYMOverflow("year-month divide overflow 5", Int.MinValue, -1) +}