diff --git a/integration_tests/src/main/python/arithmetic_ops_test.py b/integration_tests/src/main/python/arithmetic_ops_test.py index 561005da32c..9b494c7f575 100644 --- a/integration_tests/src/main/python/arithmetic_ops_test.py +++ b/integration_tests/src/main/python/arithmetic_ops_test.py @@ -430,6 +430,31 @@ def test_decimal_round(data_gen): 'round(a, 10)'), conf=allow_negative_scale_of_decimal_conf) +@incompat +@approximate_float +def test_non_decimal_round_overflow(): + gen = StructGen([('byte_c', byte_gen), ('short_c', short_gen), + ('int_c', int_gen), ('long_c', long_gen), + ('float_c', float_gen), ('double_c', double_gen)], nullable=False) + assert_gpu_and_cpu_are_equal_collect( + lambda spark: gen_df(spark, gen).selectExpr( + 'round(byte_c, -2)', 'round(byte_c, -3)', + 'round(short_c, -4)', 'round(short_c, -5)', + 'round(int_c, -9)', 'round(int_c, -10)', + 'round(long_c, -19)', 'round(long_c, -20)', + 'round(float_c, -38)', 'round(float_c, -39)', + 'round(float_c, 38)', 'round(float_c, 39)', + 'round(double_c, -308)', 'round(double_c, -309)', + 'round(double_c, 308)', 'round(double_c, 309)', + 'bround(byte_c, -2)', 'bround(byte_c, -3)', + 'bround(short_c, -4)', 'bround(short_c, -5)', + 'bround(int_c, -9)', 'bround(int_c, -10)', + 'bround(long_c, -19)', 'bround(long_c, -20)', + 'bround(float_c, -38)', 'bround(float_c, -39)', + 'bround(float_c, 38)', 'bround(float_c, 39)', + 'bround(double_c, -308)', 'bround(double_c, -309)', + 'bround(double_c, 308)', 'bround(double_c, 309)')) + @approximate_float @pytest.mark.parametrize('data_gen', double_gens, ids=idfn) def test_cbrt(data_gen): diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala index fe32d415981..f84aa67a9e2 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala @@ -457,14 +457,166 @@ abstract class GpuRoundBase(child: Expression, scale: Expression) extends GpuBin override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) override def doColumnar(value: GpuColumnVector, scale: GpuScalar): ColumnVector = { + val lhsValue = value.getBase + val scaleVal = scale.getValue.asInstanceOf[Int] + dataType match { case DecimalType.Fixed(_, scaleVal) => DecimalUtil.round(lhsValue, scaleVal, roundMode) - case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => - val scaleVal = scale.getValue.asInstanceOf[Int] - lhsValue.round(scaleVal, roundMode) - case _ => throw new IllegalArgumentException(s"Round operator doesn't support $dataType") + case ByteType => + fixUpOverflowInts(() => Scalar.fromByte(0.toByte), scaleVal, lhsValue) + case ShortType => + fixUpOverflowInts(() => Scalar.fromShort(0.toShort), scaleVal, lhsValue) + case IntegerType => + fixUpOverflowInts(() => Scalar.fromInt(0), scaleVal, lhsValue) + case LongType => + fixUpOverflowInts(() => Scalar.fromLong(0L), scaleVal, lhsValue) + case FloatType => + fpZeroReplacement( + () => Scalar.fromFloat(0.0f), + () => Scalar.fromFloat(Float.PositiveInfinity), + () => Scalar.fromFloat(Float.NegativeInfinity), + scaleVal, lhsValue) + case DoubleType => + fpZeroReplacement( + () => Scalar.fromDouble(0.0), + () => Scalar.fromDouble(Double.PositiveInfinity), + () => Scalar.fromDouble(Double.NegativeInfinity), + scaleVal, lhsValue) + case _ => + throw new IllegalArgumentException(s"Round operator doesn't support $dataType") + } + } + + // Fixes up integral values rounded by a scale exceeding/reaching the max digits of data + // type. Under this circumstance, cuDF may produce different results to Spark. + // + // In this method, we handle round overflow, aligning the inconsistent results to Spark. + // + // For scales exceeding max digits, we can simply return zero values. + // + // For scales equaling to max digits, we need to perform round. Fortunately, round up + // will NOT occur on the max digits of numeric types except LongType. Therefore, we only + // need to handle round down for most of types, through returning zero values. + private def fixUpOverflowInts(zeroFn: () => Scalar, + scale: Int, + lhs: ColumnVector): ColumnVector = { + // Rounding on the max digit of long values, which should be specialized handled since + // it may be needed to round up, which will produce inconsistent results because of + // overflow. Otherwise, we only need to handle round down situations. + if (-scale == 19 && lhs.getType == DType.INT64) { + fixUpInt64OnBounds(lhs) + } else if (-scale >= DecimalUtil.getPrecisionForIntegralType(lhs.getType)) { + withResource(zeroFn()) { s => + withResource(ColumnVector.fromScalar(s, lhs.getRowCount.toInt)) { zero => + // set null mask if necessary + if (lhs.hasNulls) { + zero.mergeAndSetValidity(BinaryOp.BITWISE_AND, lhs) + } else { + zero.incRefCount() + } + } + } + } else { + lhs.round(scale, roundMode) + } + } + + // Compared to other non-decimal numeric types, Int64(LongType) is a bit special in terms of + // rounding by the max digit. Because the bound values of LongType can be rounded up, while + // other numeric types can only be rounded down: + // + // the max value of Byte: 127 + // The first digit is up to 1, which can't be rounded up. + // the max value of Short: 32767 + // The first digit is up to 3, which can't be rounded up. + // the max value of Int32: 2147483647 + // The first digit is up to 2, which can't be rounded up. + // the max value of Float32: 3.4028235E38 + // The first digit is up to 3, which can't be rounded up. + // the max value of Float64: 1.7976931348623157E308 + // The first digit is up to 1, which can't be rounded up. + // the max value of Int64: 9223372036854775807 + // The first digit is up to 9, which can be rounded up. + // + // When rounding up 19-digits long values on the first digit, the result can be 1e19 or -1e19. + // Since LongType can not hold these two values, the 1e19 overflows as -8446744073709551616L, + // and the -1e19 overflows as 8446744073709551616L. The overflow happens in the same way for + // HALF_UP (round) and HALF_EVEN (bround). + private def fixUpInt64OnBounds(lhs: ColumnVector): ColumnVector = { + // Builds predicates on whether there is a round up on the max digit or not + val litForCmp = Seq(Scalar.fromLong(1000000000000000000L), + Scalar.fromLong(4L), + Scalar.fromLong(-4L)) + val (needRep, needNegRep) = withResource(litForCmp) { case Seq(base, four, minusFour) => + withResource(lhs.div(base)) { headDigit => + closeOnExcept(headDigit.greaterThan(four)) { posRep => + closeOnExcept(headDigit.lessThan(minusFour)) { negRep => + posRep -> negRep + } + } + } + } + // Replaces with corresponding literals + val litForRep = Seq(Scalar.fromLong(0L), + Scalar.fromLong(8446744073709551616L), + Scalar.fromLong(-8446744073709551616L)) + val repVal = withResource(litForRep) { case Seq(zero, upLit, negUpLit) => + withResource(needRep) { _ => + withResource(needNegRep) { _ => + withResource(needNegRep.ifElse(upLit, zero)) { negBranch => + needRep.ifElse(negUpLit, negBranch) + } + } + } + } + // Handles null values + withResource(repVal) { _ => + if (lhs.hasNulls) { + repVal.mergeAndSetValidity(BinaryOp.BITWISE_AND, lhs) + } else { + repVal.incRefCount() + } + } + } + + // Fixes up float points rounded by a scale exceeding the max digits of data type. Under this + // circumstance, cuDF produces different results to Spark. + // Compared to integral values, fixing up round overflow of float points needs to take care + // of some special values: nan, inf, -inf. + def fpZeroReplacement(zeroFn: () => Scalar, + infFn: () => Scalar, + negInfFn: () => Scalar, + scale: Int, + lhs: ColumnVector): ColumnVector = { + val maxDigits = if (dataType == FloatType) 39 else 309 + if (-scale >= maxDigits) { + // replaces common values (!Null AND !Nan AND !Inf And !-Inf) with zero, while keeps + // all the special values unchanged + withResource(Seq(zeroFn(), infFn(), negInfFn())) { case Seq(zero, inf, negInf) => + // builds joined predicate: !Null AND !Nan AND !Inf And !-Inf + val joinedPredicate = { + val conditions = Seq(() => lhs.isNotNan, + () => lhs.notEqualTo(inf), + () => lhs.notEqualTo(negInf)) + conditions.foldLeft(lhs.isNotNull) { case (buffer, builder) => + withResource(buffer) { _ => + withResource(builder()) { predicate => + buffer.and(predicate) + } + } + } + } + withResource(joinedPredicate) { cond => + cond.ifElse(zero, lhs) + } + } + } else if (scale >= maxDigits) { + // just returns the original values + lhs.incRefCount() + } else { + lhs.round(scale, roundMode) } }