From 105af3a198673524a7788b0a4f9733c23f4a7ca9 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Mon, 28 Mar 2022 12:01:18 +0800 Subject: [PATCH 1/8] Support Division on ANSI interval types Signed-off-by: Chong Gao --- .../src/main/python/arithmetic_ops_test.py | 54 ++ .../rapids/shims/Spark320PlusShims.scala | 2 +- .../shims/ShimVectorizedColumnReader.scala | 4 +- .../spark/rapids/shims/GpuTypeShims.scala | 8 +- .../spark/rapids/shims/RapidsErrorUtils.scala | 6 +- .../spark/rapids/shims/Spark33XShims.scala | 37 +- .../rapids/shims/intervalExpressions.scala | 486 ++++++++++++++++++ .../sql/rapids/collectionOperations.scala | 3 +- .../sql/rapids/complexTypeExtractors.scala | 2 +- .../spark/rapids/IntervalDivisionSuite.scala | 284 ++++++++++ .../rapids/SparkQueryCompareTestSuite.scala | 99 ++++ 11 files changed, 976 insertions(+), 9 deletions(-) create mode 100644 sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala create mode 100644 tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala diff --git a/integration_tests/src/main/python/arithmetic_ops_test.py b/integration_tests/src/main/python/arithmetic_ops_test.py index dd6637bedc9..324aa1becf1 100644 --- a/integration_tests/src/main/python/arithmetic_ops_test.py +++ b/integration_tests/src/main/python/arithmetic_ops_test.py @@ -867,3 +867,57 @@ def test_subtraction_overflow_with_ansi_enabled(data, tp, expr): assert_gpu_and_cpu_are_equal_collect( func=lambda spark: _get_overflow_df(spark, data, tp, expr), conf=ansi_enabled_conf) + +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') +@pytest.mark.parametrize('data_gen', _no_overflow_multiply_gens + [DoubleGen(min_exp=0, max_exp=5, special_cases=[])], ids=idfn) +def test_day_time_interval_division_number_no_overflow1(data_gen): + gen_list = [('_c1', DayTimeIntervalGen(min_value=timedelta(seconds=-5000 * 365 * 86400), max_value=timedelta(seconds=5000 * 365 * 86400))), + ('_c2', data_gen)] + assert_gpu_and_cpu_are_equal_collect( + # avoid dividing by 0 + lambda spark: gen_df(spark, gen_list).selectExpr("_c1 / case when _c2 = 0 then cast(1 as {}) else _c2 end".format(to_cast_string(data_gen.data_type)))) + +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') +@pytest.mark.parametrize('data_gen', _no_overflow_multiply_gens + [DoubleGen(min_exp=-5, max_exp=0, special_cases=[])], ids=idfn) +def test_day_time_interval_division_number_no_overflow2(data_gen): + gen_list = [('_c1', DayTimeIntervalGen(min_value=timedelta(seconds=-20 * 86400), max_value=timedelta(seconds=20 * 86400))), + ('_c2', data_gen)] + assert_gpu_and_cpu_are_equal_collect( + # avoid dividing by 0 + lambda spark: gen_df(spark, gen_list).selectExpr("_c1 / case when _c2 = 0 then cast(1 as {}) else _c2 end".format(to_cast_string(data_gen.data_type)))) + +def _get_overflow_df_2cols(spark, data_types, values, expr): + return spark.createDataFrame( + SparkContext.getOrCreate().parallelize([values]), + StructType([ + StructField('a', data_types[0]), + StructField('b', data_types[1]) + ]) + ).selectExpr(expr) + +# test interval division overflow, such as interval / 0, Long.MinValue / -1 ... +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') +@pytest.mark.parametrize('data_type,value_pair', [ + (ByteType(), [timedelta(seconds=1), 0]), + (ShortType(), [timedelta(seconds=1), 0]), + (IntegerType(), [timedelta(seconds=1), 0]), + (LongType(), [timedelta(seconds=1), 0]), + (FloatType(), [timedelta(seconds=1), 0.0]), + (FloatType(), [timedelta(seconds=1), -0.0]), + (DoubleType(), [timedelta(seconds=1), 0.0]), + (DoubleType(), [timedelta(seconds=1), -0.0]), + (FloatType(), [timedelta(seconds=1), float('NaN')]), + (DoubleType(), [timedelta(seconds=1), float('NaN')]), + (FloatType(), [MAX_DAY_TIME_INTERVAL, 0.1]), + (DoubleType(), [MAX_DAY_TIME_INTERVAL, 0.1]), + (FloatType(), [MIN_DAY_TIME_INTERVAL, 0.1]), + (DoubleType(), [MIN_DAY_TIME_INTERVAL, 0.1]), + (LongType(), [MIN_DAY_TIME_INTERVAL, -1]), + (FloatType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN + (DoubleType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN +], ids=idfn) +def test_day_time_interval_division_overflow(data_type, value_pair): + assert_gpu_and_cpu_error( + df_fun=lambda spark: _get_overflow_df_2cols(spark, [DayTimeIntervalType(), data_type], value_pair, 'a / b').collect(), + conf={}, + error_message='ArithmeticException') diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala index 10eb99692e6..dce845f400c 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala @@ -250,7 +250,7 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging { ExprChecks.projectAndAst( TypeSig.astTypes, (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.CALENDAR - + TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT + TypeSig.DAYTIME) + + TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT + TypeSig.YEARMONTH + TypeSig.DAYTIME) .nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT), TypeSig.all), diff --git a/sql-plugin/src/main/321+/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/ShimVectorizedColumnReader.scala b/sql-plugin/src/main/321+/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/ShimVectorizedColumnReader.scala index dcfbd319744..110680082e9 100644 --- a/sql-plugin/src/main/321+/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/ShimVectorizedColumnReader.scala +++ b/sql-plugin/src/main/321+/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/ShimVectorizedColumnReader.scala @@ -64,4 +64,6 @@ class ShimVectorizedColumnReader( datetimeRebaseMode, TimeZone.getDefault.getID, // use default zone because of no rebase int96RebaseMode, - TimeZone.getDefault.getID) // use default zone because of will throw exception if rebase + TimeZone.getDefault.getID, + // TODO: wait https://github.com/NVIDIA/spark-rapids/pull/5124 to merge + null) // use default zone because of will throw exception if rebase diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala index 8ddd8d58af4..b72de8ad800 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala @@ -135,6 +135,7 @@ object GpuTypeShims { */ def supportToScalarForType(t: DataType): Boolean = { t match { + case _: YearMonthIntervalType => true case _: DayTimeIntervalType => true case _ => false } @@ -143,8 +144,13 @@ object GpuTypeShims { /** * Convert the given value to Scalar */ - def toScalarForType(t: DataType, v: Any) = { + def toScalarForType(t: DataType, v: Any): Scalar = { t match { + case _: YearMonthIntervalType => v match { + case i: Int => Scalar.fromInt(i) + case _ => throw new IllegalArgumentException(s"'$v: ${v.getClass}' is not supported" + + s" for IntType, expecting int") + } case _: DayTimeIntervalType => v match { case l: Long => Scalar.fromLong(l) case _ => throw new IllegalArgumentException(s"'$v: ${v.getClass}' is not supported" + diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala index 76b0c40182b..69107c96840 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala @@ -28,9 +28,11 @@ object RapidsErrorUtils { } } - def mapKeyNotExistError(key: String, isElementAtF: Boolean = false): NoSuchElementException = { + // TODO: wait https://github.com/NVIDIA/spark-rapids/pull/5133 to fix + def mapKeyNotExistError(key: String, context: String, + isElementAtF: Boolean = false): NoSuchElementException = { // For now, the default argument is false. The caller sets the correct value accordingly. - QueryExecutionErrors.mapKeyNotExistError(key, isElementAtF) + QueryExecutionErrors.mapKeyNotExistError(key, isElementAtF, context) } def sqlArrayIndexNotStartAtOneError(): ArrayIndexOutOfBoundsException = { diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark33XShims.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark33XShims.scala index 2aa561889db..0221db967aa 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark33XShims.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark33XShims.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids._ import org.apache.spark.sql.rapids.execution.GpuShuffleMeta -import org.apache.spark.sql.rapids.shims.GpuTimeAdd +import org.apache.spark.sql.rapids.shims.{GpuDivideDTInterval, GpuDivideYMInterval, GpuTimeAdd} import org.apache.spark.sql.types.{CalendarIntervalType, DayTimeIntervalType, DecimalType, StructType} import org.apache.spark.unsafe.types.CalendarInterval @@ -312,6 +312,39 @@ trait Spark33XShims extends Spark321PlusShims with Spark320PlusNonDBShims { (a, conf, p, r) => new BinaryAstExprMeta[LessThanOrEqual](a, conf, p, r) { override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = GpuLessThanOrEqual(lhs, rhs) + }), + GpuOverrides.expr[Alias]( + "Gives a column a name", + ExprChecks.unaryProjectAndAstInputMatchesOutput( + TypeSig.astTypes, + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.MAP + TypeSig.ARRAY + TypeSig.STRUCT + + TypeSig.DECIMAL_128 + TypeSig.YEARMONTH + TypeSig.DAYTIME).nested(), + TypeSig.all), + (a, conf, p, r) => new UnaryAstExprMeta[Alias](a, conf, p, r) { + override def convertToGpu(child: Expression): GpuExpression = + GpuAlias(child, a.name)(a.exprId, a.qualifier, a.explicitMetadata) + }), + GpuOverrides.expr[DivideYMInterval]( + "Year-month interval * operator", + ExprChecks.binaryProject( + TypeSig.YEARMONTH, + TypeSig.YEARMONTH, + ("lhs", TypeSig.YEARMONTH, TypeSig.YEARMONTH), + ("rhs", TypeSig.gpuNumeric - TypeSig.DECIMAL_128, TypeSig.gpuNumeric)), + (a, conf, p, r) => new BinaryExprMeta[DivideYMInterval](a, conf, p, r) { + override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = + GpuDivideYMInterval(lhs, rhs) + }), + GpuOverrides.expr[DivideDTInterval]( + "Day-time interval * operator", + ExprChecks.binaryProject( + TypeSig.DAYTIME, + TypeSig.DAYTIME, + ("lhs", TypeSig.DAYTIME, TypeSig.DAYTIME), + ("rhs", TypeSig.gpuNumeric - TypeSig.DECIMAL_128, TypeSig.gpuNumeric)), + (a, conf, p, r) => new BinaryExprMeta[DivideDTInterval](a, conf, p, r) { + override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = + GpuDivideDTInterval(lhs, rhs) }) ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap super.getExprs ++ map @@ -456,7 +489,7 @@ trait Spark33XShims extends Spark321PlusShims with Spark320PlusNonDBShims { "The backend for most select, withColumn and dropColumn statements", ExecChecks( (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP + - TypeSig.ARRAY + TypeSig.DECIMAL_128 + TypeSig.DAYTIME).nested(), + TypeSig.ARRAY + TypeSig.DECIMAL_128 + TypeSig.YEARMONTH + TypeSig.DAYTIME).nested(), TypeSig.all), (proj, conf, p, r) => new GpuProjectExecMeta(proj, conf, p, r)), GpuOverrides.exec[FilterExec]( diff --git a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala new file mode 100644 index 00000000000..8ebccd919b7 --- /dev/null +++ b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala @@ -0,0 +1,486 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.shims + +import ai.rapids.cudf.{BinaryOperable, ColumnVector, DType, RoundMode, Scalar} +import com.nvidia.spark.rapids.{Arm, BoolUtils, GpuBinaryExpression, GpuColumnVector, GpuScalar} + +import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, NullIntolerant} +import org.apache.spark.sql.types._ + +object IntervalUtils extends Arm { + + /** + * throws exception if any value in `longCv` exceeds the int limits + */ + def toIntWithCheck(longCv: ColumnVector): ColumnVector = { + withResource(longCv.castTo(DType.INT32)) { intResult => + withResource(longCv.notEqualTo(intResult)) { notEquals => + if (BoolUtils.isAnyValidTrue(notEquals)) { + throw new ArithmeticException("overflow occurs") + } else { + intResult.incRefCount() + } + } + } + } + + def getLong(s: Scalar): Long = { + s.getType match { + case DType.INT8 => s.getByte.toLong + case DType.INT16 => s.getShort.toLong + case DType.INT32 => s.getInt.toLong + case DType.INT64 => s.getLong + case _ => throw new IllegalArgumentException() + } + } + + /** + * check infinity, -infinity and NaN for double cv + */ + def checkDoubleInfNan(doubleCv: ColumnVector): Unit = { + // check infinity + withResource(Scalar.fromDouble(Double.PositiveInfinity)) { positiveInfScalar => + withResource(doubleCv.equalTo(positiveInfScalar)) { equalsInfinity => + if (BoolUtils.isAnyValidTrue(equalsInfinity)) { + throw new ArithmeticException("Has infinity") + } + } + } + + // check -infinity + withResource(Scalar.fromDouble(Double.NegativeInfinity)) { negativeInfScalar => + withResource(doubleCv.equalTo(negativeInfScalar)) { equalsInfinity => + if (BoolUtils.isAnyValidTrue(equalsInfinity)) { + throw new ArithmeticException("Has -infinity") + } + } + } + + // check NaN + withResource(doubleCv.isNan) { isNan => + if (BoolUtils.isAnyValidTrue(isNan)) { + throw new ArithmeticException("Has NaN") + } + } + } + + /** + * round double cv to int with overflow check + * equivalent to + * com.google.common.math.DoubleMath.roundToInt(double value, RoundingMode.HALF_UP) + */ + def roundToIntWithOverflowCheck(doubleCv: ColumnVector): ColumnVector = { + // check Inf -Inf NaN + checkDoubleInfNan(doubleCv) + + withResource(doubleCv.round(RoundMode.HALF_UP)) { roundedDouble => + // throws exception if the result exceeds int limits + withResource(roundedDouble.castTo(DType.INT64)) { long => + toIntWithCheck(long) + } + } + } + + /** + * round double cv to long with overflow check + * equivalent to + * com.google.common.math.DoubleMath.roundToLong(double value, RoundingMode.HALF_UP) + */ + def roundToLongWithOverflowCheck(doubleCv: ColumnVector): ColumnVector = { + + // check Inf -Inf NaN + checkDoubleInfNan(doubleCv) + + roundToLongWithCheck(doubleCv) + } + + /** + * check if double Cv exceeds long limits + * Rewrite from + * com.google.common.math.DoubleMath.roundToLong: + * z = roundIntermediate(double value, mode) + * checkInRange(MIN_LONG_AS_DOUBLE - z < 1.0 & z < MAX_LONG_AS_DOUBLE_PLUS_ONE) + * return z.toLong + */ + def roundToLongWithCheck(doubleCv: ColumnVector): ColumnVector = { + val MIN_LONG_AS_DOUBLE: Double = -9.223372036854776E18 + val MAX_LONG_AS_DOUBLE_PLUS_ONE: Double = 9.223372036854776E18 + + withResource(doubleCv.round(RoundMode.HALF_UP)) { z => + withResource(Scalar.fromDouble(MAX_LONG_AS_DOUBLE_PLUS_ONE)) { max => + withResource(z.greaterOrEqualTo(max)) { invalid => + if (BoolUtils.isAnyValidTrue(invalid)) { + throw new ArithmeticException("Round double to long overflow") + } + } + } + + withResource(Scalar.fromDouble(MIN_LONG_AS_DOUBLE)) { min => + withResource(min.sub(z)) { diff => + withResource(Scalar.fromDouble(1.0d)) { one => + withResource(diff.greaterOrEqualTo(one)) { invalid => + if (BoolUtils.isAnyValidTrue(invalid)) { + throw new ArithmeticException("Round double to long overflow") + } + } + } + } + } + + z.castTo(DType.INT64) + } + } + + /** + * Returns the result of dividing p by q, rounding using the HALF_UP RoundingMode. + * + * It's equivalent to + * com.google.common.math.LongMath.divide(p, q, RoundingMode.HALF_UP): + * long div = p / q; // throws if q == 0 + * long rem = p - q * div; // equals p % q + * // signum is 1 if p and q are both nonnegative or both negative, and -1 otherwise. + * int signum = 1 | (int) ((p ^ q) >> (Long.SIZE - 1)); + * long absRem = abs(rem); + * long cmpRemToHalfDivisor = absRem - (abs(q) - absRem); + * increment = cmpRemToHalfDivisor >= 0 + * return increment ? div + signum : div; + * + * Note a different from LongMath.divide: + * LongMath.divide(Long.MIN_VALUE, -1L) = Long.MIN_VALUE, it's incorrect, + * Negative / Negative should be positive. + * Should throw when Long.MIN_VALUE / -1L + * + * @param p long cv or scalar + * @param q long cv or Scalar + * @return q / q with HALF_UP rounding mode + */ + def div(p: BinaryOperable, q: BinaryOperable): ColumnVector = { + // check (p == Long.MIN_VALUE && q == -1) + withResource(Scalar.fromLong(Long.MinValue)) { minScalar => + withResource(Scalar.fromLong(-1L)) { negOneScalar => + (p, q) match { + case (lCv: ColumnVector, rCv: ColumnVector) => + withResource(lCv.equalTo(minScalar)) { isMin => + withResource(rCv.equalTo(negOneScalar)) { isOne => + if (BoolUtils.isAnyValidTrue(isMin) && BoolUtils.isAnyValidTrue(isOne)) { + throw new ArithmeticException("overflow occurs") + } + } + } + case (lCv: ColumnVector, rS: Scalar) => + withResource(lCv.equalTo(minScalar)) { isMin => + if (rS.getLong == -1L && BoolUtils.isAnyValidTrue(isMin)) { + throw new ArithmeticException("overflow occurs") + } + } + case (lS: Scalar, rCv: ColumnVector) => + withResource(rCv.equalTo(negOneScalar)) { isOne => + if (lS.getLong == Long.MinValue && BoolUtils.isAnyValidTrue(isOne)) { + throw new ArithmeticException("overflow occurs") + } + } + case (lS: Scalar, rS: Scalar) => + lS.getLong == Long.MinValue && rS.getLong == -1L + } + } + } + + val absRemCv = withResource(p.mod(q)) { rem => + rem.abs() + } + + // cmpRemToHalfDivisor = absRem - (abs(q) - absRem); + // increment = (2rem >= q), 2rem == q means exactly on the half mark + // subtracting two nonnegative longs can't overflow, + // cmpRemToHalfDivisor has the same sign as compare(abs(rem), abs(q) / 2). + // Note: can handle special case is: 1L / Long.MinValue + val cmpRemToHalfDivisorCv = withResource(absRemCv) { absRem => + val absQCv = q match { + case qCv: ColumnVector => qCv.abs + case qS: Scalar => Scalar.fromLong(math.abs(IntervalUtils.getLong(qS))) + } + withResource(absQCv) { absQ => + withResource(absQ.sub(absRem)) { qSubRem => + absRem.sub(qSubRem) + } + } + } + + val incrementCv = withResource(cmpRemToHalfDivisorCv) { cmpRemToHalfDivisor => + withResource(Scalar.fromLong(0L)) { zero => + cmpRemToHalfDivisor.greaterOrEqualTo(zero) + } + } + + withResource(incrementCv) { increment => + // signum is 1 if p and q are both nonnegative or both negative, and -1 otherwise. + // val signum = 1 | ((p ^ q) >> (Long.SIZE - 1)).toInt + val signNumCv = withResource(p.bitXor(q)) { pXorQ => + withResource(Scalar.fromInt(63)) { shift => + withResource(pXorQ.shiftRight(shift)) { signNum => + withResource(Scalar.fromLong(1L)) { one => + one.bitOr(signNum) + } + } + } + } + + // round the result + withResource(signNumCv) { signNum => + withResource(p.div(q)) { div => + withResource(div.add(signNum)) { incremented => + increment.ifElse(incremented, div) + } + } + } + } + } + + def hasZero(num: BinaryOperable, dataType: DataType): Boolean = { + dataType match { + case ByteType | ShortType | IntegerType | LongType => + num match { + case s: Scalar => IntervalUtils.getLong(s) == 0L + case cv: ColumnVector => + withResource(Scalar.fromInt(0)) { zero => + withResource(cv.equalTo(zero)) { equals => + BoolUtils.isAnyValidTrue(equals) + } + } + } + case FloatType => + num match { + case s: Scalar => s.getFloat == 0.0f + case cv: ColumnVector => + withResource(Scalar.fromFloat(0.0f)) { zero => + withResource(cv.equalTo(zero)) { equals => + BoolUtils.isAnyValidTrue(equals) + } + } + } + case DoubleType => + num match { + case s: Scalar => s.getFloat == 0.0d + case cv: ColumnVector => + withResource(Scalar.fromDouble(0.0d)) { zero => + withResource(cv.equalTo(zero)) { equals => + BoolUtils.isAnyValidTrue(equals) + } + } + } + case other => throw new IllegalStateException(s"unexpected type $other") + } + } +} + +/** + * Divide a day-time interval by a numeric: + * year-month interval / number(byte, short, int, long, float, double) + * Note not support year-month interval / decimal + * Year-month interval's internal type is int, the value of int is 12 * year + month + * left expression is interval, right expression is number + * Rewrite from Spark code: + * https://github.com/apache/spark/blob/v3.2.1/sql/catalyst/src/main/scala/ + * org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala#L615 + * + */ +case class GpuDivideYMInterval( + interval: Expression, + num: Expression) extends GpuBinaryExpression with ImplicitCastInputTypes with NullIntolerant { + + override def left: Expression = interval + + override def right: Expression = num + + override def doColumnar(interval: GpuColumnVector, numScalar: GpuScalar): ColumnVector = { + doColumnar(interval.getBase, numScalar.getBase, num.dataType) + } + + override def doColumnar(interval: GpuColumnVector, num: GpuColumnVector): ColumnVector = { + doColumnar(interval.getBase, num.getBase, num.dataType) + } + + override def doColumnar(intervalScalar: GpuScalar, num: GpuColumnVector): ColumnVector = { + doColumnar(intervalScalar.getBase, num.getBase, num.dataType) + } + + override def doColumnar(numRows: Int, intervalScalar: GpuScalar, + numScalar: GpuScalar): ColumnVector = { + withResource(GpuColumnVector.from(intervalScalar, numRows, interval.dataType)) { expandedLhs => + doColumnar(expandedLhs, numScalar) + } + } + + private def doColumnar(interval: BinaryOperable, numOperable: BinaryOperable, + numType: DataType): ColumnVector = { + if (IntervalUtils.hasZero(numOperable, numType)) { + throw new ArithmeticException("overflow: interval / 0") + } + + numType match { + case ByteType | ShortType | IntegerType | LongType => // num is byte, short, int or long + // interval.toLong / num.toLong + val longRetCv = (interval, numOperable) match { + case (pCv: ColumnVector, qCv: ColumnVector) => + withResource(pCv.castTo(DType.INT64)) { p => + withResource(qCv.castTo(DType.INT64)) { q => + IntervalUtils.div(p, q) + } + } + case (pCv: ColumnVector, s: Scalar) => + withResource(pCv.castTo(DType.INT64)) { p => + withResource(Scalar.fromLong(IntervalUtils.getLong(s))) { q => + IntervalUtils.div(p, q) + } + } + case (s: Scalar, qCv: ColumnVector) => + withResource(Scalar.fromLong(IntervalUtils.getLong(s))) { p => + withResource(qCv.castTo(DType.INT64)) { q => + IntervalUtils.div(p, q) + } + } + } + + withResource(longRetCv) { longRet => + // check overflow + IntervalUtils.toIntWithCheck(longRet) + } + + case FloatType | DoubleType => // num is float or double + // compute interval.asDouble / num + val doubleCv = interval match { + case intervalVector: ColumnVector => + withResource(intervalVector.castTo(DType.FLOAT64)) { intervalD => + intervalD.div(numOperable) + } + case intervalScalar: Scalar => + withResource(Scalar.fromDouble(intervalScalar.getInt.toDouble)) { intervalD => + intervalD.div(numOperable) + } + } + + withResource(doubleCv) { double => + // check overflow, then round to int + IntervalUtils.roundToIntWithOverflowCheck(double) + } + case _ => throw new IllegalArgumentException( + s"Not support num type $numType in GpuDivideYMInterval") + } + } + + override def toString: String = s"$interval / $num" + + override def inputTypes: Seq[AbstractDataType] = Seq(YearMonthIntervalType, NumericType) + + override def dataType: DataType = YearMonthIntervalType() +} + +/** + * Divide a day-time interval by a numeric + * day-time interval / number(byte, short, int, long, float, double) + * Note not support day-time interval / decimal + * Day-time interval's interval type is long, the value of long is the total microseconds + * Rewrite from Spark code: + * https://github.com/apache/spark/blob/v3.2.1/sql/catalyst/src/main/scala/ + * org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala#L693 + */ +case class GpuDivideDTInterval( + interval: Expression, + num: Expression) + extends GpuBinaryExpression with ImplicitCastInputTypes with NullIntolerant { + + override def left: Expression = interval + + override def right: Expression = num + + override def doColumnar(interval: GpuColumnVector, numScalar: GpuScalar): ColumnVector = { + doColumnar(interval.getBase, numScalar.getBase, num.dataType) + } + + override def doColumnar(interval: GpuColumnVector, num: GpuColumnVector): ColumnVector = { + doColumnar(interval.getBase, num.getBase, num.dataType) + } + + override def doColumnar(intervalScalar: GpuScalar, num: GpuColumnVector): ColumnVector = { + doColumnar(intervalScalar.getBase, num.getBase, num.dataType) + } + + override def doColumnar(numRows: Int, intervalScalar: GpuScalar, + numScalar: GpuScalar): ColumnVector = { + withResource(GpuColumnVector.from(intervalScalar, numRows, interval.dataType)) { expandedLhs => + doColumnar(expandedLhs, numScalar) + } + } + + private def doColumnar(interval: BinaryOperable, numOperable: BinaryOperable, + numType: DataType): ColumnVector = { + if (IntervalUtils.hasZero(numOperable, numType)) { + throw new ArithmeticException("overflow: interval / 0") + } + + numType match { + case ByteType | ShortType | IntegerType | LongType => // num is byte, short, int or long + (interval, numOperable) match { + case (pCv: ColumnVector, qCv: ColumnVector) => + withResource(pCv.castTo(DType.INT64)) { p => + withResource(qCv.castTo(DType.INT64)) { q => + IntervalUtils.div(p, q) + } + } + case (pCv: ColumnVector, s: Scalar) => + withResource(pCv.castTo(DType.INT64)) { p => + withResource(Scalar.fromLong(IntervalUtils.getLong(s))) { q => + IntervalUtils.div(p, q) + } + } + case (s: Scalar, qCv: ColumnVector) => + withResource(Scalar.fromLong(IntervalUtils.getLong(s))) { p => + withResource(qCv.castTo(DType.INT64)) { q => + IntervalUtils.div(p, q) + } + } + } + case FloatType | DoubleType => // num is float or double + // compute interval.asDouble / num + val doubleCv = interval match { + case intervalVector: ColumnVector => + withResource(intervalVector.castTo(DType.FLOAT64)) { intervalD => + intervalD.div(numOperable) + } + case intervalScalar: Scalar => + withResource(Scalar.fromDouble(intervalScalar.getLong.toDouble)) { intervalD => + intervalD.div(numOperable) + } + } + + withResource(doubleCv) { double => + // check overflow, then round to long + IntervalUtils.roundToLongWithOverflowCheck(double) + } + case _ => throw new IllegalArgumentException( + s"Not support num type $numType in GpuDivideDTInterval") + } + } + + override def toString: String = s"$interval / $num" + + override def inputTypes: Seq[AbstractDataType] = Seq(DayTimeIntervalType, NumericType) + + override def dataType: DataType = DayTimeIntervalType() + +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala index fdd1d798e23..9ddf565bcd9 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala @@ -231,7 +231,8 @@ case class GpuElementAt(left: Expression, right: Expression, failOnError: Boolea if (!exist.isValid || exist.getBoolean) { map.getMapValue(key) } else { - throw RapidsErrorUtils.mapKeyNotExistError(keyS.getValue.toString, true) + throw RapidsErrorUtils.mapKeyNotExistError( + keyS.getValue.toString, origin.context, true) } } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index 99882b66c27..3bbc2aa3769 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -212,7 +212,7 @@ case class GpuGetMapValue(child: Expression, key: Expression, failOnError: Boole withResource(lhs.getBase.getMapKeyExistence(rhs.getBase)) { keyExistenceColumn => withResource(keyExistenceColumn.all) { exist => if (exist.isValid && !exist.getBoolean) { - throw RapidsErrorUtils.mapKeyNotExistError(rhs.getValue.toString) + throw RapidsErrorUtils.mapKeyNotExistError(rhs.getValue.toString, origin.context) } } } diff --git a/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala b/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala new file mode 100644 index 00000000000..e2832bff838 --- /dev/null +++ b/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala @@ -0,0 +1,284 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import java.time.Period + +import scala.util.Random + +import org.apache.spark.SparkException +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +// TODO Currently Pyspark does not support YearMonthIntervalType, +// should move this to integration tests after Spark supports YearMonthIntervalType +class IntervalDivisionSuite extends SparkQueryCompareTestSuite { + + testSparkResultsAreEqual( + "test year-month interval / num, normal case", + spark => { + val data = Seq( + Row(Period.ofMonths(5), 3.toByte, 2.toShort, 1, 4L, 2.5f, 2.6d), + Row(Period.ofMonths(5), null, null, null, null, null, null), + Row(null, 3.toByte, 3.toShort, 3, 3L, 1.1f, 2.7d), + Row(Period.ofMonths(0), 3.toByte, 3.toShort, 3, 3L, 3.1f, 3.4d), + Row(Period.ofMonths(0), Byte.MinValue, Short.MinValue, Int.MinValue, Long.MinValue, + Float.MinValue, Double.MinValue), + Row(Period.ofMonths(0), + Byte.MaxValue, Short.MaxValue, Int.MaxValue, Long.MaxValue, + Float.MaxValue, Double.MaxValue), + Row(Period.ofMonths(7), 4.toByte, 3.toShort, 2, 1L, + Float.NegativeInfinity, Double.NegativeInfinity), + Row(Period.ofMonths(7), 6.toByte, 5.toShort, 4, 3L, + Float.PositiveInfinity, Double.PositiveInfinity), + ) + val schema = StructType(Seq( + StructField("c_ym", YearMonthIntervalType()), + StructField("c_b", ByteType), + StructField("c_s", ShortType), + StructField("c_i", IntegerType), + StructField("c_l", LongType), + StructField("c_f", FloatType), + StructField("c_d", DoubleType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + }) { + df => + df.selectExpr( + "c_ym / c_b", "c_ym / c_s", "c_ym / c_i", "c_ym / c_l", "c_ym / c_f", "c_ym / c_d", + "c_ym / cast('15' as byte)", "c_ym / cast('15' as short)", "c_ym / 15", + "c_ym / 15L", "c_ym / 15.1f", "c_ym / 15.1d", + "interval '55' month / c_b", + "interval '15' month / c_s", + "interval '25' month / c_i", + "interval '15' month / c_l", + "interval '25' month / c_f", + "interval '35' month / c_d") + } + + testSparkResultsAreEqual("test year-month interval / float num, normal case", + spark => { + val numRows = 1024 + val r = new Random(0) + + def getSign: Int = if (r.nextBoolean()) 1 else -1 + + val data = (0 until numRows).map(i => { + Row(Period.ofMonths(getSign * r.nextInt(Int.MaxValue)), + getSign * (r.nextFloat() + 10f), // add 10 to overflow + getSign * (r.nextDouble() + 10d)) + }) + + val schema = StructType(Seq( + StructField("c_ym", YearMonthIntervalType()), + StructField("c_f", FloatType), + StructField("c_d", DoubleType))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => df.selectExpr("c_ym / c_f", "c_ym / c_d") + } + + testSparkResultsAreEqual("test year-month interval / int num, normal case", + spark => { + val numRows = 1024 + val r = new Random(0) + + def getSign: Int = if (r.nextBoolean()) 1 else -1 + + val data = (0 until numRows).map(i => { + Row(Period.ofMonths(getSign * r.nextInt(Int.MaxValue)), + getSign * r.nextInt(1024 * 1024) + 1, // add 1 to avoid dividing 0 + (getSign * r.nextInt(1024 * 1024) + 1).toLong) + }) + + val schema = StructType(Seq( + StructField("c_ym", YearMonthIntervalType()), + StructField("c_i", IntegerType), + StructField("c_l", LongType))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => df.selectExpr("c_ym / c_i", "c_ym / c_l") + } + + // both gpu and cpu will throw ArithmeticException + def testDivideYMOverflow(testCaseName: String, months: Int, num: Number): Unit = { + testBothCpuGpuExpectedException[SparkException](testCaseName + ", cv / scalar", + e => e.getMessage.contains("ArithmeticException"), + spark => { + val data = Seq(Row(Period.ofMonths(months))) + val schema = StructType(Seq( + StructField("c1", YearMonthIntervalType()))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => + if (num.isInstanceOf[Byte]) { + df.selectExpr(s"c1 / cast('$num' as Byte)") + } else if (num.isInstanceOf[Short]) { + df.selectExpr(s"c1 / cast('$num' as Short)") + } else if (num.isInstanceOf[Int]) { + df.selectExpr(s"c1 / $num") + } else if (num.isInstanceOf[Long]) { + df.selectExpr(s"c1 / ${num}L") + } else if (num.isInstanceOf[Float]) { + if (num.equals(Float.NaN)) { + df.selectExpr("c1 / cast('NaN' as float)") + } else if (num.equals(Float.PositiveInfinity)) { + df.selectExpr("c1 / cast('Infinity' as float)") + } + else if (num.equals(Float.NegativeInfinity)) { + df.selectExpr("c1 / cast('-Infinity' as float)") + } else { + df.selectExpr(s"c1 / ${num}f") + } + } else if (num.isInstanceOf[Double]) { + if (num.equals(Double.NaN)) { + df.selectExpr("c1 / cast('NaN' as double)") + } else if (num.equals(Double.PositiveInfinity)) { + df.selectExpr("c1 / cast('Infinity' as double)") + } + else if (num.equals(Double.NegativeInfinity)) { + df.selectExpr("c1 / cast('-Infinity' as double)") + } else { + df.selectExpr(s"c1 / ${num}d") + } + } else { + throw new IllegalStateException("") + } + } + + testBothCpuGpuExpectedException[SparkException](testCaseName + ", cv / cv", + e => e.getMessage.contains("ArithmeticException"), + spark => { + // Period is the external type of year-month type + val (data, schema) = + if (num.isInstanceOf[Byte]) { + (Seq(Row(Period.ofMonths(months), num.asInstanceOf[Byte])), + StructType(Seq( + StructField("c1", YearMonthIntervalType()), + StructField("c2", ByteType) + ))) + } else if (num.isInstanceOf[Short]) { + (Seq(Row(Period.ofMonths(months), num.asInstanceOf[Short])), + StructType(Seq( + StructField("c1", YearMonthIntervalType()), + StructField("c2", ShortType) + ))) + } else if (num.isInstanceOf[Int]) { + (Seq(Row(Period.ofMonths(months), num.asInstanceOf[Int])), + StructType(Seq( + StructField("c1", YearMonthIntervalType()), + StructField("c2", IntegerType) + ))) + } else if (num.isInstanceOf[Long]) { + (Seq(Row(Period.ofMonths(months), num.asInstanceOf[Long])), + StructType(Seq( + StructField("c1", YearMonthIntervalType()), + StructField("c2", LongType) + ))) + } else if (num.isInstanceOf[Float]) { + (Seq(Row(Period.ofMonths(months), num.asInstanceOf[Float])), + StructType(Seq( + StructField("c1", YearMonthIntervalType()), + StructField("c2", FloatType) + ))) + } else if (num.isInstanceOf[Double]) { + (Seq(Row(Period.ofMonths(months), num.asInstanceOf[Double])), + StructType(Seq( + StructField("c1", YearMonthIntervalType()), + StructField("c2", DoubleType) + ))) + } else { + throw new IllegalStateException("") + } + + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => df.selectExpr(s"c1 / c2") + } + + testBothCpuGpuExpectedException[SparkException](testCaseName + ", scalar / cv", + e => e.getMessage.contains("ArithmeticException"), + spark => { + val (data, schema) = + if (num.isInstanceOf[Byte]) { + (Seq(Row(num)), + StructType(Seq( + StructField("c1", ByteType) + ))) + } else if (num.isInstanceOf[Short]) { + (Seq(Row(num)), + StructType(Seq( + StructField("c1", ShortType) + ))) + } else if (num.isInstanceOf[Int]) { + (Seq(Row(num)), + StructType(Seq( + StructField("c1", IntegerType) + ))) + } else if (num.isInstanceOf[Long]) { + (Seq(Row(num)), + StructType(Seq( + StructField("c1", LongType) + ))) + } else if (num.isInstanceOf[Float]) { + (Seq(Row(num)), + StructType(Seq( + StructField("c1", FloatType) + ))) + } else if (num.isInstanceOf[Double]) { + (Seq(Row(num)), + StructType(Seq( + StructField("c1", DoubleType) + ))) + } else { + throw new IllegalStateException("") + } + + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => df.selectExpr(s"interval '$months' month / c1") + } + } + + // divide by 0 + testDivideYMOverflow("year-month divide 0.toByte", 1, 0.toByte) + testDivideYMOverflow("year-month divide 0.toShort", 1, 0.toShort) + testDivideYMOverflow("year-month divide 0", 1, 0) + testDivideYMOverflow("year-month divide 0L", 1, 0.toLong) + testDivideYMOverflow("year-month divide 0.0f", 1, 0.0f) + testDivideYMOverflow("year-month divide 0.0d", 1, 0.0d) + testDivideYMOverflow("year-month divide -0.0f", 1, -0.0f) + testDivideYMOverflow("year-month divide -0.0d", 1, -0.0d) + + // NaN + testDivideYMOverflow("year-month divide Float.NaN", 1, Float.NaN) + testDivideYMOverflow("year-month divide Double.NaN", 1, Double.NaN) + // 0.0 / 0.0 = NaN + testDivideYMOverflow("year-month 0 divide 0.0f", 0, 0.0f) + testDivideYMOverflow("year-month 0 divide 0.0d", 0, 0.0d) + + // divide float/double overflow + testDivideYMOverflow("year-month divide overflow 1", Int.MaxValue, 0.1f) + testDivideYMOverflow("year-month divide overflow 2", Int.MaxValue, 0.1d) + testDivideYMOverflow("year-month divide overflow 3", Int.MinValue, 0.1f) + testDivideYMOverflow("year-month divide overflow 4", Int.MinValue, 0.1d) +} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala index 1e8b266dd66..eb6fb07f7bd 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala @@ -907,6 +907,105 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm { } } + def testBothCpuGpuExpectedException[T <: Throwable]( + testName: String, + expectedException: T => Boolean, + df: SparkSession => DataFrame, + conf: SparkConf = new SparkConf(), + repart: Integer = 1, + sort: Boolean = false, + maxFloatDiff: Double = 0.0, + incompat: Boolean = false, + execsAllowedNonGpu: Seq[String] = Seq.empty, + sortBeforeRepart: Boolean = false) + (fun: DataFrame => DataFrame)(implicit classTag: ClassTag[T]): Unit = { + // test cpu throws an exception + testCpuExpectedException(testName + " ,cpu", expectedException, df, conf, repart, sort, + maxFloatDiff, incompat, execsAllowedNonGpu, sortBeforeRepart)(fun) + + // then test gpu throws an exception + testGpuExpectedException(testName + " ,gpu", expectedException, df, conf, repart, sort, + maxFloatDiff, incompat, execsAllowedNonGpu, sortBeforeRepart)(fun) + } + + def testCpuExpectedException[T <: Throwable]( + testName: String, + expectedException: T => Boolean, + df: SparkSession => DataFrame, + conf: SparkConf = new SparkConf(), + repart: Integer = 1, + sort: Boolean = false, + maxFloatDiff: Double = 0.0, + incompat: Boolean = false, + execsAllowedNonGpu: Seq[String] = Seq.empty, + sortBeforeRepart: Boolean = false) + (fun: DataFrame => DataFrame)(implicit classTag: ClassTag[T]): Unit = { + val clazz = classTag.runtimeClass + val (testConf, qualifiedTestName) = + setupTestConfAndQualifierName(testName, incompat, sort, conf, execsAllowedNonGpu, + maxFloatDiff, sortBeforeRepart) + + test(qualifiedTestName) { + val t = Try({ + withCpuSparkSession( session => { + var data = df(session) + if (repart > 0) { + // repartition the data so it is turned into a projection, + // not folded into the table scan exec + data = data.repartition(repart) + } + fun(data).collect() + }, testConf) + fail("Cpu not throw an exception") + }) + t match { + case Failure(e) if clazz.isAssignableFrom(e.getClass) => + assert(expectedException(e.asInstanceOf[T])) + case Failure(e) => throw e + case _ => fail("Expected an exception, but got none") + } + } + } + + def testGpuExpectedException[T <: Throwable]( + testName: String, + expectedException: T => Boolean, + df: SparkSession => DataFrame, + conf: SparkConf = new SparkConf(), + repart: Integer = 1, + sort: Boolean = false, + maxFloatDiff: Double = 0.0, + incompat: Boolean = false, + execsAllowedNonGpu: Seq[String] = Seq.empty, + sortBeforeRepart: Boolean = false) + (fun: DataFrame => DataFrame)(implicit classTag: ClassTag[T]): Unit = { + val clazz = classTag.runtimeClass + val (testConf, qualifiedTestName) = + setupTestConfAndQualifierName(testName, incompat, sort, conf, execsAllowedNonGpu, + maxFloatDiff, sortBeforeRepart) + + test(qualifiedTestName) { + val t = Try({ + withGpuSparkSession( session => { + var data = df(session) + if (repart > 0) { + // repartition the data so it is turned into a projection, + // not folded into the table scan exec + data = data.repartition(repart) + } + fun(data).collect() + }, testConf) + fail("Gpu not throw an exception") + }) + t match { + case Failure(e) if clazz.isAssignableFrom(e.getClass) => + assert(expectedException(e.asInstanceOf[T])) + case Failure(e) => throw e + case _ => fail("Expected an exception, but got none") + } + } + } + def INCOMPAT_testSparkResultsAreEqual2( testName: String, dfA: SparkSession => DataFrame, From 7be1af2f17d232f697f30c1387ded0ae90db591d Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Sat, 2 Apr 2022 15:50:00 +0800 Subject: [PATCH 2/8] Fix code style --- .../scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala b/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala index e2832bff838..12c746aa4af 100644 --- a/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala +++ b/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala @@ -44,7 +44,7 @@ class IntervalDivisionSuite extends SparkQueryCompareTestSuite { Row(Period.ofMonths(7), 4.toByte, 3.toShort, 2, 1L, Float.NegativeInfinity, Double.NegativeInfinity), Row(Period.ofMonths(7), 6.toByte, 5.toShort, 4, 3L, - Float.PositiveInfinity, Double.PositiveInfinity), + Float.PositiveInfinity, Double.PositiveInfinity) ) val schema = StructType(Seq( StructField("c_ym", YearMonthIntervalType()), From 1185c630fad07b6b171e2d3cc200252f8b84c658 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Mon, 11 Apr 2022 20:30:39 +0800 Subject: [PATCH 3/8] Refine --- .../spark/sql/rapids/shims/intervalExpressions.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala index 8ebccd919b7..5079cf9c758 100644 --- a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala +++ b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala @@ -176,9 +176,11 @@ object IntervalUtils extends Arm { (p, q) match { case (lCv: ColumnVector, rCv: ColumnVector) => withResource(lCv.equalTo(minScalar)) { isMin => - withResource(rCv.equalTo(negOneScalar)) { isOne => - if (BoolUtils.isAnyValidTrue(isMin) && BoolUtils.isAnyValidTrue(isOne)) { - throw new ArithmeticException("overflow occurs") + withResource(rCv.equalTo(negOneScalar)) { isNegOne => + withResource(isMin.and(isNegOne)) { invalid => + if (BoolUtils.isAnyValidTrue(invalid)) { + throw new ArithmeticException("overflow occurs") + } } } } From 5a48dffa083332d1364d23327d66cfb932e83467 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Thu, 14 Apr 2022 17:39:19 +0800 Subject: [PATCH 4/8] Add test case and Fix --- .../src/main/python/arithmetic_ops_test.py | 1 + .../sql/rapids/shims/intervalExpressions.scala | 13 +++++++++---- .../nvidia/spark/rapids/IntervalDivisionSuite.scala | 2 ++ 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/integration_tests/src/main/python/arithmetic_ops_test.py b/integration_tests/src/main/python/arithmetic_ops_test.py index 3d17e69f301..55295c67a8b 100644 --- a/integration_tests/src/main/python/arithmetic_ops_test.py +++ b/integration_tests/src/main/python/arithmetic_ops_test.py @@ -1003,6 +1003,7 @@ def _get_overflow_df_2cols(spark, data_types, values, expr): (LongType(), [MIN_DAY_TIME_INTERVAL, -1]), (FloatType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN (DoubleType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN + (IntegerType(), [timedelta(microseconds=LONG_MIN), -1]) ], ids=idfn) def test_day_time_interval_division_overflow(data_type, value_pair): assert_gpu_and_cpu_error( diff --git a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala index 2426fecd69d..4d535d7f193 100644 --- a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala +++ b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala @@ -246,8 +246,13 @@ object IntervalUtils extends Arm { throw new ArithmeticException("overflow: interval / zero") } - // 2. overflow check (p == Long.MIN_VALUE && q == -1) - withResource(Scalar.fromLong(Long.MinValue)) { minScalar => + // 2. overflow check (p == min(int min or long min) && q == -1) + val min = p.getType match { + case DType.INT32 => Int.MinValue.toLong + case DType.INT64 => Long.MinValue + case _ => throw new IllegalStateException() + } + withResource(Scalar.fromLong(min)) { minScalar => withResource(Scalar.fromLong(-1L)) { negOneScalar => (p, q) match { case (lCv: ColumnVector, rCv: ColumnVector) => @@ -268,12 +273,12 @@ object IntervalUtils extends Arm { } case (lS: Scalar, rCv: ColumnVector) => withResource(rCv.equalTo(negOneScalar)) { isNegOne => - if (getLong(lS) == Long.MinValue && BoolUtils.isAnyValidTrue(isNegOne)) { + if (getLong(lS) == min && BoolUtils.isAnyValidTrue(isNegOne)) { throw new ArithmeticException("overflow occurs") } } case (lS: Scalar, rS: Scalar) => - getLong(lS) == Long.MinValue && getLong(rS) == -1L + getLong(lS) == min && getLong(rS) == -1L } } } diff --git a/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala b/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala index d81c240a1cc..937514131a5 100644 --- a/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala +++ b/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala @@ -282,4 +282,6 @@ class IntervalDivisionSuite extends SparkQueryCompareTestSuite { testDivideYMOverflow("year-month divide overflow 2", Int.MaxValue, 0.1d) testDivideYMOverflow("year-month divide overflow 3", Int.MinValue, 0.1f) testDivideYMOverflow("year-month divide overflow 4", Int.MinValue, 0.1d) + + testDivideYMOverflow("year-month divide overflow 5", Int.MinValue, -1) } From b4c89bd4edf6ceb38bea256463788b3c55b114ce Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Fri, 15 Apr 2022 10:10:58 +0800 Subject: [PATCH 5/8] Add description for exceptions --- .../apache/spark/sql/rapids/shims/intervalExpressions.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala index 4d535d7f193..2c923ca0791 100644 --- a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala +++ b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala @@ -195,7 +195,7 @@ object IntervalUtils extends Arm { case DType.INT64 => s.getLong.toDouble case DType.FLOAT32 => s.getFloat.toDouble case DType.FLOAT64 => s.getDouble - case _ => throw new IllegalArgumentException() + case t => throw new IllegalArgumentException(s"Unexpected type $t") } } @@ -205,7 +205,7 @@ object IntervalUtils extends Arm { case DType.INT16 => s.getShort.toLong case DType.INT32 => s.getInt.toLong case DType.INT64 => s.getLong - case _ => throw new IllegalArgumentException() + case t => throw new IllegalArgumentException(s"Unexpected type $t") } } @@ -250,7 +250,7 @@ object IntervalUtils extends Arm { val min = p.getType match { case DType.INT32 => Int.MinValue.toLong case DType.INT64 => Long.MinValue - case _ => throw new IllegalStateException() + case t => throw new IllegalArgumentException(s"Unexpected type $t") } withResource(Scalar.fromLong(min)) { minScalar => withResource(Scalar.fromLong(-1L)) { negOneScalar => From 9fb5188e1365f9c107a5094c860b7958ad511b1a Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Mon, 18 Apr 2022 10:47:44 +0800 Subject: [PATCH 6/8] Refactor --- .../spark/rapids/IntervalDivisionSuite.scala | 161 ++++++------------ 1 file changed, 53 insertions(+), 108 deletions(-) diff --git a/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala b/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala index 937514131a5..84045a0ce40 100644 --- a/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala +++ b/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala @@ -118,7 +118,7 @@ class IntervalDivisionSuite extends SparkQueryCompareTestSuite { } // both gpu and cpu will throw ArithmeticException - def testDivideYMOverflow(testCaseName: String, months: Int, num: Number): Unit = { + def testDivideYMOverflow(testCaseName: String, months: Int, num: Any): Unit = { testBothCpuGpuExpectedException[SparkException](testCaseName + ", cv / scalar", e => e.getMessage.contains("ArithmeticException"), spark => { @@ -129,38 +129,20 @@ class IntervalDivisionSuite extends SparkQueryCompareTestSuite { } ) { df => - if (num.isInstanceOf[Byte]) { - df.selectExpr(s"c1 / cast('$num' as Byte)") - } else if (num.isInstanceOf[Short]) { - df.selectExpr(s"c1 / cast('$num' as Short)") - } else if (num.isInstanceOf[Int]) { - df.selectExpr(s"c1 / $num") - } else if (num.isInstanceOf[Long]) { - df.selectExpr(s"c1 / ${num}L") - } else if (num.isInstanceOf[Float]) { - if (num.equals(Float.NaN)) { - df.selectExpr("c1 / cast('NaN' as float)") - } else if (num.equals(Float.PositiveInfinity)) { - df.selectExpr("c1 / cast('Infinity' as float)") - } - else if (num.equals(Float.NegativeInfinity)) { - df.selectExpr("c1 / cast('-Infinity' as float)") - } else { - df.selectExpr(s"c1 / ${num}f") - } - } else if (num.isInstanceOf[Double]) { - if (num.equals(Double.NaN)) { - df.selectExpr("c1 / cast('NaN' as double)") - } else if (num.equals(Double.PositiveInfinity)) { - df.selectExpr("c1 / cast('Infinity' as double)") - } - else if (num.equals(Double.NegativeInfinity)) { - df.selectExpr("c1 / cast('-Infinity' as double)") - } else { - df.selectExpr(s"c1 / ${num}d") - } - } else { - throw new IllegalStateException("") + num match { + case Byte => df.selectExpr(s"c1 / cast('$num' as Byte)") + case Short => df.selectExpr(s"c1 / cast('$num' as Short)") + case Int => df.selectExpr(s"c1 / $num") + case Long => df.selectExpr(s"c1 / ${num}L") + case Float.NaN => df.selectExpr("c1 / cast('NaN' as float)") + case Float.PositiveInfinity => df.selectExpr("c1 / cast('Infinity' as float)") + case Float.NegativeInfinity => df.selectExpr("c1 / cast('-Infinity' as float)") + case Float => df.selectExpr(s"c1 / ${num}f") + case Double.NaN => df.selectExpr("c1 / cast('NaN' as double)") + case Double.PositiveInfinity => df.selectExpr("c1 / cast('Infinity' as double)") + case Double.NegativeInfinity => df.selectExpr("c1 / cast('-Infinity' as double)") + case Double => df.selectExpr(s"c1 / ${num}d") + case _ => fail(s"Unsupported num type ${num.getClass}") } } @@ -168,46 +150,35 @@ class IntervalDivisionSuite extends SparkQueryCompareTestSuite { e => e.getMessage.contains("ArithmeticException"), spark => { // Period is the external type of year-month type - val (data, schema) = - if (num.isInstanceOf[Byte]) { - (Seq(Row(Period.ofMonths(months), num.asInstanceOf[Byte])), - StructType(Seq( - StructField("c1", YearMonthIntervalType()), - StructField("c2", ByteType) - ))) - } else if (num.isInstanceOf[Short]) { - (Seq(Row(Period.ofMonths(months), num.asInstanceOf[Short])), - StructType(Seq( - StructField("c1", YearMonthIntervalType()), - StructField("c2", ShortType) - ))) - } else if (num.isInstanceOf[Int]) { - (Seq(Row(Period.ofMonths(months), num.asInstanceOf[Int])), - StructType(Seq( - StructField("c1", YearMonthIntervalType()), - StructField("c2", IntegerType) - ))) - } else if (num.isInstanceOf[Long]) { - (Seq(Row(Period.ofMonths(months), num.asInstanceOf[Long])), - StructType(Seq( - StructField("c1", YearMonthIntervalType()), - StructField("c2", LongType) - ))) - } else if (num.isInstanceOf[Float]) { - (Seq(Row(Period.ofMonths(months), num.asInstanceOf[Float])), - StructType(Seq( - StructField("c1", YearMonthIntervalType()), - StructField("c2", FloatType) - ))) - } else if (num.isInstanceOf[Double]) { - (Seq(Row(Period.ofMonths(months), num.asInstanceOf[Double])), - StructType(Seq( - StructField("c1", YearMonthIntervalType()), - StructField("c2", DoubleType) - ))) - } else { - throw new IllegalStateException("") - } + val (data, schema) = num match { + case Byte => (Seq(Row(Period.ofMonths(months), num)), StructType(Seq( + StructField("c1", YearMonthIntervalType()), + StructField("c2", ByteType) + ))) + case Short => (Seq(Row(Period.ofMonths(months), num)), StructType(Seq( + StructField("c1", YearMonthIntervalType()), + StructField("c2", ShortType) + ))) + case Int => (Seq(Row(Period.ofMonths(months), num)), StructType(Seq( + StructField("c1", YearMonthIntervalType()), + StructField("c2", IntegerType) + ))) + case Long => + (Seq(Row(Period.ofMonths(months), num)), StructType(Seq( + StructField("c1", YearMonthIntervalType()), + StructField("c2", LongType) + ))) + case Float => (Seq(Row(Period.ofMonths(months), num)), StructType(Seq( + StructField("c1", YearMonthIntervalType()), + StructField("c2", FloatType) + ))) + case Double => (Seq(Row(Period.ofMonths(months), num)), StructType(Seq( + StructField("c1", YearMonthIntervalType()), + StructField("c2", DoubleType) + ))) + case _ => + fail(s"Unsupported num type ${num.getClass}") + } spark.createDataFrame(spark.sparkContext.parallelize(data), schema) } @@ -218,41 +189,15 @@ class IntervalDivisionSuite extends SparkQueryCompareTestSuite { testBothCpuGpuExpectedException[SparkException](testCaseName + ", scalar / cv", e => e.getMessage.contains("ArithmeticException"), spark => { - val (data, schema) = - if (num.isInstanceOf[Byte]) { - (Seq(Row(num)), - StructType(Seq( - StructField("c1", ByteType) - ))) - } else if (num.isInstanceOf[Short]) { - (Seq(Row(num)), - StructType(Seq( - StructField("c1", ShortType) - ))) - } else if (num.isInstanceOf[Int]) { - (Seq(Row(num)), - StructType(Seq( - StructField("c1", IntegerType) - ))) - } else if (num.isInstanceOf[Long]) { - (Seq(Row(num)), - StructType(Seq( - StructField("c1", LongType) - ))) - } else if (num.isInstanceOf[Float]) { - (Seq(Row(num)), - StructType(Seq( - StructField("c1", FloatType) - ))) - } else if (num.isInstanceOf[Double]) { - (Seq(Row(num)), - StructType(Seq( - StructField("c1", DoubleType) - ))) - } else { - throw new IllegalStateException("") - } - + val (data, schema) = num match { + case Byte => (Seq(Row(num)), StructType(Seq(StructField("c1", ByteType)))) + case Short => (Seq(Row(num)), StructType(Seq(StructField("c1", ShortType)))) + case Int => (Seq(Row(num)), StructType(Seq(StructField("c1", IntegerType)))) + case Long => (Seq(Row(num)), StructType(Seq(StructField("c1", LongType)))) + case Float => (Seq(Row(num)), StructType(Seq(StructField("c1", FloatType)))) + case Double => (Seq(Row(num)), StructType(Seq(StructField("c1", DoubleType)))) + case _ => fail(s"Unsupported num type ${num.getClass}") + } spark.createDataFrame(spark.sparkContext.parallelize(data), schema) } ) { From dcd34bf6263e9fbb1a7238a330a428385788b910 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Mon, 18 Apr 2022 23:09:00 +0800 Subject: [PATCH 7/8] Fix test code issues --- .../spark/rapids/IntervalDivisionSuite.scala | 54 ++++++++++--------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala b/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala index 84045a0ce40..11aab62ae8c 100644 --- a/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala +++ b/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala @@ -130,18 +130,22 @@ class IntervalDivisionSuite extends SparkQueryCompareTestSuite { ) { df => num match { - case Byte => df.selectExpr(s"c1 / cast('$num' as Byte)") - case Short => df.selectExpr(s"c1 / cast('$num' as Short)") - case Int => df.selectExpr(s"c1 / $num") - case Long => df.selectExpr(s"c1 / ${num}L") - case Float.NaN => df.selectExpr("c1 / cast('NaN' as float)") - case Float.PositiveInfinity => df.selectExpr("c1 / cast('Infinity' as float)") - case Float.NegativeInfinity => df.selectExpr("c1 / cast('-Infinity' as float)") - case Float => df.selectExpr(s"c1 / ${num}f") - case Double.NaN => df.selectExpr("c1 / cast('NaN' as double)") - case Double.PositiveInfinity => df.selectExpr("c1 / cast('Infinity' as double)") - case Double.NegativeInfinity => df.selectExpr("c1 / cast('-Infinity' as double)") - case Double => df.selectExpr(s"c1 / ${num}d") + case b: Byte => df.selectExpr(s"c1 / cast('$b' as Byte)") + case s: Short => df.selectExpr(s"c1 / cast('$s' as Short)") + case i: Int => df.selectExpr(s"c1 / $i") + case l: Long => df.selectExpr(s"c1 / ${l}L") + case f: Float if f.equals(Float.NaN) => df.selectExpr("c1 / cast('NaN' as float)") + case f: Float if f.equals(Float.PositiveInfinity) => + df.selectExpr("c1 / cast('Infinity' as float)") + case f: Float if f.equals(Float.NegativeInfinity) => + df.selectExpr("c1 / cast('-Infinity' as float)") + case f: Float => df.selectExpr(s"c1 / ${f}f") + case d: Double if d.equals(Double.NaN) => df.selectExpr("c1 / cast('NaN' as double)") + case d: Double if d.equals(Double.PositiveInfinity) => + df.selectExpr("c1 / cast('Infinity' as double)") + case d: Double if d.equals(Double.NegativeInfinity) => + df.selectExpr("c1 / cast('-Infinity' as double)") + case d: Double => df.selectExpr(s"c1 / ${d}d") case _ => fail(s"Unsupported num type ${num.getClass}") } } @@ -151,28 +155,28 @@ class IntervalDivisionSuite extends SparkQueryCompareTestSuite { spark => { // Period is the external type of year-month type val (data, schema) = num match { - case Byte => (Seq(Row(Period.ofMonths(months), num)), StructType(Seq( + case b: Byte => (Seq(Row(Period.ofMonths(months), b)), StructType(Seq( StructField("c1", YearMonthIntervalType()), StructField("c2", ByteType) ))) - case Short => (Seq(Row(Period.ofMonths(months), num)), StructType(Seq( + case s: Short => (Seq(Row(Period.ofMonths(months), s)), StructType(Seq( StructField("c1", YearMonthIntervalType()), StructField("c2", ShortType) ))) - case Int => (Seq(Row(Period.ofMonths(months), num)), StructType(Seq( + case i: Int => (Seq(Row(Period.ofMonths(months), i)), StructType(Seq( StructField("c1", YearMonthIntervalType()), StructField("c2", IntegerType) ))) - case Long => - (Seq(Row(Period.ofMonths(months), num)), StructType(Seq( + case l: Long => + (Seq(Row(Period.ofMonths(months), l)), StructType(Seq( StructField("c1", YearMonthIntervalType()), StructField("c2", LongType) ))) - case Float => (Seq(Row(Period.ofMonths(months), num)), StructType(Seq( + case f: Float => (Seq(Row(Period.ofMonths(months), f)), StructType(Seq( StructField("c1", YearMonthIntervalType()), StructField("c2", FloatType) ))) - case Double => (Seq(Row(Period.ofMonths(months), num)), StructType(Seq( + case d: Double => (Seq(Row(Period.ofMonths(months), d)), StructType(Seq( StructField("c1", YearMonthIntervalType()), StructField("c2", DoubleType) ))) @@ -190,12 +194,12 @@ class IntervalDivisionSuite extends SparkQueryCompareTestSuite { e => e.getMessage.contains("ArithmeticException"), spark => { val (data, schema) = num match { - case Byte => (Seq(Row(num)), StructType(Seq(StructField("c1", ByteType)))) - case Short => (Seq(Row(num)), StructType(Seq(StructField("c1", ShortType)))) - case Int => (Seq(Row(num)), StructType(Seq(StructField("c1", IntegerType)))) - case Long => (Seq(Row(num)), StructType(Seq(StructField("c1", LongType)))) - case Float => (Seq(Row(num)), StructType(Seq(StructField("c1", FloatType)))) - case Double => (Seq(Row(num)), StructType(Seq(StructField("c1", DoubleType)))) + case b: Byte => (Seq(Row(b)), StructType(Seq(StructField("c1", ByteType)))) + case s: Short => (Seq(Row(s)), StructType(Seq(StructField("c1", ShortType)))) + case i: Int => (Seq(Row(i)), StructType(Seq(StructField("c1", IntegerType)))) + case l: Long => (Seq(Row(l)), StructType(Seq(StructField("c1", LongType)))) + case f: Float => (Seq(Row(f)), StructType(Seq(StructField("c1", FloatType)))) + case d: Double => (Seq(Row(d)), StructType(Seq(StructField("c1", DoubleType)))) case _ => fail(s"Unsupported num type ${num.getClass}") } spark.createDataFrame(spark.sparkContext.parallelize(data), schema) From 425e9775c679e27c517f68c67738323e6babe0b0 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Tue, 19 Apr 2022 09:09:59 +0800 Subject: [PATCH 8/8] Update for nit comments --- .../scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala b/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala index 11aab62ae8c..b6426da11a0 100644 --- a/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala +++ b/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalDivisionSuite.scala @@ -134,13 +134,13 @@ class IntervalDivisionSuite extends SparkQueryCompareTestSuite { case s: Short => df.selectExpr(s"c1 / cast('$s' as Short)") case i: Int => df.selectExpr(s"c1 / $i") case l: Long => df.selectExpr(s"c1 / ${l}L") - case f: Float if f.equals(Float.NaN) => df.selectExpr("c1 / cast('NaN' as float)") + case f: Float if f.isNaN => df.selectExpr("c1 / cast('NaN' as float)") case f: Float if f.equals(Float.PositiveInfinity) => df.selectExpr("c1 / cast('Infinity' as float)") case f: Float if f.equals(Float.NegativeInfinity) => df.selectExpr("c1 / cast('-Infinity' as float)") case f: Float => df.selectExpr(s"c1 / ${f}f") - case d: Double if d.equals(Double.NaN) => df.selectExpr("c1 / cast('NaN' as double)") + case d: Double if d.isNaN => df.selectExpr("c1 / cast('NaN' as double)") case d: Double if d.equals(Double.PositiveInfinity) => df.selectExpr("c1 / cast('Infinity' as double)") case d: Double if d.equals(Double.NegativeInfinity) =>