Skip to content

Commit

Permalink
[SPARK-50040][PYTHON][TESTS] Make pysaprk-connect tests passing witho…
Browse files Browse the repository at this point in the history
…ut optional dependencies

### What changes were proposed in this pull request?

This PR proposes to make pysaprk-connect tests passing without optional dependencies

### Why are the changes needed?

To make the tests passing without optional dependencies. See https://github.com/apache/spark/actions/runs/11420354598/job/31775990587

### Does this PR introduce _any_ user-facing change?

No, test-only.

### How was this patch tested?

Manually ran it locally

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #48561 from HyukjinKwon/SPARK-50040.

Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
HyukjinKwon committed Oct 20, 2024
1 parent 164a0f4 commit ae75cac
Show file tree
Hide file tree
Showing 17 changed files with 137 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
from pyspark.ml.tests.connect.test_legacy_mode_classification import ClassificationTestsMixin
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message

torch_requirement_message = "torch is required"
Expand All @@ -30,9 +31,6 @@
except ImportError:
have_torch = False

if should_test_connect:
from pyspark.ml.tests.connect.test_legacy_mode_classification import ClassificationTestsMixin


@unittest.skipIf(
not should_test_connect or not have_torch or is_remote_only(),
Expand Down
6 changes: 4 additions & 2 deletions python/pyspark/sql/connect/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#

"""PySpark Functions with Spark Connect"""
from pyspark.testing import should_test_connect

from pyspark.sql.connect.functions.builtin import * # noqa: F401,F403
from pyspark.sql.connect.functions import partitioning # noqa: F401,F403
if should_test_connect:
from pyspark.sql.connect.functions.builtin import * # noqa: F401,F403
from pyspark.sql.connect.functions import partitioning # noqa: F401,F403
7 changes: 5 additions & 2 deletions python/pyspark/sql/connect/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.sql.connect.utils import check_dependencies

check_dependencies(__name__)

import sys
from typing import Dict, Optional, TYPE_CHECKING, List, Callable
Expand Down Expand Up @@ -235,12 +238,12 @@ def _test() -> None:

globs = pyspark.sql.connect.merge.__dict__.copy()
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.dataframe tests")
PySparkSession.builder.appName("sql.connect.merge tests")
.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
.getOrCreate()
)
(failure_count, test_count) = doctest.testmod(
pyspark.sql.merge,
pyspark.sql.connect.merge,
globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF,
)
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/sql/connect/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.sql.connect.utils import check_dependencies

check_dependencies(__name__)

from typing import Any, Dict, Optional
import uuid

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


def check_dependencies(mod_name: str) -> None:
if mod_name == "__main__":
if mod_name == "__main__" or mod_name == "pyspark.sql.connect.utils":
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message

if not should_test_connect:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
import unittest

from pyspark.sql.tests.streaming.test_streaming_foreach_batch import StreamingTestsForeachBatchMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.connectutils import ReusedConnectTestCase, should_test_connect
from pyspark.errors import PySparkPicklingError
from pyspark.errors.exceptions.connect import SparkConnectGrpcException

if should_test_connect:
from pyspark.errors.exceptions.connect import SparkConnectGrpcException


class StreamingForeachBatchParityTests(StreamingTestsForeachBatchMixin, ReusedConnectTestCase):
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/connect/test_connect_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
BooleanType,
)
from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.errors.exceptions.connect import SparkConnectException
from pyspark.testing.connectutils import should_test_connect
from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase

Expand All @@ -61,6 +60,7 @@
JVM_LONG_MIN,
JVM_LONG_MAX,
)
from pyspark.errors.exceptions.connect import SparkConnectException


class SparkConnectColumnTests(SparkConnectSQLTestCase):
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/connect/test_connect_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@
from pyspark.testing.sqlutils import MyObject, PythonOnlyUDT

from pyspark.testing.connectutils import should_test_connect
from pyspark.errors.exceptions.connect import ParseException
from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase

if should_test_connect:
import pandas as pd
import numpy as np
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF
from pyspark.errors.exceptions.connect import ParseException


class SparkConnectCreationTests(SparkConnectSQLTestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@

from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, DoubleType
from pyspark.sql.utils import is_remote

from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase
from pyspark.testing.connectutils import should_test_connect
from pyspark.testing.sqlutils import (
have_pandas,
have_pyarrow,
Expand All @@ -38,6 +36,9 @@
if have_pandas:
import pandas as pd

if should_test_connect:
from pyspark.sql.connect import functions as CF


class SparkConnectDataFramePropertyTests(SparkConnectSQLTestCase):
def test_cached_property_is_copied(self):
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/connect/test_connect_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
from pyspark.sql.types import Row
from pyspark.testing.connectutils import should_test_connect
from pyspark.errors import PySparkTypeError
from pyspark.errors.exceptions.connect import AnalysisException
from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase

if should_test_connect:
from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
from pyspark.sql.connect import functions as CF
from pyspark.sql.connect.column import Column
from pyspark.errors.exceptions.connect import AnalysisException


class SparkConnectErrorTests(SparkConnectSQLTestCase):
Expand Down
51 changes: 26 additions & 25 deletions python/pyspark/sql/tests/connect/test_connect_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,24 @@
RetriesExceeded,
)
from pyspark.sql import SparkSession as PySparkSession
from pyspark.sql.connect.client.retries import RetryPolicy

from pyspark.testing.connectutils import (
should_test_connect,
ReusedConnectTestCase,
connect_requirement_message,
)
from pyspark.errors.exceptions.connect import (
AnalysisException,
SparkConnectException,
SparkUpgradeException,
)

if should_test_connect:
import grpc
from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
from pyspark.sql.connect.client import DefaultChannelBuilder, ChannelBuilder
from pyspark.sql.connect.client.core import Retrying, SparkConnectClient
from pyspark.sql.connect.client.retries import RetryPolicy
from pyspark.errors.exceptions.connect import (
AnalysisException,
SparkConnectException,
SparkUpgradeException,
)


@unittest.skipIf(is_remote_only(), "Session creation different from local mode")
Expand Down Expand Up @@ -282,6 +282,7 @@ def test_stop_invalid_session(self): # SPARK-47986
session.stop()


@unittest.skipIf(not should_test_connect, connect_requirement_message)
class SparkConnectSessionWithOptionsTest(unittest.TestCase):
def setUp(self) -> None:
self.spark = (
Expand All @@ -303,31 +304,31 @@ def test_config(self):
self.assertEqual(self.spark.conf.get("integer"), "1")


class TestError(grpc.RpcError, Exception):
def __init__(self, code: grpc.StatusCode):
self._code = code

def code(self):
return self._code
if should_test_connect:

class TestError(grpc.RpcError, Exception):
def __init__(self, code: grpc.StatusCode):
self._code = code

class TestPolicy(RetryPolicy):
# Put a small value for initial backoff so that tests don't spend
# Time waiting
def __init__(self, initial_backoff=10, **kwargs):
super().__init__(initial_backoff=initial_backoff, **kwargs)
def code(self):
return self._code

def can_retry(self, exception: BaseException):
return isinstance(exception, TestError)
class TestPolicy(RetryPolicy):
# Put a small value for initial backoff so that tests don't spend
# Time waiting
def __init__(self, initial_backoff=10, **kwargs):
super().__init__(initial_backoff=initial_backoff, **kwargs)

def can_retry(self, exception: BaseException):
return isinstance(exception, TestError)

class TestPolicySpecificError(TestPolicy):
def __init__(self, specific_code: grpc.StatusCode, **kwargs):
super().__init__(**kwargs)
self.specific_code = specific_code
class TestPolicySpecificError(TestPolicy):
def __init__(self, specific_code: grpc.StatusCode, **kwargs):
super().__init__(**kwargs)
self.specific_code = specific_code

def can_retry(self, exception: BaseException):
return exception.code() == self.specific_code
def can_retry(self, exception: BaseException):
return exception.code() == self.specific_code


@unittest.skipIf(not should_test_connect, connect_requirement_message)
Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/sql/tests/connect/test_connect_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@

from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.testing.connectutils import should_test_connect
from pyspark.errors.exceptions.connect import (
AnalysisException,
SparkConnectException,
)
from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase

if should_test_connect:
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF
from pyspark.errors.exceptions.connect import (
AnalysisException,
SparkConnectException,
)


class SparkConnectStatTests(SparkConnectSQLTestCase):
Expand Down
7 changes: 3 additions & 4 deletions python/pyspark/sql/tests/connect/test_parity_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,16 @@
import unittest

from pyspark.testing.connectutils import should_test_connect
from pyspark.sql.tests.test_udtf import BaseUDTFTestsMixin, UDTFArrowTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase

if should_test_connect:
from pyspark import sql
from pyspark.sql.connect.udtf import UserDefinedTableFunction

sql.udtf.UserDefinedTableFunction = UserDefinedTableFunction
from pyspark.sql.connect.functions import lit, udtf

from pyspark.sql.tests.test_udtf import BaseUDTFTestsMixin, UDTFArrowTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.errors.exceptions.connect import SparkConnectGrpcException, PythonException
from pyspark.errors.exceptions.connect import SparkConnectGrpcException, PythonException


class UDTFParityTests(BaseUDTFTestsMixin, ReusedConnectTestCase):
Expand Down
16 changes: 13 additions & 3 deletions python/pyspark/sql/tests/plot/test_frame_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,21 @@
import unittest
from pyspark.errors import PySparkValueError
from pyspark.sql import Row
from pyspark.sql.plot import PySparkSampledPlotBase, PySparkTopNPlotBase
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_plotly,
plotly_requirement_message,
have_pandas,
pandas_requirement_message,
)

if have_plotly and have_pandas:
from pyspark.sql.plot import PySparkSampledPlotBase, PySparkTopNPlotBase

@unittest.skipIf(not have_plotly, plotly_requirement_message)

@unittest.skipIf(
not have_plotly or not have_pandas, plotly_requirement_message or pandas_requirement_message
)
class DataFramePlotTestsMixin:
def test_backend(self):
accessor = self.spark.range(2).plot
Expand Down
10 changes: 8 additions & 2 deletions python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,24 @@
import unittest
from datetime import datetime

import pyspark.sql.plot # noqa: F401
from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_plotly,
have_numpy,
plotly_requirement_message,
numpy_requirement_message,
have_pandas,
pandas_requirement_message,
)

if have_plotly and have_pandas:
import pyspark.sql.plot # noqa: F401

@unittest.skipIf(not have_plotly, plotly_requirement_message)

@unittest.skipIf(
not have_plotly or not have_pandas, plotly_requirement_message or pandas_requirement_message
)
class DataFramePlotPlotlyTestsMixin:
@property
def sdf(self):
Expand Down
45 changes: 45 additions & 0 deletions python/pyspark/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,51 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import typing

from pyspark.testing.utils import assertDataFrameEqual, assertSchemaEqual


grpc_requirement_message = None
try:
import grpc
except ImportError as e:
grpc_requirement_message = str(e)
have_grpc = grpc_requirement_message is None


grpc_status_requirement_message = None
try:
import grpc_status
except ImportError as e:
grpc_status_requirement_message = str(e)
have_grpc_status = grpc_status_requirement_message is None

googleapis_common_protos_requirement_message = None
try:
from google.rpc import error_details_pb2
except ImportError as e:
googleapis_common_protos_requirement_message = str(e)
have_googleapis_common_protos = googleapis_common_protos_requirement_message is None

graphviz_requirement_message = None
try:
import graphviz
except ImportError as e:
graphviz_requirement_message = str(e)
have_graphviz: bool = graphviz_requirement_message is None

from pyspark.testing.utils import PySparkErrorTestUtils
from pyspark.testing.sqlutils import pandas_requirement_message, pyarrow_requirement_message


connect_requirement_message = (
pandas_requirement_message
or pyarrow_requirement_message
or grpc_requirement_message
or googleapis_common_protos_requirement_message
or grpc_status_requirement_message
)
should_test_connect: str = typing.cast(str, connect_requirement_message is None)

__all__ = ["assertDataFrameEqual", "assertSchemaEqual"]
Loading

0 comments on commit ae75cac

Please sign in to comment.