Skip to content

Commit

Permalink
[SPARK-50435][PYTHON][TESTS] Use assertDataFrameEqual in pyspark.sql.…
Browse files Browse the repository at this point in the history
…tests.test_functions

### What changes were proposed in this pull request?
Use `assertDataFrameEqual` in pyspark.sql.tests.test_functions

### Why are the changes needed?
`assertDataFrameEqual` is explicitly built to handle DataFrame-specific comparisons, including schema.

So we propose to replace `assertEqual` with `assertDataFrameEqual`

Part of https://issues.apache.org/jira/browse/SPARK-50435.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#49011 from xinrong-meng/impr_test_functions.

Lead-authored-by: Xinrong Meng <xinrong@apache.org>
Co-authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Xinrong Meng <xinrong@apache.org>
  • Loading branch information
xinrong-meng and HyukjinKwon committed Dec 2, 2024
1 parent dc73342 commit e7071c0
Showing 1 changed file with 92 additions and 104 deletions.
196 changes: 92 additions & 104 deletions python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pyspark.sql.column import Column
from pyspark.sql.functions.builtin import nullifzero, randstr, uniform, zeroifnull
from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils
from pyspark.testing.utils import have_numpy
from pyspark.testing.utils import have_numpy, assertDataFrameEqual


class FunctionsTestsMixin:
Expand Down Expand Up @@ -344,29 +344,29 @@ def test_try_parse_url(self):
[("https://spark.apache.org/path?query=1", "QUERY", "query")],
["url", "part", "key"],
)
actual = df.select(F.try_parse_url(df.url, df.part, df.key)).collect()
self.assertEqual(actual, [Row("1")])
actual = df.select(F.try_parse_url(df.url, df.part, df.key))
assertDataFrameEqual(actual, [Row("1")])
df = self.spark.createDataFrame(
[("inva lid://spark.apache.org/path?query=1", "QUERY", "query")],
["url", "part", "key"],
)
actual = df.select(F.try_parse_url(df.url, df.part, df.key)).collect()
self.assertEqual(actual, [Row(None)])
actual = df.select(F.try_parse_url(df.url, df.part, df.key))
assertDataFrameEqual(actual, [Row(None)])

def test_try_make_timestamp(self):
data = [(2024, 5, 22, 10, 30, 0)]
df = self.spark.createDataFrame(data, ["year", "month", "day", "hour", "minute", "second"])
actual = df.select(
F.try_make_timestamp(df.year, df.month, df.day, df.hour, df.minute, df.second)
).collect()
self.assertEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 30))])
)
assertDataFrameEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 30))])

data = [(2024, 13, 22, 10, 30, 0)]
df = self.spark.createDataFrame(data, ["year", "month", "day", "hour", "minute", "second"])
actual = df.select(
F.try_make_timestamp(df.year, df.month, df.day, df.hour, df.minute, df.second)
).collect()
self.assertEqual(actual, [Row(None)])
)
assertDataFrameEqual(actual, [Row(None)])

def test_try_make_timestamp_ltz(self):
# use local timezone here to avoid flakiness
Expand All @@ -378,8 +378,8 @@ def test_try_make_timestamp_ltz(self):
F.try_make_timestamp_ltz(
df.year, df.month, df.day, df.hour, df.minute, df.second, df.timezone
)
).collect()
self.assertEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 30, 0))])
)
assertDataFrameEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 30, 0))])

# use local timezone here to avoid flakiness
data = [(2024, 13, 22, 10, 30, 0, datetime.datetime.now().astimezone().tzinfo.__str__())]
Expand All @@ -390,23 +390,23 @@ def test_try_make_timestamp_ltz(self):
F.try_make_timestamp_ltz(
df.year, df.month, df.day, df.hour, df.minute, df.second, df.timezone
)
).collect()
self.assertEqual(actual, [Row(None)])
)
assertDataFrameEqual(actual, [Row(None)])

def test_try_make_timestamp_ntz(self):
data = [(2024, 5, 22, 10, 30, 0)]
df = self.spark.createDataFrame(data, ["year", "month", "day", "hour", "minute", "second"])
actual = df.select(
F.try_make_timestamp_ntz(df.year, df.month, df.day, df.hour, df.minute, df.second)
).collect()
self.assertEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 30))])
)
assertDataFrameEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 30))])

data = [(2024, 13, 22, 10, 30, 0)]
df = self.spark.createDataFrame(data, ["year", "month", "day", "hour", "minute", "second"])
actual = df.select(
F.try_make_timestamp_ntz(df.year, df.month, df.day, df.hour, df.minute, df.second)
).collect()
self.assertEqual(actual, [Row(None)])
)
assertDataFrameEqual(actual, [Row(None)])

def test_string_functions(self):
string_functions = [
Expand Down Expand Up @@ -448,51 +448,51 @@ def test_string_functions(self):
)

for name in string_functions:
self.assertEqual(
df.select(getattr(F, name)("name")).first()[0],
df.select(getattr(F, name)(F.col("name"))).first()[0],
assertDataFrameEqual(
df.select(getattr(F, name)("name")),
df.select(getattr(F, name)(F.col("name"))),
)

def test_collation(self):
df = self.spark.createDataFrame([("a",), ("b",)], ["name"])
actual = df.select(F.collation(F.collate("name", "UNICODE"))).distinct().collect()
self.assertEqual([Row("SYSTEM.BUILTIN.UNICODE")], actual)
actual = df.select(F.collation(F.collate("name", "UNICODE"))).distinct()
assertDataFrameEqual([Row("SYSTEM.BUILTIN.UNICODE")], actual)

def test_try_make_interval(self):
df = self.spark.createDataFrame([(2147483647,)], ["num"])
actual = df.select(F.isnull(F.try_make_interval("num"))).collect()
self.assertEqual([Row(True)], actual)
actual = df.select(F.isnull(F.try_make_interval("num")))
assertDataFrameEqual([Row(True)], actual)

def test_octet_length_function(self):
# SPARK-36751: add octet length api for python
df = self.spark.createDataFrame([("cat",), ("\U0001F408",)], ["cat"])
actual = df.select(F.octet_length("cat")).collect()
self.assertEqual([Row(3), Row(4)], actual)
actual = df.select(F.octet_length("cat"))
assertDataFrameEqual([Row(3), Row(4)], actual)

def test_bit_length_function(self):
# SPARK-36751: add bit length api for python
df = self.spark.createDataFrame([("cat",), ("\U0001F408",)], ["cat"])
actual = df.select(F.bit_length("cat")).collect()
self.assertEqual([Row(24), Row(32)], actual)
actual = df.select(F.bit_length("cat"))
assertDataFrameEqual([Row(24), Row(32)], actual)

def test_array_contains_function(self):
df = self.spark.createDataFrame([(["1", "2", "3"],), ([],)], ["data"])
actual = df.select(F.array_contains(df.data, "1").alias("b")).collect()
self.assertEqual([Row(b=True), Row(b=False)], actual)
actual = df.select(F.array_contains(df.data, "1").alias("b"))
assertDataFrameEqual([Row(b=True), Row(b=False)], actual)

def test_levenshtein_function(self):
df = self.spark.createDataFrame([("kitten", "sitting")], ["l", "r"])
actual_without_threshold = df.select(F.levenshtein(df.l, df.r).alias("b")).collect()
self.assertEqual([Row(b=3)], actual_without_threshold)
actual_with_threshold = df.select(F.levenshtein(df.l, df.r, 2).alias("b")).collect()
self.assertEqual([Row(b=-1)], actual_with_threshold)
actual_without_threshold = df.select(F.levenshtein(df.l, df.r).alias("b"))
assertDataFrameEqual([Row(b=3)], actual_without_threshold)
actual_with_threshold = df.select(F.levenshtein(df.l, df.r, 2).alias("b"))
assertDataFrameEqual([Row(b=-1)], actual_with_threshold)

def test_between_function(self):
df = self.spark.createDataFrame(
[Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)]
)
self.assertEqual(
[Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)], df.filter(df.a.between(df.b, df.c)).collect()
assertDataFrameEqual(
[Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)], df.filter(df.a.between(df.b, df.c))
)

def test_dayofweek(self):
Expand Down Expand Up @@ -608,7 +608,7 @@ def test_first_last_ignorenulls(self):
F.last(df2.id, False).alias("c"),
F.last(df2.id, True).alias("d"),
)
self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect())
assertDataFrameEqual([Row(a=None, b=1, c=None, d=98)], df3)

def test_approxQuantile(self):
df = self.spark.createDataFrame([Row(a=i, b=i + 10) for i in range(10)])
Expand Down Expand Up @@ -666,20 +666,20 @@ def test_sort_with_nulls_order(self):
df = self.spark.createDataFrame(
[("Tom", 80), (None, 60), ("Alice", 50)], ["name", "height"]
)
self.assertEqual(
df.select(df.name).orderBy(F.asc_nulls_first("name")).collect(),
assertDataFrameEqual(
df.select(df.name).orderBy(F.asc_nulls_first("name")),
[Row(name=None), Row(name="Alice"), Row(name="Tom")],
)
self.assertEqual(
df.select(df.name).orderBy(F.asc_nulls_last("name")).collect(),
assertDataFrameEqual(
df.select(df.name).orderBy(F.asc_nulls_last("name")),
[Row(name="Alice"), Row(name="Tom"), Row(name=None)],
)
self.assertEqual(
df.select(df.name).orderBy(F.desc_nulls_first("name")).collect(),
assertDataFrameEqual(
df.select(df.name).orderBy(F.desc_nulls_first("name")),
[Row(name=None), Row(name="Tom"), Row(name="Alice")],
)
self.assertEqual(
df.select(df.name).orderBy(F.desc_nulls_last("name")).collect(),
assertDataFrameEqual(
df.select(df.name).orderBy(F.desc_nulls_last("name")),
[Row(name="Tom"), Row(name="Alice"), Row(name=None)],
)

Expand Down Expand Up @@ -716,20 +716,16 @@ def test_slice(self):
)

expected = [Row(sliced=[2, 3]), Row(sliced=[5])]
self.assertEqual(df.select(F.slice(df.x, 2, 2).alias("sliced")).collect(), expected)
self.assertEqual(
df.select(F.slice(df.x, F.lit(2), F.lit(2)).alias("sliced")).collect(), expected
)
self.assertEqual(
df.select(F.slice("x", "index", "len").alias("sliced")).collect(), expected
)
assertDataFrameEqual(df.select(F.slice(df.x, 2, 2).alias("sliced")), expected)
assertDataFrameEqual(df.select(F.slice(df.x, F.lit(2), F.lit(2)).alias("sliced")), expected)
assertDataFrameEqual(df.select(F.slice("x", "index", "len").alias("sliced")), expected)

self.assertEqual(
df.select(F.slice(df.x, F.size(df.x) - 1, F.lit(1)).alias("sliced")).collect(),
assertDataFrameEqual(
df.select(F.slice(df.x, F.size(df.x) - 1, F.lit(1)).alias("sliced")),
[Row(sliced=[2]), Row(sliced=[4])],
)
self.assertEqual(
df.select(F.slice(df.x, F.lit(1), F.size(df.x) - 1).alias("sliced")).collect(),
assertDataFrameEqual(
df.select(F.slice(df.x, F.lit(1), F.size(df.x) - 1).alias("sliced")),
[Row(sliced=[1, 2]), Row(sliced=[4])],
)

Expand All @@ -738,11 +734,9 @@ def test_array_repeat(self):
df = df.withColumn("repeat_n", F.lit(3))

expected = [Row(val=[0, 0, 0])]
self.assertEqual(df.select(F.array_repeat("id", 3).alias("val")).collect(), expected)
self.assertEqual(df.select(F.array_repeat("id", F.lit(3)).alias("val")).collect(), expected)
self.assertEqual(
df.select(F.array_repeat("id", "repeat_n").alias("val")).collect(), expected
)
assertDataFrameEqual(df.select(F.array_repeat("id", 3).alias("val")), expected)
assertDataFrameEqual(df.select(F.array_repeat("id", F.lit(3)).alias("val")), expected)
assertDataFrameEqual(df.select(F.array_repeat("id", "repeat_n").alias("val")), expected)

def test_input_file_name_udf(self):
df = self.spark.read.text("python/test_support/hello/hello.txt")
Expand All @@ -754,11 +748,11 @@ def test_least(self):
df = self.spark.createDataFrame([(1, 4, 3)], ["a", "b", "c"])

expected = [Row(least=1)]
self.assertEqual(df.select(F.least(df.a, df.b, df.c).alias("least")).collect(), expected)
self.assertEqual(
df.select(F.least(F.lit(3), F.lit(5), F.lit(1)).alias("least")).collect(), expected
assertDataFrameEqual(df.select(F.least(df.a, df.b, df.c).alias("least")), expected)
assertDataFrameEqual(
df.select(F.least(F.lit(3), F.lit(5), F.lit(1)).alias("least")), expected
)
self.assertEqual(df.select(F.least("a", "b", "c").alias("least")).collect(), expected)
assertDataFrameEqual(df.select(F.least("a", "b", "c").alias("least")), expected)

with self.assertRaises(PySparkValueError) as pe:
df.select(F.least(df.a).alias("least")).collect()
Expand Down Expand Up @@ -800,11 +794,9 @@ def test_overlay(self):
df = self.spark.createDataFrame([("SPARK_SQL", "CORE", 7, 0)], ("x", "y", "pos", "len"))

exp = [Row(ol="SPARK_CORESQL")]
self.assertEqual(df.select(F.overlay(df.x, df.y, 7, 0).alias("ol")).collect(), exp)
self.assertEqual(
df.select(F.overlay(df.x, df.y, F.lit(7), F.lit(0)).alias("ol")).collect(), exp
)
self.assertEqual(df.select(F.overlay("x", "y", "pos", "len").alias("ol")).collect(), exp)
assertDataFrameEqual(df.select(F.overlay(df.x, df.y, 7, 0).alias("ol")), exp)
assertDataFrameEqual(df.select(F.overlay(df.x, df.y, F.lit(7), F.lit(0)).alias("ol")), exp)
assertDataFrameEqual(df.select(F.overlay("x", "y", "pos", "len").alias("ol")), exp)

with self.assertRaises(PySparkTypeError) as pe:
df.select(F.overlay(df.x, df.y, 7.5, 0).alias("ol")).collect()
Expand Down Expand Up @@ -1164,8 +1156,8 @@ def test_assert_true(self):
def check_assert_true(self, tpe):
df = self.spark.range(3)

self.assertEqual(
df.select(F.assert_true(df.id < 3)).toDF("val").collect(),
assertDataFrameEqual(
df.select(F.assert_true(df.id < 3)).toDF("val"),
[Row(val=None), Row(val=None), Row(val=None)],
)

Expand Down Expand Up @@ -1302,17 +1294,17 @@ def test_np_scalar_input(self):

df = self.spark.createDataFrame([([1, 2, 3],), ([],)], ["data"])
for dtype in [np.int8, np.int16, np.int32, np.int64]:
res = df.select(F.array_contains(df.data, dtype(1)).alias("b")).collect()
self.assertEqual([Row(b=True), Row(b=False)], res)
res = df.select(F.array_position(df.data, dtype(1)).alias("c")).collect()
self.assertEqual([Row(c=1), Row(c=0)], res)
res = df.select(F.array_contains(df.data, dtype(1)).alias("b"))
assertDataFrameEqual([Row(b=True), Row(b=False)], res)
res = df.select(F.array_position(df.data, dtype(1)).alias("c"))
assertDataFrameEqual([Row(c=1), Row(c=0)], res)

df = self.spark.createDataFrame([([1.0, 2.0, 3.0],), ([],)], ["data"])
for dtype in [np.float32, np.float64]:
res = df.select(F.array_contains(df.data, dtype(1)).alias("b")).collect()
self.assertEqual([Row(b=True), Row(b=False)], res)
res = df.select(F.array_position(df.data, dtype(1)).alias("c")).collect()
self.assertEqual([Row(c=1), Row(c=0)], res)
res = df.select(F.array_contains(df.data, dtype(1)).alias("b"))
assertDataFrameEqual([Row(b=True), Row(b=False)], res)
res = df.select(F.array_position(df.data, dtype(1)).alias("c"))
assertDataFrameEqual([Row(c=1), Row(c=0)], res)

@unittest.skipIf(not have_numpy, "NumPy not installed")
def test_ndarray_input(self):
Expand Down Expand Up @@ -1729,46 +1721,42 @@ class IntEnum(Enum):

def test_nullifzero_zeroifnull(self):
df = self.spark.createDataFrame([(0,), (1,)], ["a"])
result = df.select(nullifzero(df.a).alias("r")).collect()
self.assertEqual([Row(r=None), Row(r=1)], result)
result = df.select(nullifzero(df.a).alias("r"))
assertDataFrameEqual([Row(r=None), Row(r=1)], result)

df = self.spark.createDataFrame([(None,), (1,)], ["a"])
result = df.select(zeroifnull(df.a).alias("r")).collect()
self.assertEqual([Row(r=0), Row(r=1)], result)
result = df.select(zeroifnull(df.a).alias("r"))
assertDataFrameEqual([Row(r=0), Row(r=1)], result)

def test_randstr_uniform(self):
df = self.spark.createDataFrame([(0,)], ["a"])
result = df.select(randstr(F.lit(5), F.lit(0)).alias("x")).selectExpr("length(x)").collect()
self.assertEqual([Row(5)], result)
result = df.select(randstr(F.lit(5), F.lit(0)).alias("x")).selectExpr("length(x)")
assertDataFrameEqual([Row(5)], result)
# The random seed is optional.
result = df.select(randstr(F.lit(5)).alias("x")).selectExpr("length(x)").collect()
self.assertEqual([Row(5)], result)
result = df.select(randstr(F.lit(5)).alias("x")).selectExpr("length(x)")
assertDataFrameEqual([Row(5)], result)

df = self.spark.createDataFrame([(0,)], ["a"])
result = (
df.select(uniform(F.lit(10), F.lit(20), F.lit(0)).alias("x"))
.selectExpr("x > 5")
.collect()
)
self.assertEqual([Row(True)], result)
result = df.select(uniform(F.lit(10), F.lit(20), F.lit(0)).alias("x")).selectExpr("x > 5")
assertDataFrameEqual([Row(True)], result)
# The random seed is optional.
result = df.select(uniform(F.lit(10), F.lit(20)).alias("x")).selectExpr("x > 5").collect()
self.assertEqual([Row(True)], result)
result = df.select(uniform(F.lit(10), F.lit(20)).alias("x")).selectExpr("x > 5")
assertDataFrameEqual([Row(True)], result)

def test_string_validation(self):
df = self.spark.createDataFrame([("abc",)], ["a"])
# test is_valid_utf8
result_is_valid_utf8 = df.select(F.is_valid_utf8(df.a).alias("r")).collect()
self.assertEqual([Row(r=True)], result_is_valid_utf8)
result_is_valid_utf8 = df.select(F.is_valid_utf8(df.a).alias("r"))
assertDataFrameEqual([Row(r=True)], result_is_valid_utf8)
# test make_valid_utf8
result_make_valid_utf8 = df.select(F.make_valid_utf8(df.a).alias("r")).collect()
self.assertEqual([Row(r="abc")], result_make_valid_utf8)
result_make_valid_utf8 = df.select(F.make_valid_utf8(df.a).alias("r"))
assertDataFrameEqual([Row(r="abc")], result_make_valid_utf8)
# test validate_utf8
result_validate_utf8 = df.select(F.validate_utf8(df.a).alias("r")).collect()
self.assertEqual([Row(r="abc")], result_validate_utf8)
result_validate_utf8 = df.select(F.validate_utf8(df.a).alias("r"))
assertDataFrameEqual([Row(r="abc")], result_validate_utf8)
# test try_validate_utf8
result_try_validate_utf8 = df.select(F.try_validate_utf8(df.a).alias("r")).collect()
self.assertEqual([Row(r="abc")], result_try_validate_utf8)
result_try_validate_utf8 = df.select(F.try_validate_utf8(df.a).alias("r"))
assertDataFrameEqual([Row(r="abc")], result_try_validate_utf8)


class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin):
Expand Down

0 comments on commit e7071c0

Please sign in to comment.