Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support division on ANSI interval types #5134

Merged
merged 9 commits into from
Apr 20, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,3 +954,59 @@ def test_day_time_interval_multiply_number(data_gen):
('_c2', data_gen)]
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, gen_list).selectExpr("_c1 * _c2"))


@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('data_gen', _no_overflow_multiply_gens + [DoubleGen(min_exp=0, max_exp=5, special_cases=[])], ids=idfn)
def test_day_time_interval_division_number_no_overflow1(data_gen):
gen_list = [('_c1', DayTimeIntervalGen(min_value=timedelta(seconds=-5000 * 365 * 86400), max_value=timedelta(seconds=5000 * 365 * 86400))),
('_c2', data_gen)]
assert_gpu_and_cpu_are_equal_collect(
# avoid dividing by 0
lambda spark: gen_df(spark, gen_list).selectExpr("_c1 / case when _c2 = 0 then cast(1 as {}) else _c2 end".format(to_cast_string(data_gen.data_type))))

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('data_gen', _no_overflow_multiply_gens + [DoubleGen(min_exp=-5, max_exp=0, special_cases=[])], ids=idfn)
def test_day_time_interval_division_number_no_overflow2(data_gen):
gen_list = [('_c1', DayTimeIntervalGen(min_value=timedelta(seconds=-20 * 86400), max_value=timedelta(seconds=20 * 86400))),
('_c2', data_gen)]
assert_gpu_and_cpu_are_equal_collect(
# avoid dividing by 0
lambda spark: gen_df(spark, gen_list).selectExpr("_c1 / case when _c2 = 0 then cast(1 as {}) else _c2 end".format(to_cast_string(data_gen.data_type))))

def _get_overflow_df_2cols(spark, data_types, values, expr):
return spark.createDataFrame(
SparkContext.getOrCreate().parallelize([values]),
StructType([
StructField('a', data_types[0]),
StructField('b', data_types[1])
])
).selectExpr(expr)

# test interval division overflow, such as interval / 0, Long.MinValue / -1 ...
@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('data_type,value_pair', [
(ByteType(), [timedelta(seconds=1), 0]),
(ShortType(), [timedelta(seconds=1), 0]),
(IntegerType(), [timedelta(seconds=1), 0]),
(LongType(), [timedelta(seconds=1), 0]),
(FloatType(), [timedelta(seconds=1), 0.0]),
(FloatType(), [timedelta(seconds=1), -0.0]),
(DoubleType(), [timedelta(seconds=1), 0.0]),
(DoubleType(), [timedelta(seconds=1), -0.0]),
(FloatType(), [timedelta(seconds=1), float('NaN')]),
(DoubleType(), [timedelta(seconds=1), float('NaN')]),
(FloatType(), [MAX_DAY_TIME_INTERVAL, 0.1]),
(DoubleType(), [MAX_DAY_TIME_INTERVAL, 0.1]),
(FloatType(), [MIN_DAY_TIME_INTERVAL, 0.1]),
(DoubleType(), [MIN_DAY_TIME_INTERVAL, 0.1]),
(LongType(), [MIN_DAY_TIME_INTERVAL, -1]),
(FloatType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN
(DoubleType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN
(IntegerType(), [timedelta(microseconds=LONG_MIN), -1])
], ids=idfn)
def test_day_time_interval_division_overflow(data_type, value_pair):
assert_gpu_and_cpu_error(
df_fun=lambda spark: _get_overflow_df_2cols(spark, [DayTimeIntervalType(), data_type], value_pair, 'a / b').collect(),
conf={},
error_message='ArithmeticException')
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids._
import org.apache.spark.sql.rapids.shims.{GpuMultiplyDTInterval, GpuMultiplyYMInterval, GpuTimeAdd}
import org.apache.spark.sql.rapids.shims.{GpuDivideDTInterval, GpuDivideYMInterval, GpuMultiplyDTInterval, GpuMultiplyYMInterval, GpuTimeAdd}
import org.apache.spark.sql.types.{CalendarIntervalType, DayTimeIntervalType, DecimalType, StructType}
import org.apache.spark.unsafe.types.CalendarInterval

Expand Down Expand Up @@ -203,6 +203,28 @@ trait Spark33XShims extends Spark321PlusShims with Spark320PlusNonDBShims {
(a, conf, p, r) => new BinaryExprMeta[MultiplyDTInterval](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuMultiplyDTInterval(lhs, rhs)
}),
GpuOverrides.expr[DivideYMInterval](
"Year-month interval * operator",
ExprChecks.binaryProject(
TypeSig.YEARMONTH,
TypeSig.YEARMONTH,
("lhs", TypeSig.YEARMONTH, TypeSig.YEARMONTH),
("rhs", TypeSig.gpuNumeric - TypeSig.DECIMAL_128, TypeSig.gpuNumeric)),
(a, conf, p, r) => new BinaryExprMeta[DivideYMInterval](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuDivideYMInterval(lhs, rhs)
}),
GpuOverrides.expr[DivideDTInterval](
"Day-time interval * operator",
ExprChecks.binaryProject(
TypeSig.DAYTIME,
TypeSig.DAYTIME,
("lhs", TypeSig.DAYTIME, TypeSig.DAYTIME),
("rhs", TypeSig.gpuNumeric - TypeSig.DECIMAL_128, TypeSig.gpuNumeric)),
(a, conf, p, r) => new BinaryExprMeta[DivideDTInterval](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuDivideDTInterval(lhs, rhs)
})
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
super.getExprs ++ map
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.apache.spark.sql.rapids.shims

import java.math.BigInteger

import ai.rapids.cudf.{BinaryOperable, ColumnVector, DType, RoundMode, Scalar}
import com.nvidia.spark.rapids.{Arm, BoolUtils, GpuBinaryExpression, GpuColumnVector, GpuScalar}

Expand Down Expand Up @@ -184,6 +186,113 @@ object IntervalUtils extends Arm {
z.castTo(DType.INT64)
}
}

def getDouble(s: Scalar): Double = {
s.getType match {
case DType.INT8 => s.getByte.toDouble
case DType.INT16 => s.getShort.toDouble
case DType.INT32 => s.getInt.toDouble
case DType.INT64 => s.getLong.toDouble
case DType.FLOAT32 => s.getFloat.toDouble
case DType.FLOAT64 => s.getDouble
case t => throw new IllegalArgumentException(s"Unexpected type $t")
}
}

def getLong(s: Scalar): Long = {
s.getType match {
case DType.INT8 => s.getByte.toLong
case DType.INT16 => s.getShort.toLong
case DType.INT32 => s.getInt.toLong
case DType.INT64 => s.getLong
case t => throw new IllegalArgumentException(s"Unexpected type $t")
}
}

def hasZero(cv: BinaryOperable): Boolean = {
cv match {
case s: Scalar => getDouble(s) == 0.0d
case cv: ColumnVector =>
withResource(Scalar.fromInt(0)) { zero =>
withResource(cv.equalTo(zero)) { isZero =>
BoolUtils.isAnyValidTrue(isZero)
}
}
}
}

/**
* Divide p by q rounding with the HALF_UP mode.
* Logic is round(p.asDecimal / q, HALF_UP).
*
* It's equivalent to
* com.google.common.math.LongMath.divide(p, q, RoundingMode.HALF_UP):
* long div = p / q; // throws if q == 0
* long rem = p - q * div; // equals p % q
* // signum is 1 if p and q are both non-negative or both negative, and -1 otherwise.
* int signum = 1 | (int) ((p xor q) >> (Long.SIZE - 1));
* long absRem = abs(rem);
* long cmpRemToHalfDivisor = absRem - (abs(q) - absRem);
* increment = cmpRemToHalfDivisor >= 0
* return increment ? div + signum : div;
*
* @param p integer(int/long) cv or scalar
* @param q integer(byte/short/int/long) cv or Scalar, if p is scala, q will not be scala
* @return decimal 128 cv of p / q
*/
def divWithHalfUpModeWithOverflowCheck(p: BinaryOperable, q: BinaryOperable): ColumnVector = {
// 1. overflow check q is 0
if (IntervalUtils.hasZero(q)) {
throw new ArithmeticException("overflow: interval / zero")
}

// 2. overflow check (p == min(int min or long min) && q == -1)
val min = p.getType match {
case DType.INT32 => Int.MinValue.toLong
case DType.INT64 => Long.MinValue
case t => throw new IllegalArgumentException(s"Unexpected type $t")
}
withResource(Scalar.fromLong(min)) { minScalar =>
withResource(Scalar.fromLong(-1L)) { negOneScalar =>
(p, q) match {
case (lCv: ColumnVector, rCv: ColumnVector) =>
withResource(lCv.equalTo(minScalar)) { isMin =>
withResource(rCv.equalTo(negOneScalar)) { isNegOne =>
withResource(isMin.and(isNegOne)) { invalid =>
if (BoolUtils.isAnyValidTrue(invalid)) {
throw new ArithmeticException("overflow occurs")
}
}
}
}
case (lCv: ColumnVector, rS: Scalar) =>
withResource(lCv.equalTo(minScalar)) { isMin =>
if (getLong(rS) == -1L && BoolUtils.isAnyValidTrue(isMin)) {
throw new ArithmeticException("overflow occurs")
}
}
case (lS: Scalar, rCv: ColumnVector) =>
withResource(rCv.equalTo(negOneScalar)) { isNegOne =>
if (getLong(lS) == min && BoolUtils.isAnyValidTrue(isNegOne)) {
throw new ArithmeticException("overflow occurs")
}
}
case (lS: Scalar, rS: Scalar) =>
getLong(lS) == min && getLong(rS) == -1L
}
}
}

// 3. round(p.asDecimal / q)
val dT = DType.create(DType.DTypeEnum.DECIMAL128, -1)
val leftDecimal = p match {
case pCv: ColumnVector => pCv.castTo(dT)
jlowe marked this conversation as resolved.
Show resolved Hide resolved
case pS: Scalar => Scalar.fromDecimal(-1, new BigInteger((getLong(pS) * 10L).toString))
}
withResource(leftDecimal.div(q, dT)) { t =>
t.round(RoundMode.HALF_UP)
}
}
}

/**
Expand Down Expand Up @@ -321,3 +430,139 @@ case class GpuMultiplyDTInterval(
override def dataType: DataType = DayTimeIntervalType()

}

/**
* Divide a day-time interval by a numeric with a HALF_UP mode:
* e.g.: 3 / 2 = 1.5 will be rounded to 2; -3 / 2 = -1.5 will be rounded to -2
* year-month interval / number(byte, short, int, long, float, double)
* Note not support year-month interval / decimal
* Year-month interval's internal type is int, the value of int is 12 * year + month
* left expression is interval, right expression is number
* Rewrite from Spark code:
* https://github.com/apache/spark/blob/v3.2.1/sql/catalyst/src/main/scala/
* org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala#L615
*
*/
case class GpuDivideYMInterval(
interval: Expression,
num: Expression) extends GpuBinaryExpression with ImplicitCastInputTypes with NullIntolerant {

override def left: Expression = interval

override def right: Expression = num

override def doColumnar(interval: GpuColumnVector, numScalar: GpuScalar): ColumnVector = {
doColumnar(interval.getBase, numScalar.getBase, num.dataType)
}

override def doColumnar(interval: GpuColumnVector, num: GpuColumnVector): ColumnVector = {
doColumnar(interval.getBase, num.getBase, num.dataType)
}

override def doColumnar(intervalScalar: GpuScalar, num: GpuColumnVector): ColumnVector = {
doColumnar(intervalScalar.getBase, num.getBase, num.dataType)
}

override def doColumnar(numRows: Int, intervalScalar: GpuScalar,
numScalar: GpuScalar): ColumnVector = {
withResource(GpuColumnVector.from(intervalScalar, numRows, interval.dataType)) { expandedLhs =>
doColumnar(expandedLhs, numScalar)
}
}

private def doColumnar(interval: BinaryOperable, numOperable: BinaryOperable,
numType: DataType): ColumnVector = {

numType match {
case ByteType | ShortType | IntegerType | LongType =>
// interval is long; num is byte, short, int or long
// For overflow check: num is 0; interval == Long.Min && num == -1
withResource(IntervalUtils.divWithHalfUpModeWithOverflowCheck(interval, numOperable)) {
// overflow already checked, directly cast without overflow check
decimalRet => decimalRet.castTo(DType.INT32)
}

case FloatType | DoubleType => // num is float or double
withResource(interval.div(numOperable, DType.FLOAT64)) { double =>
// check overflow, then round to int
IntervalUtils.roundDoubleToIntWithOverflowCheck(double)
}
case _ => throw new IllegalArgumentException(
s"Not support num type $numType in GpuDivideYMInterval")
}
}

override def toString: String = s"$interval / $num"

override def inputTypes: Seq[AbstractDataType] = Seq(YearMonthIntervalType, NumericType)

override def dataType: DataType = YearMonthIntervalType()
}

/**
* Divide a day-time interval by a numeric with a HALF_UP mode:
* e.g.: 3 / 2 = 1.5 will be rounded to 2; -3 / 2 = -1.5 will be rounded to -2
* day-time interval / number(byte, short, int, long, float, double)
* Note not support day-time interval / decimal
* Day-time interval's interval type is long, the value of long is the total microseconds
* Rewrite from Spark code:
* https://github.com/apache/spark/blob/v3.2.1/sql/catalyst/src/main/scala/
* org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala#L693
*/
case class GpuDivideDTInterval(
interval: Expression,
num: Expression)
extends GpuBinaryExpression with ImplicitCastInputTypes with NullIntolerant {

override def left: Expression = interval

override def right: Expression = num

override def doColumnar(interval: GpuColumnVector, numScalar: GpuScalar): ColumnVector = {
doColumnar(interval.getBase, numScalar.getBase, num.dataType)
}

override def doColumnar(interval: GpuColumnVector, num: GpuColumnVector): ColumnVector = {
doColumnar(interval.getBase, num.getBase, num.dataType)
}

override def doColumnar(intervalScalar: GpuScalar, num: GpuColumnVector): ColumnVector = {
doColumnar(intervalScalar.getBase, num.getBase, num.dataType)
}

override def doColumnar(numRows: Int, intervalScalar: GpuScalar,
numScalar: GpuScalar): ColumnVector = {
withResource(GpuColumnVector.from(intervalScalar, numRows, interval.dataType)) { expandedLhs =>
doColumnar(expandedLhs, numScalar)
}
}

private def doColumnar(interval: BinaryOperable, numOperable: BinaryOperable,
numType: DataType): ColumnVector = {
numType match {
case ByteType | ShortType | IntegerType | LongType =>
// interval is long; num is byte, short, int or long
// For overflow check: num is 0; interval == Long.Min && num == -1
withResource(IntervalUtils.divWithHalfUpModeWithOverflowCheck(interval, numOperable)) {
// overflow already checked, directly cast without overflow check
decimalRet => decimalRet.castTo(DType.INT64)
}

case FloatType | DoubleType =>
// interval is long; num is float or double
withResource(interval.div(numOperable, DType.FLOAT64)) { double =>
// check overflow, then round to long
IntervalUtils.roundDoubleToLongWithOverflowCheck(double)
}
case _ => throw new IllegalArgumentException(
s"Not support num type $numType in GpuDivideDTInterval")
}
}

override def toString: String = s"$interval / $num"

override def inputTypes: Seq[AbstractDataType] = Seq(DayTimeIntervalType, NumericType)

override def dataType: DataType = DayTimeIntervalType()

}
Loading