Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix up incorrect results of rounding past the max digits of data type [databricks] #4420

Merged
merged 9 commits into from
Jan 12, 2022
25 changes: 25 additions & 0 deletions integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,31 @@ def test_decimal_round(data_gen):
'round(a, 10)'),
conf=allow_negative_scale_of_decimal_conf)

@incompat
@approximate_float
def test_non_decimal_round_overflow():
gen = StructGen([('byte_c', byte_gen), ('short_c', short_gen),
('int_c', int_gen), ('long_c', long_gen),
('float_c', float_gen), ('double_c', double_gen)], nullable=False)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, gen).selectExpr(
'round(byte_c, -2)', 'round(byte_c, -3)',
'round(short_c, -4)', 'round(short_c, -5)',
'round(int_c, -9)', 'round(int_c, -10)',
'round(long_c, -19)', 'round(long_c, -20)',
'round(float_c, -38)', 'round(float_c, -39)',
'round(float_c, 38)', 'round(float_c, 39)',
'round(double_c, -308)', 'round(double_c, -309)',
'round(double_c, 308)', 'round(double_c, 309)',
'bround(byte_c, -2)', 'bround(byte_c, -3)',
'bround(short_c, -4)', 'bround(short_c, -5)',
'bround(int_c, -9)', 'bround(int_c, -10)',
'bround(long_c, -19)', 'bround(long_c, -20)',
'bround(float_c, -38)', 'bround(float_c, -39)',
'bround(float_c, 38)', 'bround(float_c, 39)',
'bround(double_c, -308)', 'bround(double_c, -309)',
'bround(double_c, 308)', 'bround(double_c, 309)'))

@approximate_float
@pytest.mark.parametrize('data_gen', double_gens, ids=idfn)
def test_cbrt(data_gen):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,14 +457,166 @@ abstract class GpuRoundBase(child: Expression, scale: Expression) extends GpuBin
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)

override def doColumnar(value: GpuColumnVector, scale: GpuScalar): ColumnVector = {

val lhsValue = value.getBase
val scaleVal = scale.getValue.asInstanceOf[Int]

dataType match {
case DecimalType.Fixed(_, scaleVal) =>
DecimalUtil.round(lhsValue, scaleVal, roundMode)
case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType =>
val scaleVal = scale.getValue.asInstanceOf[Int]
lhsValue.round(scaleVal, roundMode)
case _ => throw new IllegalArgumentException(s"Round operator doesn't support $dataType")
case ByteType =>
fixUpOverflowInts(() => Scalar.fromByte(0.toByte), scaleVal, lhsValue)
case ShortType =>
fixUpOverflowInts(() => Scalar.fromShort(0.toShort), scaleVal, lhsValue)
case IntegerType =>
fixUpOverflowInts(() => Scalar.fromInt(0), scaleVal, lhsValue)
case LongType =>
fixUpOverflowInts(() => Scalar.fromLong(0L), scaleVal, lhsValue)
case FloatType =>
fpZeroReplacement(
() => Scalar.fromFloat(0.0f),
() => Scalar.fromFloat(Float.PositiveInfinity),
() => Scalar.fromFloat(Float.NegativeInfinity),
scaleVal, lhsValue)
case DoubleType =>
fpZeroReplacement(
() => Scalar.fromDouble(0.0),
() => Scalar.fromDouble(Double.PositiveInfinity),
() => Scalar.fromDouble(Double.NegativeInfinity),
scaleVal, lhsValue)
case _ =>
throw new IllegalArgumentException(s"Round operator doesn't support $dataType")
}
}

// Fixes up integral values rounded by a scale exceeding/reaching the max digits of data
// type. Under this circumstance, cuDF may produce different results to Spark.
//
// In this method, we handle round overflow, aligning the inconsistent results to Spark.
//
// For scales exceeding max digits, we can simply return zero values.
//
// For scales equaling to max digits, we need to perform round. Fortunately, round up
// will NOT occur on the max digits of numeric types except LongType. Therefore, we only
// need to handle round down for most of types, through returning zero values.
private def fixUpOverflowInts(zeroFn: () => Scalar,
scale: Int,
lhs: ColumnVector): ColumnVector = {
// Rounding on the max digit of long values, which should be specialized handled since
// it may be needed to round up, which will produce inconsistent results because of
// overflow. Otherwise, we only need to handle round down situations.
if (-scale == 19 && lhs.getType == DType.INT64) {
fixUpInt64OnBounds(lhs)
} else if (-scale >= DecimalUtil.getPrecisionForIntegralType(lhs.getType)) {
withResource(zeroFn()) { s =>
withResource(ColumnVector.fromScalar(s, lhs.getRowCount.toInt)) { zero =>
// set null mask if necessary
if (lhs.hasNulls) {
zero.mergeAndSetValidity(BinaryOp.BITWISE_AND, lhs)
} else {
zero.incRefCount()
}
}
}
} else {
lhs.round(scale, roundMode)
}
}

// Compared to other non-decimal numeric types, Int64(LongType) is a bit special in terms of
// rounding by the max digit. Because the bound values of LongType can be rounded up, while
// other numeric types can only be rounded down:
//
// the max value of Byte: 127
// The first digit is up to 1, which can't be rounded up.
// the max value of Short: 32767
// The first digit is up to 3, which can't be rounded up.
// the max value of Int32: 2147483647
// The first digit is up to 2, which can't be rounded up.
// the max value of Float32: 3.4028235E38
// The first digit is up to 3, which can't be rounded up.
// the max value of Float64: 1.7976931348623157E308
// The first digit is up to 1, which can't be rounded up.
// the max value of Int64: 9223372036854775807
// The first digit is up to 9, which can be rounded up.
//
// When rounding up 19-digits long values on the first digit, the result can be 1e19 or -1e19.
// Since LongType can not hold these two values, the 1e19 overflows as -8446744073709551616L,
// and the -1e19 overflows as 8446744073709551616L. The overflow happens in the same way for
// HALF_UP (round) and HALF_EVEN (bround).
private def fixUpInt64OnBounds(lhs: ColumnVector): ColumnVector = {
// Builds predicates on whether there is a round up on the max digit or not
val litForCmp = Seq(Scalar.fromLong(1000000000000000000L),
Scalar.fromLong(4L),
Scalar.fromLong(-4L))
val (needRep, needNegRep) = withResource(litForCmp) { case Seq(base, four, minusFour) =>
withResource(lhs.div(base)) { headDigit =>
closeOnExcept(headDigit.greaterThan(four)) { posRep =>
closeOnExcept(headDigit.lessThan(minusFour)) { negRep =>
posRep -> negRep
}
}
}
}
// Replaces with corresponding literals
val litForRep = Seq(Scalar.fromLong(0L),
Scalar.fromLong(8446744073709551616L),
Scalar.fromLong(-8446744073709551616L))
val repVal = withResource(litForRep) { case Seq(zero, upLit, negUpLit) =>
withResource(needRep) { _ =>
withResource(needNegRep) { _ =>
withResource(needNegRep.ifElse(upLit, zero)) { negBranch =>
needRep.ifElse(negUpLit, negBranch)
}
}
}
}
// Handles null values
withResource(repVal) { _ =>
if (lhs.hasNulls) {
repVal.mergeAndSetValidity(BinaryOp.BITWISE_AND, lhs)
} else {
repVal.incRefCount()
}
}
}

// Fixes up float points rounded by a scale exceeding the max digits of data type. Under this
// circumstance, cuDF produces different results to Spark.
// Compared to integral values, fixing up round overflow of float points needs to take care
// of some special values: nan, inf, -inf.
def fpZeroReplacement(zeroFn: () => Scalar,
infFn: () => Scalar,
negInfFn: () => Scalar,
scale: Int,
lhs: ColumnVector): ColumnVector = {
val maxDigits = if (dataType == FloatType) 39 else 309
if (-scale >= maxDigits) {
// replaces common values (!Null AND !Nan AND !Inf And !-Inf) with zero, while keeps
// all the special values unchanged
withResource(Seq(zeroFn(), infFn(), negInfFn())) { case Seq(zero, inf, negInf) =>
// builds joined predicate: !Null AND !Nan AND !Inf And !-Inf
val joinedPredicate = {
val conditions = Seq(() => lhs.isNotNan,
() => lhs.notEqualTo(inf),
() => lhs.notEqualTo(negInf))
conditions.foldLeft(lhs.isNotNull) { case (buffer, builder) =>
withResource(buffer) { _ =>
withResource(builder()) { predicate =>
buffer.and(predicate)
}
}
}
}
withResource(joinedPredicate) { cond =>
cond.ifElse(zero, lhs)
}
}
} else if (scale >= maxDigits) {
// just returns the original values
lhs.incRefCount()
} else {
lhs.round(scale, roundMode)
}
}

Expand Down