Skip to content

Commit

Permalink
GpuSequence refactor[databricks] (#4520)
Browse files Browse the repository at this point in the history
* GpuSequence refactor

And update the tests

Signed-off-by: Firestarman <firestarmanllc@gmail.com>
  • Loading branch information
firestarman authored Jan 20, 2022
1 parent 321b760 commit 3c59706
Show file tree
Hide file tree
Showing 3 changed files with 401 additions and 294 deletions.
236 changes: 156 additions & 80 deletions integration_tests/src/main/python/collection_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error
from data_gen import *
from pyspark.sql.types import *
from spark_session import with_cpu_session
Expand Down Expand Up @@ -116,90 +116,166 @@ def test_sort_array_lit(data_gen, is_ascending):
lambda spark: unary_op_df(spark, data_gen, length=10).select(
f.sort_array(f.lit(array_lit), is_ascending)))

# We must restrict the length of sequence, since we may suffer the exception
# "Too long sequence: 2147483745. Should be <= 2147483632" or OOM.
sequence_integral_gens = [
ByteGen(nullable=False, min_val=-20, max_val=20, special_cases=[]),
ShortGen(nullable=False, min_val=-20, max_val=20, special_cases=[]),
IntegerGen(nullable=False, min_val=-20, max_val=20, special_cases=[]),
LongGen(nullable=False, min_val=-20, max_val=20, special_cases=[])
# For functionality test, the sequence length in each row should be limited,
# to avoid the exception as below,
# "Too long sequence: 2147483745. Should be <= 2147483632"
# And the input data should follow the rules below,
# (step > 0 && start <= stop)
# or (step < 0 && start >= stop)
# or (step == 0 && start == stop)
sequence_normal_integral_gens = [
# (step > 0 && start <= stop)
(ByteGen(min_val=-10, max_val=20, special_cases=[]),
ByteGen(min_val=20, max_val=50, special_cases=[]),
ByteGen(min_val=1, max_val=5, special_cases=[])),
(ShortGen(min_val=-10, max_val=20, special_cases=[]),
ShortGen(min_val=20, max_val=50, special_cases=[]),
ShortGen(min_val=1, max_val=5, special_cases=[])),
(IntegerGen(min_val=-10, max_val=20, special_cases=[]),
IntegerGen(min_val=20, max_val=50, special_cases=[]),
IntegerGen(min_val=1, max_val=5, special_cases=[])),
(LongGen(min_val=-10, max_val=20, special_cases=[None]),
LongGen(min_val=20, max_val=50, special_cases=[None]),
LongGen(min_val=1, max_val=5, special_cases=[None])),
# (step < 0 && start >= stop)
(ByteGen(min_val=20, max_val=50, special_cases=[]),
ByteGen(min_val=-10, max_val=20, special_cases=[]),
ByteGen(min_val=-5, max_val=-1, special_cases=[])),
(ShortGen(min_val=20, max_val=50, special_cases=[]),
ShortGen(min_val=-10, max_val=20, special_cases=[]),
ShortGen(min_val=-5, max_val=-1, special_cases=[])),
(IntegerGen(min_val=20, max_val=50, special_cases=[]),
IntegerGen(min_val=-10, max_val=20, special_cases=[]),
IntegerGen(min_val=-5, max_val=-1, special_cases=[])),
(LongGen(min_val=20, max_val=50, special_cases=[None]),
LongGen(min_val=-10, max_val=20, special_cases=[None]),
LongGen(min_val=-5, max_val=-1, special_cases=[None])),
# (step == 0 && start == stop)
(ByteGen(min_val=20, max_val=20, special_cases=[]),
ByteGen(min_val=20, max_val=20, special_cases=[]),
ByteGen(min_val=0, max_val=0, special_cases=[])),
(ShortGen(min_val=20, max_val=20, special_cases=[]),
ShortGen(min_val=20, max_val=20, special_cases=[]),
ShortGen(min_val=0, max_val=0, special_cases=[])),
(IntegerGen(min_val=20, max_val=20, special_cases=[]),
IntegerGen(min_val=20, max_val=20, special_cases=[]),
IntegerGen(min_val=0, max_val=0, special_cases=[])),
(LongGen(min_val=20, max_val=20, special_cases=[None]),
LongGen(min_val=20, max_val=20, special_cases=[None]),
LongGen(min_val=0, max_val=0, special_cases=[None])),
]

@pytest.mark.parametrize('data_gen', sequence_integral_gens, ids=idfn)
def test_sequence_without_step(data_gen):
sequence_normal_no_step_integral_gens = [(gens[0], gens[1]) for
gens in sequence_normal_integral_gens]

@pytest.mark.parametrize('start_gen,stop_gen', sequence_normal_no_step_integral_gens, ids=idfn)
def test_sequence_without_step(start_gen, stop_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark :
three_col_df(spark, data_gen, data_gen, data_gen)
.selectExpr("sequence(a, b)",
"sequence(a, 0)",
"sequence(0, b)"))

# This function is to generate the correct sequence data according to below limitations.
# (step > num.zero && start <= stop)
# || (step < num.zero && start >= stop)
# || (step == num.zero && start == stop)
def get_sequence_data(data_gen, length=2048):
rand = random.Random(0)
data_gen.start(rand)
list = []
for index in range(length):
start = data_gen.gen()
stop = data_gen.gen()
step = data_gen.gen()
# decide the direction of step
if start < stop:
step = abs(step) + 1
elif start == stop:
step = 0
else:
step = -(abs(step) + 1)
list.append(tuple([start, stop, step]))
# add special case
list.append(tuple([2, 2, 0]))
return list

def get_sequence_df(spark, data, data_type):
return spark.createDataFrame(
SparkContext.getOrCreate().parallelize(data),
StructType([StructField('a', data_type), StructField('b', data_type), StructField('c', data_type)]))

# test below case
# (2, -1, -1)
# (2, 5, 2)
# (2, 2, 0)
@pytest.mark.parametrize('data_gen', sequence_integral_gens, ids=idfn)
def test_sequence_with_step_case1(data_gen):
data = get_sequence_data(data_gen)
lambda spark: two_col_df(spark, start_gen, stop_gen).selectExpr(
"sequence(a, b)",
"sequence(a, 20)",
"sequence(20, b)"))

@pytest.mark.parametrize('start_gen,stop_gen,step_gen', sequence_normal_integral_gens, ids=idfn)
def test_sequence_with_step(start_gen, stop_gen, step_gen):
# Get a step scalar from the 'step_gen' which follows the rules.
step_gen.start(random.Random(0))
step_lit = step_gen.gen()
assert_gpu_and_cpu_are_equal_collect(
lambda spark :
get_sequence_df(spark, data, data_gen.data_type)
.selectExpr("sequence(a, b, c)"))
lambda spark: three_col_df(spark, start_gen, stop_gen, step_gen).selectExpr(
"sequence(a, b, c)",
"sequence(a, b, {})".format(step_lit),
"sequence(a, 20, c)",
"sequence(a, 20, {})".format(step_lit),
"sequence(20, b, c)",
"sequence(20, 20, c)",
"sequence(20, b, {})".format(step_lit)))

sequence_three_cols_integral_gens = [
(ByteGen(nullable=False, min_val=-10, max_val=10, special_cases=[]),
ByteGen(nullable=False, min_val=30, max_val=50, special_cases=[]),
ByteGen(nullable=False, min_val=1, max_val=10, special_cases=[])),
(ShortGen(nullable=False, min_val=-10, max_val=10, special_cases=[]),
ShortGen(nullable=False, min_val=30, max_val=50, special_cases=[]),
ShortGen(nullable=False, min_val=1, max_val=10, special_cases=[])),
(IntegerGen(nullable=False, min_val=-10, max_val=10, special_cases=[]),
IntegerGen(nullable=False, min_val=30, max_val=50, special_cases=[]),
IntegerGen(nullable=False, min_val=1, max_val=10, special_cases=[])),
(LongGen(nullable=False, min_val=-10, max_val=10, special_cases=[-10, 10]),
LongGen(nullable=False, min_val=30, max_val=50, special_cases=[30, 50]),
LongGen(nullable=False, min_val=1, max_val=10, special_cases=[1, 10])),
# Illegal sequence boundaries:
# step > 0, but start > stop
# step < 0, but start < stop
# step == 0, but start != stop
#
# All integral types share the same check implementation, so each case
# will not run over all the types in the tests.
sequence_illegal_boundaries_integral_gens = [
# step > 0, but start > stop
(ShortGen(min_val=20, max_val=50, special_cases=[]),
ShortGen(min_val=-10, max_val=19, special_cases=[]),
ShortGen(min_val=1, max_val=5, special_cases=[])),
(LongGen(min_val=20, max_val=50, special_cases=[None]),
LongGen(min_val=-10, max_val=19, special_cases=[None]),
LongGen(min_val=1, max_val=5, special_cases=[None])),
# step < 0, but start < stop
(ByteGen(min_val=-10, max_val=19, special_cases=[]),
ByteGen(min_val=20, max_val=50, special_cases=[]),
ByteGen(min_val=-5, max_val=-1, special_cases=[])),
(IntegerGen(min_val=-10, max_val=19, special_cases=[]),
IntegerGen(min_val=20, max_val=50, special_cases=[]),
IntegerGen(min_val=-5, max_val=-1, special_cases=[])),
# step == 0, but start != stop
(IntegerGen(min_val=-10, max_val=19, special_cases=[]),
IntegerGen(min_val=20, max_val=50, special_cases=[]),
IntegerGen(min_val=0, max_val=0, special_cases=[]))
]

# Test the scalar case for the data start < stop and step > 0
@pytest.mark.parametrize('start_gen,stop_gen,step_gen', sequence_three_cols_integral_gens, ids=idfn)
def test_sequence_with_step_case2(start_gen, stop_gen, step_gen):
@pytest.mark.parametrize('start_gen,stop_gen,step_gen', sequence_illegal_boundaries_integral_gens, ids=idfn)
def test_sequence_illegal_boundaries(start_gen, stop_gen, step_gen):
assert_gpu_and_cpu_error(
lambda spark:three_col_df(spark, start_gen, stop_gen, step_gen).selectExpr(
"sequence(a, b, c)").collect(),
conf = {}, error_message = "Illegal sequence boundaries")

# Exceed the max length of a sequence
# "Too long sequence: xxxxxxxxxx. Should be <= 2147483632"
sequence_too_long_length_gens = [
IntegerGen(min_val=2147483633, max_val=2147483633, special_cases=[]),
LongGen(min_val=2147483635, max_val=2147483635, special_cases=[None])
]

@pytest.mark.parametrize('stop_gen', sequence_too_long_length_gens, ids=idfn)
def test_sequence_too_long_sequence(stop_gen):
assert_gpu_and_cpu_error(
# To avoid OOM, reduce the row number to 1, it is enough to verify this case.
lambda spark:unary_op_df(spark, stop_gen, 1).selectExpr(
"sequence(0, a)").collect(),
conf = {}, error_message = "Too long sequence")

def get_sequence_cases_mixed_df(spark, length=2048):
# Generate the sequence data following the 3 rules mixed in a single dataset.
# (step > num.zero && start <= stop) ||
# (step < num.zero && start >= stop) ||
# (step == num.zero && start == stop)
data_gen = IntegerGen(nullable=False, min_val=-10, max_val=10, special_cases=[])
def get_sequence_data(gen, len):
gen.start(random.Random(0))
list = []
for index in range(len):
start = gen.gen()
stop = gen.gen()
step = gen.gen()
# decide the direction of step
if start < stop:
step = abs(step) + 1
elif start == stop:
step = 0
else:
step = -(abs(step) + 1)
list.append(tuple([start, stop, step]))
# add special case
list.append(tuple([2, 2, 0]))
return list

mixed_schema = StructType([
StructField('a', data_gen.data_type),
StructField('b', data_gen.data_type),
StructField('c', data_gen.data_type)])
return spark.createDataFrame(
SparkContext.getOrCreate().parallelize(get_sequence_data(data_gen, length)),
mixed_schema)

# test for 3 cases mixed in a single dataset
def test_sequence_with_step_mixed_cases():
assert_gpu_and_cpu_are_equal_collect(
lambda spark :
three_col_df(spark, start_gen, stop_gen, step_gen)
.selectExpr("sequence(a, b, c)",
"sequence(a, b, 2)",
"sequence(a, 20, c)",
"sequence(a, 20, 2)",
"sequence(0, b, c)",
"sequence(0, 4, c)",
"sequence(0, b, 3)"),)
lambda spark: get_sequence_cases_mixed_df(spark)
.selectExpr("sequence(a, b, c)"))
44 changes: 44 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/BoolUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import ai.rapids.cudf.{ColumnVector, DType}

object BoolUtils extends Arm {

/**
* Whether all the valid rows in 'col' are true. An empty column will get true.
* null rows are skipped.
*/
def isAllValidTrue(col: ColumnVector): Boolean = {
assert(DType.BOOL8 == col.getType, "input column type is not bool")
if (col.getRowCount == 0) {
return true
}

if (col.getRowCount == col.getNullCount) {
// all is null, equal to empty, since nulls should be skipped.
return true
}
withResource(col.all()) { allTrue =>
// Guaranteed there is at least one row and not all of the rows are null,
// so result scalar must be valid
allTrue.getBoolean
}
}

}
Loading

0 comments on commit 3c59706

Please sign in to comment.