From 379bb298ebcb56ddd1b7ad76d746de9b7706d4ad Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Fri, 17 Dec 2021 15:49:38 +0800 Subject: [PATCH 1/7] fix cornor cases of non-decimal round Signed-off-by: sperlingxx --- .../src/main/python/arithmetic_ops_test.py | 21 +++++++++++++++++++ .../spark/sql/rapids/mathExpressions.scala | 20 +++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/integration_tests/src/main/python/arithmetic_ops_test.py b/integration_tests/src/main/python/arithmetic_ops_test.py index 561005da32c..8e725d6d7c4 100644 --- a/integration_tests/src/main/python/arithmetic_ops_test.py +++ b/integration_tests/src/main/python/arithmetic_ops_test.py @@ -430,6 +430,27 @@ def test_decimal_round(data_gen): 'round(a, 10)'), conf=allow_negative_scale_of_decimal_conf) +@approximate_float +def test_round_overflow(): + gen = StructGen([('byte_c', byte_gen), ('short_c', short_gen), ('int_c', int_gen), + ('long_c', long_gen), ('float_c', float_gen), ('double_c', double_gen) + ('dec64_c', decimal_gen_12_2), ('dec128_c', decimal_gen_30_2)], nullable=False) + assert_gpu_and_cpu_are_equal_collect( + lambda spark: gen_df(spark, gen).selectExpr( + 'round(byte_c, -3)', + 'round(short_c, -5)', + 'round(int_c, -10)', + 'round(long_c, -19)', + 'round(float_c, -39)', 'round(float_c, 39)', + 'round(double_c, -309)', 'round(double_c, 309)', + 'bround(byte_c, -3)', + 'bround(short_c, -5)', + 'bround(int_c, -10)', + 'bround(long_c, -19)', + 'bround(float_c, -39)', 'bround(float_c, 39)', + 'bround(double_c, -309)', 'bround(double_c, 309)'), + conf=allow_negative_scale_of_decimal_conf) + @approximate_float @pytest.mark.parametrize('data_gen', double_gens, ids=idfn) def test_cbrt(data_gen): diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala index 0e617da6d7b..d0cd6c2444b 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala @@ -459,7 +459,25 @@ abstract class GpuRoundBase(child: Expression, scale: Expression) extends GpuBin DecimalUtil.round(lhsValue, scaleVal, roundMode) case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => val scaleVal = scale.getValue.asInstanceOf[Int] - lhsValue.round(scaleVal, roundMode) + if (-scaleVal >= DecimalUtil.getPrecisionForIntegralType(lhsValue.getType)) { + withResource(GpuScalar.from(0, dataType)) { zero => + ColumnVector.fromScalar(zero, lhsValue.getRowCount.toInt) + } + } else { + lhsValue.round(scaleVal, roundMode) + } + case FloatType | DoubleType => + val scaleVal = scale.getValue.asInstanceOf[Int] + val maxDigits = if (dataType == FloatType) 39 else 309 + if (-scaleVal >= maxDigits) { + withResource(GpuScalar.from(0, dataType)) { zero => + ColumnVector.fromScalar(zero, lhsValue.getRowCount.toInt) + } + } else if (scaleVal >= maxDigits) { + lhsValue.incRefCount() + } else { + lhsValue.round(scaleVal, roundMode) + } case _ => throw new IllegalArgumentException(s"Round operator doesn't support $dataType") } } From 4f01ebaa51ccf6b1507cfa0060208d1da93425f7 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Mon, 20 Dec 2021 13:47:33 +0800 Subject: [PATCH 2/7] update --- .../src/main/python/arithmetic_ops_test.py | 43 ++++++++------ .../spark/sql/rapids/mathExpressions.scala | 56 +++++++++++++++++-- 2 files changed, 76 insertions(+), 23 deletions(-) diff --git a/integration_tests/src/main/python/arithmetic_ops_test.py b/integration_tests/src/main/python/arithmetic_ops_test.py index 8e725d6d7c4..cc84b08e6ca 100644 --- a/integration_tests/src/main/python/arithmetic_ops_test.py +++ b/integration_tests/src/main/python/arithmetic_ops_test.py @@ -430,26 +430,33 @@ def test_decimal_round(data_gen): 'round(a, 10)'), conf=allow_negative_scale_of_decimal_conf) +@incompat @approximate_float def test_round_overflow(): - gen = StructGen([('byte_c', byte_gen), ('short_c', short_gen), ('int_c', int_gen), - ('long_c', long_gen), ('float_c', float_gen), ('double_c', double_gen) - ('dec64_c', decimal_gen_12_2), ('dec128_c', decimal_gen_30_2)], nullable=False) - assert_gpu_and_cpu_are_equal_collect( - lambda spark: gen_df(spark, gen).selectExpr( - 'round(byte_c, -3)', - 'round(short_c, -5)', - 'round(int_c, -10)', - 'round(long_c, -19)', - 'round(float_c, -39)', 'round(float_c, 39)', - 'round(double_c, -309)', 'round(double_c, 309)', - 'bround(byte_c, -3)', - 'bround(short_c, -5)', - 'bround(int_c, -10)', - 'bround(long_c, -19)', - 'bround(float_c, -39)', 'bround(float_c, 39)', - 'bround(double_c, -309)', 'bround(double_c, 309)'), - conf=allow_negative_scale_of_decimal_conf) + gen = StructGen([('byte_c', byte_gen), ('short_c', short_gen), + ('int_c', int_gen), ('long_c', long_gen), + ('float_c', float_gen), ('double_c', double_gen), + ('dec32_c', DecimalGen(5, 2)), ('dec64_c', decimal_gen_12_2), + ('dec128_c', decimal_gen_30_2)], nullable=False) + assert_gpu_and_cpu_are_equal_collect( + lambda spark: gen_df(spark, gen, length=100).selectExpr( + # 'round(byte_c, -3)', + # 'round(short_c, -5)', + # 'round(int_c, -10)', + # 'round(long_c, -20)', + 'round(float_c, -39)', # 'round(float_c, 39)', + # 'round(double_c, -309)', 'round(double_c, 309)', + # 'round(dec32_c, -9)', + # 'round(dec64_c, -19)', + # 'round(dec128_c, -39)', + # 'bround(byte_c, -3)', + # 'bround(short_c, -5)', + # 'bround(int_c, -10)', + # 'bround(long_c, -20)', + # 'bround(float_c, -39)', 'bround(float_c, 39)', + # 'bround(double_c, -309)', 'bround(double_c, 309)' + ), + conf=allow_negative_scale_of_decimal_conf) @approximate_float @pytest.mark.parametrize('data_gen', double_gens, ids=idfn) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala index d0cd6c2444b..96219b33452 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala @@ -454,14 +454,58 @@ abstract class GpuRoundBase(child: Expression, scale: Expression) extends GpuBin override def doColumnar(value: GpuColumnVector, scale: GpuScalar): ColumnVector = { val lhsValue = value.getBase + + val intZeroReplacement = (zeroScalar: Scalar) => { + withResource(zeroScalar) { scalar => + withResource(ColumnVector.fromScalar(scalar, lhsValue.getRowCount.toInt)) { cv => + if (lhsValue.hasNulls) { + cv.mergeAndSetValidity(BinaryOp.BITWISE_AND, lhsValue) + } else { + cv.incRefCount() + } + } + } + } + + val fpZeroReplacement = (zeroScalar: Scalar, infScalar: Scalar, negInfScalar: Scalar) => { + withResource(zeroScalar) { scalar => + + val condition = closeOnExcept(lhsValue.isNull) { isNull => + List(() => lhsValue.equalTo(negInfScalar), + () => lhsValue.equalTo(infScalar), + () => lhsValue.isNan + ).foldLeft(isNull) { case (condition, builder) => + withResource(condition) { x => + withResource(builder()) { y => + x.or(y) + } + } + } + } + + withResource(lhsValue.isNan) { isNan => + withResource(lhsValue.equalTo(infScalar)) { isInf => + withResource(lhsValue.isNull) { isNull => + withResource(isNan.or(isNull)) { isNanOrNull => + isNanOrNull.ifElse(lhsValue, scalar) + } + } + } + } + } + } + dataType match { case DecimalType.Fixed(_, scaleVal) => DecimalUtil.round(lhsValue, scaleVal, roundMode) - case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => + case ByteType | ShortType | IntegerType | LongType => val scaleVal = scale.getValue.asInstanceOf[Int] if (-scaleVal >= DecimalUtil.getPrecisionForIntegralType(lhsValue.getType)) { - withResource(GpuScalar.from(0, dataType)) { zero => - ColumnVector.fromScalar(zero, lhsValue.getRowCount.toInt) + dataType match { + case ByteType => intZeroReplacement(Scalar.fromByte(0.toByte)) + case ShortType => intZeroReplacement(Scalar.fromShort(0.toShort)) + case IntegerType => intZeroReplacement(Scalar.fromInt(0)) + case LongType => intZeroReplacement(Scalar.fromLong(0L)) } } else { lhsValue.round(scaleVal, roundMode) @@ -470,8 +514,10 @@ abstract class GpuRoundBase(child: Expression, scale: Expression) extends GpuBin val scaleVal = scale.getValue.asInstanceOf[Int] val maxDigits = if (dataType == FloatType) 39 else 309 if (-scaleVal >= maxDigits) { - withResource(GpuScalar.from(0, dataType)) { zero => - ColumnVector.fromScalar(zero, lhsValue.getRowCount.toInt) + if (dataType == FloatType) { + fpZeroReplacement(Scalar.fromFloat(0.0f)) + } else { + fpZeroReplacement(Scalar.fromDouble(0.0)) } } else if (scaleVal >= maxDigits) { lhsValue.incRefCount() From 371cf2ad159e33b4235f457eb3ddc0b5eea82ce9 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Mon, 20 Dec 2021 16:15:32 +0800 Subject: [PATCH 3/7] update --- .../spark/sql/rapids/mathExpressions.scala | 83 ++++++++++++++----- 1 file changed, 63 insertions(+), 20 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala index d0cd6c2444b..2ad761d137f 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala @@ -453,31 +453,74 @@ abstract class GpuRoundBase(child: Expression, scale: Expression) extends GpuBin override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) override def doColumnar(value: GpuColumnVector, scale: GpuScalar): ColumnVector = { + val lhsValue = value.getBase - dataType match { - case DecimalType.Fixed(_, scaleVal) => - DecimalUtil.round(lhsValue, scaleVal, roundMode) - case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => - val scaleVal = scale.getValue.asInstanceOf[Int] - if (-scaleVal >= DecimalUtil.getPrecisionForIntegralType(lhsValue.getType)) { - withResource(GpuScalar.from(0, dataType)) { zero => - ColumnVector.fromScalar(zero, lhsValue.getRowCount.toInt) + + def intZeroReplacement(zero: Scalar): ColumnVector = { + val scaleVal = scale.getValue.asInstanceOf[Int] + if (-scaleVal >= DecimalUtil.getPrecisionForIntegralType(lhsValue.getType)) { + withResource(zero) { s => + withResource(ColumnVector.fromScalar(s, lhsValue.getRowCount.toInt)) { cv => + if (lhsValue.hasNulls) { + cv.mergeAndSetValidity(BinaryOp.BITWISE_AND, lhsValue) + } else { + cv.incRefCount() + } } - } else { - lhsValue.round(scaleVal, roundMode) } - case FloatType | DoubleType => - val scaleVal = scale.getValue.asInstanceOf[Int] - val maxDigits = if (dataType == FloatType) 39 else 309 - if (-scaleVal >= maxDigits) { - withResource(GpuScalar.from(0, dataType)) { zero => - ColumnVector.fromScalar(zero, lhsValue.getRowCount.toInt) + } else { + lhsValue.round(scaleVal, roundMode) + } + } + + def fpZeroReplacement(zero: Scalar, inf: Scalar, negInf: Scalar): ColumnVector = { + val scaleVal = scale.getValue.asInstanceOf[Int] + val maxDigits = if (dataType == FloatType) 39 else 309 + if (-scaleVal >= maxDigits) { + withResource(Seq(zero, inf, negInf)) { _ => + val joinedCondition = List( + () => lhsValue.isNan, + () => lhsValue.equalTo(inf), + () => lhsValue.equalTo(negInf) + ).foldLeft(lhsValue.isNull) { case (cond, fn) => + withResource(cond) { _ => + withResource(fn()) { newCondition => + cond.or(newCondition) + } + } + } + withResource(joinedCondition) { cond => + cond.ifElse(lhsValue, zero) } - } else if (scaleVal >= maxDigits) { - lhsValue.incRefCount() - } else { - lhsValue.round(scaleVal, roundMode) } + } else if (scaleVal >= maxDigits) { + lhsValue.incRefCount() + } else { + lhsValue.round(scaleVal, roundMode) + } + } + + dataType match { + case DecimalType.Fixed(_, scaleVal) => + DecimalUtil.round(lhsValue, scaleVal, roundMode) + case ByteType => + intZeroReplacement(Scalar.fromByte(0.toByte)) + case ShortType => + intZeroReplacement(Scalar.fromShort(0.toShort)) + case IntegerType => + intZeroReplacement(Scalar.fromInt(0)) + case LongType => + intZeroReplacement(Scalar.fromLong(0L)) + case FloatType => + fpZeroReplacement( + Scalar.fromFloat(0.0f), + Scalar.fromFloat(Float.PositiveInfinity), + Scalar.fromFloat(Float.NegativeInfinity)) + case DoubleType => + fpZeroReplacement( + Scalar.fromDouble(0.0), + Scalar.fromDouble(Double.PositiveInfinity), + Scalar.fromDouble(Double.NegativeInfinity)) case _ => throw new IllegalArgumentException(s"Round operator doesn't support $dataType") } } From d1822a17de988bfbae0994672981bc8fc8d474ca Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Tue, 21 Dec 2021 13:36:09 +0800 Subject: [PATCH 4/7] cache --- .../src/main/python/arithmetic_ops_test.py | 32 +++---- .../spark/sql/rapids/mathExpressions.scala | 83 +++++++++++++------ 2 files changed, 75 insertions(+), 40 deletions(-) diff --git a/integration_tests/src/main/python/arithmetic_ops_test.py b/integration_tests/src/main/python/arithmetic_ops_test.py index cc84b08e6ca..09a6ad52f88 100644 --- a/integration_tests/src/main/python/arithmetic_ops_test.py +++ b/integration_tests/src/main/python/arithmetic_ops_test.py @@ -439,22 +439,22 @@ def test_round_overflow(): ('dec32_c', DecimalGen(5, 2)), ('dec64_c', decimal_gen_12_2), ('dec128_c', decimal_gen_30_2)], nullable=False) assert_gpu_and_cpu_are_equal_collect( - lambda spark: gen_df(spark, gen, length=100).selectExpr( - # 'round(byte_c, -3)', - # 'round(short_c, -5)', - # 'round(int_c, -10)', - # 'round(long_c, -20)', - 'round(float_c, -39)', # 'round(float_c, 39)', - # 'round(double_c, -309)', 'round(double_c, 309)', - # 'round(dec32_c, -9)', - # 'round(dec64_c, -19)', - # 'round(dec128_c, -39)', - # 'bround(byte_c, -3)', - # 'bround(short_c, -5)', - # 'bround(int_c, -10)', - # 'bround(long_c, -20)', - # 'bround(float_c, -39)', 'bround(float_c, 39)', - # 'bround(double_c, -309)', 'bround(double_c, 309)' + lambda spark: gen_df(spark, gen, length=8192).selectExpr( + 'round(byte_c, -2)', 'round(byte_c, -3)', + 'round(short_c, -4)', 'round(short_c, -5)', + 'round(int_c, -9)', 'round(int_c, -10)', + 'round(long_c, -19)', 'round(long_c, -20)', + 'round(float_c, -39)', 'round(float_c, 39)', + 'round(double_c, -309)', 'round(double_c, 309)', + 'round(dec32_c, -9)', + 'round(dec64_c, -19)', + # 'round(dec128_c, -34)', + 'bround(byte_c, -3)', + 'bround(short_c, -5)', + 'bround(int_c, -10)', + 'bround(long_c, -20)', + 'bround(float_c, -39)', 'bround(float_c, 39)', + 'bround(double_c, -309)', 'bround(double_c, 309)' ), conf=allow_negative_scale_of_decimal_conf) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala index 2ad761d137f..d736d5081af 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala @@ -456,15 +456,18 @@ abstract class GpuRoundBase(child: Expression, scale: Expression) extends GpuBin val lhsValue = value.getBase - def intZeroReplacement(zero: Scalar): ColumnVector = { + def intZeroReplacement(zeroFn: () => Scalar): ColumnVector = { val scaleVal = scale.getValue.asInstanceOf[Int] - if (-scaleVal >= DecimalUtil.getPrecisionForIntegralType(lhsValue.getType)) { - withResource(zero) { s => - withResource(ColumnVector.fromScalar(s, lhsValue.getRowCount.toInt)) { cv => + + if (-scaleVal == 19 && lhsValue.getType == DType.INT64) { + longBoundReplacement(zeroFn) + } else if (-scaleVal >= DecimalUtil.getPrecisionForIntegralType(lhsValue.getType)) { + withResource(zeroFn()) { s => + withResource(ColumnVector.fromScalar(s, lhsValue.getRowCount.toInt)) { zero => if (lhsValue.hasNulls) { - cv.mergeAndSetValidity(BinaryOp.BITWISE_AND, lhsValue) + zero.mergeAndSetValidity(BinaryOp.BITWISE_AND, lhsValue) } else { - cv.incRefCount() + zero.incRefCount() } } } @@ -473,24 +476,56 @@ abstract class GpuRoundBase(child: Expression, scale: Expression) extends GpuBin } } - def fpZeroReplacement(zero: Scalar, inf: Scalar, negInf: Scalar): ColumnVector = { + def longBoundReplacement(zeroFn: () => Scalar): ColumnVector = { + val scalars = Seq(zeroFn(), + Scalar.fromLong(1000000000000000000L), + Scalar.fromLong(4L), Scalar.fromLong(-4L), + Scalar.fromLong(8446744073709551616L), Scalar.fromLong(-8446744073709551616L)) + withResource(scalars) { case Seq(zero, base, five, negFive, repLit, negRepLit) => + val (needPosRep, needNegRep) = withResource(lhsValue.div(base)) { headDigit => + closeOnExcept(headDigit.greaterThan(five)) { posRep => + closeOnExcept(headDigit.lessThan(negFive)) { negRep => + posRep -> negRep + } + } + } + val repVal = withResource(needPosRep) { _ => + withResource(needNegRep) { _ => + withResource(needNegRep.ifElse(repLit, zero)) { negBranch => + needPosRep.ifElse(negRepLit, negBranch) + } + } + } + withResource(repVal) { _ => + if (lhsValue.hasNulls) { + repVal.mergeAndSetValidity(BinaryOp.BITWISE_AND, lhsValue) + } else { + repVal.incRefCount() + } + } + } + } + + def fpZeroReplacement(zeroFn: () => Scalar, + infFn: () => Scalar, + negInfFn: () => Scalar): ColumnVector = { val scaleVal = scale.getValue.asInstanceOf[Int] val maxDigits = if (dataType == FloatType) 39 else 309 if (-scaleVal >= maxDigits) { - withResource(Seq(zero, inf, negInf)) { _ => + withResource(Array(zeroFn(), infFn(), negInfFn())) { case Array(zero, inf, negInf) => val joinedCondition = List( - () => lhsValue.isNan, - () => lhsValue.equalTo(inf), - () => lhsValue.equalTo(negInf) - ).foldLeft(lhsValue.isNull) { case (cond, fn) => + () => lhsValue.isNotNan, + () => lhsValue.notEqualTo(inf), + () => lhsValue.notEqualTo(negInf) + ).foldLeft(lhsValue.isNotNull) { case (cond, fn) => withResource(cond) { _ => withResource(fn()) { newCondition => - cond.or(newCondition) + cond.and(newCondition) } } } withResource(joinedCondition) { cond => - cond.ifElse(lhsValue, zero) + cond.ifElse(zero, lhsValue) } } } else if (scaleVal >= maxDigits) { @@ -504,23 +539,23 @@ abstract class GpuRoundBase(child: Expression, scale: Expression) extends GpuBin case DecimalType.Fixed(_, scaleVal) => DecimalUtil.round(lhsValue, scaleVal, roundMode) case ByteType => - intZeroReplacement(Scalar.fromByte(0.toByte)) + intZeroReplacement(() => Scalar.fromByte(0.toByte)) case ShortType => - intZeroReplacement(Scalar.fromShort(0.toShort)) + intZeroReplacement(() => Scalar.fromShort(0.toShort)) case IntegerType => - intZeroReplacement(Scalar.fromInt(0)) + intZeroReplacement(() => Scalar.fromInt(0)) case LongType => - intZeroReplacement(Scalar.fromLong(0L)) + intZeroReplacement(() => Scalar.fromLong(0L)) case FloatType => fpZeroReplacement( - Scalar.fromFloat(0.0f), - Scalar.fromFloat(Float.PositiveInfinity), - Scalar.fromFloat(Float.NegativeInfinity)) + () => Scalar.fromFloat(0.0f), + () => Scalar.fromFloat(Float.PositiveInfinity), + () => Scalar.fromFloat(Float.NegativeInfinity)) case DoubleType => fpZeroReplacement( - Scalar.fromDouble(0.0), - Scalar.fromDouble(Double.PositiveInfinity), - Scalar.fromDouble(Double.NegativeInfinity)) + () => Scalar.fromDouble(0.0), + () => Scalar.fromDouble(Double.PositiveInfinity), + () => Scalar.fromDouble(Double.NegativeInfinity)) case _ => throw new IllegalArgumentException(s"Round operator doesn't support $dataType") } } From 3e9822e2bce1462957ea993967cd88ae22019c62 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Wed, 22 Dec 2021 18:16:23 +0800 Subject: [PATCH 5/7] fix up some cornor cases of round Signed-off-by: sperlingxx --- .../src/main/python/arithmetic_ops_test.py | 33 +++-- .../spark/sql/rapids/mathExpressions.scala | 123 +++++++++++++----- 2 files changed, 102 insertions(+), 54 deletions(-) diff --git a/integration_tests/src/main/python/arithmetic_ops_test.py b/integration_tests/src/main/python/arithmetic_ops_test.py index 09a6ad52f88..7a463d1f0cf 100644 --- a/integration_tests/src/main/python/arithmetic_ops_test.py +++ b/integration_tests/src/main/python/arithmetic_ops_test.py @@ -432,31 +432,28 @@ def test_decimal_round(data_gen): @incompat @approximate_float -def test_round_overflow(): +def test_non_decimal_round_overflow(): gen = StructGen([('byte_c', byte_gen), ('short_c', short_gen), ('int_c', int_gen), ('long_c', long_gen), - ('float_c', float_gen), ('double_c', double_gen), - ('dec32_c', DecimalGen(5, 2)), ('dec64_c', decimal_gen_12_2), - ('dec128_c', decimal_gen_30_2)], nullable=False) + ('float_c', float_gen), ('double_c', double_gen)], nullable=False) assert_gpu_and_cpu_are_equal_collect( - lambda spark: gen_df(spark, gen, length=8192).selectExpr( + lambda spark: gen_df(spark, gen).selectExpr( 'round(byte_c, -2)', 'round(byte_c, -3)', 'round(short_c, -4)', 'round(short_c, -5)', 'round(int_c, -9)', 'round(int_c, -10)', 'round(long_c, -19)', 'round(long_c, -20)', - 'round(float_c, -39)', 'round(float_c, 39)', - 'round(double_c, -309)', 'round(double_c, 309)', - 'round(dec32_c, -9)', - 'round(dec64_c, -19)', - # 'round(dec128_c, -34)', - 'bround(byte_c, -3)', - 'bround(short_c, -5)', - 'bround(int_c, -10)', - 'bround(long_c, -20)', - 'bround(float_c, -39)', 'bround(float_c, 39)', - 'bround(double_c, -309)', 'bround(double_c, 309)' - ), - conf=allow_negative_scale_of_decimal_conf) + 'round(float_c, -38)', 'round(float_c, -39)', + 'round(float_c, 38)', 'round(float_c, 39)', + 'round(double_c, -308)', 'round(double_c, -309)', + 'round(double_c, 309)', 'round(double_c, 309)', + 'bround(byte_c, -2)', 'bround(byte_c, -3)', + 'bround(short_c, -4)', 'bround(short_c, -5)', + 'bround(int_c, -9)', 'bround(int_c, -10)', + 'bround(long_c, -19)', 'bround(long_c, -20)', + 'bround(float_c, -38)', 'bround(float_c, -39)', + 'bround(float_c, 38)', 'bround(float_c, 39)', + 'bround(double_c, -308)', 'bround(double_c, -309)', + 'bround(double_c, 309)', 'bround(double_c, 309)')) @approximate_float @pytest.mark.parametrize('data_gen', double_gens, ids=idfn) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala index d736d5081af..6919c84a7d2 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala @@ -456,14 +456,28 @@ abstract class GpuRoundBase(child: Expression, scale: Expression) extends GpuBin val lhsValue = value.getBase - def intZeroReplacement(zeroFn: () => Scalar): ColumnVector = { + // Fixes up integral values rounded by a scale exceeding/reaching the max digits of data + // type. Under this circumstance, cuDF may produce different results to Spark. + // + // In this method, we handle round overflow, aligning the inconsistent results to Spark. + // + // For scales exceeding max digits, we can simply return zero values. + // + // For scales equaling to max digits, we need to perform round. Fortunately, round up + // will NOT occur on the max digits of numeric types except LongType. Therefore, we only + // need to handle round down for most of types, through returning zero values. + def fixUpOverflowInts(zeroFn: () => Scalar): ColumnVector = { val scaleVal = scale.getValue.asInstanceOf[Int] + // Rounding on the max digit of long values, which should be specialized handled since + // it may be needed to round up, which will produce inconsistent results because of + // overflow. Otherwise, we only need to handle round down situations. if (-scaleVal == 19 && lhsValue.getType == DType.INT64) { - longBoundReplacement(zeroFn) + fixUpInt64OnBounds(zeroFn) } else if (-scaleVal >= DecimalUtil.getPrecisionForIntegralType(lhsValue.getType)) { withResource(zeroFn()) { s => withResource(ColumnVector.fromScalar(s, lhsValue.getRowCount.toInt)) { zero => + // set null mask if necessary if (lhsValue.hasNulls) { zero.mergeAndSetValidity(BinaryOp.BITWISE_AND, lhsValue) } else { @@ -476,59 +490,96 @@ abstract class GpuRoundBase(child: Expression, scale: Expression) extends GpuBin } } - def longBoundReplacement(zeroFn: () => Scalar): ColumnVector = { - val scalars = Seq(zeroFn(), - Scalar.fromLong(1000000000000000000L), - Scalar.fromLong(4L), Scalar.fromLong(-4L), - Scalar.fromLong(8446744073709551616L), Scalar.fromLong(-8446744073709551616L)) - withResource(scalars) { case Seq(zero, base, five, negFive, repLit, negRepLit) => - val (needPosRep, needNegRep) = withResource(lhsValue.div(base)) { headDigit => - closeOnExcept(headDigit.greaterThan(five)) { posRep => - closeOnExcept(headDigit.lessThan(negFive)) { negRep => + // Compared to other non-decimal numeric types, Int64(LongType) is a bit special in terms of + // rounding by the max digit. Because the bound values of LongType can be rounded up, while + // other numeric types can only be rounded down: + // + // the max value of Byte: 127 + // The first digit is up to 1, which can't be rounded up. + // the max value of Short: 32767 + // The first digit is up to 3, which can't be rounded up. + // the max value of Int32: 2147483647 + // The first digit is up to 2, which can't be rounded up. + // the max value of Float32: 3.4028235E38 + // The first digit is up to 3, which can't be rounded up. + // the max value of Float64: 1.7976931348623157E308 + // The first digit is up to 1, which can't be rounded up. + // the max value of Int64: 9223372036854775807 + // The first digit is up to 9, which can be rounded up. + // + // When rounding up 19-digits long values on the first digit, the result can be 1e19 or -1e19. + // Since LongType can not hold these two values, the 1e19 overflows as -8446744073709551616L, + // and the -1e19 overflows as 8446744073709551616L. The overflow happens in the same way for + // HALF_UP (round) and HALF_EVEN (bround). + def fixUpInt64OnBounds(zeroFn: () => Scalar): ColumnVector = { + // Builds predicates on whether there is a round up on the max digit or not + val litForCmp = Seq(Scalar.fromLong(1000000000000000000L), + Scalar.fromLong(4L), + Scalar.fromLong(-4L)) + val (needRep, needNegRep) = withResource(litForCmp) { case Seq(base, four, minusFour) => + withResource(lhsValue.div(base)) { headDigit => + closeOnExcept(headDigit.greaterThan(four)) { posRep => + closeOnExcept(headDigit.lessThan(minusFour)) { negRep => posRep -> negRep } } } - val repVal = withResource(needPosRep) { _ => + } + // Replaces with corresponding literals + val litForRep = Seq(zeroFn(), + Scalar.fromLong(8446744073709551616L), + Scalar.fromLong(-8446744073709551616L)) + val repVal = withResource(litForRep) { case Seq(zero, upLit, negUpLit) => + withResource(needRep) { _ => withResource(needNegRep) { _ => - withResource(needNegRep.ifElse(repLit, zero)) { negBranch => - needPosRep.ifElse(negRepLit, negBranch) + withResource(needNegRep.ifElse(upLit, zero)) { negBranch => + needRep.ifElse(negUpLit, negBranch) } } } - withResource(repVal) { _ => - if (lhsValue.hasNulls) { - repVal.mergeAndSetValidity(BinaryOp.BITWISE_AND, lhsValue) - } else { - repVal.incRefCount() - } + } + // Handles null values + withResource(repVal) { _ => + if (lhsValue.hasNulls) { + repVal.mergeAndSetValidity(BinaryOp.BITWISE_AND, lhsValue) + } else { + repVal.incRefCount() } } } + // Fixes up float points rounded by a scale exceeding the max digits of data type. Under this + // circumstance, cuDF produces different results to Spark. + // Compared to integral values, fixing up round overflow of float points needs to take care + // of some special values: nan, inf, -inf. def fpZeroReplacement(zeroFn: () => Scalar, - infFn: () => Scalar, - negInfFn: () => Scalar): ColumnVector = { + infFn: () => Scalar, + negInfFn: () => Scalar): ColumnVector = { val scaleVal = scale.getValue.asInstanceOf[Int] val maxDigits = if (dataType == FloatType) 39 else 309 if (-scaleVal >= maxDigits) { - withResource(Array(zeroFn(), infFn(), negInfFn())) { case Array(zero, inf, negInf) => - val joinedCondition = List( - () => lhsValue.isNotNan, - () => lhsValue.notEqualTo(inf), - () => lhsValue.notEqualTo(negInf) - ).foldLeft(lhsValue.isNotNull) { case (cond, fn) => - withResource(cond) { _ => - withResource(fn()) { newCondition => - cond.and(newCondition) + // replaces common values (!Null AND !Nan AND !Inf And !-Inf) with zero, while keeps + // all the special values unchanged + withResource(Seq(zeroFn(), infFn(), negInfFn())) { case Seq(zero, inf, negInf) => + // builds joined predicate: !Null AND !Nan AND !Inf And !-Inf + val joinedPredicate = { + val conditions = Seq(() => lhsValue.isNotNan, + () => lhsValue.notEqualTo(inf), + () => lhsValue.notEqualTo(negInf)) + conditions.foldLeft(lhsValue.isNotNull) { case (buffer, builder) => + withResource(buffer) { _ => + withResource(builder()) { predicate => + buffer.and(predicate) + } } } } - withResource(joinedCondition) { cond => + withResource(joinedPredicate) { cond => cond.ifElse(zero, lhsValue) } } } else if (scaleVal >= maxDigits) { + // just returns the original values lhsValue.incRefCount() } else { lhsValue.round(scaleVal, roundMode) @@ -539,13 +590,13 @@ abstract class GpuRoundBase(child: Expression, scale: Expression) extends GpuBin case DecimalType.Fixed(_, scaleVal) => DecimalUtil.round(lhsValue, scaleVal, roundMode) case ByteType => - intZeroReplacement(() => Scalar.fromByte(0.toByte)) + fixUpOverflowInts(() => Scalar.fromByte(0.toByte)) case ShortType => - intZeroReplacement(() => Scalar.fromShort(0.toShort)) + fixUpOverflowInts(() => Scalar.fromShort(0.toShort)) case IntegerType => - intZeroReplacement(() => Scalar.fromInt(0)) + fixUpOverflowInts(() => Scalar.fromInt(0)) case LongType => - intZeroReplacement(() => Scalar.fromLong(0L)) + fixUpOverflowInts(() => Scalar.fromLong(0L)) case FloatType => fpZeroReplacement( () => Scalar.fromFloat(0.0f), From c8babb50f39175f054635a9146eae512d9279abc Mon Sep 17 00:00:00 2001 From: Alfred Xu Date: Mon, 10 Jan 2022 13:41:39 +0800 Subject: [PATCH 6/7] Apply suggestions from code review Co-authored-by: Liangcai Li --- integration_tests/src/main/python/arithmetic_ops_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integration_tests/src/main/python/arithmetic_ops_test.py b/integration_tests/src/main/python/arithmetic_ops_test.py index 7a463d1f0cf..9b494c7f575 100644 --- a/integration_tests/src/main/python/arithmetic_ops_test.py +++ b/integration_tests/src/main/python/arithmetic_ops_test.py @@ -445,7 +445,7 @@ def test_non_decimal_round_overflow(): 'round(float_c, -38)', 'round(float_c, -39)', 'round(float_c, 38)', 'round(float_c, 39)', 'round(double_c, -308)', 'round(double_c, -309)', - 'round(double_c, 309)', 'round(double_c, 309)', + 'round(double_c, 308)', 'round(double_c, 309)', 'bround(byte_c, -2)', 'bround(byte_c, -3)', 'bround(short_c, -4)', 'bround(short_c, -5)', 'bround(int_c, -9)', 'bround(int_c, -10)', @@ -453,7 +453,7 @@ def test_non_decimal_round_overflow(): 'bround(float_c, -38)', 'bround(float_c, -39)', 'bround(float_c, 38)', 'bround(float_c, 39)', 'bround(double_c, -308)', 'bround(double_c, -309)', - 'bround(double_c, 309)', 'bround(double_c, 309)')) + 'bround(double_c, 308)', 'bround(double_c, 309)')) @approximate_float @pytest.mark.parametrize('data_gen', double_gens, ids=idfn) From b285be98a2982d41b8e5ea61600be225f37b076a Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Mon, 10 Jan 2022 14:31:49 +0800 Subject: [PATCH 7/7] refine Signed-off-by: sperlingxx --- .../spark/sql/rapids/mathExpressions.scala | 271 +++++++++--------- 1 file changed, 138 insertions(+), 133 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala index 6bb271b666e..f84aa67a9e2 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala @@ -459,159 +459,164 @@ abstract class GpuRoundBase(child: Expression, scale: Expression) extends GpuBin override def doColumnar(value: GpuColumnVector, scale: GpuScalar): ColumnVector = { val lhsValue = value.getBase + val scaleVal = scale.getValue.asInstanceOf[Int] - // Fixes up integral values rounded by a scale exceeding/reaching the max digits of data - // type. Under this circumstance, cuDF may produce different results to Spark. - // - // In this method, we handle round overflow, aligning the inconsistent results to Spark. - // - // For scales exceeding max digits, we can simply return zero values. - // - // For scales equaling to max digits, we need to perform round. Fortunately, round up - // will NOT occur on the max digits of numeric types except LongType. Therefore, we only - // need to handle round down for most of types, through returning zero values. - def fixUpOverflowInts(zeroFn: () => Scalar): ColumnVector = { - val scaleVal = scale.getValue.asInstanceOf[Int] - - // Rounding on the max digit of long values, which should be specialized handled since - // it may be needed to round up, which will produce inconsistent results because of - // overflow. Otherwise, we only need to handle round down situations. - if (-scaleVal == 19 && lhsValue.getType == DType.INT64) { - fixUpInt64OnBounds(zeroFn) - } else if (-scaleVal >= DecimalUtil.getPrecisionForIntegralType(lhsValue.getType)) { - withResource(zeroFn()) { s => - withResource(ColumnVector.fromScalar(s, lhsValue.getRowCount.toInt)) { zero => - // set null mask if necessary - if (lhsValue.hasNulls) { - zero.mergeAndSetValidity(BinaryOp.BITWISE_AND, lhsValue) - } else { - zero.incRefCount() - } + dataType match { + case DecimalType.Fixed(_, scaleVal) => + DecimalUtil.round(lhsValue, scaleVal, roundMode) + case ByteType => + fixUpOverflowInts(() => Scalar.fromByte(0.toByte), scaleVal, lhsValue) + case ShortType => + fixUpOverflowInts(() => Scalar.fromShort(0.toShort), scaleVal, lhsValue) + case IntegerType => + fixUpOverflowInts(() => Scalar.fromInt(0), scaleVal, lhsValue) + case LongType => + fixUpOverflowInts(() => Scalar.fromLong(0L), scaleVal, lhsValue) + case FloatType => + fpZeroReplacement( + () => Scalar.fromFloat(0.0f), + () => Scalar.fromFloat(Float.PositiveInfinity), + () => Scalar.fromFloat(Float.NegativeInfinity), + scaleVal, lhsValue) + case DoubleType => + fpZeroReplacement( + () => Scalar.fromDouble(0.0), + () => Scalar.fromDouble(Double.PositiveInfinity), + () => Scalar.fromDouble(Double.NegativeInfinity), + scaleVal, lhsValue) + case _ => + throw new IllegalArgumentException(s"Round operator doesn't support $dataType") + } + } + + // Fixes up integral values rounded by a scale exceeding/reaching the max digits of data + // type. Under this circumstance, cuDF may produce different results to Spark. + // + // In this method, we handle round overflow, aligning the inconsistent results to Spark. + // + // For scales exceeding max digits, we can simply return zero values. + // + // For scales equaling to max digits, we need to perform round. Fortunately, round up + // will NOT occur on the max digits of numeric types except LongType. Therefore, we only + // need to handle round down for most of types, through returning zero values. + private def fixUpOverflowInts(zeroFn: () => Scalar, + scale: Int, + lhs: ColumnVector): ColumnVector = { + // Rounding on the max digit of long values, which should be specialized handled since + // it may be needed to round up, which will produce inconsistent results because of + // overflow. Otherwise, we only need to handle round down situations. + if (-scale == 19 && lhs.getType == DType.INT64) { + fixUpInt64OnBounds(lhs) + } else if (-scale >= DecimalUtil.getPrecisionForIntegralType(lhs.getType)) { + withResource(zeroFn()) { s => + withResource(ColumnVector.fromScalar(s, lhs.getRowCount.toInt)) { zero => + // set null mask if necessary + if (lhs.hasNulls) { + zero.mergeAndSetValidity(BinaryOp.BITWISE_AND, lhs) + } else { + zero.incRefCount() } } - } else { - lhsValue.round(scaleVal, roundMode) } + } else { + lhs.round(scale, roundMode) } + } - // Compared to other non-decimal numeric types, Int64(LongType) is a bit special in terms of - // rounding by the max digit. Because the bound values of LongType can be rounded up, while - // other numeric types can only be rounded down: - // - // the max value of Byte: 127 - // The first digit is up to 1, which can't be rounded up. - // the max value of Short: 32767 - // The first digit is up to 3, which can't be rounded up. - // the max value of Int32: 2147483647 - // The first digit is up to 2, which can't be rounded up. - // the max value of Float32: 3.4028235E38 - // The first digit is up to 3, which can't be rounded up. - // the max value of Float64: 1.7976931348623157E308 - // The first digit is up to 1, which can't be rounded up. - // the max value of Int64: 9223372036854775807 - // The first digit is up to 9, which can be rounded up. - // - // When rounding up 19-digits long values on the first digit, the result can be 1e19 or -1e19. - // Since LongType can not hold these two values, the 1e19 overflows as -8446744073709551616L, - // and the -1e19 overflows as 8446744073709551616L. The overflow happens in the same way for - // HALF_UP (round) and HALF_EVEN (bround). - def fixUpInt64OnBounds(zeroFn: () => Scalar): ColumnVector = { - // Builds predicates on whether there is a round up on the max digit or not - val litForCmp = Seq(Scalar.fromLong(1000000000000000000L), - Scalar.fromLong(4L), - Scalar.fromLong(-4L)) - val (needRep, needNegRep) = withResource(litForCmp) { case Seq(base, four, minusFour) => - withResource(lhsValue.div(base)) { headDigit => - closeOnExcept(headDigit.greaterThan(four)) { posRep => - closeOnExcept(headDigit.lessThan(minusFour)) { negRep => - posRep -> negRep - } + // Compared to other non-decimal numeric types, Int64(LongType) is a bit special in terms of + // rounding by the max digit. Because the bound values of LongType can be rounded up, while + // other numeric types can only be rounded down: + // + // the max value of Byte: 127 + // The first digit is up to 1, which can't be rounded up. + // the max value of Short: 32767 + // The first digit is up to 3, which can't be rounded up. + // the max value of Int32: 2147483647 + // The first digit is up to 2, which can't be rounded up. + // the max value of Float32: 3.4028235E38 + // The first digit is up to 3, which can't be rounded up. + // the max value of Float64: 1.7976931348623157E308 + // The first digit is up to 1, which can't be rounded up. + // the max value of Int64: 9223372036854775807 + // The first digit is up to 9, which can be rounded up. + // + // When rounding up 19-digits long values on the first digit, the result can be 1e19 or -1e19. + // Since LongType can not hold these two values, the 1e19 overflows as -8446744073709551616L, + // and the -1e19 overflows as 8446744073709551616L. The overflow happens in the same way for + // HALF_UP (round) and HALF_EVEN (bround). + private def fixUpInt64OnBounds(lhs: ColumnVector): ColumnVector = { + // Builds predicates on whether there is a round up on the max digit or not + val litForCmp = Seq(Scalar.fromLong(1000000000000000000L), + Scalar.fromLong(4L), + Scalar.fromLong(-4L)) + val (needRep, needNegRep) = withResource(litForCmp) { case Seq(base, four, minusFour) => + withResource(lhs.div(base)) { headDigit => + closeOnExcept(headDigit.greaterThan(four)) { posRep => + closeOnExcept(headDigit.lessThan(minusFour)) { negRep => + posRep -> negRep } } } - // Replaces with corresponding literals - val litForRep = Seq(zeroFn(), - Scalar.fromLong(8446744073709551616L), - Scalar.fromLong(-8446744073709551616L)) - val repVal = withResource(litForRep) { case Seq(zero, upLit, negUpLit) => - withResource(needRep) { _ => - withResource(needNegRep) { _ => - withResource(needNegRep.ifElse(upLit, zero)) { negBranch => - needRep.ifElse(negUpLit, negBranch) - } + } + // Replaces with corresponding literals + val litForRep = Seq(Scalar.fromLong(0L), + Scalar.fromLong(8446744073709551616L), + Scalar.fromLong(-8446744073709551616L)) + val repVal = withResource(litForRep) { case Seq(zero, upLit, negUpLit) => + withResource(needRep) { _ => + withResource(needNegRep) { _ => + withResource(needNegRep.ifElse(upLit, zero)) { negBranch => + needRep.ifElse(negUpLit, negBranch) } } } - // Handles null values - withResource(repVal) { _ => - if (lhsValue.hasNulls) { - repVal.mergeAndSetValidity(BinaryOp.BITWISE_AND, lhsValue) - } else { - repVal.incRefCount() - } + } + // Handles null values + withResource(repVal) { _ => + if (lhs.hasNulls) { + repVal.mergeAndSetValidity(BinaryOp.BITWISE_AND, lhs) + } else { + repVal.incRefCount() } } + } - // Fixes up float points rounded by a scale exceeding the max digits of data type. Under this - // circumstance, cuDF produces different results to Spark. - // Compared to integral values, fixing up round overflow of float points needs to take care - // of some special values: nan, inf, -inf. - def fpZeroReplacement(zeroFn: () => Scalar, - infFn: () => Scalar, - negInfFn: () => Scalar): ColumnVector = { - val scaleVal = scale.getValue.asInstanceOf[Int] - val maxDigits = if (dataType == FloatType) 39 else 309 - if (-scaleVal >= maxDigits) { - // replaces common values (!Null AND !Nan AND !Inf And !-Inf) with zero, while keeps - // all the special values unchanged - withResource(Seq(zeroFn(), infFn(), negInfFn())) { case Seq(zero, inf, negInf) => - // builds joined predicate: !Null AND !Nan AND !Inf And !-Inf - val joinedPredicate = { - val conditions = Seq(() => lhsValue.isNotNan, - () => lhsValue.notEqualTo(inf), - () => lhsValue.notEqualTo(negInf)) - conditions.foldLeft(lhsValue.isNotNull) { case (buffer, builder) => - withResource(buffer) { _ => - withResource(builder()) { predicate => - buffer.and(predicate) - } + // Fixes up float points rounded by a scale exceeding the max digits of data type. Under this + // circumstance, cuDF produces different results to Spark. + // Compared to integral values, fixing up round overflow of float points needs to take care + // of some special values: nan, inf, -inf. + def fpZeroReplacement(zeroFn: () => Scalar, + infFn: () => Scalar, + negInfFn: () => Scalar, + scale: Int, + lhs: ColumnVector): ColumnVector = { + val maxDigits = if (dataType == FloatType) 39 else 309 + if (-scale >= maxDigits) { + // replaces common values (!Null AND !Nan AND !Inf And !-Inf) with zero, while keeps + // all the special values unchanged + withResource(Seq(zeroFn(), infFn(), negInfFn())) { case Seq(zero, inf, negInf) => + // builds joined predicate: !Null AND !Nan AND !Inf And !-Inf + val joinedPredicate = { + val conditions = Seq(() => lhs.isNotNan, + () => lhs.notEqualTo(inf), + () => lhs.notEqualTo(negInf)) + conditions.foldLeft(lhs.isNotNull) { case (buffer, builder) => + withResource(buffer) { _ => + withResource(builder()) { predicate => + buffer.and(predicate) } } } - withResource(joinedPredicate) { cond => - cond.ifElse(zero, lhsValue) - } } - } else if (scaleVal >= maxDigits) { - // just returns the original values - lhsValue.incRefCount() - } else { - lhsValue.round(scaleVal, roundMode) + withResource(joinedPredicate) { cond => + cond.ifElse(zero, lhs) + } } - } - - dataType match { - case DecimalType.Fixed(_, scaleVal) => - DecimalUtil.round(lhsValue, scaleVal, roundMode) - case ByteType => - fixUpOverflowInts(() => Scalar.fromByte(0.toByte)) - case ShortType => - fixUpOverflowInts(() => Scalar.fromShort(0.toShort)) - case IntegerType => - fixUpOverflowInts(() => Scalar.fromInt(0)) - case LongType => - fixUpOverflowInts(() => Scalar.fromLong(0L)) - case FloatType => - fpZeroReplacement( - () => Scalar.fromFloat(0.0f), - () => Scalar.fromFloat(Float.PositiveInfinity), - () => Scalar.fromFloat(Float.NegativeInfinity)) - case DoubleType => - fpZeroReplacement( - () => Scalar.fromDouble(0.0), - () => Scalar.fromDouble(Double.PositiveInfinity), - () => Scalar.fromDouble(Double.NegativeInfinity)) - case _ => throw new IllegalArgumentException(s"Round operator doesn't support $dataType") + } else if (scale >= maxDigits) { + // just returns the original values + lhs.incRefCount() + } else { + lhs.round(scale, roundMode) } }