Skip to content

Commit

Permalink
Added support if dataframe of instance pyspark.sql.connect.dataframe.…
Browse files Browse the repository at this point in the history
…DataFrame is passed as input
  • Loading branch information
dgunt2 committed Sep 9, 2024
1 parent 64ba1be commit 79b04d6
Showing 1 changed file with 24 additions and 10 deletions.
34 changes: 24 additions & 10 deletions spark_expectations/core/expectations.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import functools
from dataclasses import dataclass
from typing import Dict, Optional, Any, Union
import packaging.version as package_version
from pyspark import version as spark_version
from pyspark import StorageLevel
from pyspark.sql import DataFrame, SparkSession

try:
from pyspark.sql import connect
except ImportError:
# Spark/Databricks connect is available only in Databricks runtime 14 and above
pass
from spark_expectations import _log
from spark_expectations.config.user_config import Constants as user_config
from spark_expectations.core.context import SparkExpectationsContext
Expand All @@ -28,6 +24,17 @@
from spark_expectations.utils.regulate_flow import SparkExpectationsRegulateFlow


min_spark_version_for_connect = "3.4.0"
installed_spark_version = spark_version.__version__
is_spark_connect_supported = False
if package_version.parse(installed_spark_version) >= package_version.parse(
min_spark_version_for_connect
):
from pyspark.sql import connect

is_spark_connect_supported = True


@dataclass
class SparkExpectations:
"""
Expand All @@ -52,8 +59,11 @@ class SparkExpectations:

def __post_init__(self) -> None:
# Databricks runtime 14 and above could pass either instance of a Dataframe depending on how data was read
if isinstance(self.rules_df, DataFrame) or isinstance(
self.rules_df, connect.dataframe.DataFrame
if (
is_spark_connect_supported is True
and isinstance(self.rules_df, (DataFrame, connect.dataframe.DataFrame))
) or (
is_spark_connect_supported is False and isinstance(self.rules_df, DataFrame)
):
try:
self.spark: Optional[SparkSession] = self.rules_df.sparkSession
Expand Down Expand Up @@ -362,8 +372,12 @@ def wrapper(*args: tuple, **kwargs: dict) -> DataFrame:
self._context.get_run_id,
)

if isinstance(_df, DataFrame) or isinstance(
_df, connect.dataframe.DataFrame
if (
is_spark_connect_supported is True
and isinstance(_df, (DataFrame, connect.dataframe.DataFrame))
) or (
is_spark_connect_supported is False
and isinstance(_df, DataFrame)
):
_log.info("The function dataframe is created")
self._context.set_table_name(table_name)
Expand Down

0 comments on commit 79b04d6

Please sign in to comment.