Skip to content

Commit

Permalink
Follow on to ANSI Add (#3561)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <bobby@apache.org>
  • Loading branch information
revans2 authored Sep 20, 2021
1 parent f1fbb52 commit 5856f17
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 18 deletions.
16 changes: 12 additions & 4 deletions integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,13 @@ def test_multiplication(data_gen):
conf=allow_negative_scale_of_decimal_conf)

# No overflow gens here because we just focus on verifying the fallback to CPU when
# enabling ansi mode. But overflows will fail the tests because CPU runs raise
# enabling ANSI mode. But overflows will fail the tests because CPU runs raise
# exceptions.
_no_overflow_multiply_gens = [
ByteGen(min_val = 1, max_val = 10, special_cases=[]),
ShortGen(min_val = 1, max_val = 100, special_cases=[]),
IntegerGen(min_val = 1, max_val = 1000, special_cases=[]),
LongGen(min_val = 1, max_val = 3000, special_cases=[]),
float_gen, double_gen,
decimal_gen_scale_precision, decimal_gen_same_scale_precision, DecimalGen(8, 8)]
LongGen(min_val = 1, max_val = 3000, special_cases=[])]

@allow_non_gpu('ProjectExec', 'Alias', 'CheckOverflow', 'Multiply', 'PromotePrecision', 'Cast')
@pytest.mark.parametrize('data_gen', _no_overflow_multiply_gens, ids=idfn)
Expand All @@ -83,6 +81,16 @@ def test_multiplication_fallback_when_ansi_enabled(data_gen):
'ProjectExec',
conf={'spark.sql.ansi.enabled': 'true'})

@pytest.mark.parametrize('data_gen', [float_gen, double_gen,
decimal_gen_scale_precision], ids=idfn)
def test_multiplication_ansi_enabled(data_gen):
data_type = data_gen.data_type
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).select(
f.col('a') * f.lit(100).cast(data_type),
f.col('a') * f.col('b')),
conf={'spark.sql.ansi.enabled': 'true'})

@pytest.mark.parametrize('lhs', [DecimalGen(6, 5), DecimalGen(6, 4), DecimalGen(5, 4), DecimalGen(5, 3), DecimalGen(4, 2), DecimalGen(3, -2)], ids=idfn)
@pytest.mark.parametrize('rhs', [DecimalGen(6, 3)], ids=idfn)
def test_multiplication_mixed(lhs, rhs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1380,8 +1380,11 @@ object GpuOverrides extends Logging {
"Natural log 1 + expr",
ExprChecks.mathUnary,
(a, conf, p, r) => new UnaryExprMeta[Log1p](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression =
GpuLog(GpuAdd(child, GpuLiteral(1d, DataTypes.DoubleType), SQLConf.get.ansiEnabled))
override def convertToGpu(child: Expression): GpuExpression = {
// No need for overflow checking on the GpuAdd in Double as Double handles overflow
// the same in all modes.
GpuLog(GpuAdd(child, GpuLiteral(1d, DataTypes.DoubleType), false))
}
}),
expr[Log2](
"Log base 2",
Expand Down Expand Up @@ -1663,8 +1666,8 @@ object GpuOverrides extends Logging {

override def tagSelfForAst(): Unit = {
super.tagSelfForAst()
if (ansiEnabled) {
willNotWorkInAst("AST Addition does not support ansi mode.")
if (ansiEnabled && GpuAnsi.needBasicOpOverflowCheck(a.dataType)) {
willNotWorkInAst("AST Addition does not support ANSI mode.")
}
}

Expand Down Expand Up @@ -1720,8 +1723,8 @@ object GpuOverrides extends Logging {
case _ => // NOOP
}

if (SQLConf.get.ansiEnabled) {
willNotWorkOnGpu("GPU Multiplication does not support ansi mode")
if (SQLConf.get.ansiEnabled && GpuAnsi.needBasicOpOverflowCheck(a.dataType)) {
willNotWorkOnGpu("GPU Multiplication does not support ANSI mode")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch

object GpuAnsi {
def needBasicOpOverflowCheck(dt: DataType): Boolean =
dt.isInstanceOf[IntegralType]
}

case class GpuUnaryMinus(child: Expression) extends GpuUnaryExpression
with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
Expand Down Expand Up @@ -115,7 +120,7 @@ case class GpuAdd(
override def doColumnar(lhs: BinaryOperable, rhs: BinaryOperable): ColumnVector = {
val ret = super.doColumnar(lhs, rhs)
// No shims are needed, because it actually supports ANSI mode from Spark v3.0.1.
if (failOnError && needOverflowCheckForType(dataType)) {
if (failOnError && GpuAnsi.needBasicOpOverflowCheck(dataType)) {
// Check overflow. It is true when both arguments have the opposite sign of the result.
// Which is equal to "((x ^ r) & (y ^ r)) < 0" in the form of arithmetic.
closeOnExcept(ret) { r =>
Expand All @@ -140,13 +145,6 @@ case class GpuAdd(
}
ret
}

/**
* Spark "Add" checks the overflow only for integral types.
*/
private def needOverflowCheckForType(dt: DataType): Boolean =
dt.isInstanceOf[IntegralType]

}

case class GpuSubtract(left: Expression, right: Expression) extends CudfBinaryArithmetic {
Expand Down

0 comments on commit 5856f17

Please sign in to comment.