Skip to content

Commit

Permalink
Accommodate altered semantics of cudf::lists::contains()
Browse files Browse the repository at this point in the history
  • Loading branch information
mythrocks committed Dec 14, 2021
1 parent 1588f6a commit cf03bf9
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 8 deletions.
14 changes: 10 additions & 4 deletions integration_tests/src/main/python/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,15 @@ def test_orderby_array_of_structs(data_gen):
def test_array_contains(data_gen):
arr_gen = ArrayGen(data_gen)
lit = gen_scalar(data_gen, force_no_nulls=True)
assert_gpu_and_cpu_are_equal_collect(lambda spark: two_col_df(
spark, arr_gen, data_gen).select(array_contains(col('a'), lit.cast(data_gen.data_type)),
array_contains(col('a'), col('b')),
array_contains(col('a'), col('a')[5])), no_nans_conf)

def get_input(spark):
return two_col_df(spark, arr_gen, data_gen)

assert_gpu_and_cpu_are_equal_collect(lambda spark: get_input(spark).select(
array_contains(col('a'), lit.cast(data_gen.data_type)),
array_contains(col('a'), col('b')),
array_contains(col('a'), col('a')[5])
), no_nans_conf)


# Test array_contains() with a literal key that is extracted from the input array of doubles
Expand All @@ -118,6 +123,7 @@ def main_df(spark):
return df.select(array_contains(col('a'), chk_val))
assert_gpu_and_cpu_are_equal_collect(main_df)


@pytest.mark.skipif(is_before_spark_311(), reason="Only in Spark 3.1.1 + ANSI mode, array index throws on out of range indexes")
@pytest.mark.parametrize('data_gen', array_gens_sample, ids=idfn)
def test_get_array_item_ansi_fail(data_gen):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,17 +223,43 @@ case class GpuArrayContains(left: Expression, right: Expression)
left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull
}

override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector =
lhs.getBase.listContains(rhs.getBase)
override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = {
val contains = lhs.getBase.listContains(rhs.getBase)
withResource(contains) { containsCV =>
val containsNull = lhs.getBase.listContainsNulls()
withResource(containsNull) { containsNullCV =>
val notContainsNull = containsNullCV.not
withResource(notContainsNull) { notContainsNullCV =>
val validity = containsCV.or(notContainsNullCV)
withResource(validity) { validityCV =>
containsCV.copyWithValidity(validityCV)
}
}
}
}
}

override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector =
throw new IllegalStateException("This is not supported yet")

override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector =
throw new IllegalStateException("This is not supported yet")

override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector =
lhs.getBase.listContainsColumn(rhs.getBase)
override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = {
val contains = lhs.getBase.listContainsColumn(rhs.getBase)
withResource(contains) { containsCV =>
val containsNull = lhs.getBase.listContainsNulls()
withResource(containsNull) { containsNullCV =>
val notContainsNull = containsNullCV.not
withResource(notContainsNull) { notContainsNullCV =>
val validity = containsCV.or(notContainsNullCV)
withResource(validity) { validityCV =>
containsCV.copyWithValidity(validityCV)
}
}
}
}
}

override def prettyName: String = "array_contains"
}

0 comments on commit cf03bf9

Please sign in to comment.