diff --git a/python/pyspark/ml/tests/connect/test_connect_classification.py b/python/pyspark/ml/tests/connect/test_connect_classification.py index 8083090523a0e..910d2d2ec42f9 100644 --- a/python/pyspark/ml/tests/connect/test_connect_classification.py +++ b/python/pyspark/ml/tests/connect/test_connect_classification.py @@ -21,6 +21,7 @@ from pyspark.util import is_remote_only from pyspark.sql import SparkSession +from pyspark.ml.tests.connect.test_legacy_mode_classification import ClassificationTestsMixin from pyspark.testing.connectutils import should_test_connect, connect_requirement_message torch_requirement_message = "torch is required" @@ -30,9 +31,6 @@ except ImportError: have_torch = False -if should_test_connect: - from pyspark.ml.tests.connect.test_legacy_mode_classification import ClassificationTestsMixin - @unittest.skipIf( not should_test_connect or not have_torch or is_remote_only(), diff --git a/python/pyspark/sql/connect/functions/__init__.py b/python/pyspark/sql/connect/functions/__init__.py index e0179d4d56cf8..087a51e8616b9 100644 --- a/python/pyspark/sql/connect/functions/__init__.py +++ b/python/pyspark/sql/connect/functions/__init__.py @@ -16,6 +16,8 @@ # """PySpark Functions with Spark Connect""" +from pyspark.testing import should_test_connect -from pyspark.sql.connect.functions.builtin import * # noqa: F401,F403 -from pyspark.sql.connect.functions import partitioning # noqa: F401,F403 +if should_test_connect: + from pyspark.sql.connect.functions.builtin import * # noqa: F401,F403 + from pyspark.sql.connect.functions import partitioning # noqa: F401,F403 diff --git a/python/pyspark/sql/connect/merge.py b/python/pyspark/sql/connect/merge.py index 9c3b3e4370a40..295e6089e092e 100644 --- a/python/pyspark/sql/connect/merge.py +++ b/python/pyspark/sql/connect/merge.py @@ -14,6 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from pyspark.sql.connect.utils import check_dependencies + +check_dependencies(__name__) import sys from typing import Dict, Optional, TYPE_CHECKING, List, Callable @@ -235,12 +238,12 @@ def _test() -> None: globs = pyspark.sql.connect.merge.__dict__.copy() globs["spark"] = ( - PySparkSession.builder.appName("sql.connect.dataframe tests") + PySparkSession.builder.appName("sql.connect.merge tests") .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) (failure_count, test_count) = doctest.testmod( - pyspark.sql.merge, + pyspark.sql.connect.merge, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF, ) diff --git a/python/pyspark/sql/connect/observation.py b/python/pyspark/sql/connect/observation.py index e4b9b8a2d4fba..bfb8a0a9355fe 100644 --- a/python/pyspark/sql/connect/observation.py +++ b/python/pyspark/sql/connect/observation.py @@ -14,6 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from pyspark.sql.connect.utils import check_dependencies + +check_dependencies(__name__) + from typing import Any, Dict, Optional import uuid diff --git a/python/pyspark/sql/connect/utils.py b/python/pyspark/sql/connect/utils.py index ce57b490c4532..a2511836816c9 100644 --- a/python/pyspark/sql/connect/utils.py +++ b/python/pyspark/sql/connect/utils.py @@ -22,7 +22,7 @@ def check_dependencies(mod_name: str) -> None: - if mod_name == "__main__": + if mod_name == "__main__" or mod_name == "pyspark.sql.connect.utils": from pyspark.testing.connectutils import should_test_connect, connect_requirement_message if not should_test_connect: diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py b/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py index d79bfef2426a4..9d28ec0e19702 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py @@ -18,9 +18,11 @@ import unittest from pyspark.sql.tests.streaming.test_streaming_foreach_batch import StreamingTestsForeachBatchMixin -from pyspark.testing.connectutils import ReusedConnectTestCase +from pyspark.testing.connectutils import ReusedConnectTestCase, should_test_connect from pyspark.errors import PySparkPicklingError -from pyspark.errors.exceptions.connect import SparkConnectGrpcException + +if should_test_connect: + from pyspark.errors.exceptions.connect import SparkConnectGrpcException class StreamingForeachBatchParityTests(StreamingTestsForeachBatchMixin, ReusedConnectTestCase): diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index 509f381f97fec..60ddcb6f22a54 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -40,7 +40,6 @@ BooleanType, ) from pyspark.errors import PySparkTypeError, PySparkValueError -from pyspark.errors.exceptions.connect import SparkConnectException from pyspark.testing.connectutils import should_test_connect from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase @@ -61,6 +60,7 @@ JVM_LONG_MIN, JVM_LONG_MAX, ) + from pyspark.errors.exceptions.connect import SparkConnectException class SparkConnectColumnTests(SparkConnectSQLTestCase): diff --git a/python/pyspark/sql/tests/connect/test_connect_creation.py b/python/pyspark/sql/tests/connect/test_connect_creation.py index cf6c2e86d2f5b..5352913f6609d 100644 --- a/python/pyspark/sql/tests/connect/test_connect_creation.py +++ b/python/pyspark/sql/tests/connect/test_connect_creation.py @@ -35,7 +35,6 @@ from pyspark.testing.sqlutils import MyObject, PythonOnlyUDT from pyspark.testing.connectutils import should_test_connect -from pyspark.errors.exceptions.connect import ParseException from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase if should_test_connect: @@ -43,6 +42,7 @@ import numpy as np from pyspark.sql import functions as SF from pyspark.sql.connect import functions as CF + from pyspark.errors.exceptions.connect import ParseException class SparkConnectCreationTests(SparkConnectSQLTestCase): diff --git a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py index c712e5d6efcb6..1a8c7190e31a6 100644 --- a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py +++ b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py @@ -19,11 +19,9 @@ from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, DoubleType from pyspark.sql.utils import is_remote - from pyspark.sql import functions as SF -from pyspark.sql.connect import functions as CF - from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase +from pyspark.testing.connectutils import should_test_connect from pyspark.testing.sqlutils import ( have_pandas, have_pyarrow, @@ -38,6 +36,9 @@ if have_pandas: import pandas as pd +if should_test_connect: + from pyspark.sql.connect import functions as CF + class SparkConnectDataFramePropertyTests(SparkConnectSQLTestCase): def test_cached_property_is_copied(self): diff --git a/python/pyspark/sql/tests/connect/test_connect_error.py b/python/pyspark/sql/tests/connect/test_connect_error.py index 685e95a69ee74..01047741f6740 100644 --- a/python/pyspark/sql/tests/connect/test_connect_error.py +++ b/python/pyspark/sql/tests/connect/test_connect_error.py @@ -22,13 +22,13 @@ from pyspark.sql.types import Row from pyspark.testing.connectutils import should_test_connect from pyspark.errors import PySparkTypeError -from pyspark.errors.exceptions.connect import AnalysisException from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase if should_test_connect: from pyspark.sql.connect.session import SparkSession as RemoteSparkSession from pyspark.sql.connect import functions as CF from pyspark.sql.connect.column import Column + from pyspark.errors.exceptions.connect import AnalysisException class SparkConnectErrorTests(SparkConnectSQLTestCase): diff --git a/python/pyspark/sql/tests/connect/test_connect_session.py b/python/pyspark/sql/tests/connect/test_connect_session.py index 4ddefc7385839..0028ecb95830d 100644 --- a/python/pyspark/sql/tests/connect/test_connect_session.py +++ b/python/pyspark/sql/tests/connect/test_connect_session.py @@ -27,24 +27,24 @@ RetriesExceeded, ) from pyspark.sql import SparkSession as PySparkSession -from pyspark.sql.connect.client.retries import RetryPolicy from pyspark.testing.connectutils import ( should_test_connect, ReusedConnectTestCase, connect_requirement_message, ) -from pyspark.errors.exceptions.connect import ( - AnalysisException, - SparkConnectException, - SparkUpgradeException, -) if should_test_connect: import grpc from pyspark.sql.connect.session import SparkSession as RemoteSparkSession from pyspark.sql.connect.client import DefaultChannelBuilder, ChannelBuilder from pyspark.sql.connect.client.core import Retrying, SparkConnectClient + from pyspark.sql.connect.client.retries import RetryPolicy + from pyspark.errors.exceptions.connect import ( + AnalysisException, + SparkConnectException, + SparkUpgradeException, + ) @unittest.skipIf(is_remote_only(), "Session creation different from local mode") @@ -282,6 +282,7 @@ def test_stop_invalid_session(self): # SPARK-47986 session.stop() +@unittest.skipIf(not should_test_connect, connect_requirement_message) class SparkConnectSessionWithOptionsTest(unittest.TestCase): def setUp(self) -> None: self.spark = ( @@ -303,31 +304,31 @@ def test_config(self): self.assertEqual(self.spark.conf.get("integer"), "1") -class TestError(grpc.RpcError, Exception): - def __init__(self, code: grpc.StatusCode): - self._code = code - - def code(self): - return self._code +if should_test_connect: + class TestError(grpc.RpcError, Exception): + def __init__(self, code: grpc.StatusCode): + self._code = code -class TestPolicy(RetryPolicy): - # Put a small value for initial backoff so that tests don't spend - # Time waiting - def __init__(self, initial_backoff=10, **kwargs): - super().__init__(initial_backoff=initial_backoff, **kwargs) + def code(self): + return self._code - def can_retry(self, exception: BaseException): - return isinstance(exception, TestError) + class TestPolicy(RetryPolicy): + # Put a small value for initial backoff so that tests don't spend + # Time waiting + def __init__(self, initial_backoff=10, **kwargs): + super().__init__(initial_backoff=initial_backoff, **kwargs) + def can_retry(self, exception: BaseException): + return isinstance(exception, TestError) -class TestPolicySpecificError(TestPolicy): - def __init__(self, specific_code: grpc.StatusCode, **kwargs): - super().__init__(**kwargs) - self.specific_code = specific_code + class TestPolicySpecificError(TestPolicy): + def __init__(self, specific_code: grpc.StatusCode, **kwargs): + super().__init__(**kwargs) + self.specific_code = specific_code - def can_retry(self, exception: BaseException): - return exception.code() == self.specific_code + def can_retry(self, exception: BaseException): + return exception.code() == self.specific_code @unittest.skipIf(not should_test_connect, connect_requirement_message) diff --git a/python/pyspark/sql/tests/connect/test_connect_stat.py b/python/pyspark/sql/tests/connect/test_connect_stat.py index a2f23b44023d3..6e3cc2f58d814 100644 --- a/python/pyspark/sql/tests/connect/test_connect_stat.py +++ b/python/pyspark/sql/tests/connect/test_connect_stat.py @@ -19,15 +19,15 @@ from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.testing.connectutils import should_test_connect -from pyspark.errors.exceptions.connect import ( - AnalysisException, - SparkConnectException, -) from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase if should_test_connect: from pyspark.sql import functions as SF from pyspark.sql.connect import functions as CF + from pyspark.errors.exceptions.connect import ( + AnalysisException, + SparkConnectException, + ) class SparkConnectStatTests(SparkConnectSQLTestCase): diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py b/python/pyspark/sql/tests/connect/test_parity_udtf.py index 2ea6ef8cc389d..6955e7377b4c4 100644 --- a/python/pyspark/sql/tests/connect/test_parity_udtf.py +++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py @@ -17,6 +17,8 @@ import unittest from pyspark.testing.connectutils import should_test_connect +from pyspark.sql.tests.test_udtf import BaseUDTFTestsMixin, UDTFArrowTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase if should_test_connect: from pyspark import sql @@ -24,10 +26,7 @@ sql.udtf.UserDefinedTableFunction = UserDefinedTableFunction from pyspark.sql.connect.functions import lit, udtf - -from pyspark.sql.tests.test_udtf import BaseUDTFTestsMixin, UDTFArrowTestsMixin -from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.errors.exceptions.connect import SparkConnectGrpcException, PythonException + from pyspark.errors.exceptions.connect import SparkConnectGrpcException, PythonException class UDTFParityTests(BaseUDTFTestsMixin, ReusedConnectTestCase): diff --git a/python/pyspark/sql/tests/plot/test_frame_plot.py b/python/pyspark/sql/tests/plot/test_frame_plot.py index 2a6971e896292..3221a408d153d 100644 --- a/python/pyspark/sql/tests/plot/test_frame_plot.py +++ b/python/pyspark/sql/tests/plot/test_frame_plot.py @@ -18,11 +18,21 @@ import unittest from pyspark.errors import PySparkValueError from pyspark.sql import Row -from pyspark.sql.plot import PySparkSampledPlotBase, PySparkTopNPlotBase -from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message +from pyspark.testing.sqlutils import ( + ReusedSQLTestCase, + have_plotly, + plotly_requirement_message, + have_pandas, + pandas_requirement_message, +) +if have_plotly and have_pandas: + from pyspark.sql.plot import PySparkSampledPlotBase, PySparkTopNPlotBase -@unittest.skipIf(not have_plotly, plotly_requirement_message) + +@unittest.skipIf( + not have_plotly or not have_pandas, plotly_requirement_message or pandas_requirement_message +) class DataFramePlotTestsMixin: def test_backend(self): accessor = self.spark.range(2).plot diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py index 9764b4a277273..a6005b6f7c4d9 100644 --- a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py @@ -18,7 +18,6 @@ import unittest from datetime import datetime -import pyspark.sql.plot # noqa: F401 from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.testing.sqlutils import ( ReusedSQLTestCase, @@ -26,10 +25,17 @@ have_numpy, plotly_requirement_message, numpy_requirement_message, + have_pandas, + pandas_requirement_message, ) +if have_plotly and have_pandas: + import pyspark.sql.plot # noqa: F401 -@unittest.skipIf(not have_plotly, plotly_requirement_message) + +@unittest.skipIf( + not have_plotly or not have_pandas, plotly_requirement_message or pandas_requirement_message +) class DataFramePlotPlotlyTestsMixin: @property def sdf(self): diff --git a/python/pyspark/testing/__init__.py b/python/pyspark/testing/__init__.py index 88853e925f801..2a20035e54898 100644 --- a/python/pyspark/testing/__init__.py +++ b/python/pyspark/testing/__init__.py @@ -14,6 +14,51 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import typing + from pyspark.testing.utils import assertDataFrameEqual, assertSchemaEqual + +grpc_requirement_message = None +try: + import grpc +except ImportError as e: + grpc_requirement_message = str(e) +have_grpc = grpc_requirement_message is None + + +grpc_status_requirement_message = None +try: + import grpc_status +except ImportError as e: + grpc_status_requirement_message = str(e) +have_grpc_status = grpc_status_requirement_message is None + +googleapis_common_protos_requirement_message = None +try: + from google.rpc import error_details_pb2 +except ImportError as e: + googleapis_common_protos_requirement_message = str(e) +have_googleapis_common_protos = googleapis_common_protos_requirement_message is None + +graphviz_requirement_message = None +try: + import graphviz +except ImportError as e: + graphviz_requirement_message = str(e) +have_graphviz: bool = graphviz_requirement_message is None + +from pyspark.testing.utils import PySparkErrorTestUtils +from pyspark.testing.sqlutils import pandas_requirement_message, pyarrow_requirement_message + + +connect_requirement_message = ( + pandas_requirement_message + or pyarrow_requirement_message + or grpc_requirement_message + or googleapis_common_protos_requirement_message + or grpc_status_requirement_message +) +should_test_connect: str = typing.cast(str, connect_requirement_message is None) + __all__ = ["assertDataFrameEqual", "assertSchemaEqual"] diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index 2f18cd8a6ccdc..7dea8a2103c3d 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -23,35 +23,18 @@ import uuid import contextlib -grpc_requirement_message = None -try: - import grpc -except ImportError as e: - grpc_requirement_message = str(e) -have_grpc = grpc_requirement_message is None - - -grpc_status_requirement_message = None -try: - import grpc_status -except ImportError as e: - grpc_status_requirement_message = str(e) -have_grpc_status = grpc_status_requirement_message is None - -googleapis_common_protos_requirement_message = None -try: - from google.rpc import error_details_pb2 -except ImportError as e: - googleapis_common_protos_requirement_message = str(e) -have_googleapis_common_protos = googleapis_common_protos_requirement_message is None - -graphviz_requirement_message = None -try: - import graphviz -except ImportError as e: - graphviz_requirement_message = str(e) -have_graphviz: bool = graphviz_requirement_message is None - +from pyspark.testing import ( + grpc_requirement_message, + have_grpc, + grpc_status_requirement_message, + have_grpc_status, + googleapis_common_protos_requirement_message, + have_googleapis_common_protos, + graphviz_requirement_message, + have_graphviz, + connect_requirement_message, + should_test_connect, +) from pyspark import Row, SparkConf from pyspark.util import is_remote_only from pyspark.testing.utils import PySparkErrorTestUtils @@ -64,15 +47,6 @@ from pyspark.sql.session import SparkSession as PySparkSession -connect_requirement_message = ( - pandas_requirement_message - or pyarrow_requirement_message - or grpc_requirement_message - or googleapis_common_protos_requirement_message - or grpc_status_requirement_message -) -should_test_connect: str = typing.cast(str, connect_requirement_message is None) - if should_test_connect: from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.connect.plan import Read, Range, SQL, LogicalPlan