diff --git a/Makefile b/Makefile index be653ed59..4e0b4e640 100644 --- a/Makefile +++ b/Makefile @@ -52,10 +52,10 @@ sign-off: # kedro-datasets related only test-no-spark: - cd kedro-datasets && pytest tests --no-cov --ignore tests/spark --numprocesses 4 --dist loadfile + cd kedro-datasets && pytest tests --no-cov --ignore tests/spark --ignore tests/databricks --numprocesses 4 --dist loadfile test-no-spark-sequential: - cd kedro-datasets && pytest tests --no-cov --ignore tests/spark + cd kedro-datasets && pytest tests --no-cov --ignore tests/spark --ignore tests/databricks # kedro-datasets/snowflake tests skipped from default scope test-snowflake-only: diff --git a/kedro-datasets/.gitignore b/kedro-datasets/.gitignore index d20ee9733..721e13f70 100644 --- a/kedro-datasets/.gitignore +++ b/kedro-datasets/.gitignore @@ -145,3 +145,6 @@ kedro.db kedro/html docs/tmp-build-artifacts docs/build +spark-warehouse +metastore_db/ +derby.log diff --git a/kedro-datasets/kedro_datasets/databricks/__init__.py b/kedro-datasets/kedro_datasets/databricks/__init__.py new file mode 100644 index 000000000..d416ac291 --- /dev/null +++ b/kedro-datasets/kedro_datasets/databricks/__init__.py @@ -0,0 +1,8 @@ +"""Provides interface to Unity Catalog Tables.""" + +__all__ = ["ManagedTableDataSet"] + +from contextlib import suppress + +with suppress(ImportError): + from .managed_table_dataset import ManagedTableDataSet diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py new file mode 100644 index 000000000..01ec15a6f --- /dev/null +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -0,0 +1,432 @@ +"""``ManagedTableDataSet`` implementation to access managed delta tables +in Databricks. +""" +import logging +import re +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +import pandas as pd +from kedro.io.core import ( + AbstractVersionedDataSet, + DataSetError, + Version, + VersionNotFoundError, +) +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.types import StructType +from pyspark.sql.utils import AnalysisException, ParseException + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class ManagedTable: # pylint: disable=too-many-instance-attributes + """Stores the definition of a managed table""" + + # regex for tables, catalogs and schemas + _NAMING_REGEX = r"\b[0-9a-zA-Z_-]{1,}\b" + _VALID_WRITE_MODES = ["overwrite", "upsert", "append"] + _VALID_DATAFRAME_TYPES = ["spark", "pandas"] + database: str + catalog: Optional[str] + table: str + write_mode: str + dataframe_type: str + primary_key: Optional[str] + owner_group: str + partition_columns: Union[str, List[str]] + json_schema: StructType + + def __post_init__(self): + """Run validation methods if declared. + The validation method can be a simple check + that raises DataSetError. + The validation is performed by calling a function named: + `validate_(self, value) -> raises DataSetError` + """ + for name in self.__dataclass_fields__.keys(): # pylint: disable=no-member + method = getattr(self, f"_validate_{name}", None) + if method: + method() + + def _validate_table(self): + """Validates table name + + Raises: + DataSetError: If the table name does not conform to naming constraints. + """ + if not re.fullmatch(self._NAMING_REGEX, self.table): + raise DataSetError("table does not conform to naming") + + def _validate_database(self): + """Validates database name + + Raises: + DataSetError: If the dataset name does not conform to naming constraints. + """ + if not re.fullmatch(self._NAMING_REGEX, self.database): + raise DataSetError("database does not conform to naming") + + def _validate_catalog(self): + """Validates catalog name + + Raises: + DataSetError: If the catalog name does not conform to naming constraints. + """ + if self.catalog: + if not re.fullmatch(self._NAMING_REGEX, self.catalog): + raise DataSetError("catalog does not conform to naming") + + def _validate_write_mode(self): + """Validates the write mode + + Raises: + DataSetError: If an invalid `write_mode` is passed. + """ + if self.write_mode not in self._VALID_WRITE_MODES: + valid_modes = ", ".join(self._VALID_WRITE_MODES) + raise DataSetError( + f"Invalid `write_mode` provided: {self.write_mode}. " + f"`write_mode` must be one of: {valid_modes}" + ) + + def _validate_dataframe_type(self): + """Validates the dataframe type + + Raises: + DataSetError: If an invalid `dataframe_type` is passed + """ + if self.dataframe_type not in self._VALID_DATAFRAME_TYPES: + valid_types = ", ".join(self._VALID_DATAFRAME_TYPES) + raise DataSetError(f"`dataframe_type` must be one of {valid_types}") + + def _validate_primary_key(self): + """Validates the primary key of the table + + Raises: + DataSetError: If no `primary_key` is specified. + """ + if self.primary_key is None or len(self.primary_key) == 0: + if self.write_mode == "upsert": + raise DataSetError( + f"`primary_key` must be provided for" + f"`write_mode` {self.write_mode}" + ) + + def full_table_location(self) -> str: + """Returns the full table location + + Returns: + str: table location in the format catalog.database.table + """ + full_table_location = None + if self.catalog and self.database and self.table: + full_table_location = f"`{self.catalog}`.`{self.database}`.`{self.table}`" + elif self.database and self.table: + full_table_location = f"`{self.database}`.`{self.table}`" + return full_table_location + + def schema(self) -> StructType: + """Returns the Spark schema of the table if it exists + + Returns: + StructType: + """ + schema = None + try: + if self.json_schema is not None: + schema = StructType.fromJson(self.json_schema) + except (KeyError, ValueError) as exc: + raise DataSetError(exc) from exc + return schema + + +class ManagedTableDataSet(AbstractVersionedDataSet): + """``ManagedTableDataSet`` loads and saves data into managed delta tables on Databricks. + Load and save can be in Spark or Pandas dataframes, specified in dataframe_type. + When saving data, you can specify one of three modes: overwrite(default), append, + or upsert. Upsert requires you to specify the primary_column parameter which + will be used as part of the join condition. This dataset works best with + the databricks kedro starter. That starter comes with hooks that allow this + dataset to function properly. Follow the instructions in that starter to + setup your project for this dataset. + + Example usage for the + `YAML API `_: + + .. code-block:: yaml + + names_and_ages@spark: + type: databricks.ManagedTableDataSet + table: names_and_ages + + names_and_ages@pandas: + type: databricks.ManagedTableDataSet + table: names_and_ages + dataframe_type: pandas + + Example usage for the + `Python API `_: + .. code-block:: python + + from pyspark.sql import SparkSession + from pyspark.sql.types import (StructField, StringType, + IntegerType, StructType) + from kedro_datasets.databricks import ManagedTableDataSet + schema = StructType([StructField("name", StringType(), True), + StructField("age", IntegerType(), True)]) + data = [('Alex', 31), ('Bob', 12), ('Clarke', 65), ('Dave', 29)] + spark_df = SparkSession.builder.getOrCreate().createDataFrame(data, schema) + data_set = ManagedTableDataSet(table="names_and_ages") + data_set.save(spark_df) + reloaded = data_set.load() + reloaded.take(4) + """ + + # this dataset cannot be used with ``ParallelRunner``, + # therefore it has the attribute ``_SINGLE_PROCESS = True`` + # for parallelism within a Spark pipeline please consider + # using ``ThreadRunner`` instead + _SINGLE_PROCESS = True + + def __init__( # pylint: disable=R0913 + self, + table: str, + catalog: str = None, + database: str = "default", + write_mode: str = "overwrite", + dataframe_type: str = "spark", + primary_key: Optional[Union[str, List[str]]] = None, + version: Version = None, + *, + # the following parameters are used by project hooks + # to create or update table properties + schema: Dict[str, Any] = None, + partition_columns: List[str] = None, + owner_group: str = None, + ) -> None: + """Creates a new instance of ``ManagedTableDataSet`` + + Args: + table (str): the name of the table + catalog (str, optional): the name of the catalog in Unity. + Defaults to None. + database (str, optional): the name of the database. + (also referred to as schema). Defaults to "default". + write_mode (str, optional): the mode to write the data into the table. + Options are:["overwrite", "append", "upsert"]. + "upsert" mode requires primary_key field to be populated. + Defaults to "overwrite". + dataframe_type (str, optional): "pandas" or "spark" dataframe. + Defaults to "spark". + primary_key (Union[str, List[str]], optional): the primary key of the table. + Can be in the form of a list. Defaults to None. + version (Version, optional): kedro.io.core.Version instance to load the data. + Defaults to None. + schema (Dict[str, Any], optional): the schema of the table in JSON form. + Dataframes will be truncated to match the schema if provided. + Used by the hooks to create the table if the schema is provided + Defaults to None. + partition_columns (List[str], optional): the columns to use for partitioning the table. + Used by the hooks. Defaults to None. + owner_group (str, optional): if table access control is enabled in your workspace, + specifying owner_group will transfer ownership of the table and database to + this owner. All databases should have the same owner_group. Defaults to None. + Raises: + DataSetError: Invalid configuration supplied (through ManagedTable validation) + """ + + self._table = ManagedTable( + database=database, + catalog=catalog, + table=table, + write_mode=write_mode, + dataframe_type=dataframe_type, + primary_key=primary_key, + owner_group=owner_group, + partition_columns=partition_columns, + json_schema=schema, + ) + + self._version = version + + super().__init__( + filepath=None, + version=version, + exists_function=self._exists, + ) + + @staticmethod + def _get_spark() -> SparkSession: + return SparkSession.builder.getOrCreate() + + def _load(self) -> Union[DataFrame, pd.DataFrame]: + """Loads the version of data in the format defined in the init + (spark|pandas dataframe) + + Raises: + VersionNotFoundError: if the version defined in + the init doesn't exist + + Returns: + Union[DataFrame, pd.DataFrame]: Returns a dataframe + in the format defined in the init + """ + if self._version and self._version.load >= 0: + try: + data = ( + self._get_spark() + .read.format("delta") + .option("versionAsOf", self._version.load) + .table(self._table.full_table_location()) + ) + except Exception as exc: + raise VersionNotFoundError(self._version.load) from exc + else: + data = self._get_spark().table(self._table.full_table_location()) + if self._table.dataframe_type == "pandas": + data = data.toPandas() + return data + + def _save_append(self, data: DataFrame) -> None: + """Saves the data to the table by appending it + to the location defined in the init + + Args: + data (DataFrame): the Spark dataframe to append to the table + """ + data.write.format("delta").mode("append").saveAsTable( + self._table.full_table_location() + ) + + def _save_overwrite(self, data: DataFrame) -> None: + """Overwrites the data in the table with the data provided. + (this is the default save mode) + + Args: + data (DataFrame): the Spark dataframe to overwrite the table with. + """ + delta_table = data.write.format("delta") + if self._table.write_mode == "overwrite": + delta_table = delta_table.mode("overwrite").option( + "overwriteSchema", "true" + ) + delta_table.saveAsTable(self._table.full_table_location()) + + def _save_upsert(self, update_data: DataFrame) -> None: + """Upserts the data by joining on primary_key columns or column. + If table doesn't exist at save, the data is inserted to a new table. + + Args: + update_data (DataFrame): the Spark dataframe to upsert + """ + if self._exists(): + base_data = self._get_spark().table(self._table.full_table_location()) + base_columns = base_data.columns + update_columns = update_data.columns + + if set(update_columns) != set(base_columns): + raise DataSetError( + f"Upsert requires tables to have identical columns. " + f"Delta table {self._table.full_table_location()} " + f"has columns: {base_columns}, whereas " + f"dataframe has columns {update_columns}" + ) + + where_expr = "" + if isinstance(self._table.primary_key, str): + where_expr = ( + f"base.{self._table.primary_key}=update.{self._table.primary_key}" + ) + elif isinstance(self._table.primary_key, list): + where_expr = " AND ".join( + f"base.{col}=update.{col}" for col in self._table.primary_key + ) + + update_data.createOrReplaceTempView("update") + self._get_spark().conf.set( + "fullTableAddress", self._table.full_table_location() + ) + self._get_spark().conf.set("whereExpr", where_expr) + upsert_sql = """MERGE INTO ${fullTableAddress} base USING update ON ${whereExpr} + WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *""" + self._get_spark().sql(upsert_sql) + else: + self._save_append(update_data) + + def _save(self, data: Union[DataFrame, pd.DataFrame]) -> None: + """Saves the data based on the write_mode and dataframe_type in the init. + If write_mode is pandas, Spark dataframe is created first. + If schema is provided, data is matched to schema before saving + (columns will be sorted and truncated). + + Args: + data (Any): Spark or pandas dataframe to save to the table location + """ + # filter columns specified in schema and match their ordering + if self._table.schema(): + cols = self._table.schema().fieldNames() + if self._table.dataframe_type == "pandas": + data = self._get_spark().createDataFrame( + data.loc[:, cols], schema=self._table.schema() + ) + else: + data = data.select(*cols) + else: + if self._table.dataframe_type == "pandas": + data = self._get_spark().createDataFrame(data) + if self._table.write_mode == "overwrite": + self._save_overwrite(data) + elif self._table.write_mode == "upsert": + self._save_upsert(data) + elif self._table.write_mode == "append": + self._save_append(data) + + def _describe(self) -> Dict[str, str]: + """Returns a description of the instance of ManagedTableDataSet + + Returns: + Dict[str, str]: Dict with the details of the dataset + """ + return { + "catalog": self._table.catalog, + "database": self._table.database, + "table": self._table.table, + "write_mode": self._table.write_mode, + "dataframe_type": self._table.dataframe_type, + "primary_key": self._table.primary_key, + "version": str(self._version), + "owner_group": self._table.owner_group, + "partition_columns": self._table.partition_columns, + } + + def _exists(self) -> bool: + """Checks to see if the table exists + + Returns: + bool: boolean of whether the table defined + in the dataset instance exists in the Spark session + """ + if self._table.catalog: + try: + self._get_spark().sql(f"USE CATALOG `{self._table.catalog}`") + except (ParseException, AnalysisException) as exc: + logger.warning( + "catalog %s not found or unity not enabled. Error message: %s", + self._table.catalog, + exc, + ) + try: + return ( + self._get_spark() + .sql(f"SHOW TABLES IN `{self._table.database}`") + .filter(f"tableName = '{self._table.table}'") + .count() + > 0 + ) + except (ParseException, AnalysisException) as exc: + logger.warning("error occured while trying to find table: %s", exc) + return False diff --git a/kedro-datasets/setup.py b/kedro-datasets/setup.py index f2f4921a5..f5c5d931a 100644 --- a/kedro-datasets/setup.py +++ b/kedro-datasets/setup.py @@ -8,6 +8,7 @@ HDFS = "hdfs>=2.5.8, <3.0" S3FS = "s3fs>=0.3.0, <0.5" POLARS = "polars~=0.17.0" +DELTA = "delta-spark~=1.2.1" def _collect_requirements(requires): @@ -16,7 +17,10 @@ def _collect_requirements(requires): api_require = {"api.APIDataSet": ["requests~=2.20"]} biosequence_require = {"biosequence.BioSequenceDataSet": ["biopython~=1.73"]} -dask_require = {"dask.ParquetDataSet": ["dask[complete]", "triad>=0.6.7, <1.0"]} +dask_require = {"dask.ParquetDataSet": ["dask[complete]~=2021.10", "triad>=0.6.7, <1.0"]} +databricks_require = { + "databricks.ManagedTableDataSet": [SPARK, PANDAS, DELTA] +} geopandas_require = { "geopandas.GeoJSONDataSet": ["geopandas>=0.6.0, <1.0", "pyproj~=3.0"] } @@ -76,6 +80,7 @@ def _collect_requirements(requires): "api": _collect_requirements(api_require), "biosequence": _collect_requirements(biosequence_require), "dask": _collect_requirements(dask_require), + "databricks": _collect_requirements(databricks_require), "docs": [ "docutils==0.16", "sphinx~=3.4.3", @@ -105,6 +110,7 @@ def _collect_requirements(requires): **api_require, **biosequence_require, **dask_require, + **databricks_require, **geopandas_require, **holoviews_require, **matplotlib_require, diff --git a/kedro-datasets/test_requirements.txt b/kedro-datasets/test_requirements.txt index 4d4954739..fe20fee5f 100644 --- a/kedro-datasets/test_requirements.txt +++ b/kedro-datasets/test_requirements.txt @@ -30,7 +30,7 @@ networkx~=2.4 opencv-python~=4.5.5.64 openpyxl>=3.0.3, <4.0 pandas-gbq>=0.12.0, <0.18.0 -pandas>=1.3 # 1.3 for read_xml/to_xml +pandas>=1.3, <2 # 1.3 for read_xml/to_xml, <2 for compatibility with Spark < 3.4 Pillow~=9.0 plotly>=4.8.0, <6.0 polars~=0.15.13 diff --git a/kedro-datasets/tests/databricks/__init__.py b/kedro-datasets/tests/databricks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kedro-datasets/tests/databricks/conftest.py b/kedro-datasets/tests/databricks/conftest.py new file mode 100644 index 000000000..26d63b056 --- /dev/null +++ b/kedro-datasets/tests/databricks/conftest.py @@ -0,0 +1,25 @@ +""" +This file contains the fixtures that are reusable by any tests within +this directory. You don't need to import the fixtures as pytest will +discover them automatically. More info here: +https://docs.pytest.org/en/latest/fixture.html +""" +import pytest +from pyspark.sql import SparkSession + + +@pytest.fixture(scope="class", autouse=True) +def spark_session(): + spark = ( + SparkSession.builder.appName("test") + .config("spark.jars.packages", "io.delta:delta-core_2.12:1.2.1") + .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") + .config( + "spark.sql.catalog.spark_catalog", + "org.apache.spark.sql.delta.catalog.DeltaCatalog", + ) + .getOrCreate() + ) + spark.sql("create database if not exists test") + yield spark + spark.sql("drop database test cascade;") diff --git a/kedro-datasets/tests/databricks/test_managed_table_dataset.py b/kedro-datasets/tests/databricks/test_managed_table_dataset.py new file mode 100644 index 000000000..9aae08707 --- /dev/null +++ b/kedro-datasets/tests/databricks/test_managed_table_dataset.py @@ -0,0 +1,484 @@ +import pandas as pd +import pytest +from kedro.io.core import DataSetError, Version, VersionNotFoundError +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.types import IntegerType, StringType, StructField, StructType + +from kedro_datasets.databricks import ManagedTableDataSet + + +@pytest.fixture +def sample_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 32), ("Evan", 23)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def mismatched_upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", IntegerType(), True), + ] + ) + + data = [("Alex", 32, 174), ("Evan", 23, 166)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def subset_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", IntegerType(), True), + ] + ) + + data = [("Alex", 32, 174), ("Evan", 23, 166)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def subset_pandas_df(): + return pd.DataFrame( + {"name": ["Alex", "Evan"], "age": [32, 23], "height": [174, 166]} + ) + + +@pytest.fixture +def subset_expected_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 32), ("Evan", 23)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def sample_pandas_df(): + return pd.DataFrame( + {"name": ["Alex", "Bob", "Clarke", "Dave"], "age": [31, 12, 65, 29]} + ) + + +@pytest.fixture +def append_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Evan", 23), ("Frank", 13)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_append_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 31), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ("Frank", 13), + ] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 32), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 31), + ("Alex", 32), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ] + + return spark_session.createDataFrame(data, schema) + + +# pylint: disable=too-many-public-methods +class TestManagedTableDataSet: + def test_full_table(self): + unity_ds = ManagedTableDataSet(catalog="test", database="test", table="test") + assert unity_ds._table.full_table_location() == "`test`.`test`.`test`" + + unity_ds = ManagedTableDataSet( + catalog="test-test", database="test", table="test" + ) + assert unity_ds._table.full_table_location() == "`test-test`.`test`.`test`" + + unity_ds = ManagedTableDataSet(database="test", table="test") + assert unity_ds._table.full_table_location() == "`test`.`test`" + + unity_ds = ManagedTableDataSet(table="test") + assert unity_ds._table.full_table_location() == "`default`.`test`" + + with pytest.raises(TypeError): + ManagedTableDataSet() # pylint: disable=no-value-for-parameter + + def test_describe(self): + unity_ds = ManagedTableDataSet(table="test") + assert unity_ds._describe() == { + "catalog": None, + "database": "default", + "table": "test", + "write_mode": "overwrite", + "dataframe_type": "spark", + "primary_key": None, + "version": "None", + "owner_group": None, + "partition_columns": None, + } + + def test_invalid_write_mode(self): + with pytest.raises(DataSetError): + ManagedTableDataSet(table="test", write_mode="invalid") + + def test_dataframe_type(self): + with pytest.raises(DataSetError): + ManagedTableDataSet(table="test", dataframe_type="invalid") + + def test_missing_primary_key_upsert(self): + with pytest.raises(DataSetError): + ManagedTableDataSet(table="test", write_mode="upsert") + + def test_invalid_table_name(self): + with pytest.raises(DataSetError): + ManagedTableDataSet(table="invalid!") + + def test_invalid_database(self): + with pytest.raises(DataSetError): + ManagedTableDataSet(table="test", database="invalid!") + + def test_invalid_catalog(self): + with pytest.raises(DataSetError): + ManagedTableDataSet(table="test", catalog="invalid!") + + def test_schema(self): + unity_ds = ManagedTableDataSet( + table="test", + schema={ + "fields": [ + { + "metadata": {}, + "name": "name", + "nullable": True, + "type": "string", + }, + { + "metadata": {}, + "name": "age", + "nullable": True, + "type": "integer", + }, + ], + "type": "struct", + }, + ) + expected_schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + assert unity_ds._table.schema() == expected_schema + + def test_invalid_schema(self): + with pytest.raises(DataSetError): + ManagedTableDataSet( + table="test", + schema={ + "fields": [ + { + "invalid": "schema", + } + ], + "type": "struct", + }, + )._table.schema() + + def test_catalog_exists(self): + unity_ds = ManagedTableDataSet( + catalog="test", database="invalid", table="test_not_there" + ) + assert not unity_ds._exists() + + def test_table_does_not_exist(self): + unity_ds = ManagedTableDataSet(database="invalid", table="test_not_there") + assert not unity_ds._exists() + + def test_save_default(self, sample_spark_df: DataFrame): + unity_ds = ManagedTableDataSet(database="test", table="test_save") + unity_ds.save(sample_spark_df) + saved_table = unity_ds.load() + assert ( + unity_ds._exists() and sample_spark_df.exceptAll(saved_table).count() == 0 + ) + + def test_save_schema_spark( + self, subset_spark_df: DataFrame, subset_expected_df: DataFrame + ): + unity_ds = ManagedTableDataSet( + database="test", + table="test_save_spark_schema", + schema={ + "fields": [ + { + "metadata": {}, + "name": "name", + "nullable": True, + "type": "string", + }, + { + "metadata": {}, + "name": "age", + "nullable": True, + "type": "integer", + }, + ], + "type": "struct", + }, + ) + unity_ds.save(subset_spark_df) + saved_table = unity_ds.load() + assert subset_expected_df.exceptAll(saved_table).count() == 0 + + def test_save_schema_pandas( + self, subset_pandas_df: pd.DataFrame, subset_expected_df: DataFrame + ): + unity_ds = ManagedTableDataSet( + database="test", + table="test_save_pd_schema", + schema={ + "fields": [ + { + "metadata": {}, + "name": "name", + "nullable": True, + "type": "string", + }, + { + "metadata": {}, + "name": "age", + "nullable": True, + "type": "integer", + }, + ], + "type": "struct", + }, + dataframe_type="pandas", + ) + unity_ds.save(subset_pandas_df) + saved_ds = ManagedTableDataSet( + database="test", + table="test_save_pd_schema", + ) + saved_table = saved_ds.load() + assert subset_expected_df.exceptAll(saved_table).count() == 0 + + def test_save_overwrite( + self, sample_spark_df: DataFrame, append_spark_df: DataFrame + ): + unity_ds = ManagedTableDataSet(database="test", table="test_save") + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + overwritten_table = unity_ds.load() + + assert append_spark_df.exceptAll(overwritten_table).count() == 0 + + def test_save_append( + self, + sample_spark_df: DataFrame, + append_spark_df: DataFrame, + expected_append_spark_df: DataFrame, + ): + unity_ds = ManagedTableDataSet( + database="test", table="test_save_append", write_mode="append" + ) + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + appended_table = unity_ds.load() + + assert expected_append_spark_df.exceptAll(appended_table).count() == 0 + + def test_save_upsert( + self, + sample_spark_df: DataFrame, + upsert_spark_df: DataFrame, + expected_upsert_spark_df: DataFrame, + ): + unity_ds = ManagedTableDataSet( + database="test", + table="test_save_upsert", + write_mode="upsert", + primary_key="name", + ) + unity_ds.save(sample_spark_df) + unity_ds.save(upsert_spark_df) + + upserted_table = unity_ds.load() + + assert expected_upsert_spark_df.exceptAll(upserted_table).count() == 0 + + def test_save_upsert_multiple_primary( + self, + sample_spark_df: DataFrame, + upsert_spark_df: DataFrame, + expected_upsert_multiple_primary_spark_df: DataFrame, + ): + unity_ds = ManagedTableDataSet( + database="test", + table="test_save_upsert_multiple", + write_mode="upsert", + primary_key=["name", "age"], + ) + unity_ds.save(sample_spark_df) + unity_ds.save(upsert_spark_df) + + upserted_table = unity_ds.load() + + assert ( + expected_upsert_multiple_primary_spark_df.exceptAll(upserted_table).count() + == 0 + ) + + def test_save_upsert_mismatched_columns( + self, + sample_spark_df: DataFrame, + mismatched_upsert_spark_df: DataFrame, + ): + unity_ds = ManagedTableDataSet( + database="test", + table="test_save_upsert_mismatch", + write_mode="upsert", + primary_key="name", + ) + unity_ds.save(sample_spark_df) + with pytest.raises(DataSetError): + unity_ds.save(mismatched_upsert_spark_df) + + def test_load_spark(self, sample_spark_df: DataFrame): + unity_ds = ManagedTableDataSet(database="test", table="test_load_spark") + unity_ds.save(sample_spark_df) + + delta_ds = ManagedTableDataSet(database="test", table="test_load_spark") + delta_table = delta_ds.load() + + assert ( + isinstance(delta_table, DataFrame) + and delta_table.exceptAll(sample_spark_df).count() == 0 + ) + + def test_load_spark_no_version(self, sample_spark_df: DataFrame): + unity_ds = ManagedTableDataSet(database="test", table="test_load_spark") + unity_ds.save(sample_spark_df) + + delta_ds = ManagedTableDataSet( + database="test", table="test_load_spark", version=Version(2, None) + ) + with pytest.raises(VersionNotFoundError): + _ = delta_ds.load() + + def test_load_version(self, sample_spark_df: DataFrame, append_spark_df: DataFrame): + unity_ds = ManagedTableDataSet( + database="test", table="test_load_version", write_mode="append" + ) + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + loaded_ds = ManagedTableDataSet( + database="test", table="test_load_version", version=Version(0, None) + ) + loaded_df = loaded_ds.load() + + assert loaded_df.exceptAll(sample_spark_df).count() == 0 + + def test_load_pandas(self, sample_pandas_df: pd.DataFrame): + unity_ds = ManagedTableDataSet( + database="test", table="test_load_pandas", dataframe_type="pandas" + ) + unity_ds.save(sample_pandas_df) + + pandas_ds = ManagedTableDataSet( + database="test", table="test_load_pandas", dataframe_type="pandas" + ) + pandas_df = pandas_ds.load().sort_values("name", ignore_index=True) + + assert isinstance(pandas_df, pd.DataFrame) and pandas_df.equals( + sample_pandas_df + )