diff --git a/spark_expectations/core/expectations.py b/spark_expectations/core/expectations.py index cc877fd..8c2ce26 100644 --- a/spark_expectations/core/expectations.py +++ b/spark_expectations/core/expectations.py @@ -1,14 +1,10 @@ import functools from dataclasses import dataclass from typing import Dict, Optional, Any, Union +import packaging.version as package_version +from pyspark import version as spark_version from pyspark import StorageLevel from pyspark.sql import DataFrame, SparkSession - -try: - from pyspark.sql import connect -except ImportError: - # Spark/Databricks connect is available only in Databricks runtime 14 and above - pass from spark_expectations import _log from spark_expectations.config.user_config import Constants as user_config from spark_expectations.core.context import SparkExpectationsContext @@ -28,6 +24,17 @@ from spark_expectations.utils.regulate_flow import SparkExpectationsRegulateFlow +min_spark_version_for_connect = "3.4.0" +installed_spark_version = spark_version.__version__ +is_spark_connect_supported = False +if package_version.parse(installed_spark_version) >= package_version.parse( + min_spark_version_for_connect +): + from pyspark.sql import connect + + is_spark_connect_supported = True + + @dataclass class SparkExpectations: """ @@ -52,8 +59,11 @@ class SparkExpectations: def __post_init__(self) -> None: # Databricks runtime 14 and above could pass either instance of a Dataframe depending on how data was read - if isinstance(self.rules_df, DataFrame) or isinstance( - self.rules_df, connect.dataframe.DataFrame + if ( + is_spark_connect_supported is True + and isinstance(self.rules_df, (DataFrame, connect.dataframe.DataFrame)) + ) or ( + is_spark_connect_supported is False and isinstance(self.rules_df, DataFrame) ): try: self.spark: Optional[SparkSession] = self.rules_df.sparkSession @@ -362,8 +372,12 @@ def wrapper(*args: tuple, **kwargs: dict) -> DataFrame: self._context.get_run_id, ) - if isinstance(_df, DataFrame) or isinstance( - _df, connect.dataframe.DataFrame + if ( + is_spark_connect_supported is True + and isinstance(_df, (DataFrame, connect.dataframe.DataFrame)) + ) or ( + is_spark_connect_supported is False + and isinstance(_df, DataFrame) ): _log.info("The function dataframe is created") self._context.set_table_name(table_name)