From 9a91954598a327439c1a724bf7ff6ce4c501f450 Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Fri, 10 Feb 2023 14:32:15 -0500 Subject: [PATCH] committing first version of UnityTableCatalog with unit tests. This datasets allows users to interface with Unity catalog tables in Databricks to both read and write. Signed-off-by: Danny Farah --- kedro-datasets/.gitignore | 3 + .../kedro_datasets/databricks/__init__.py | 8 + .../kedro_datasets/databricks/unity.py | 202 ++++++++ kedro-datasets/setup.py | 3 + kedro-datasets/tests/databricks/__init__.py | 0 kedro-datasets/tests/databricks/conftest.py | 26 + .../tests/databricks/test_unity_dataset.py | 448 ++++++++++++++++++ 7 files changed, 690 insertions(+) create mode 100644 kedro-datasets/kedro_datasets/databricks/__init__.py create mode 100644 kedro-datasets/kedro_datasets/databricks/unity.py create mode 100644 kedro-datasets/tests/databricks/__init__.py create mode 100644 kedro-datasets/tests/databricks/conftest.py create mode 100644 kedro-datasets/tests/databricks/test_unity_dataset.py diff --git a/kedro-datasets/.gitignore b/kedro-datasets/.gitignore index d20ee9733..3725bd847 100644 --- a/kedro-datasets/.gitignore +++ b/kedro-datasets/.gitignore @@ -145,3 +145,6 @@ kedro.db kedro/html docs/tmp-build-artifacts docs/build +spark-warehouse +metastore_db/ +derby.log \ No newline at end of file diff --git a/kedro-datasets/kedro_datasets/databricks/__init__.py b/kedro-datasets/kedro_datasets/databricks/__init__.py new file mode 100644 index 000000000..2fd3eccb9 --- /dev/null +++ b/kedro-datasets/kedro_datasets/databricks/__init__.py @@ -0,0 +1,8 @@ +"""Provides interface to Unity Catalog Tables.""" + +__all__ = ["UnityTableDataSet"] + +from contextlib import suppress + +with suppress(ImportError): + from .unity import UnityTableDataSet diff --git a/kedro-datasets/kedro_datasets/databricks/unity.py b/kedro-datasets/kedro_datasets/databricks/unity.py new file mode 100644 index 000000000..8921fca1b --- /dev/null +++ b/kedro-datasets/kedro_datasets/databricks/unity.py @@ -0,0 +1,202 @@ +import logging +from typing import Any, Dict, List, Union +import pandas as pd + +from kedro.io.core import ( + AbstractVersionedDataSet, + DataSetError, + VersionNotFoundError, +) +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.types import StructType +from pyspark.sql.utils import AnalysisException +from cachetools import Cache + +logger = logging.getLogger(__name__) + + +class UnityTableDataSet(AbstractVersionedDataSet): + """``UnityTableDataSet`` loads data into Unity managed tables.""" + + # this dataset cannot be used with ``ParallelRunner``, + # therefore it has the attribute ``_SINGLE_PROCESS = True`` + # for parallelism within a Spark pipeline please consider + # using ``ThreadRunner`` instead + _SINGLE_PROCESS = True + _VALID_WRITE_MODES = ["overwrite", "upsert", "append"] + _VALID_DATAFRAME_TYPES = ["spark", "pandas"] + + def __init__( + self, + table: str, + catalog: str = None, + database: str = "default", + write_mode: str = "overwrite", + dataframe_type: str = "spark", + primary_key: Union[str, List[str]] = None, + version: int = None, + *, + # the following parameters are used by the hook to create or update unity + schema: Dict[str, Any] = None, # pylint: disable=unused-argument + partition_columns: List[str] = None, # pylint: disable=unused-argument + owner_group: str = None, + ) -> None: + """Creates a new instance of ``UnityTableDataSet``.""" + + self._database = database + self._catalog = catalog + self._table = table + self._owner_group = owner_group + self._partition_columns = partition_columns + if catalog and database and table: + self._full_table_address = f"{catalog}.{database}.{table}" + elif table: + self._full_table_address = f"{database}.{table}" + + if write_mode not in self._VALID_WRITE_MODES: + valid_modes = ", ".join(self._VALID_WRITE_MODES) + raise DataSetError( + f"Invalid `write_mode` provided: {write_mode}. " + f"`write_mode` must be one of: {valid_modes}" + ) + self._write_mode = write_mode + + if dataframe_type not in self._VALID_DATAFRAME_TYPES: + valid_types = ", ".join(self._VALID_DATAFRAME_TYPES) + raise DataSetError(f"`dataframe_type` must be one of {valid_types}") + self._dataframe_type = dataframe_type + + if primary_key is None or len(primary_key) == 0: + if write_mode == "upsert": + raise DataSetError( + f"`primary_key` must be provided for" f"`write_mode` {write_mode}" + ) + + self._primary_key = primary_key + + self._version = version + self._version_cache = Cache(maxsize=2) + + self._schema = None + if schema is not None: + self._schema = StructType.fromJson(schema) + + def _get_spark(self) -> SparkSession: + return ( + SparkSession.builder.config( + "spark.jars.packages", "io.delta:delta-core_2.12:1.2.1" + ) + .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") + .config( + "spark.sql.catalog.spark_catalog", + "org.apache.spark.sql.delta.catalog.DeltaCatalog", + ) + .getOrCreate() + ) + + def _load(self) -> Union[DataFrame, pd.DataFrame]: + if self._version is not None and self._version >= 0: + try: + data = ( + self._get_spark() + .read.format("delta") + .option("versionAsOf", self._version) + .table(self._full_table_address) + ) + except: + raise VersionNotFoundError + else: + data = self._get_spark().table(self._full_table_address) + if self._dataframe_type == "pandas": + data = data.toPandas() + return data + + def _save_append(self, data: DataFrame) -> None: + data.write.format("delta").mode("append").saveAsTable(self._full_table_address) + + def _save_overwrite(self, data: DataFrame) -> None: + delta_table = data.write.format("delta") + if self._write_mode == "overwrite": + delta_table = delta_table.mode("overwrite").option( + "overwriteSchema", "true" + ) + delta_table.saveAsTable(self._full_table_address) + + def _save_upsert(self, update_data: DataFrame) -> None: + if self._exists(): + base_data = self._get_spark().table(self._full_table_address) + base_columns = base_data.columns + update_columns = update_data.columns + + if set(update_columns) != set(base_columns): + raise DataSetError( + f"Upsert requires tables to have identical columns. " + f"Delta table {self._full_table_address} " + f"has columns: {base_columns}, whereas " + f"dataframe has columns {update_columns}" + ) + + where_expr = "" + if isinstance(self._primary_key, str): + where_expr = f"base.{self._primary_key}=update.{self._primary_key}" + elif isinstance(self._primary_key, list): + where_expr = " AND ".join( + f"base.{col}=update.{col}" for col in self._primary_key + ) + + update_data.createOrReplaceTempView("update") + + upsert_sql = f"""MERGE INTO {self._full_table_address} base USING update + ON {where_expr} WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT * + """ + self._get_spark().sql(upsert_sql) + else: + self._save_append(update_data) + + def _save(self, data: Any) -> None: + # filter columns specified in schema and match their ordering + if self._schema: + cols = self._schema.fieldNames() + if self._dataframe_type == "pandas": + data = self._get_spark().createDataFrame( + data.loc[:, cols], schema=self._schema + ) + else: + data = data.select(*cols) + else: + if self._dataframe_type == "pandas": + data = self._get_spark().createDataFrame(data) + if self._write_mode == "overwrite": + self._save_overwrite(data) + elif self._write_mode == "upsert": + self._save_upsert(data) + elif self._write_mode == "append": + self._save_append(data) + + def _describe(self) -> Dict[str, str]: + return dict( + catalog=self._catalog, + database=self._database, + table=self._table, + write_mode=self._write_mode, + dataframe_type=self._dataframe_type, + primary_key=self._primary_key, + version=self._version, + ) + + def _exists(self) -> bool: + if self._catalog: + try: + self._get_spark().sql(f"USE CATALOG {self._catalog}") + except: + logger.warn(f"catalog {self._catalog} not found") + try: + return ( + self._get_spark() + .sql(f"SHOW TABLES IN `{self._database}`") + .filter(f"tableName = '{self._table}'") + .count() + > 0 + ) + except: + return False diff --git a/kedro-datasets/setup.py b/kedro-datasets/setup.py index f75d3cad1..8c5440a75 100644 --- a/kedro-datasets/setup.py +++ b/kedro-datasets/setup.py @@ -37,6 +37,7 @@ def _collect_requirements(requires): api_require = {"api.APIDataSet": ["requests~=2.20"]} biosequence_require = {"biosequence.BioSequenceDataSet": ["biopython~=1.73"]} dask_require = {"dask.ParquetDataSet": ["dask[complete]~=2021.10", "triad>=0.6.7, <1.0"]} +databricks_require = {"databricks.UnityTableDataSet": [SPARK]} geopandas_require = { "geopandas.GeoJSONDataSet": ["geopandas>=0.6.0, <1.0", "pyproj~=3.0"] } @@ -90,6 +91,7 @@ def _collect_requirements(requires): "api": _collect_requirements(api_require), "biosequence": _collect_requirements(biosequence_require), "dask": _collect_requirements(dask_require), + "databricks": _collect_requirements(databricks_require), "docs": [ "docutils==0.16", "sphinx~=3.4.3", @@ -117,6 +119,7 @@ def _collect_requirements(requires): **api_require, **biosequence_require, **dask_require, + **databricks_require, **geopandas_require, **matplotlib_require, **holoviews_require, diff --git a/kedro-datasets/tests/databricks/__init__.py b/kedro-datasets/tests/databricks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kedro-datasets/tests/databricks/conftest.py b/kedro-datasets/tests/databricks/conftest.py new file mode 100644 index 000000000..d360ffb68 --- /dev/null +++ b/kedro-datasets/tests/databricks/conftest.py @@ -0,0 +1,26 @@ +""" +This file contains the fixtures that are reusable by any tests within +this directory. You don't need to import the fixtures as pytest will +discover them automatically. More info here: +https://docs.pytest.org/en/latest/fixture.html +""" +import pytest +from pyspark.sql import SparkSession +from delta.pip_utils import configure_spark_with_delta_pip + + +@pytest.fixture(scope="class", autouse=True) +def spark_session(): + spark = ( + SparkSession.builder.appName("test") + .config("spark.jars.packages", "io.delta:delta-core_2.12:1.2.1") + .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") + .config( + "spark.sql.catalog.spark_catalog", + "org.apache.spark.sql.delta.catalog.DeltaCatalog", + ) + .getOrCreate() + ) + spark.sql("create database if not exists test") + yield spark + spark.sql("drop database test cascade;") diff --git a/kedro-datasets/tests/databricks/test_unity_dataset.py b/kedro-datasets/tests/databricks/test_unity_dataset.py new file mode 100644 index 000000000..3f29a1e95 --- /dev/null +++ b/kedro-datasets/tests/databricks/test_unity_dataset.py @@ -0,0 +1,448 @@ +import pytest +from kedro.io.core import DataSetError, VersionNotFoundError +from pyspark.sql.types import IntegerType, StringType, StructField, StructType +from pyspark.sql import DataFrame, SparkSession +import pandas as pd +from kedro_datasets.databricks import UnityTableDataSet + + +@pytest.fixture +def sample_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 32), ("Evan", 23)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def mismatched_upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", IntegerType(), True), + ] + ) + + data = [("Alex", 32, 174), ("Evan", 23, 166)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def subset_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", IntegerType(), True), + ] + ) + + data = [("Alex", 32, 174), ("Evan", 23, 166)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def subset_pandas_df(): + return pd.DataFrame( + {"name": ["Alex", "Evan"], "age": [32, 23], "height": [174, 166]} + ) + + +@pytest.fixture +def subset_expected_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 32), ("Evan", 23)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def sample_pandas_df(): + return pd.DataFrame( + {"name": ["Alex", "Bob", "Clarke", "Dave"], "age": [31, 12, 65, 29]} + ) + + +@pytest.fixture +def append_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Evan", 23), ("Frank", 13)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_append_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 31), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ("Frank", 13), + ] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 32), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 31), + ("Alex", 32), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ] + + return spark_session.createDataFrame(data, schema) + + +class TestUnityTableDataSet: + def test_full_table(self): + unity_ds = UnityTableDataSet(catalog="test", database="test", table="test") + assert unity_ds._full_table_address == "test.test.test" + + def test_database_table(self): + unity_ds = UnityTableDataSet(database="test", table="test") + assert unity_ds._full_table_address == "test.test" + + def test_table_only(self): + unity_ds = UnityTableDataSet(table="test") + assert unity_ds._full_table_address == "default.test" + + def test_table_missing(self): + with pytest.raises(TypeError): + UnityTableDataSet() + + def test_describe(self): + unity_ds = UnityTableDataSet(table="test") + assert unity_ds._describe() == { + "catalog": None, + "database": "default", + "table": "test", + "write_mode": "overwrite", + "dataframe_type": "spark", + "primary_key": None, + "version": None, + } + + def test_invalid_write_mode(self): + with pytest.raises(DataSetError): + UnityTableDataSet(table="test", write_mode="invalid") + + def test_dataframe_type(self): + with pytest.raises(DataSetError): + UnityTableDataSet(table="test", dataframe_type="invalid") + + def test_missing_primary_key_upsert(self): + with pytest.raises(DataSetError): + UnityTableDataSet(table="test", write_mode="upsert") + + def test_schema(self): + unity_ds = UnityTableDataSet( + table="test", + schema={ + "fields": [ + { + "metadata": {}, + "name": "name", + "nullable": True, + "type": "string", + }, + { + "metadata": {}, + "name": "age", + "nullable": True, + "type": "integer", + }, + ], + "type": "struct", + }, + ) + expected_schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + assert unity_ds._schema == expected_schema + + def test_catalog_exists(self): + unity_ds = UnityTableDataSet(catalog="test", database="invalid", table="test_not_there") + assert not unity_ds._exists() + + def test_table_does_not_exist(self): + unity_ds = UnityTableDataSet(database="invalid", table="test_not_there") + assert not unity_ds._exists() + + def test_save_default(self, sample_spark_df: DataFrame): + unity_ds = UnityTableDataSet(database="test", table="test_save") + unity_ds.save(sample_spark_df) + saved_table = unity_ds.load() + assert unity_ds.exists() and sample_spark_df.exceptAll(saved_table).count() == 0 + + def test_save_schema_spark( + self, subset_spark_df: DataFrame, subset_expected_df: DataFrame + ): + unity_ds = UnityTableDataSet( + database="test", + table="test_save_spark_schema", + schema={ + "fields": [ + { + "metadata": {}, + "name": "name", + "nullable": True, + "type": "string", + }, + { + "metadata": {}, + "name": "age", + "nullable": True, + "type": "integer", + }, + ], + "type": "struct", + }, + ) + unity_ds.save(subset_spark_df) + saved_table = unity_ds.load() + assert subset_expected_df.exceptAll(saved_table).count() == 0 + + def test_save_schema_pandas( + self, subset_pandas_df: pd.DataFrame, subset_expected_df: DataFrame + ): + unity_ds = UnityTableDataSet( + database="test", + table="test_save_pd_schema", + schema={ + "fields": [ + { + "metadata": {}, + "name": "name", + "nullable": True, + "type": "string", + }, + { + "metadata": {}, + "name": "age", + "nullable": True, + "type": "integer", + }, + ], + "type": "struct", + }, + dataframe_type="pandas", + ) + unity_ds.save(subset_pandas_df) + saved_ds = UnityTableDataSet( + database="test", + table="test_save_pd_schema", + ) + saved_table = saved_ds.load() + assert subset_expected_df.exceptAll(saved_table).count() == 0 + + def test_save_overwrite( + self, sample_spark_df: DataFrame, append_spark_df: DataFrame + ): + unity_ds = UnityTableDataSet(database="test", table="test_save") + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + overwritten_table = unity_ds.load() + + assert append_spark_df.exceptAll(overwritten_table).count() == 0 + + def test_save_append( + self, + sample_spark_df: DataFrame, + append_spark_df: DataFrame, + expected_append_spark_df: DataFrame, + ): + unity_ds = UnityTableDataSet( + database="test", table="test_save_append", write_mode="append" + ) + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + appended_table = unity_ds.load() + + assert expected_append_spark_df.exceptAll(appended_table).count() == 0 + + def test_save_upsert( + self, + sample_spark_df: DataFrame, + upsert_spark_df: DataFrame, + expected_upsert_spark_df: DataFrame, + ): + unity_ds = UnityTableDataSet( + database="test", + table="test_save_upsert", + write_mode="upsert", + primary_key="name", + ) + unity_ds.save(sample_spark_df) + unity_ds.save(upsert_spark_df) + + upserted_table = unity_ds.load() + + assert expected_upsert_spark_df.exceptAll(upserted_table).count() == 0 + + def test_save_upsert_multiple_primary( + self, + sample_spark_df: DataFrame, + upsert_spark_df: DataFrame, + expected_upsert_multiple_primary_spark_df: DataFrame, + ): + unity_ds = UnityTableDataSet( + database="test", + table="test_save_upsert_multiple", + write_mode="upsert", + primary_key=["name", "age"], + ) + unity_ds.save(sample_spark_df) + unity_ds.save(upsert_spark_df) + + upserted_table = unity_ds.load() + + assert ( + expected_upsert_multiple_primary_spark_df.exceptAll(upserted_table).count() + == 0 + ) + + def test_save_upsert_mismatched_columns( + self, + sample_spark_df: DataFrame, + mismatched_upsert_spark_df: DataFrame, + ): + unity_ds = UnityTableDataSet( + database="test", + table="test_save_upsert_mismatch", + write_mode="upsert", + primary_key="name", + ) + unity_ds.save(sample_spark_df) + with pytest.raises(DataSetError): + unity_ds.save(mismatched_upsert_spark_df) + + def test_load_spark(self, sample_spark_df: DataFrame): + unity_ds = UnityTableDataSet(database="test", table="test_load_spark") + unity_ds.save(sample_spark_df) + + delta_ds = UnityTableDataSet(database="test", table="test_load_spark") + delta_table = delta_ds.load() + + assert ( + isinstance(delta_table, DataFrame) + and delta_table.exceptAll(sample_spark_df).count() == 0 + ) + + def test_load_spark_no_version(self, sample_spark_df: DataFrame): + unity_ds = UnityTableDataSet(database="test", table="test_load_spark") + unity_ds.save(sample_spark_df) + + delta_ds = UnityTableDataSet( + database="test", table="test_load_spark", version=2 + ) + with pytest.raises(VersionNotFoundError): + _ = delta_ds.load() + + def test_load_version(self, sample_spark_df: DataFrame, append_spark_df: DataFrame): + unity_ds = UnityTableDataSet( + database="test", table="test_load_version", write_mode="append" + ) + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + loaded_ds = UnityTableDataSet( + database="test", table="test_load_version", version=0 + ) + loaded_df = loaded_ds.load() + + assert loaded_df.exceptAll(sample_spark_df).count() == 0 + + def test_load_pandas(self, sample_pandas_df: pd.DataFrame): + unity_ds = UnityTableDataSet( + database="test", table="test_load_pandas", dataframe_type="pandas" + ) + unity_ds.save(sample_pandas_df) + + pandas_ds = UnityTableDataSet( + database="test", table="test_load_pandas", dataframe_type="pandas" + ) + pandas_df = pandas_ds.load().sort_values("name", ignore_index=True) + + assert isinstance(pandas_df, pd.DataFrame) and pandas_df.equals( + sample_pandas_df + )