From c6b09c0b71e772de0605a555df9e78cbc4439ed6 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 8 Oct 2024 09:37:17 +0900 Subject: [PATCH] [SPARK-49894][PYTHON][CONNECT] Refine the string representation of column field operations ### What changes were proposed in this pull request? Refine the string representation of column field operations: `GetField`, `WithField`, and `DropFields` ### Why are the changes needed? make the string representations consistent between pyspark classic and connect ### Does this PR introduce _any_ user-facing change? yes before ``` In [1]: from pyspark.sql import functions as sf In [2]: c = sf.col("c") In [3]: c.x Out[3]: Column<'UnresolvedExtractValue(c, x)'> ``` after ``` In [1]: from pyspark.sql import functions as sf In [2]: c = sf.col("c") In [3]: c.x Out[3]: Column<'c['x']'> ``` ### How was this patch tested? added ut ### Was this patch authored or co-authored using generative AI tooling? no Closes #48369 from zhengruifeng/py_connect_col_str. Lead-authored-by: Ruifeng Zheng Co-authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/expressions.py | 6 +- python/pyspark/sql/tests/test_column.py | 71 +++++++++++++++++++++++ 2 files changed, 74 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 0b5512b61925c..85f1b3565c696 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -809,7 +809,7 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression: return expr def __repr__(self) -> str: - return f"WithField({self._structExpr}, {self._fieldName}, {self._valueExpr})" + return f"update_field({self._structExpr}, {self._fieldName}, {self._valueExpr})" class DropField(Expression): @@ -833,7 +833,7 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression: return expr def __repr__(self) -> str: - return f"DropField({self._structExpr}, {self._fieldName})" + return f"drop_field({self._structExpr}, {self._fieldName})" class UnresolvedExtractValue(Expression): @@ -857,7 +857,7 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression: return expr def __repr__(self) -> str: - return f"UnresolvedExtractValue({str(self._child)}, {str(self._extraction)})" + return f"{self._child}['{self._extraction}']" class UnresolvedRegex(Expression): diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index 1972dd2804d98..5f1991973d27d 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -283,6 +283,77 @@ def test_expr_str_representation(self): when_cond = sf.when(expression, sf.lit(None)) self.assertEqual(str(when_cond), "Column<'CASE WHEN foo THEN NULL END'>") + def test_col_field_ops_representation(self): + # SPARK-49894: Test string representation of columns + c = sf.col("c") + + # getField + self.assertEqual(str(c.x), "Column<'c['x']'>") + self.assertEqual(str(c.x.y), "Column<'c['x']['y']'>") + self.assertEqual(str(c.x.y.z), "Column<'c['x']['y']['z']'>") + + self.assertEqual(str(c["x"]), "Column<'c['x']'>") + self.assertEqual(str(c["x"]["y"]), "Column<'c['x']['y']'>") + self.assertEqual(str(c["x"]["y"]["z"]), "Column<'c['x']['y']['z']'>") + + self.assertEqual(str(c.getField("x")), "Column<'c['x']'>") + self.assertEqual( + str(c.getField("x").getField("y")), + "Column<'c['x']['y']'>", + ) + self.assertEqual( + str(c.getField("x").getField("y").getField("z")), + "Column<'c['x']['y']['z']'>", + ) + + self.assertEqual(str(c.getItem("x")), "Column<'c['x']'>") + self.assertEqual( + str(c.getItem("x").getItem("y")), + "Column<'c['x']['y']'>", + ) + self.assertEqual( + str(c.getItem("x").getItem("y").getItem("z")), + "Column<'c['x']['y']['z']'>", + ) + + self.assertEqual( + str(c.x["y"].getItem("z")), + "Column<'c['x']['y']['z']'>", + ) + self.assertEqual( + str(c["x"].getField("y").getItem("z")), + "Column<'c['x']['y']['z']'>", + ) + self.assertEqual( + str(c.getField("x").getItem("y").z), + "Column<'c['x']['y']['z']'>", + ) + self.assertEqual( + str(c["x"].y.getField("z")), + "Column<'c['x']['y']['z']'>", + ) + + # WithField + self.assertEqual( + str(c.withField("x", sf.col("y"))), + "Column<'update_field(c, x, y)'>", + ) + self.assertEqual( + str(c.withField("x", sf.col("y")).withField("x", sf.col("z"))), + "Column<'update_field(update_field(c, x, y), x, z)'>", + ) + + # DropFields + self.assertEqual(str(c.dropFields("x")), "Column<'drop_field(c, x)'>") + self.assertEqual( + str(c.dropFields("x", "y")), + "Column<'drop_field(drop_field(c, x), y)'>", + ) + self.assertEqual( + str(c.dropFields("x", "y", "z")), + "Column<'drop_field(drop_field(drop_field(c, x), y), z)'>", + ) + def test_lit_time_representation(self): dt = datetime.date(2021, 3, 4) self.assertEqual(str(sf.lit(dt)), "Column<'2021-03-04'>")