From d1bc9ab5e1a696769196f83c5a28f20b4d43b71a Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Fri, 10 Feb 2023 14:32:15 -0500 Subject: [PATCH 01/40] committing first version of UnityTableCatalog with unit tests. This datasets allows users to interface with Unity catalog tables in Databricks to both read and write. Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- kedro-datasets/.gitignore | 3 + .../kedro_datasets/databricks/__init__.py | 8 + .../kedro_datasets/databricks/unity.py | 202 ++++++++ kedro-datasets/setup.py | 5 +- kedro-datasets/tests/databricks/__init__.py | 0 kedro-datasets/tests/databricks/conftest.py | 26 + .../tests/databricks/test_unity_dataset.py | 448 ++++++++++++++++++ 7 files changed, 691 insertions(+), 1 deletion(-) create mode 100644 kedro-datasets/kedro_datasets/databricks/__init__.py create mode 100644 kedro-datasets/kedro_datasets/databricks/unity.py create mode 100644 kedro-datasets/tests/databricks/__init__.py create mode 100644 kedro-datasets/tests/databricks/conftest.py create mode 100644 kedro-datasets/tests/databricks/test_unity_dataset.py diff --git a/kedro-datasets/.gitignore b/kedro-datasets/.gitignore index d20ee9733..3725bd847 100644 --- a/kedro-datasets/.gitignore +++ b/kedro-datasets/.gitignore @@ -145,3 +145,6 @@ kedro.db kedro/html docs/tmp-build-artifacts docs/build +spark-warehouse +metastore_db/ +derby.log \ No newline at end of file diff --git a/kedro-datasets/kedro_datasets/databricks/__init__.py b/kedro-datasets/kedro_datasets/databricks/__init__.py new file mode 100644 index 000000000..2fd3eccb9 --- /dev/null +++ b/kedro-datasets/kedro_datasets/databricks/__init__.py @@ -0,0 +1,8 @@ +"""Provides interface to Unity Catalog Tables.""" + +__all__ = ["UnityTableDataSet"] + +from contextlib import suppress + +with suppress(ImportError): + from .unity import UnityTableDataSet diff --git a/kedro-datasets/kedro_datasets/databricks/unity.py b/kedro-datasets/kedro_datasets/databricks/unity.py new file mode 100644 index 000000000..8921fca1b --- /dev/null +++ b/kedro-datasets/kedro_datasets/databricks/unity.py @@ -0,0 +1,202 @@ +import logging +from typing import Any, Dict, List, Union +import pandas as pd + +from kedro.io.core import ( + AbstractVersionedDataSet, + DataSetError, + VersionNotFoundError, +) +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.types import StructType +from pyspark.sql.utils import AnalysisException +from cachetools import Cache + +logger = logging.getLogger(__name__) + + +class UnityTableDataSet(AbstractVersionedDataSet): + """``UnityTableDataSet`` loads data into Unity managed tables.""" + + # this dataset cannot be used with ``ParallelRunner``, + # therefore it has the attribute ``_SINGLE_PROCESS = True`` + # for parallelism within a Spark pipeline please consider + # using ``ThreadRunner`` instead + _SINGLE_PROCESS = True + _VALID_WRITE_MODES = ["overwrite", "upsert", "append"] + _VALID_DATAFRAME_TYPES = ["spark", "pandas"] + + def __init__( + self, + table: str, + catalog: str = None, + database: str = "default", + write_mode: str = "overwrite", + dataframe_type: str = "spark", + primary_key: Union[str, List[str]] = None, + version: int = None, + *, + # the following parameters are used by the hook to create or update unity + schema: Dict[str, Any] = None, # pylint: disable=unused-argument + partition_columns: List[str] = None, # pylint: disable=unused-argument + owner_group: str = None, + ) -> None: + """Creates a new instance of ``UnityTableDataSet``.""" + + self._database = database + self._catalog = catalog + self._table = table + self._owner_group = owner_group + self._partition_columns = partition_columns + if catalog and database and table: + self._full_table_address = f"{catalog}.{database}.{table}" + elif table: + self._full_table_address = f"{database}.{table}" + + if write_mode not in self._VALID_WRITE_MODES: + valid_modes = ", ".join(self._VALID_WRITE_MODES) + raise DataSetError( + f"Invalid `write_mode` provided: {write_mode}. " + f"`write_mode` must be one of: {valid_modes}" + ) + self._write_mode = write_mode + + if dataframe_type not in self._VALID_DATAFRAME_TYPES: + valid_types = ", ".join(self._VALID_DATAFRAME_TYPES) + raise DataSetError(f"`dataframe_type` must be one of {valid_types}") + self._dataframe_type = dataframe_type + + if primary_key is None or len(primary_key) == 0: + if write_mode == "upsert": + raise DataSetError( + f"`primary_key` must be provided for" f"`write_mode` {write_mode}" + ) + + self._primary_key = primary_key + + self._version = version + self._version_cache = Cache(maxsize=2) + + self._schema = None + if schema is not None: + self._schema = StructType.fromJson(schema) + + def _get_spark(self) -> SparkSession: + return ( + SparkSession.builder.config( + "spark.jars.packages", "io.delta:delta-core_2.12:1.2.1" + ) + .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") + .config( + "spark.sql.catalog.spark_catalog", + "org.apache.spark.sql.delta.catalog.DeltaCatalog", + ) + .getOrCreate() + ) + + def _load(self) -> Union[DataFrame, pd.DataFrame]: + if self._version is not None and self._version >= 0: + try: + data = ( + self._get_spark() + .read.format("delta") + .option("versionAsOf", self._version) + .table(self._full_table_address) + ) + except: + raise VersionNotFoundError + else: + data = self._get_spark().table(self._full_table_address) + if self._dataframe_type == "pandas": + data = data.toPandas() + return data + + def _save_append(self, data: DataFrame) -> None: + data.write.format("delta").mode("append").saveAsTable(self._full_table_address) + + def _save_overwrite(self, data: DataFrame) -> None: + delta_table = data.write.format("delta") + if self._write_mode == "overwrite": + delta_table = delta_table.mode("overwrite").option( + "overwriteSchema", "true" + ) + delta_table.saveAsTable(self._full_table_address) + + def _save_upsert(self, update_data: DataFrame) -> None: + if self._exists(): + base_data = self._get_spark().table(self._full_table_address) + base_columns = base_data.columns + update_columns = update_data.columns + + if set(update_columns) != set(base_columns): + raise DataSetError( + f"Upsert requires tables to have identical columns. " + f"Delta table {self._full_table_address} " + f"has columns: {base_columns}, whereas " + f"dataframe has columns {update_columns}" + ) + + where_expr = "" + if isinstance(self._primary_key, str): + where_expr = f"base.{self._primary_key}=update.{self._primary_key}" + elif isinstance(self._primary_key, list): + where_expr = " AND ".join( + f"base.{col}=update.{col}" for col in self._primary_key + ) + + update_data.createOrReplaceTempView("update") + + upsert_sql = f"""MERGE INTO {self._full_table_address} base USING update + ON {where_expr} WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT * + """ + self._get_spark().sql(upsert_sql) + else: + self._save_append(update_data) + + def _save(self, data: Any) -> None: + # filter columns specified in schema and match their ordering + if self._schema: + cols = self._schema.fieldNames() + if self._dataframe_type == "pandas": + data = self._get_spark().createDataFrame( + data.loc[:, cols], schema=self._schema + ) + else: + data = data.select(*cols) + else: + if self._dataframe_type == "pandas": + data = self._get_spark().createDataFrame(data) + if self._write_mode == "overwrite": + self._save_overwrite(data) + elif self._write_mode == "upsert": + self._save_upsert(data) + elif self._write_mode == "append": + self._save_append(data) + + def _describe(self) -> Dict[str, str]: + return dict( + catalog=self._catalog, + database=self._database, + table=self._table, + write_mode=self._write_mode, + dataframe_type=self._dataframe_type, + primary_key=self._primary_key, + version=self._version, + ) + + def _exists(self) -> bool: + if self._catalog: + try: + self._get_spark().sql(f"USE CATALOG {self._catalog}") + except: + logger.warn(f"catalog {self._catalog} not found") + try: + return ( + self._get_spark() + .sql(f"SHOW TABLES IN `{self._database}`") + .filter(f"tableName = '{self._table}'") + .count() + > 0 + ) + except: + return False diff --git a/kedro-datasets/setup.py b/kedro-datasets/setup.py index f2f4921a5..635127e20 100644 --- a/kedro-datasets/setup.py +++ b/kedro-datasets/setup.py @@ -16,7 +16,8 @@ def _collect_requirements(requires): api_require = {"api.APIDataSet": ["requests~=2.20"]} biosequence_require = {"biosequence.BioSequenceDataSet": ["biopython~=1.73"]} -dask_require = {"dask.ParquetDataSet": ["dask[complete]", "triad>=0.6.7, <1.0"]} +dask_require = {"dask.ParquetDataSet": ["dask[complete]~=2021.10", "triad>=0.6.7, <1.0"]} +databricks_require = {"databricks.UnityTableDataSet": [SPARK]} geopandas_require = { "geopandas.GeoJSONDataSet": ["geopandas>=0.6.0, <1.0", "pyproj~=3.0"] } @@ -76,6 +77,7 @@ def _collect_requirements(requires): "api": _collect_requirements(api_require), "biosequence": _collect_requirements(biosequence_require), "dask": _collect_requirements(dask_require), + "databricks": _collect_requirements(databricks_require), "docs": [ "docutils==0.16", "sphinx~=3.4.3", @@ -105,6 +107,7 @@ def _collect_requirements(requires): **api_require, **biosequence_require, **dask_require, + **databricks_require, **geopandas_require, **holoviews_require, **matplotlib_require, diff --git a/kedro-datasets/tests/databricks/__init__.py b/kedro-datasets/tests/databricks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kedro-datasets/tests/databricks/conftest.py b/kedro-datasets/tests/databricks/conftest.py new file mode 100644 index 000000000..d360ffb68 --- /dev/null +++ b/kedro-datasets/tests/databricks/conftest.py @@ -0,0 +1,26 @@ +""" +This file contains the fixtures that are reusable by any tests within +this directory. You don't need to import the fixtures as pytest will +discover them automatically. More info here: +https://docs.pytest.org/en/latest/fixture.html +""" +import pytest +from pyspark.sql import SparkSession +from delta.pip_utils import configure_spark_with_delta_pip + + +@pytest.fixture(scope="class", autouse=True) +def spark_session(): + spark = ( + SparkSession.builder.appName("test") + .config("spark.jars.packages", "io.delta:delta-core_2.12:1.2.1") + .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") + .config( + "spark.sql.catalog.spark_catalog", + "org.apache.spark.sql.delta.catalog.DeltaCatalog", + ) + .getOrCreate() + ) + spark.sql("create database if not exists test") + yield spark + spark.sql("drop database test cascade;") diff --git a/kedro-datasets/tests/databricks/test_unity_dataset.py b/kedro-datasets/tests/databricks/test_unity_dataset.py new file mode 100644 index 000000000..3f29a1e95 --- /dev/null +++ b/kedro-datasets/tests/databricks/test_unity_dataset.py @@ -0,0 +1,448 @@ +import pytest +from kedro.io.core import DataSetError, VersionNotFoundError +from pyspark.sql.types import IntegerType, StringType, StructField, StructType +from pyspark.sql import DataFrame, SparkSession +import pandas as pd +from kedro_datasets.databricks import UnityTableDataSet + + +@pytest.fixture +def sample_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 32), ("Evan", 23)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def mismatched_upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", IntegerType(), True), + ] + ) + + data = [("Alex", 32, 174), ("Evan", 23, 166)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def subset_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", IntegerType(), True), + ] + ) + + data = [("Alex", 32, 174), ("Evan", 23, 166)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def subset_pandas_df(): + return pd.DataFrame( + {"name": ["Alex", "Evan"], "age": [32, 23], "height": [174, 166]} + ) + + +@pytest.fixture +def subset_expected_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 32), ("Evan", 23)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def sample_pandas_df(): + return pd.DataFrame( + {"name": ["Alex", "Bob", "Clarke", "Dave"], "age": [31, 12, 65, 29]} + ) + + +@pytest.fixture +def append_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Evan", 23), ("Frank", 13)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_append_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 31), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ("Frank", 13), + ] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 32), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 31), + ("Alex", 32), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ] + + return spark_session.createDataFrame(data, schema) + + +class TestUnityTableDataSet: + def test_full_table(self): + unity_ds = UnityTableDataSet(catalog="test", database="test", table="test") + assert unity_ds._full_table_address == "test.test.test" + + def test_database_table(self): + unity_ds = UnityTableDataSet(database="test", table="test") + assert unity_ds._full_table_address == "test.test" + + def test_table_only(self): + unity_ds = UnityTableDataSet(table="test") + assert unity_ds._full_table_address == "default.test" + + def test_table_missing(self): + with pytest.raises(TypeError): + UnityTableDataSet() + + def test_describe(self): + unity_ds = UnityTableDataSet(table="test") + assert unity_ds._describe() == { + "catalog": None, + "database": "default", + "table": "test", + "write_mode": "overwrite", + "dataframe_type": "spark", + "primary_key": None, + "version": None, + } + + def test_invalid_write_mode(self): + with pytest.raises(DataSetError): + UnityTableDataSet(table="test", write_mode="invalid") + + def test_dataframe_type(self): + with pytest.raises(DataSetError): + UnityTableDataSet(table="test", dataframe_type="invalid") + + def test_missing_primary_key_upsert(self): + with pytest.raises(DataSetError): + UnityTableDataSet(table="test", write_mode="upsert") + + def test_schema(self): + unity_ds = UnityTableDataSet( + table="test", + schema={ + "fields": [ + { + "metadata": {}, + "name": "name", + "nullable": True, + "type": "string", + }, + { + "metadata": {}, + "name": "age", + "nullable": True, + "type": "integer", + }, + ], + "type": "struct", + }, + ) + expected_schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + assert unity_ds._schema == expected_schema + + def test_catalog_exists(self): + unity_ds = UnityTableDataSet(catalog="test", database="invalid", table="test_not_there") + assert not unity_ds._exists() + + def test_table_does_not_exist(self): + unity_ds = UnityTableDataSet(database="invalid", table="test_not_there") + assert not unity_ds._exists() + + def test_save_default(self, sample_spark_df: DataFrame): + unity_ds = UnityTableDataSet(database="test", table="test_save") + unity_ds.save(sample_spark_df) + saved_table = unity_ds.load() + assert unity_ds.exists() and sample_spark_df.exceptAll(saved_table).count() == 0 + + def test_save_schema_spark( + self, subset_spark_df: DataFrame, subset_expected_df: DataFrame + ): + unity_ds = UnityTableDataSet( + database="test", + table="test_save_spark_schema", + schema={ + "fields": [ + { + "metadata": {}, + "name": "name", + "nullable": True, + "type": "string", + }, + { + "metadata": {}, + "name": "age", + "nullable": True, + "type": "integer", + }, + ], + "type": "struct", + }, + ) + unity_ds.save(subset_spark_df) + saved_table = unity_ds.load() + assert subset_expected_df.exceptAll(saved_table).count() == 0 + + def test_save_schema_pandas( + self, subset_pandas_df: pd.DataFrame, subset_expected_df: DataFrame + ): + unity_ds = UnityTableDataSet( + database="test", + table="test_save_pd_schema", + schema={ + "fields": [ + { + "metadata": {}, + "name": "name", + "nullable": True, + "type": "string", + }, + { + "metadata": {}, + "name": "age", + "nullable": True, + "type": "integer", + }, + ], + "type": "struct", + }, + dataframe_type="pandas", + ) + unity_ds.save(subset_pandas_df) + saved_ds = UnityTableDataSet( + database="test", + table="test_save_pd_schema", + ) + saved_table = saved_ds.load() + assert subset_expected_df.exceptAll(saved_table).count() == 0 + + def test_save_overwrite( + self, sample_spark_df: DataFrame, append_spark_df: DataFrame + ): + unity_ds = UnityTableDataSet(database="test", table="test_save") + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + overwritten_table = unity_ds.load() + + assert append_spark_df.exceptAll(overwritten_table).count() == 0 + + def test_save_append( + self, + sample_spark_df: DataFrame, + append_spark_df: DataFrame, + expected_append_spark_df: DataFrame, + ): + unity_ds = UnityTableDataSet( + database="test", table="test_save_append", write_mode="append" + ) + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + appended_table = unity_ds.load() + + assert expected_append_spark_df.exceptAll(appended_table).count() == 0 + + def test_save_upsert( + self, + sample_spark_df: DataFrame, + upsert_spark_df: DataFrame, + expected_upsert_spark_df: DataFrame, + ): + unity_ds = UnityTableDataSet( + database="test", + table="test_save_upsert", + write_mode="upsert", + primary_key="name", + ) + unity_ds.save(sample_spark_df) + unity_ds.save(upsert_spark_df) + + upserted_table = unity_ds.load() + + assert expected_upsert_spark_df.exceptAll(upserted_table).count() == 0 + + def test_save_upsert_multiple_primary( + self, + sample_spark_df: DataFrame, + upsert_spark_df: DataFrame, + expected_upsert_multiple_primary_spark_df: DataFrame, + ): + unity_ds = UnityTableDataSet( + database="test", + table="test_save_upsert_multiple", + write_mode="upsert", + primary_key=["name", "age"], + ) + unity_ds.save(sample_spark_df) + unity_ds.save(upsert_spark_df) + + upserted_table = unity_ds.load() + + assert ( + expected_upsert_multiple_primary_spark_df.exceptAll(upserted_table).count() + == 0 + ) + + def test_save_upsert_mismatched_columns( + self, + sample_spark_df: DataFrame, + mismatched_upsert_spark_df: DataFrame, + ): + unity_ds = UnityTableDataSet( + database="test", + table="test_save_upsert_mismatch", + write_mode="upsert", + primary_key="name", + ) + unity_ds.save(sample_spark_df) + with pytest.raises(DataSetError): + unity_ds.save(mismatched_upsert_spark_df) + + def test_load_spark(self, sample_spark_df: DataFrame): + unity_ds = UnityTableDataSet(database="test", table="test_load_spark") + unity_ds.save(sample_spark_df) + + delta_ds = UnityTableDataSet(database="test", table="test_load_spark") + delta_table = delta_ds.load() + + assert ( + isinstance(delta_table, DataFrame) + and delta_table.exceptAll(sample_spark_df).count() == 0 + ) + + def test_load_spark_no_version(self, sample_spark_df: DataFrame): + unity_ds = UnityTableDataSet(database="test", table="test_load_spark") + unity_ds.save(sample_spark_df) + + delta_ds = UnityTableDataSet( + database="test", table="test_load_spark", version=2 + ) + with pytest.raises(VersionNotFoundError): + _ = delta_ds.load() + + def test_load_version(self, sample_spark_df: DataFrame, append_spark_df: DataFrame): + unity_ds = UnityTableDataSet( + database="test", table="test_load_version", write_mode="append" + ) + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + loaded_ds = UnityTableDataSet( + database="test", table="test_load_version", version=0 + ) + loaded_df = loaded_ds.load() + + assert loaded_df.exceptAll(sample_spark_df).count() == 0 + + def test_load_pandas(self, sample_pandas_df: pd.DataFrame): + unity_ds = UnityTableDataSet( + database="test", table="test_load_pandas", dataframe_type="pandas" + ) + unity_ds.save(sample_pandas_df) + + pandas_ds = UnityTableDataSet( + database="test", table="test_load_pandas", dataframe_type="pandas" + ) + pandas_df = pandas_ds.load().sort_values("name", ignore_index=True) + + assert isinstance(pandas_df, pd.DataFrame) and pandas_df.equals( + sample_pandas_df + ) From 798055ec6a7fc245b69654a905f1df7b1ea740b8 Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Mon, 13 Feb 2023 23:29:15 -0500 Subject: [PATCH 02/40] renaming dataset Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- .../kedro_datasets/databricks/__init__.py | 4 +- .../kedro_datasets/databricks/unity.py | 6 +- .../{test_unity_dataset.py => test_unity.py} | 60 +++++++++---------- 3 files changed, 35 insertions(+), 35 deletions(-) rename kedro-datasets/tests/databricks/{test_unity_dataset.py => test_unity.py} (87%) diff --git a/kedro-datasets/kedro_datasets/databricks/__init__.py b/kedro-datasets/kedro_datasets/databricks/__init__.py index 2fd3eccb9..313f3bdba 100644 --- a/kedro-datasets/kedro_datasets/databricks/__init__.py +++ b/kedro-datasets/kedro_datasets/databricks/__init__.py @@ -1,8 +1,8 @@ """Provides interface to Unity Catalog Tables.""" -__all__ = ["UnityTableDataSet"] +__all__ = ["ManagedTableDataSet"] from contextlib import suppress with suppress(ImportError): - from .unity import UnityTableDataSet + from .unity import ManagedTableDataSet diff --git a/kedro-datasets/kedro_datasets/databricks/unity.py b/kedro-datasets/kedro_datasets/databricks/unity.py index 8921fca1b..b6270f58c 100644 --- a/kedro-datasets/kedro_datasets/databricks/unity.py +++ b/kedro-datasets/kedro_datasets/databricks/unity.py @@ -15,8 +15,8 @@ logger = logging.getLogger(__name__) -class UnityTableDataSet(AbstractVersionedDataSet): - """``UnityTableDataSet`` loads data into Unity managed tables.""" +class ManagedTableDataSet(AbstractVersionedDataSet): + """``ManagedTableDataSet`` loads data into Unity managed tables.""" # this dataset cannot be used with ``ParallelRunner``, # therefore it has the attribute ``_SINGLE_PROCESS = True`` @@ -41,7 +41,7 @@ def __init__( partition_columns: List[str] = None, # pylint: disable=unused-argument owner_group: str = None, ) -> None: - """Creates a new instance of ``UnityTableDataSet``.""" + """Creates a new instance of ``ManagedTableDataSet``.""" self._database = database self._catalog = catalog diff --git a/kedro-datasets/tests/databricks/test_unity_dataset.py b/kedro-datasets/tests/databricks/test_unity.py similarity index 87% rename from kedro-datasets/tests/databricks/test_unity_dataset.py rename to kedro-datasets/tests/databricks/test_unity.py index 3f29a1e95..471f81f57 100644 --- a/kedro-datasets/tests/databricks/test_unity_dataset.py +++ b/kedro-datasets/tests/databricks/test_unity.py @@ -3,7 +3,7 @@ from pyspark.sql.types import IntegerType, StringType, StructField, StructType from pyspark.sql import DataFrame, SparkSession import pandas as pd -from kedro_datasets.databricks import UnityTableDataSet +from kedro_datasets.databricks import ManagedTableDataSet @pytest.fixture @@ -168,25 +168,25 @@ def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession): return spark_session.createDataFrame(data, schema) -class TestUnityTableDataSet: +class TestManagedTableDataSet: def test_full_table(self): - unity_ds = UnityTableDataSet(catalog="test", database="test", table="test") + unity_ds = ManagedTableDataSet(catalog="test", database="test", table="test") assert unity_ds._full_table_address == "test.test.test" def test_database_table(self): - unity_ds = UnityTableDataSet(database="test", table="test") + unity_ds = ManagedTableDataSet(database="test", table="test") assert unity_ds._full_table_address == "test.test" def test_table_only(self): - unity_ds = UnityTableDataSet(table="test") + unity_ds = ManagedTableDataSet(table="test") assert unity_ds._full_table_address == "default.test" def test_table_missing(self): with pytest.raises(TypeError): - UnityTableDataSet() + ManagedTableDataSet() def test_describe(self): - unity_ds = UnityTableDataSet(table="test") + unity_ds = ManagedTableDataSet(table="test") assert unity_ds._describe() == { "catalog": None, "database": "default", @@ -199,18 +199,18 @@ def test_describe(self): def test_invalid_write_mode(self): with pytest.raises(DataSetError): - UnityTableDataSet(table="test", write_mode="invalid") + ManagedTableDataSet(table="test", write_mode="invalid") def test_dataframe_type(self): with pytest.raises(DataSetError): - UnityTableDataSet(table="test", dataframe_type="invalid") + ManagedTableDataSet(table="test", dataframe_type="invalid") def test_missing_primary_key_upsert(self): with pytest.raises(DataSetError): - UnityTableDataSet(table="test", write_mode="upsert") + ManagedTableDataSet(table="test", write_mode="upsert") def test_schema(self): - unity_ds = UnityTableDataSet( + unity_ds = ManagedTableDataSet( table="test", schema={ "fields": [ @@ -239,15 +239,15 @@ def test_schema(self): assert unity_ds._schema == expected_schema def test_catalog_exists(self): - unity_ds = UnityTableDataSet(catalog="test", database="invalid", table="test_not_there") + unity_ds = ManagedTableDataSet(catalog="test", database="invalid", table="test_not_there") assert not unity_ds._exists() def test_table_does_not_exist(self): - unity_ds = UnityTableDataSet(database="invalid", table="test_not_there") + unity_ds = ManagedTableDataSet(database="invalid", table="test_not_there") assert not unity_ds._exists() def test_save_default(self, sample_spark_df: DataFrame): - unity_ds = UnityTableDataSet(database="test", table="test_save") + unity_ds = ManagedTableDataSet(database="test", table="test_save") unity_ds.save(sample_spark_df) saved_table = unity_ds.load() assert unity_ds.exists() and sample_spark_df.exceptAll(saved_table).count() == 0 @@ -255,7 +255,7 @@ def test_save_default(self, sample_spark_df: DataFrame): def test_save_schema_spark( self, subset_spark_df: DataFrame, subset_expected_df: DataFrame ): - unity_ds = UnityTableDataSet( + unity_ds = ManagedTableDataSet( database="test", table="test_save_spark_schema", schema={ @@ -283,7 +283,7 @@ def test_save_schema_spark( def test_save_schema_pandas( self, subset_pandas_df: pd.DataFrame, subset_expected_df: DataFrame ): - unity_ds = UnityTableDataSet( + unity_ds = ManagedTableDataSet( database="test", table="test_save_pd_schema", schema={ @@ -306,7 +306,7 @@ def test_save_schema_pandas( dataframe_type="pandas", ) unity_ds.save(subset_pandas_df) - saved_ds = UnityTableDataSet( + saved_ds = ManagedTableDataSet( database="test", table="test_save_pd_schema", ) @@ -316,7 +316,7 @@ def test_save_schema_pandas( def test_save_overwrite( self, sample_spark_df: DataFrame, append_spark_df: DataFrame ): - unity_ds = UnityTableDataSet(database="test", table="test_save") + unity_ds = ManagedTableDataSet(database="test", table="test_save") unity_ds.save(sample_spark_df) unity_ds.save(append_spark_df) @@ -330,7 +330,7 @@ def test_save_append( append_spark_df: DataFrame, expected_append_spark_df: DataFrame, ): - unity_ds = UnityTableDataSet( + unity_ds = ManagedTableDataSet( database="test", table="test_save_append", write_mode="append" ) unity_ds.save(sample_spark_df) @@ -346,7 +346,7 @@ def test_save_upsert( upsert_spark_df: DataFrame, expected_upsert_spark_df: DataFrame, ): - unity_ds = UnityTableDataSet( + unity_ds = ManagedTableDataSet( database="test", table="test_save_upsert", write_mode="upsert", @@ -365,7 +365,7 @@ def test_save_upsert_multiple_primary( upsert_spark_df: DataFrame, expected_upsert_multiple_primary_spark_df: DataFrame, ): - unity_ds = UnityTableDataSet( + unity_ds = ManagedTableDataSet( database="test", table="test_save_upsert_multiple", write_mode="upsert", @@ -386,7 +386,7 @@ def test_save_upsert_mismatched_columns( sample_spark_df: DataFrame, mismatched_upsert_spark_df: DataFrame, ): - unity_ds = UnityTableDataSet( + unity_ds = ManagedTableDataSet( database="test", table="test_save_upsert_mismatch", write_mode="upsert", @@ -397,10 +397,10 @@ def test_save_upsert_mismatched_columns( unity_ds.save(mismatched_upsert_spark_df) def test_load_spark(self, sample_spark_df: DataFrame): - unity_ds = UnityTableDataSet(database="test", table="test_load_spark") + unity_ds = ManagedTableDataSet(database="test", table="test_load_spark") unity_ds.save(sample_spark_df) - delta_ds = UnityTableDataSet(database="test", table="test_load_spark") + delta_ds = ManagedTableDataSet(database="test", table="test_load_spark") delta_table = delta_ds.load() assert ( @@ -409,23 +409,23 @@ def test_load_spark(self, sample_spark_df: DataFrame): ) def test_load_spark_no_version(self, sample_spark_df: DataFrame): - unity_ds = UnityTableDataSet(database="test", table="test_load_spark") + unity_ds = ManagedTableDataSet(database="test", table="test_load_spark") unity_ds.save(sample_spark_df) - delta_ds = UnityTableDataSet( + delta_ds = ManagedTableDataSet( database="test", table="test_load_spark", version=2 ) with pytest.raises(VersionNotFoundError): _ = delta_ds.load() def test_load_version(self, sample_spark_df: DataFrame, append_spark_df: DataFrame): - unity_ds = UnityTableDataSet( + unity_ds = ManagedTableDataSet( database="test", table="test_load_version", write_mode="append" ) unity_ds.save(sample_spark_df) unity_ds.save(append_spark_df) - loaded_ds = UnityTableDataSet( + loaded_ds = ManagedTableDataSet( database="test", table="test_load_version", version=0 ) loaded_df = loaded_ds.load() @@ -433,12 +433,12 @@ def test_load_version(self, sample_spark_df: DataFrame, append_spark_df: DataFra assert loaded_df.exceptAll(sample_spark_df).count() == 0 def test_load_pandas(self, sample_pandas_df: pd.DataFrame): - unity_ds = UnityTableDataSet( + unity_ds = ManagedTableDataSet( database="test", table="test_load_pandas", dataframe_type="pandas" ) unity_ds.save(sample_pandas_df) - pandas_ds = UnityTableDataSet( + pandas_ds = ManagedTableDataSet( database="test", table="test_load_pandas", dataframe_type="pandas" ) pandas_df = pandas_ds.load().sort_values("name", ignore_index=True) From f2ea2558e1202bd8ecd2a467df7ef77aede3cc4d Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Thu, 23 Feb 2023 13:49:30 -0500 Subject: [PATCH 03/40] adding mlflow connectors Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- .../kedro_datasets/databricks/__init__.py | 3 +- .../databricks/mlflow/__init__.py | 6 + .../databricks/mlflow/artifact.py | 133 +++++++++++++++ .../databricks/mlflow/common.py | 89 ++++++++++ .../databricks/mlflow/dataset.py | 80 +++++++++ .../databricks/mlflow/flavors/__init__.py | 0 .../mlflow/flavors/kedro_dataset_flavor.py | 154 ++++++++++++++++++ .../databricks/mlflow/metrics.py | 93 +++++++++++ .../kedro_datasets/databricks/mlflow/model.py | 75 +++++++++ .../databricks/mlflow/model_metadata.py | 49 ++++++ .../kedro_datasets/databricks/mlflow/tags.py | 94 +++++++++++ .../databricks/unity/__init__.py | 1 + .../managed_table_dataset.py} | 1 + kedro-datasets/setup.py | 2 +- 14 files changed, 778 insertions(+), 2 deletions(-) create mode 100644 kedro-datasets/kedro_datasets/databricks/mlflow/__init__.py create mode 100644 kedro-datasets/kedro_datasets/databricks/mlflow/artifact.py create mode 100644 kedro-datasets/kedro_datasets/databricks/mlflow/common.py create mode 100644 kedro-datasets/kedro_datasets/databricks/mlflow/dataset.py create mode 100644 kedro-datasets/kedro_datasets/databricks/mlflow/flavors/__init__.py create mode 100644 kedro-datasets/kedro_datasets/databricks/mlflow/flavors/kedro_dataset_flavor.py create mode 100644 kedro-datasets/kedro_datasets/databricks/mlflow/metrics.py create mode 100644 kedro-datasets/kedro_datasets/databricks/mlflow/model.py create mode 100644 kedro-datasets/kedro_datasets/databricks/mlflow/model_metadata.py create mode 100644 kedro-datasets/kedro_datasets/databricks/mlflow/tags.py create mode 100644 kedro-datasets/kedro_datasets/databricks/unity/__init__.py rename kedro-datasets/kedro_datasets/databricks/{unity.py => unity/managed_table_dataset.py} (99%) diff --git a/kedro-datasets/kedro_datasets/databricks/__init__.py b/kedro-datasets/kedro_datasets/databricks/__init__.py index 313f3bdba..ec9d4b45d 100644 --- a/kedro-datasets/kedro_datasets/databricks/__init__.py +++ b/kedro-datasets/kedro_datasets/databricks/__init__.py @@ -1,8 +1,9 @@ """Provides interface to Unity Catalog Tables.""" -__all__ = ["ManagedTableDataSet"] +__all__ = ["ManagedTableDataSet", "MLFlowModel", "MLFlowArtifact", "MLFlowDataSet", "MLFlowMetrics", "MLFlowModelMetadata", "MLFlowTags"] from contextlib import suppress with suppress(ImportError): from .unity import ManagedTableDataSet + from .mlflow import MLFlowModel, MLFlowArtifact, MLFlowDataSet, MLFlowMetrics, MLFlowModelMetadata, MLFlowTags diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/__init__.py b/kedro-datasets/kedro_datasets/databricks/mlflow/__init__.py new file mode 100644 index 000000000..f4cc1567a --- /dev/null +++ b/kedro-datasets/kedro_datasets/databricks/mlflow/__init__.py @@ -0,0 +1,6 @@ +from .artifact import MLFlowArtifact +from .dataset import MLFlowDataSet +from .metrics import MLFlowMetrics +from .model_metadata import MLFlowModelMetadata +from .tags import MLFlowTags +from .model import MLFlowModel \ No newline at end of file diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/artifact.py b/kedro-datasets/kedro_datasets/databricks/mlflow/artifact.py new file mode 100644 index 000000000..15691db43 --- /dev/null +++ b/kedro-datasets/kedro_datasets/databricks/mlflow/artifact.py @@ -0,0 +1,133 @@ +import logging +import os +from pathlib import Path +from tempfile import mkdtemp +from typing import Any, Dict + +import mlflow +from kedro.io.core import AbstractDataSet +from kedro.utils import load_obj as load_dataset +from mlflow.exceptions import MlflowException +from mlflow.tracking.artifact_utils import _download_artifact_from_uri + +from .common import MLFLOW_RUN_ID_ENV_VAR, ModelOpsException + +logger = logging.getLogger(__name__) + + +class MLFlowArtifact(AbstractDataSet): + def __init__( + self, + dataset_name: str, + dataset_type: str, + dataset_args: Dict[str, Any] = None, + *, + file_suffix: str, + run_id: str = None, + registered_model_name: str = None, + registered_model_version: str = None, + ): + """ + Log arbitrary Kedro datasets as mlflow artifacts + + Args: + dataset_name: dataset name as it should appear on mlflow run + dataset_type: full kedro dataset class name (incl. module) + dataset_args: kedro dataset args + file_suffix: file extension as it should appear on mlflow run + run_id: mlflow run-id, this should only be used when loading a + dataset saved from run which is different from active run + registered_model_name: mlflow registered model name, this should + only be used when loading an artifact linked to a model of + interest (i.e. back tracing atifacts from the run corresponding + to the model) + registered_model_version: mlflow registered model name, should be + used in combination with `registered_model_name` + + `run_id` and `registered_model_name` can't be specified together. + """ + if None in (registered_model_name, registered_model_version): + if registered_model_name or registered_model_version: + raise ModelOpsException( + "'registered_model_name' and " + "'registered_model_version' should be " + "set together" + ) + + if run_id and registered_model_name: + raise ModelOpsException( + "'run_id' cannot be passed when " "'registered_model_name' is set" + ) + + self._dataset_name = dataset_name + self._dataset_type = dataset_type + self._dataset_args = dataset_args or {} + self._file_suffix = file_suffix + self._run_id = run_id or os.environ.get(MLFLOW_RUN_ID_ENV_VAR) + self._registered_model_name = registered_model_name + self._registered_model_version = registered_model_version + + self._artifact_path = f"{dataset_name}{self._file_suffix}" + + self._filepath = Path(mkdtemp()) / self._artifact_path + + if registered_model_name: + self._version = f"{registered_model_name}/{registered_model_version}" + else: + self._version = run_id + + def _save(self, data: Any) -> None: + cls = load_dataset(self._dataset_type) + ds = cls(filepath=self._filepath.as_posix(), **self._dataset_args) + ds.save(data) + + filepath = self._filepath.as_posix() + if os.path.isdir(filepath): + mlflow.log_artifacts(self._filepath.as_posix(), self._artifact_path) + elif os.path.isfile(filepath): + mlflow.log_artifact(self._filepath.as_posix()) + else: + raise RuntimeError("cls.save() didn't work. Unexpected error.") + + run_id = mlflow.active_run().info.run_id + if self._version is not None: + logger.warning( + f"Ignoring version {self._version} set " + f"earlier, will use version='{run_id}' for loading" + ) + self._version = run_id + + def _load(self) -> Any: + if self._version is None: + msg = ( + "Could not determine the version to load. " + "Please specify either 'run_id' or 'registered_model_name' " + "along with 'registered_model_version' explicitly in " + "MLFlowArtifact constructor" + ) + raise MlflowException(msg) + + if "/" in self._version: + model_uri = f"models:/{self._version}" + model = mlflow.pyfunc.load_model(model_uri) + run_id = model._model_meta.run_id + else: + run_id = self._version + + local_path = _download_artifact_from_uri( + f"runs:/{run_id}/{self._artifact_path}" + ) + + cls = load_dataset(self._dataset_type) + ds = cls(filepath=local_path, **self._dataset_args) + return ds.load() + + def _describe(self) -> Dict[str, Any]: + return dict( + dataset_name=self._dataset_name, + dataset_type=self._dataset_type, + dataset_args=self._dataset_args, + file_suffix=self._file_suffix, + registered_model_name=self._registered_model_name, + registered_model_version=self._registered_model_version, + ) diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/common.py b/kedro-datasets/kedro_datasets/databricks/mlflow/common.py new file mode 100644 index 000000000..af102d6b3 --- /dev/null +++ b/kedro-datasets/kedro_datasets/databricks/mlflow/common.py @@ -0,0 +1,89 @@ +import mlflow +from mlflow.tracking import MlflowClient + +MLFLOW_RUN_ID_ENV_VAR = "mlflow_run_id" + + +def parse_model_uri(model_uri): + parts = model_uri.split("/") + + if len(parts) < 2 or len(parts) > 3: + raise ValueError( + f"model uri should have the format " + f"'models:/' or " + f"'models://', got {model_uri}" + ) + + if parts[0] == "models:": + protocol = "models" + else: + raise ValueError("model uri should start with `models:/`, got %s", model_uri) + + name = parts[1] + + client = MlflowClient() + if len(parts) == 2: + results = client.search_model_versions(f"name='{name}'") + sorted_results = sorted( + results, + key=lambda modelversion: modelversion.creation_timestamp, + reverse=True, + ) + latest_version = sorted_results[0].version + version = latest_version + else: + version = parts[2] + if version in ["Production", "Staging", "Archived"]: + results = client.get_latest_versions(name, stages=[version]) + if len(results) > 0: + version = results[0].version + else: + version = None + + return protocol, name, version + + +def promote_model(model_name, model_version, stage): + import datetime + + now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + client = MlflowClient() + + new_model_uri = f"models:/{model_name}/{model_version}" + _, _, new_model_version = parse_model_uri(new_model_uri) + new_model = mlflow.pyfunc.load_model(new_model_uri) + new_model_runid = new_model._model_meta.run_id + + msg = f"```Promoted version {model_version} to {stage}, at {now}```" + client.set_tag(new_model_runid, "mlflow.note.content", msg) + client.set_tag(new_model_runid, "Promoted at", now) + + results = client.get_latest_versions(model_name, stages=[stage]) + if len(results) > 0: + old_model_uri = f"models:/{model_name}/{stage}" + _, _, old_model_version = parse_model_uri(old_model_uri) + old_model = mlflow.pyfunc.load_model(old_model_uri) + old_model_runid = old_model._model_meta.run_id + + client.set_tag( + old_model._model_meta.run_id, + "mlflow.note.content", + f"```Replaced by version {new_model_version}, at {now}```", + ) + client.set_tag(old_model_runid, "Retired at", now) + client.set_tag(old_model_runid, "Replaced by", new_model_version) + + client.set_tag(new_model_runid, "Replaces", old_model_version) + + client.transition_model_version_stage( + name=model_name, version=old_model_version, stage="Archived" + ) + + client.transition_model_version_stage( + name=model_name, version=new_model_version, stage=stage + ) + + +class ModelOpsException(Exception): + pass diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/dataset.py b/kedro-datasets/kedro_datasets/databricks/mlflow/dataset.py new file mode 100644 index 000000000..ee0a1e0ed --- /dev/null +++ b/kedro-datasets/kedro_datasets/databricks/mlflow/dataset.py @@ -0,0 +1,80 @@ +import importlib +import logging +from typing import Any, Dict + +from kedro.io.core import AbstractDataSet + +from .common import ModelOpsException, parse_model_uri + +logger = logging.getLogger(__name__) + + +class MLFlowDataSet(AbstractDataSet): + def __init__( + self, + flavor: str, + dataset_name: str = None, + dataset_type: str = None, + dataset_args: Dict[str, Any] = None, + *, + file_suffix: str = None, + load_version: str = None, + ): + self._flavor = flavor + self._dataset_name = dataset_name + self._dataset_type = dataset_type + self._dataset_args = dataset_args or {} + self._file_suffix = file_suffix + self._load_version = load_version + + def _save(self, model: Any) -> None: + if self._load_version is not None: + msg = ( + f"Trying to save an MLFlowDataSet::{self._describe} which " + f"was initialized with load_version={self._load_version}. " + f"This can lead to inconsistency between saved and loaded " + f"versions, therefore disallowed. Please create separate " + f"catalog entries for saved and loaded datasets." + ) + raise ModelOpsException(msg) + + importlib.import_module(self._flavor).log_model( + model, + self._dataset_name, + registered_model_name=self._dataset_name, + dataset_type=self._dataset_type, + dataset_args=self._dataset_args, + file_suffix=self._file_suffix, + ) + + def _load(self) -> Any: + *_, latest_version = parse_model_uri(f"models:/{self._dataset_name}") + + dataset_version = self._load_version or latest_version + *_, dataset_version = parse_model_uri( + f"models:/{self._dataset_name}/{dataset_version}" + ) + + logger.info(f"Loading model '{self._dataset_name}' version '{dataset_version}'") + + if dataset_version != latest_version: + logger.warning(f"Newer version {latest_version} exists in repo") + + model = importlib.import_module(self._flavor).load_model( + f"models:/{self._dataset_name}/{dataset_version}", + dataset_type=self._dataset_type, + dataset_args=self._dataset_args, + file_suffix=self._file_suffix, + ) + + return model + + def _describe(self) -> Dict[str, Any]: + return dict( + flavor=self._flavor, + dataset_name=self._dataset_name, + dataset_type=self._dataset_type, + dataset_args=self._dataset_args, + file_suffix=self._file_suffix, + load_version=self._load_version, + ) diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/flavors/__init__.py b/kedro-datasets/kedro_datasets/databricks/mlflow/flavors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/flavors/kedro_dataset_flavor.py b/kedro-datasets/kedro_datasets/databricks/mlflow/flavors/kedro_dataset_flavor.py new file mode 100644 index 000000000..e0a43a1b0 --- /dev/null +++ b/kedro-datasets/kedro_datasets/databricks/mlflow/flavors/kedro_dataset_flavor.py @@ -0,0 +1,154 @@ +import os +import sys +from pathlib import Path +from typing import Any, Dict, Union + +import kedro +import yaml +from kedro.utils import load_obj as load_dataset +from mlflow import pyfunc +from mlflow.exceptions import MlflowException +from mlflow.models import Model +from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS +from mlflow.tracking.artifact_utils import _download_artifact_from_uri +from mlflow.utils.environment import _mlflow_conda_env +from mlflow.utils.model_utils import _get_flavor_configuration + +FLAVOR_NAME = "kedro_dataset" + + +DEFAULT_CONDA_ENV = _mlflow_conda_env( + additional_conda_deps=["kedro[all]={}".format(kedro.__version__)], + additional_pip_deps=None, + additional_conda_channels=None, +) + + +def save_model( + data: Any, + path: str, + conda_env: Union[str, Dict[str, Any]] = None, + mlflow_model: Model = Model(), + *, + dataset_type: str, + dataset_args: Dict[str, Any], + file_suffix: str, +): + if os.path.exists(path): + raise RuntimeError("Path '{}' already exists".format(path)) + os.makedirs(path) + + model_data_subpath = f"data.{file_suffix}" + model_data_path = os.path.join(path, model_data_subpath) + + cls = load_dataset(dataset_type) + ds = cls(filepath=model_data_path, **dataset_args) + ds.save(data) + + conda_env_subpath = "conda.yaml" + if conda_env is None: + conda_env = DEFAULT_CONDA_ENV + elif not isinstance(conda_env, dict): + with open(conda_env, "r") as f: + conda_env = yaml.safe_load(f) + with open(os.path.join(path, conda_env_subpath), "w") as f: + yaml.safe_dump(conda_env, stream=f, default_flow_style=False) + + pyfunc.add_to_model( + mlflow_model, + loader_module=__name__, + data=model_data_subpath, + env=conda_env_subpath, + ) + + mlflow_model.add_flavor( + FLAVOR_NAME, + data=model_data_subpath, + dataset_type=dataset_type, + dataset_args=dataset_args, + file_suffix=file_suffix, + ) + mlflow_model.save(os.path.join(path, "MLmodel")) + + +def log_model( + model: Any, + artifact_path: str, + conda_env: Dict[str, Any] = None, + registered_model_name: str = None, + await_registration_for: int = DEFAULT_AWAIT_MAX_SLEEP_SECONDS, + *, + dataset_type: str, + dataset_args: Dict[str, Any], + file_suffix: str, +): + return Model.log( + artifact_path=artifact_path, + flavor=sys.modules[__name__], + registered_model_name=registered_model_name, + await_registration_for=await_registration_for, + data=model, + conda_env=conda_env, + dataset_type=dataset_type, + dataset_args=dataset_args, + file_suffix=file_suffix, + ) + + +def _load_model_from_local_file( + local_path: str, + *, + dataset_type: str = None, + dataset_args: Dict[str, Any] = None, + file_suffix: str = None, +): + if dataset_type is not None: + model_data_subpath = f"data.{file_suffix}" + data_path = os.path.join(local_path, model_data_subpath) + else: + flavor_conf = _get_flavor_configuration( + model_path=local_path, flavor_name=FLAVOR_NAME + ) + data_path = os.path.join(local_path, flavor_conf["data"]) + dataset_type = flavor_conf["dataset_type"] + dataset_args = flavor_conf["dataset_args"] + + cls = load_dataset(dataset_type) + ds = cls(filepath=data_path, **dataset_args) + return ds.load() + + +def load_model( + model_uri: str, + *, + dataset_type: str = None, + dataset_args: Dict[str, Any] = None, + file_suffix: str = None, +): + if dataset_type is not None or dataset_args is not None or file_suffix is not None: + assert ( + dataset_type is not None + and dataset_args is not None + and file_suffix is not None + ), ("Please set 'dataset_type', " "'dataset_args' and 'file_suffix'") + + local_path = _download_artifact_from_uri(model_uri) + return _load_model_from_local_file( + local_path, + dataset_type=dataset_type, + dataset_args=dataset_args, + file_suffix=file_suffix, + ) + + +def _load_pyfunc(model_file: str): + local_path = Path(model_file).parent.absolute() + model = _load_model_from_local_file(local_path) + if not hasattr(model, "predict"): + try: + setattr(model, "predict", None) + except AttributeError: + raise MlflowException( + f"`pyfunc` flavor not supported, use " f"{__name__}.load instead" + ) + return model diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/metrics.py b/kedro-datasets/kedro_datasets/databricks/mlflow/metrics.py new file mode 100644 index 000000000..1c7760375 --- /dev/null +++ b/kedro-datasets/kedro_datasets/databricks/mlflow/metrics.py @@ -0,0 +1,93 @@ +import logging +from typing import Any, Dict, Union + +import mlflow +from kedro.io.core import AbstractDataSet +from mlflow.exceptions import MlflowException +from mlflow.tracking import MlflowClient + +from .common import ModelOpsException + +logger = logging.getLogger(__name__) + + +class MLFlowMetrics(AbstractDataSet): + def __init__( + self, + prefix: str = None, + run_id: str = None, + registered_model_name: str = None, + registered_model_version: str = None, + ): + if None in (registered_model_name, registered_model_version): + if registered_model_name or registered_model_version: + raise ModelOpsException( + "'registered_model_name' and " + "'registered_model_version' should be " + "set together" + ) + + if run_id and registered_model_name: + raise ModelOpsException( + "'run_id' cannot be passed when " "'registered_model_name' is set" + ) + + self._prefix = prefix + self._run_id = run_id + self._registered_model_name = registered_model_name + self._registered_model_version = registered_model_version + + if registered_model_name: + self._version = f"{registered_model_name}/{registered_model_version}" + else: + self._version = run_id + + def _save(self, metrics: Dict[str, Union[str, float, int]]) -> None: + if self._prefix is not None: + metrics = {f"{self._prefix}_{key}": value for key, value in metrics.items()} + mlflow.log_metrics(metrics) + + run_id = mlflow.active_run().info.run_id + if self._version is not None: + logger.warning( + f"Ignoring version {self._version.save} set " + f"earlier, will use version='{run_id}' for loading" + ) + self._version = run_id + + def _load(self) -> Any: + if self._version is None: + msg = ( + "Could not determine the version to load. " + "Please specify either 'run_id' or 'registered_model_name' " + "along with 'registered_model_version' explicitly in " + "MLFlowMetrics constructor" + ) + raise MlflowException(msg) + + client = MlflowClient() + + if "/" in self._version: + model_uri = f"models:/{self._version}" + model = mlflow.pyfunc.load_model(model_uri) + run_id = model._model_meta.run_id + else: + run_id = self._version + + run = client.get_run(run_id) + metrics = run.data.metrics + if self._prefix is not None: + metrics = { + key[len(self._prefix) + 1 :]: value + for key, value in metrics.items() + if key[: len(self._prefix)] == self._prefix + } + return metrics + + def _describe(self) -> Dict[str, Any]: + return dict( + prefix=self._prefix, + run_id=self._run_id, + registered_model_name=self._registered_model_name, + registered_model_version=self._registered_model_version, + ) diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/model.py b/kedro-datasets/kedro_datasets/databricks/mlflow/model.py new file mode 100644 index 000000000..c5f2356a2 --- /dev/null +++ b/kedro-datasets/kedro_datasets/databricks/mlflow/model.py @@ -0,0 +1,75 @@ +import importlib +import logging +from typing import Any, Dict + +from kedro.io.core import AbstractDataSet +from mlflow.models.signature import ModelSignature + +from .common import ModelOpsException, parse_model_uri + +logger = logging.getLogger(__name__) + + +class MLFlowModel(AbstractDataSet): + def __init__( + self, + flavor: str, + model_name: str, + signature: Dict[str, Dict[str, str]] = None, + input_example: Dict[str, Any] = None, + load_version: str = None, + ): + self._flavor = flavor + self._model_name = model_name + + if signature: + self._signature = ModelSignature.from_dict(signature) + else: + self._signature = None + self._input_example = input_example + + self._load_version = load_version + + def _save(self, model: Any) -> None: + if self._load_version is not None: + msg = ( + f"Trying to save an MLFlowModel::{self._describe} which " + f"was initialized with load_version={self._load_version}. " + f"This can lead to inconsistency between saved and loaded " + f"versions, therefore disallowed. Please create separate " + f"catalog entries for saved and loaded datasets." + ) + raise ModelOpsException(msg) + + importlib.import_module(self._flavor).log_model( + model, + self._model_name, + registered_model_name=self._model_name, + signature=self._signature, + input_example=self._input_example, + ) + + def _load(self) -> Any: + *_, latest_version = parse_model_uri(f"models:/{self._model_name}") + + model_version = self._load_version or latest_version + + logger.info(f"Loading model '{self._model_name}' version '{model_version}'") + + if model_version != latest_version: + logger.warning(f"Newer version {latest_version} exists in repo") + + model = importlib.import_module(self._flavor).load_model( + f"models:/{self._model_name}/{model_version}" + ) + + return model + + def _describe(self) -> Dict[str, Any]: + return dict( + flavor=self._flavor, + model_name=self._model_name, + signature=self._signature, + input_example=self._input_example, + load_version=self._load_version, + ) diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/model_metadata.py b/kedro-datasets/kedro_datasets/databricks/mlflow/model_metadata.py new file mode 100644 index 000000000..3c160cec4 --- /dev/null +++ b/kedro-datasets/kedro_datasets/databricks/mlflow/model_metadata.py @@ -0,0 +1,49 @@ +import logging +from typing import Any, Dict, Union + +import mlflow +from kedro.io.core import AbstractDataSet + +from .common import ModelOpsException, parse_model_uri + +logger = logging.getLogger(__name__) + + +class MLFlowModelMetadata(AbstractDataSet): + def __init__( + self, registered_model_name: str, registered_model_version: str = None + ): + self._model_name = registered_model_name + self._model_version = registered_model_version + + def _save(self, tags: Dict[str, Union[str, float, int]]) -> None: + raise NotImplementedError() + + def _load(self) -> Any: + if self._model_version is None: + model_uri = f"models:/{self._model_name}" + else: + model_uri = f"models:/{self._model_name}/{self._model_version}" + _, _, load_version = parse_model_uri(model_uri) + + if load_version is None: + raise ModelOpsException( + f"No model with version " f"'{self._model_version}'" + ) + + pyfunc_model = mlflow.pyfunc.load_model( + f"models:/{self._model_name}/{load_version}" + ) + all_metadata = pyfunc_model._model_meta + model_metadata = { + "model_name": self._model_name, + "model_version": int(load_version), + "run_id": all_metadata.run_id, + } + return model_metadata + + def _describe(self) -> Dict[str, Any]: + return dict( + registered_model_name=self._model_name, + registered_model_version=self._model_version, + ) diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/tags.py b/kedro-datasets/kedro_datasets/databricks/mlflow/tags.py new file mode 100644 index 000000000..153810ae4 --- /dev/null +++ b/kedro-datasets/kedro_datasets/databricks/mlflow/tags.py @@ -0,0 +1,94 @@ +import logging +from typing import Any, Dict, Union + +import mlflow +from kedro.io.core import AbstractDataSet +from mlflow.exceptions import MlflowException +from mlflow.tracking import MlflowClient + +from .common import ModelOpsException + +logger = logging.getLogger(__name__) + + +class MLFlowTags(AbstractDataSet): + def __init__( + self, + prefix: str = None, + run_id: str = None, + registered_model_name: str = None, + registered_model_version: str = None, + ): + if None in (registered_model_name, registered_model_version): + if registered_model_name or registered_model_version: + raise ModelOpsException( + "'registered_model_name' and " + "'registered_model_version' should be " + "set together" + ) + + if run_id and registered_model_name: + raise ModelOpsException( + "'run_id' cannot be passed when " "'registered_model_name' is set" + ) + + self._prefix = prefix + self._run_id = run_id + self._registered_model_name = registered_model_name + self._registered_model_version = registered_model_version + + if registered_model_name: + self._version = f"{registered_model_name}/{registered_model_version}" + else: + self._version = run_id + + def _save(self, tags: Dict[str, Union[str, float, int]]) -> None: + if self._prefix is not None: + tags = {f"{self._prefix}_{key}": value for key, value in tags.items()} + + mlflow.set_tags(tags) + + run_id = mlflow.active_run().info.run_id + if self._version is not None: + logger.warning( + f"Ignoring version {self._version.save} set " + f"earlier, will use version='{run_id}' for loading" + ) + self._version = run_id + + def _load(self) -> Any: + if self._version is None: + msg = ( + "Could not determine the version to load. " + "Please specify either 'run_id' or 'registered_model_name' " + "along with 'registered_model_version' explicitly in " + "MLFlowTags constructor" + ) + raise MlflowException(msg) + + client = MlflowClient() + + if "/" in self._version: + model_uri = f"models:/{self._version}" + model = mlflow.pyfunc.load_model(model_uri) + run_id = model._model_meta.run_id + else: + run_id = self._version + + run = client.get_run(run_id) + tags = run.data.tags + if self._prefix is not None: + tags = { + key[len(self._prefix) + 1 :]: value + for key, value in tags.items() + if key[: len(self._prefix)] == self._prefix + } + return tags + + def _describe(self) -> Dict[str, Any]: + return dict( + prefix=self._prefix, + run_id=self._run_id, + registered_model_name=self._registered_model_name, + registered_model_version=self._registered_model_version, + ) diff --git a/kedro-datasets/kedro_datasets/databricks/unity/__init__.py b/kedro-datasets/kedro_datasets/databricks/unity/__init__.py new file mode 100644 index 000000000..ab452e146 --- /dev/null +++ b/kedro-datasets/kedro_datasets/databricks/unity/__init__.py @@ -0,0 +1 @@ +from .managed_table_dataset import ManagedTableDataSet \ No newline at end of file diff --git a/kedro-datasets/kedro_datasets/databricks/unity.py b/kedro-datasets/kedro_datasets/databricks/unity/managed_table_dataset.py similarity index 99% rename from kedro-datasets/kedro_datasets/databricks/unity.py rename to kedro-datasets/kedro_datasets/databricks/unity/managed_table_dataset.py index b6270f58c..b46122197 100644 --- a/kedro-datasets/kedro_datasets/databricks/unity.py +++ b/kedro-datasets/kedro_datasets/databricks/unity/managed_table_dataset.py @@ -182,6 +182,7 @@ def _describe(self) -> Dict[str, str]: dataframe_type=self._dataframe_type, primary_key=self._primary_key, version=self._version, + owner_group=self._owner_group, ) def _exists(self) -> bool: diff --git a/kedro-datasets/setup.py b/kedro-datasets/setup.py index 635127e20..16a8336dc 100644 --- a/kedro-datasets/setup.py +++ b/kedro-datasets/setup.py @@ -17,7 +17,7 @@ def _collect_requirements(requires): api_require = {"api.APIDataSet": ["requests~=2.20"]} biosequence_require = {"biosequence.BioSequenceDataSet": ["biopython~=1.73"]} dask_require = {"dask.ParquetDataSet": ["dask[complete]~=2021.10", "triad>=0.6.7, <1.0"]} -databricks_require = {"databricks.UnityTableDataSet": [SPARK]} +databricks_require = {"databricks.ManagedTableDataSet": [SPARK]} geopandas_require = { "geopandas.GeoJSONDataSet": ["geopandas>=0.6.0, <1.0", "pyproj~=3.0"] } From 9bb88c2fb85b960ac4deabe94c1d1661ae859544 Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Thu, 23 Feb 2023 17:21:24 -0500 Subject: [PATCH 04/40] fixing mlflow imports Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- kedro-datasets/kedro_datasets/databricks/__init__.py | 9 ++------- .../kedro_datasets/databricks/mlflow/__init__.py | 2 +- kedro-datasets/setup.py | 10 +++++++++- kedro-datasets/test_requirements.txt | 1 + 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/__init__.py b/kedro-datasets/kedro_datasets/databricks/__init__.py index ec9d4b45d..cba69d17c 100644 --- a/kedro-datasets/kedro_datasets/databricks/__init__.py +++ b/kedro-datasets/kedro_datasets/databricks/__init__.py @@ -1,9 +1,4 @@ """Provides interface to Unity Catalog Tables.""" -__all__ = ["ManagedTableDataSet", "MLFlowModel", "MLFlowArtifact", "MLFlowDataSet", "MLFlowMetrics", "MLFlowModelMetadata", "MLFlowTags"] - -from contextlib import suppress - -with suppress(ImportError): - from .unity import ManagedTableDataSet - from .mlflow import MLFlowModel, MLFlowArtifact, MLFlowDataSet, MLFlowMetrics, MLFlowModelMetadata, MLFlowTags +from .unity import ManagedTableDataSet +from .mlflow import MLFlowModel, MLFlowArtifact, MLFlowDataSet, MLFlowMetrics, MLFlowModelMetadata, MLFlowTags diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/__init__.py b/kedro-datasets/kedro_datasets/databricks/mlflow/__init__.py index f4cc1567a..1c3babc0f 100644 --- a/kedro-datasets/kedro_datasets/databricks/mlflow/__init__.py +++ b/kedro-datasets/kedro_datasets/databricks/mlflow/__init__.py @@ -3,4 +3,4 @@ from .metrics import MLFlowMetrics from .model_metadata import MLFlowModelMetadata from .tags import MLFlowTags -from .model import MLFlowModel \ No newline at end of file +from .model import MLFlowModel diff --git a/kedro-datasets/setup.py b/kedro-datasets/setup.py index 16a8336dc..ea64d9314 100644 --- a/kedro-datasets/setup.py +++ b/kedro-datasets/setup.py @@ -17,7 +17,15 @@ def _collect_requirements(requires): api_require = {"api.APIDataSet": ["requests~=2.20"]} biosequence_require = {"biosequence.BioSequenceDataSet": ["biopython~=1.73"]} dask_require = {"dask.ParquetDataSet": ["dask[complete]~=2021.10", "triad>=0.6.7, <1.0"]} -databricks_require = {"databricks.ManagedTableDataSet": [SPARK]} +databricks_require = { + "databricks.ManagedTableDataSet": [SPARK, PANDAS], + "databricks.MLFlowModel":[SPARK, PANDAS, "mlflow>=2.0.0"], + "databricks.MLFlowArtifact":[SPARK, PANDAS, "mlflow>=2.0.0"], + "databricks.MLFlowDataSet":[SPARK, PANDAS, "mlflow>=2.0.0"], + "databricks.MLFlowMetrics":[SPARK, PANDAS, "mlflow>=2.0.0"], + "databricks.MLFlowModelMetadata":[SPARK, PANDAS, "mlflow>=2.0.0"], + "databricks.MLFlowTags":[SPARK, PANDAS, "mlflow>=2.0.0"] +} geopandas_require = { "geopandas.GeoJSONDataSet": ["geopandas>=0.6.0, <1.0", "pyproj~=3.0"] } diff --git a/kedro-datasets/test_requirements.txt b/kedro-datasets/test_requirements.txt index 4d4954739..90faa0b02 100644 --- a/kedro-datasets/test_requirements.txt +++ b/kedro-datasets/test_requirements.txt @@ -24,6 +24,7 @@ lxml~=4.6 matplotlib>=3.0.3, <3.4; python_version < '3.10' # 3.4.0 breaks holoviews matplotlib>=3.5, <3.6; python_version == '3.10' memory_profiler>=0.50.0, <1.0 +mlflow==2.2.1 moto==1.3.7; python_version < '3.10' moto==3.0.4; python_version == '3.10' networkx~=2.4 From 20d20b57381d673cad63d494a9899436af1adf3e Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Wed, 8 Mar 2023 14:26:57 -0500 Subject: [PATCH 05/40] cleaned up mlflow for initial release Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- .../kedro_datasets/databricks/__init__.py | 1 - .../databricks/mlflow/__init__.py | 6 - .../databricks/mlflow/artifact.py | 133 --------------- .../databricks/mlflow/common.py | 89 ---------- .../databricks/mlflow/dataset.py | 80 --------- .../databricks/mlflow/flavors/__init__.py | 0 .../mlflow/flavors/kedro_dataset_flavor.py | 154 ------------------ .../databricks/mlflow/metrics.py | 93 ----------- .../kedro_datasets/databricks/mlflow/model.py | 75 --------- .../databricks/mlflow/model_metadata.py | 49 ------ .../kedro_datasets/databricks/mlflow/tags.py | 94 ----------- .../databricks/unity/managed_table_dataset.py | 29 ++-- kedro-datasets/tests/databricks/conftest.py | 1 - kedro-datasets/tests/databricks/test_unity.py | 7 +- 14 files changed, 16 insertions(+), 795 deletions(-) delete mode 100644 kedro-datasets/kedro_datasets/databricks/mlflow/__init__.py delete mode 100644 kedro-datasets/kedro_datasets/databricks/mlflow/artifact.py delete mode 100644 kedro-datasets/kedro_datasets/databricks/mlflow/common.py delete mode 100644 kedro-datasets/kedro_datasets/databricks/mlflow/dataset.py delete mode 100644 kedro-datasets/kedro_datasets/databricks/mlflow/flavors/__init__.py delete mode 100644 kedro-datasets/kedro_datasets/databricks/mlflow/flavors/kedro_dataset_flavor.py delete mode 100644 kedro-datasets/kedro_datasets/databricks/mlflow/metrics.py delete mode 100644 kedro-datasets/kedro_datasets/databricks/mlflow/model.py delete mode 100644 kedro-datasets/kedro_datasets/databricks/mlflow/model_metadata.py delete mode 100644 kedro-datasets/kedro_datasets/databricks/mlflow/tags.py diff --git a/kedro-datasets/kedro_datasets/databricks/__init__.py b/kedro-datasets/kedro_datasets/databricks/__init__.py index cba69d17c..7819a2e06 100644 --- a/kedro-datasets/kedro_datasets/databricks/__init__.py +++ b/kedro-datasets/kedro_datasets/databricks/__init__.py @@ -1,4 +1,3 @@ """Provides interface to Unity Catalog Tables.""" from .unity import ManagedTableDataSet -from .mlflow import MLFlowModel, MLFlowArtifact, MLFlowDataSet, MLFlowMetrics, MLFlowModelMetadata, MLFlowTags diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/__init__.py b/kedro-datasets/kedro_datasets/databricks/mlflow/__init__.py deleted file mode 100644 index 1c3babc0f..000000000 --- a/kedro-datasets/kedro_datasets/databricks/mlflow/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .artifact import MLFlowArtifact -from .dataset import MLFlowDataSet -from .metrics import MLFlowMetrics -from .model_metadata import MLFlowModelMetadata -from .tags import MLFlowTags -from .model import MLFlowModel diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/artifact.py b/kedro-datasets/kedro_datasets/databricks/mlflow/artifact.py deleted file mode 100644 index 15691db43..000000000 --- a/kedro-datasets/kedro_datasets/databricks/mlflow/artifact.py +++ /dev/null @@ -1,133 +0,0 @@ -import logging -import os -from pathlib import Path -from tempfile import mkdtemp -from typing import Any, Dict - -import mlflow -from kedro.io.core import AbstractDataSet -from kedro.utils import load_obj as load_dataset -from mlflow.exceptions import MlflowException -from mlflow.tracking.artifact_utils import _download_artifact_from_uri - -from .common import MLFLOW_RUN_ID_ENV_VAR, ModelOpsException - -logger = logging.getLogger(__name__) - - -class MLFlowArtifact(AbstractDataSet): - def __init__( - self, - dataset_name: str, - dataset_type: str, - dataset_args: Dict[str, Any] = None, - *, - file_suffix: str, - run_id: str = None, - registered_model_name: str = None, - registered_model_version: str = None, - ): - """ - Log arbitrary Kedro datasets as mlflow artifacts - - Args: - dataset_name: dataset name as it should appear on mlflow run - dataset_type: full kedro dataset class name (incl. module) - dataset_args: kedro dataset args - file_suffix: file extension as it should appear on mlflow run - run_id: mlflow run-id, this should only be used when loading a - dataset saved from run which is different from active run - registered_model_name: mlflow registered model name, this should - only be used when loading an artifact linked to a model of - interest (i.e. back tracing atifacts from the run corresponding - to the model) - registered_model_version: mlflow registered model name, should be - used in combination with `registered_model_name` - - `run_id` and `registered_model_name` can't be specified together. - """ - if None in (registered_model_name, registered_model_version): - if registered_model_name or registered_model_version: - raise ModelOpsException( - "'registered_model_name' and " - "'registered_model_version' should be " - "set together" - ) - - if run_id and registered_model_name: - raise ModelOpsException( - "'run_id' cannot be passed when " "'registered_model_name' is set" - ) - - self._dataset_name = dataset_name - self._dataset_type = dataset_type - self._dataset_args = dataset_args or {} - self._file_suffix = file_suffix - self._run_id = run_id or os.environ.get(MLFLOW_RUN_ID_ENV_VAR) - self._registered_model_name = registered_model_name - self._registered_model_version = registered_model_version - - self._artifact_path = f"{dataset_name}{self._file_suffix}" - - self._filepath = Path(mkdtemp()) / self._artifact_path - - if registered_model_name: - self._version = f"{registered_model_name}/{registered_model_version}" - else: - self._version = run_id - - def _save(self, data: Any) -> None: - cls = load_dataset(self._dataset_type) - ds = cls(filepath=self._filepath.as_posix(), **self._dataset_args) - ds.save(data) - - filepath = self._filepath.as_posix() - if os.path.isdir(filepath): - mlflow.log_artifacts(self._filepath.as_posix(), self._artifact_path) - elif os.path.isfile(filepath): - mlflow.log_artifact(self._filepath.as_posix()) - else: - raise RuntimeError("cls.save() didn't work. Unexpected error.") - - run_id = mlflow.active_run().info.run_id - if self._version is not None: - logger.warning( - f"Ignoring version {self._version} set " - f"earlier, will use version='{run_id}' for loading" - ) - self._version = run_id - - def _load(self) -> Any: - if self._version is None: - msg = ( - "Could not determine the version to load. " - "Please specify either 'run_id' or 'registered_model_name' " - "along with 'registered_model_version' explicitly in " - "MLFlowArtifact constructor" - ) - raise MlflowException(msg) - - if "/" in self._version: - model_uri = f"models:/{self._version}" - model = mlflow.pyfunc.load_model(model_uri) - run_id = model._model_meta.run_id - else: - run_id = self._version - - local_path = _download_artifact_from_uri( - f"runs:/{run_id}/{self._artifact_path}" - ) - - cls = load_dataset(self._dataset_type) - ds = cls(filepath=local_path, **self._dataset_args) - return ds.load() - - def _describe(self) -> Dict[str, Any]: - return dict( - dataset_name=self._dataset_name, - dataset_type=self._dataset_type, - dataset_args=self._dataset_args, - file_suffix=self._file_suffix, - registered_model_name=self._registered_model_name, - registered_model_version=self._registered_model_version, - ) diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/common.py b/kedro-datasets/kedro_datasets/databricks/mlflow/common.py deleted file mode 100644 index af102d6b3..000000000 --- a/kedro-datasets/kedro_datasets/databricks/mlflow/common.py +++ /dev/null @@ -1,89 +0,0 @@ -import mlflow -from mlflow.tracking import MlflowClient - -MLFLOW_RUN_ID_ENV_VAR = "mlflow_run_id" - - -def parse_model_uri(model_uri): - parts = model_uri.split("/") - - if len(parts) < 2 or len(parts) > 3: - raise ValueError( - f"model uri should have the format " - f"'models:/' or " - f"'models://', got {model_uri}" - ) - - if parts[0] == "models:": - protocol = "models" - else: - raise ValueError("model uri should start with `models:/`, got %s", model_uri) - - name = parts[1] - - client = MlflowClient() - if len(parts) == 2: - results = client.search_model_versions(f"name='{name}'") - sorted_results = sorted( - results, - key=lambda modelversion: modelversion.creation_timestamp, - reverse=True, - ) - latest_version = sorted_results[0].version - version = latest_version - else: - version = parts[2] - if version in ["Production", "Staging", "Archived"]: - results = client.get_latest_versions(name, stages=[version]) - if len(results) > 0: - version = results[0].version - else: - version = None - - return protocol, name, version - - -def promote_model(model_name, model_version, stage): - import datetime - - now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - - client = MlflowClient() - - new_model_uri = f"models:/{model_name}/{model_version}" - _, _, new_model_version = parse_model_uri(new_model_uri) - new_model = mlflow.pyfunc.load_model(new_model_uri) - new_model_runid = new_model._model_meta.run_id - - msg = f"```Promoted version {model_version} to {stage}, at {now}```" - client.set_tag(new_model_runid, "mlflow.note.content", msg) - client.set_tag(new_model_runid, "Promoted at", now) - - results = client.get_latest_versions(model_name, stages=[stage]) - if len(results) > 0: - old_model_uri = f"models:/{model_name}/{stage}" - _, _, old_model_version = parse_model_uri(old_model_uri) - old_model = mlflow.pyfunc.load_model(old_model_uri) - old_model_runid = old_model._model_meta.run_id - - client.set_tag( - old_model._model_meta.run_id, - "mlflow.note.content", - f"```Replaced by version {new_model_version}, at {now}```", - ) - client.set_tag(old_model_runid, "Retired at", now) - client.set_tag(old_model_runid, "Replaced by", new_model_version) - - client.set_tag(new_model_runid, "Replaces", old_model_version) - - client.transition_model_version_stage( - name=model_name, version=old_model_version, stage="Archived" - ) - - client.transition_model_version_stage( - name=model_name, version=new_model_version, stage=stage - ) - - -class ModelOpsException(Exception): - pass diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/dataset.py b/kedro-datasets/kedro_datasets/databricks/mlflow/dataset.py deleted file mode 100644 index ee0a1e0ed..000000000 --- a/kedro-datasets/kedro_datasets/databricks/mlflow/dataset.py +++ /dev/null @@ -1,80 +0,0 @@ -import importlib -import logging -from typing import Any, Dict - -from kedro.io.core import AbstractDataSet - -from .common import ModelOpsException, parse_model_uri - -logger = logging.getLogger(__name__) - - -class MLFlowDataSet(AbstractDataSet): - def __init__( - self, - flavor: str, - dataset_name: str = None, - dataset_type: str = None, - dataset_args: Dict[str, Any] = None, - *, - file_suffix: str = None, - load_version: str = None, - ): - self._flavor = flavor - self._dataset_name = dataset_name - self._dataset_type = dataset_type - self._dataset_args = dataset_args or {} - self._file_suffix = file_suffix - self._load_version = load_version - - def _save(self, model: Any) -> None: - if self._load_version is not None: - msg = ( - f"Trying to save an MLFlowDataSet::{self._describe} which " - f"was initialized with load_version={self._load_version}. " - f"This can lead to inconsistency between saved and loaded " - f"versions, therefore disallowed. Please create separate " - f"catalog entries for saved and loaded datasets." - ) - raise ModelOpsException(msg) - - importlib.import_module(self._flavor).log_model( - model, - self._dataset_name, - registered_model_name=self._dataset_name, - dataset_type=self._dataset_type, - dataset_args=self._dataset_args, - file_suffix=self._file_suffix, - ) - - def _load(self) -> Any: - *_, latest_version = parse_model_uri(f"models:/{self._dataset_name}") - - dataset_version = self._load_version or latest_version - *_, dataset_version = parse_model_uri( - f"models:/{self._dataset_name}/{dataset_version}" - ) - - logger.info(f"Loading model '{self._dataset_name}' version '{dataset_version}'") - - if dataset_version != latest_version: - logger.warning(f"Newer version {latest_version} exists in repo") - - model = importlib.import_module(self._flavor).load_model( - f"models:/{self._dataset_name}/{dataset_version}", - dataset_type=self._dataset_type, - dataset_args=self._dataset_args, - file_suffix=self._file_suffix, - ) - - return model - - def _describe(self) -> Dict[str, Any]: - return dict( - flavor=self._flavor, - dataset_name=self._dataset_name, - dataset_type=self._dataset_type, - dataset_args=self._dataset_args, - file_suffix=self._file_suffix, - load_version=self._load_version, - ) diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/flavors/__init__.py b/kedro-datasets/kedro_datasets/databricks/mlflow/flavors/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/flavors/kedro_dataset_flavor.py b/kedro-datasets/kedro_datasets/databricks/mlflow/flavors/kedro_dataset_flavor.py deleted file mode 100644 index e0a43a1b0..000000000 --- a/kedro-datasets/kedro_datasets/databricks/mlflow/flavors/kedro_dataset_flavor.py +++ /dev/null @@ -1,154 +0,0 @@ -import os -import sys -from pathlib import Path -from typing import Any, Dict, Union - -import kedro -import yaml -from kedro.utils import load_obj as load_dataset -from mlflow import pyfunc -from mlflow.exceptions import MlflowException -from mlflow.models import Model -from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS -from mlflow.tracking.artifact_utils import _download_artifact_from_uri -from mlflow.utils.environment import _mlflow_conda_env -from mlflow.utils.model_utils import _get_flavor_configuration - -FLAVOR_NAME = "kedro_dataset" - - -DEFAULT_CONDA_ENV = _mlflow_conda_env( - additional_conda_deps=["kedro[all]={}".format(kedro.__version__)], - additional_pip_deps=None, - additional_conda_channels=None, -) - - -def save_model( - data: Any, - path: str, - conda_env: Union[str, Dict[str, Any]] = None, - mlflow_model: Model = Model(), - *, - dataset_type: str, - dataset_args: Dict[str, Any], - file_suffix: str, -): - if os.path.exists(path): - raise RuntimeError("Path '{}' already exists".format(path)) - os.makedirs(path) - - model_data_subpath = f"data.{file_suffix}" - model_data_path = os.path.join(path, model_data_subpath) - - cls = load_dataset(dataset_type) - ds = cls(filepath=model_data_path, **dataset_args) - ds.save(data) - - conda_env_subpath = "conda.yaml" - if conda_env is None: - conda_env = DEFAULT_CONDA_ENV - elif not isinstance(conda_env, dict): - with open(conda_env, "r") as f: - conda_env = yaml.safe_load(f) - with open(os.path.join(path, conda_env_subpath), "w") as f: - yaml.safe_dump(conda_env, stream=f, default_flow_style=False) - - pyfunc.add_to_model( - mlflow_model, - loader_module=__name__, - data=model_data_subpath, - env=conda_env_subpath, - ) - - mlflow_model.add_flavor( - FLAVOR_NAME, - data=model_data_subpath, - dataset_type=dataset_type, - dataset_args=dataset_args, - file_suffix=file_suffix, - ) - mlflow_model.save(os.path.join(path, "MLmodel")) - - -def log_model( - model: Any, - artifact_path: str, - conda_env: Dict[str, Any] = None, - registered_model_name: str = None, - await_registration_for: int = DEFAULT_AWAIT_MAX_SLEEP_SECONDS, - *, - dataset_type: str, - dataset_args: Dict[str, Any], - file_suffix: str, -): - return Model.log( - artifact_path=artifact_path, - flavor=sys.modules[__name__], - registered_model_name=registered_model_name, - await_registration_for=await_registration_for, - data=model, - conda_env=conda_env, - dataset_type=dataset_type, - dataset_args=dataset_args, - file_suffix=file_suffix, - ) - - -def _load_model_from_local_file( - local_path: str, - *, - dataset_type: str = None, - dataset_args: Dict[str, Any] = None, - file_suffix: str = None, -): - if dataset_type is not None: - model_data_subpath = f"data.{file_suffix}" - data_path = os.path.join(local_path, model_data_subpath) - else: - flavor_conf = _get_flavor_configuration( - model_path=local_path, flavor_name=FLAVOR_NAME - ) - data_path = os.path.join(local_path, flavor_conf["data"]) - dataset_type = flavor_conf["dataset_type"] - dataset_args = flavor_conf["dataset_args"] - - cls = load_dataset(dataset_type) - ds = cls(filepath=data_path, **dataset_args) - return ds.load() - - -def load_model( - model_uri: str, - *, - dataset_type: str = None, - dataset_args: Dict[str, Any] = None, - file_suffix: str = None, -): - if dataset_type is not None or dataset_args is not None or file_suffix is not None: - assert ( - dataset_type is not None - and dataset_args is not None - and file_suffix is not None - ), ("Please set 'dataset_type', " "'dataset_args' and 'file_suffix'") - - local_path = _download_artifact_from_uri(model_uri) - return _load_model_from_local_file( - local_path, - dataset_type=dataset_type, - dataset_args=dataset_args, - file_suffix=file_suffix, - ) - - -def _load_pyfunc(model_file: str): - local_path = Path(model_file).parent.absolute() - model = _load_model_from_local_file(local_path) - if not hasattr(model, "predict"): - try: - setattr(model, "predict", None) - except AttributeError: - raise MlflowException( - f"`pyfunc` flavor not supported, use " f"{__name__}.load instead" - ) - return model diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/metrics.py b/kedro-datasets/kedro_datasets/databricks/mlflow/metrics.py deleted file mode 100644 index 1c7760375..000000000 --- a/kedro-datasets/kedro_datasets/databricks/mlflow/metrics.py +++ /dev/null @@ -1,93 +0,0 @@ -import logging -from typing import Any, Dict, Union - -import mlflow -from kedro.io.core import AbstractDataSet -from mlflow.exceptions import MlflowException -from mlflow.tracking import MlflowClient - -from .common import ModelOpsException - -logger = logging.getLogger(__name__) - - -class MLFlowMetrics(AbstractDataSet): - def __init__( - self, - prefix: str = None, - run_id: str = None, - registered_model_name: str = None, - registered_model_version: str = None, - ): - if None in (registered_model_name, registered_model_version): - if registered_model_name or registered_model_version: - raise ModelOpsException( - "'registered_model_name' and " - "'registered_model_version' should be " - "set together" - ) - - if run_id and registered_model_name: - raise ModelOpsException( - "'run_id' cannot be passed when " "'registered_model_name' is set" - ) - - self._prefix = prefix - self._run_id = run_id - self._registered_model_name = registered_model_name - self._registered_model_version = registered_model_version - - if registered_model_name: - self._version = f"{registered_model_name}/{registered_model_version}" - else: - self._version = run_id - - def _save(self, metrics: Dict[str, Union[str, float, int]]) -> None: - if self._prefix is not None: - metrics = {f"{self._prefix}_{key}": value for key, value in metrics.items()} - mlflow.log_metrics(metrics) - - run_id = mlflow.active_run().info.run_id - if self._version is not None: - logger.warning( - f"Ignoring version {self._version.save} set " - f"earlier, will use version='{run_id}' for loading" - ) - self._version = run_id - - def _load(self) -> Any: - if self._version is None: - msg = ( - "Could not determine the version to load. " - "Please specify either 'run_id' or 'registered_model_name' " - "along with 'registered_model_version' explicitly in " - "MLFlowMetrics constructor" - ) - raise MlflowException(msg) - - client = MlflowClient() - - if "/" in self._version: - model_uri = f"models:/{self._version}" - model = mlflow.pyfunc.load_model(model_uri) - run_id = model._model_meta.run_id - else: - run_id = self._version - - run = client.get_run(run_id) - metrics = run.data.metrics - if self._prefix is not None: - metrics = { - key[len(self._prefix) + 1 :]: value - for key, value in metrics.items() - if key[: len(self._prefix)] == self._prefix - } - return metrics - - def _describe(self) -> Dict[str, Any]: - return dict( - prefix=self._prefix, - run_id=self._run_id, - registered_model_name=self._registered_model_name, - registered_model_version=self._registered_model_version, - ) diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/model.py b/kedro-datasets/kedro_datasets/databricks/mlflow/model.py deleted file mode 100644 index c5f2356a2..000000000 --- a/kedro-datasets/kedro_datasets/databricks/mlflow/model.py +++ /dev/null @@ -1,75 +0,0 @@ -import importlib -import logging -from typing import Any, Dict - -from kedro.io.core import AbstractDataSet -from mlflow.models.signature import ModelSignature - -from .common import ModelOpsException, parse_model_uri - -logger = logging.getLogger(__name__) - - -class MLFlowModel(AbstractDataSet): - def __init__( - self, - flavor: str, - model_name: str, - signature: Dict[str, Dict[str, str]] = None, - input_example: Dict[str, Any] = None, - load_version: str = None, - ): - self._flavor = flavor - self._model_name = model_name - - if signature: - self._signature = ModelSignature.from_dict(signature) - else: - self._signature = None - self._input_example = input_example - - self._load_version = load_version - - def _save(self, model: Any) -> None: - if self._load_version is not None: - msg = ( - f"Trying to save an MLFlowModel::{self._describe} which " - f"was initialized with load_version={self._load_version}. " - f"This can lead to inconsistency between saved and loaded " - f"versions, therefore disallowed. Please create separate " - f"catalog entries for saved and loaded datasets." - ) - raise ModelOpsException(msg) - - importlib.import_module(self._flavor).log_model( - model, - self._model_name, - registered_model_name=self._model_name, - signature=self._signature, - input_example=self._input_example, - ) - - def _load(self) -> Any: - *_, latest_version = parse_model_uri(f"models:/{self._model_name}") - - model_version = self._load_version or latest_version - - logger.info(f"Loading model '{self._model_name}' version '{model_version}'") - - if model_version != latest_version: - logger.warning(f"Newer version {latest_version} exists in repo") - - model = importlib.import_module(self._flavor).load_model( - f"models:/{self._model_name}/{model_version}" - ) - - return model - - def _describe(self) -> Dict[str, Any]: - return dict( - flavor=self._flavor, - model_name=self._model_name, - signature=self._signature, - input_example=self._input_example, - load_version=self._load_version, - ) diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/model_metadata.py b/kedro-datasets/kedro_datasets/databricks/mlflow/model_metadata.py deleted file mode 100644 index 3c160cec4..000000000 --- a/kedro-datasets/kedro_datasets/databricks/mlflow/model_metadata.py +++ /dev/null @@ -1,49 +0,0 @@ -import logging -from typing import Any, Dict, Union - -import mlflow -from kedro.io.core import AbstractDataSet - -from .common import ModelOpsException, parse_model_uri - -logger = logging.getLogger(__name__) - - -class MLFlowModelMetadata(AbstractDataSet): - def __init__( - self, registered_model_name: str, registered_model_version: str = None - ): - self._model_name = registered_model_name - self._model_version = registered_model_version - - def _save(self, tags: Dict[str, Union[str, float, int]]) -> None: - raise NotImplementedError() - - def _load(self) -> Any: - if self._model_version is None: - model_uri = f"models:/{self._model_name}" - else: - model_uri = f"models:/{self._model_name}/{self._model_version}" - _, _, load_version = parse_model_uri(model_uri) - - if load_version is None: - raise ModelOpsException( - f"No model with version " f"'{self._model_version}'" - ) - - pyfunc_model = mlflow.pyfunc.load_model( - f"models:/{self._model_name}/{load_version}" - ) - all_metadata = pyfunc_model._model_meta - model_metadata = { - "model_name": self._model_name, - "model_version": int(load_version), - "run_id": all_metadata.run_id, - } - return model_metadata - - def _describe(self) -> Dict[str, Any]: - return dict( - registered_model_name=self._model_name, - registered_model_version=self._model_version, - ) diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/tags.py b/kedro-datasets/kedro_datasets/databricks/mlflow/tags.py deleted file mode 100644 index 153810ae4..000000000 --- a/kedro-datasets/kedro_datasets/databricks/mlflow/tags.py +++ /dev/null @@ -1,94 +0,0 @@ -import logging -from typing import Any, Dict, Union - -import mlflow -from kedro.io.core import AbstractDataSet -from mlflow.exceptions import MlflowException -from mlflow.tracking import MlflowClient - -from .common import ModelOpsException - -logger = logging.getLogger(__name__) - - -class MLFlowTags(AbstractDataSet): - def __init__( - self, - prefix: str = None, - run_id: str = None, - registered_model_name: str = None, - registered_model_version: str = None, - ): - if None in (registered_model_name, registered_model_version): - if registered_model_name or registered_model_version: - raise ModelOpsException( - "'registered_model_name' and " - "'registered_model_version' should be " - "set together" - ) - - if run_id and registered_model_name: - raise ModelOpsException( - "'run_id' cannot be passed when " "'registered_model_name' is set" - ) - - self._prefix = prefix - self._run_id = run_id - self._registered_model_name = registered_model_name - self._registered_model_version = registered_model_version - - if registered_model_name: - self._version = f"{registered_model_name}/{registered_model_version}" - else: - self._version = run_id - - def _save(self, tags: Dict[str, Union[str, float, int]]) -> None: - if self._prefix is not None: - tags = {f"{self._prefix}_{key}": value for key, value in tags.items()} - - mlflow.set_tags(tags) - - run_id = mlflow.active_run().info.run_id - if self._version is not None: - logger.warning( - f"Ignoring version {self._version.save} set " - f"earlier, will use version='{run_id}' for loading" - ) - self._version = run_id - - def _load(self) -> Any: - if self._version is None: - msg = ( - "Could not determine the version to load. " - "Please specify either 'run_id' or 'registered_model_name' " - "along with 'registered_model_version' explicitly in " - "MLFlowTags constructor" - ) - raise MlflowException(msg) - - client = MlflowClient() - - if "/" in self._version: - model_uri = f"models:/{self._version}" - model = mlflow.pyfunc.load_model(model_uri) - run_id = model._model_meta.run_id - else: - run_id = self._version - - run = client.get_run(run_id) - tags = run.data.tags - if self._prefix is not None: - tags = { - key[len(self._prefix) + 1 :]: value - for key, value in tags.items() - if key[: len(self._prefix)] == self._prefix - } - return tags - - def _describe(self) -> Dict[str, Any]: - return dict( - prefix=self._prefix, - run_id=self._run_id, - registered_model_name=self._registered_model_name, - registered_model_version=self._registered_model_version, - ) diff --git a/kedro-datasets/kedro_datasets/databricks/unity/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/unity/managed_table_dataset.py index b46122197..f0f04b7be 100644 --- a/kedro-datasets/kedro_datasets/databricks/unity/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/unity/managed_table_dataset.py @@ -1,22 +1,26 @@ import logging -from typing import Any, Dict, List, Union import pandas as pd +from operator import attrgetter +from functools import partial +from cachetools.keys import hashkey +from typing import Any, Dict, List, Union +from cachetools import Cache, cachedmethod from kedro.io.core import ( AbstractVersionedDataSet, DataSetError, + Version, VersionNotFoundError, ) from pyspark.sql import DataFrame, SparkSession from pyspark.sql.types import StructType -from pyspark.sql.utils import AnalysisException from cachetools import Cache logger = logging.getLogger(__name__) class ManagedTableDataSet(AbstractVersionedDataSet): - """``ManagedTableDataSet`` loads data into Unity managed tables.""" + """``ManagedTableDataSet`` loads and saves data into managed delta tables.""" # this dataset cannot be used with ``ParallelRunner``, # therefore it has the attribute ``_SINGLE_PROCESS = True`` @@ -34,7 +38,7 @@ def __init__( write_mode: str = "overwrite", dataframe_type: str = "spark", primary_key: Union[str, List[str]] = None, - version: int = None, + version: Version = None, *, # the following parameters are used by the hook to create or update unity schema: Dict[str, Any] = None, # pylint: disable=unused-argument @@ -73,9 +77,8 @@ def __init__( ) self._primary_key = primary_key - - self._version = version self._version_cache = Cache(maxsize=2) + self._version = version self._schema = None if schema is not None: @@ -83,24 +86,16 @@ def __init__( def _get_spark(self) -> SparkSession: return ( - SparkSession.builder.config( - "spark.jars.packages", "io.delta:delta-core_2.12:1.2.1" - ) - .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") - .config( - "spark.sql.catalog.spark_catalog", - "org.apache.spark.sql.delta.catalog.DeltaCatalog", - ) - .getOrCreate() + SparkSession.builder.getOrCreate() ) def _load(self) -> Union[DataFrame, pd.DataFrame]: - if self._version is not None and self._version >= 0: + if self._version and self._version.load >= 0: try: data = ( self._get_spark() .read.format("delta") - .option("versionAsOf", self._version) + .option("versionAsOf", self._version.load) .table(self._full_table_address) ) except: diff --git a/kedro-datasets/tests/databricks/conftest.py b/kedro-datasets/tests/databricks/conftest.py index d360ffb68..26d63b056 100644 --- a/kedro-datasets/tests/databricks/conftest.py +++ b/kedro-datasets/tests/databricks/conftest.py @@ -6,7 +6,6 @@ """ import pytest from pyspark.sql import SparkSession -from delta.pip_utils import configure_spark_with_delta_pip @pytest.fixture(scope="class", autouse=True) diff --git a/kedro-datasets/tests/databricks/test_unity.py b/kedro-datasets/tests/databricks/test_unity.py index 471f81f57..0d54e29e4 100644 --- a/kedro-datasets/tests/databricks/test_unity.py +++ b/kedro-datasets/tests/databricks/test_unity.py @@ -1,5 +1,5 @@ import pytest -from kedro.io.core import DataSetError, VersionNotFoundError +from kedro.io.core import DataSetError, VersionNotFoundError, Version from pyspark.sql.types import IntegerType, StringType, StructField, StructType from pyspark.sql import DataFrame, SparkSession import pandas as pd @@ -195,6 +195,7 @@ def test_describe(self): "dataframe_type": "spark", "primary_key": None, "version": None, + "owner_group": None } def test_invalid_write_mode(self): @@ -413,7 +414,7 @@ def test_load_spark_no_version(self, sample_spark_df: DataFrame): unity_ds.save(sample_spark_df) delta_ds = ManagedTableDataSet( - database="test", table="test_load_spark", version=2 + database="test", table="test_load_spark", version=Version(2,None) ) with pytest.raises(VersionNotFoundError): _ = delta_ds.load() @@ -426,7 +427,7 @@ def test_load_version(self, sample_spark_df: DataFrame, append_spark_df: DataFra unity_ds.save(append_spark_df) loaded_ds = ManagedTableDataSet( - database="test", table="test_load_version", version=0 + database="test", table="test_load_version", version=Version(0,None) ) loaded_df = loaded_ds.load() From d6bc149d3e9663819167a1fe55f5153a894ab32c Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Wed, 8 Mar 2023 15:14:37 -0500 Subject: [PATCH 06/40] cleaned up mlflow references from setup.py for initial release Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- kedro-datasets/setup.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/kedro-datasets/setup.py b/kedro-datasets/setup.py index ea64d9314..6f151c1fa 100644 --- a/kedro-datasets/setup.py +++ b/kedro-datasets/setup.py @@ -18,13 +18,7 @@ def _collect_requirements(requires): biosequence_require = {"biosequence.BioSequenceDataSet": ["biopython~=1.73"]} dask_require = {"dask.ParquetDataSet": ["dask[complete]~=2021.10", "triad>=0.6.7, <1.0"]} databricks_require = { - "databricks.ManagedTableDataSet": [SPARK, PANDAS], - "databricks.MLFlowModel":[SPARK, PANDAS, "mlflow>=2.0.0"], - "databricks.MLFlowArtifact":[SPARK, PANDAS, "mlflow>=2.0.0"], - "databricks.MLFlowDataSet":[SPARK, PANDAS, "mlflow>=2.0.0"], - "databricks.MLFlowMetrics":[SPARK, PANDAS, "mlflow>=2.0.0"], - "databricks.MLFlowModelMetadata":[SPARK, PANDAS, "mlflow>=2.0.0"], - "databricks.MLFlowTags":[SPARK, PANDAS, "mlflow>=2.0.0"] + "databricks.ManagedTableDataSet": [SPARK, PANDAS] } geopandas_require = { "geopandas.GeoJSONDataSet": ["geopandas>=0.6.0, <1.0", "pyproj~=3.0"] From aee12a29e849f66174e1fe7f09d4e37279d9a542 Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Wed, 8 Mar 2023 15:16:51 -0500 Subject: [PATCH 07/40] fixed deps in setup.py Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- kedro-datasets/setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/kedro-datasets/setup.py b/kedro-datasets/setup.py index 6f151c1fa..f5c5d931a 100644 --- a/kedro-datasets/setup.py +++ b/kedro-datasets/setup.py @@ -8,6 +8,7 @@ HDFS = "hdfs>=2.5.8, <3.0" S3FS = "s3fs>=0.3.0, <0.5" POLARS = "polars~=0.17.0" +DELTA = "delta-spark~=1.2.1" def _collect_requirements(requires): @@ -18,7 +19,7 @@ def _collect_requirements(requires): biosequence_require = {"biosequence.BioSequenceDataSet": ["biopython~=1.73"]} dask_require = {"dask.ParquetDataSet": ["dask[complete]~=2021.10", "triad>=0.6.7, <1.0"]} databricks_require = { - "databricks.ManagedTableDataSet": [SPARK, PANDAS] + "databricks.ManagedTableDataSet": [SPARK, PANDAS, DELTA] } geopandas_require = { "geopandas.GeoJSONDataSet": ["geopandas>=0.6.0, <1.0", "pyproj~=3.0"] From 911e53f3a004cd54a6f94a3360f8b816d5e39107 Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Mon, 13 Mar 2023 18:08:37 -0400 Subject: [PATCH 08/40] adding comments before intiial PR Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- .../kedro_datasets/databricks/__init__.py | 7 +- .../databricks/managed_table_dataset.py | 342 ++++++++++++++++++ .../databricks/unity/__init__.py | 1 - .../databricks/unity/managed_table_dataset.py | 198 ---------- .../kedro_datasets/pandas/generic_dataset.py | 2 - .../spark/spark_jdbc_dataset.py | 1 - ...unity.py => test_managed_table_dataset.py} | 33 +- 7 files changed, 366 insertions(+), 218 deletions(-) create mode 100644 kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py delete mode 100644 kedro-datasets/kedro_datasets/databricks/unity/__init__.py delete mode 100644 kedro-datasets/kedro_datasets/databricks/unity/managed_table_dataset.py rename kedro-datasets/tests/databricks/{test_unity.py => test_managed_table_dataset.py} (94%) diff --git a/kedro-datasets/kedro_datasets/databricks/__init__.py b/kedro-datasets/kedro_datasets/databricks/__init__.py index 7819a2e06..d416ac291 100644 --- a/kedro-datasets/kedro_datasets/databricks/__init__.py +++ b/kedro-datasets/kedro_datasets/databricks/__init__.py @@ -1,3 +1,8 @@ """Provides interface to Unity Catalog Tables.""" -from .unity import ManagedTableDataSet +__all__ = ["ManagedTableDataSet"] + +from contextlib import suppress + +with suppress(ImportError): + from .managed_table_dataset import ManagedTableDataSet diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py new file mode 100644 index 000000000..1b9e0c737 --- /dev/null +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -0,0 +1,342 @@ +"""``ManagedTableDataSet`` implementation to access managed delta tables +in Databricks. +""" +import dataclasses +import logging +from functools import partial +from operator import attrgetter +from typing import Any, Dict, List, Union + +import pandas as pd +from cachetools import Cache, cachedmethod +from cachetools.keys import hashkey +from kedro.io.core import ( + AbstractVersionedDataSet, + DataSetError, + Version, + VersionNotFoundError, +) +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.types import StructType +from pyspark.sql.utils import AnalysisException, ParseException + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class Table: # pylint: disable=R0902 + """Stores the definition of a managed table""" + + database: str + catalog: str + table: str + full_table_location: str + write_mode: str + dataframe_type: str + primary_key: str + owner_group: str + partition_columns: str | List[str] + + +class ManagedTableDataSet(AbstractVersionedDataSet): + """``ManagedTableDataSet`` loads and saves data into managed delta tables on Databricks. + + Example usage for the + `YAML API `_: + .. code-block:: yaml + + names_and_ages@spark: + type: databricks.ManagedTableDataSet + table: names_and_ages + + names_and_ages@pandas: + type: databricks.ManagedTableDataSet + table: names_and_ages + dataframe_type: pandas + + Example usage for the + `Python API `_: + :: + Launch a pyspark session with the following configs: + % pyspark --packages io.delta:delta-core_2.12:1.2.1 + --conf "spark.sql.extensions=io.delta.sql.DeltaSparkSessionExtension" + --conf "spark.sql.catalog.spark_catalog=org.apache.spark.sql.delta.catalog.DeltaCatalog" + + >>> from pyspark.sql import SparkSession + >>> from pyspark.sql.types import (StructField, StringType, + IntegerType, StructType) + >>> from kedro_datasets.databricks import ManagedTableDataSet + >>> schema = StructType([StructField("name", StringType(), True), + StructField("age", IntegerType(), True)]) + >>> data = [('Alex', 31), ('Bob', 12), ('Clarke', 65), ('Dave', 29)] + >>> spark_df = SparkSession.builder.getOrCreate().createDataFrame(data, schema) + >>> data_set = ManagedTableDataSet(table="names_and_ages") + >>> data_set.save(spark_df) + >>> reloaded = data_set.load() + >>> reloaded.take(4)""" + + # this dataset cannot be used with ``ParallelRunner``, + # therefore it has the attribute ``_SINGLE_PROCESS = True`` + # for parallelism within a Spark pipeline please consider + # using ``ThreadRunner`` instead + _SINGLE_PROCESS = True + _VALID_WRITE_MODES = ["overwrite", "upsert", "append"] + _VALID_DATAFRAME_TYPES = ["spark", "pandas"] + + def __init__( # pylint: disable=R0913 + self, + table: str, + catalog: str = None, + database: str = "default", + write_mode: str = "overwrite", + dataframe_type: str = "spark", + primary_key: Union[str, List[str]] = None, + version: Version = None, + *, + # the following parameters are used by project hooks + # to create or update table properties + schema: Dict[str, Any] = None, + partition_columns: List[str] = None, + owner_group: str = None, + ) -> None: + """Creates a new instance of ``ManagedTableDataSet``.""" + + full_table_location = None + if catalog and database and table: + full_table_location = f"{catalog}.{database}.{table}" + elif table: + full_table_location = f"{database}.{table}" + if write_mode not in self._VALID_WRITE_MODES: + valid_modes = ", ".join(self._VALID_WRITE_MODES) + raise DataSetError( + f"Invalid `write_mode` provided: {write_mode}. " + f"`write_mode` must be one of: {valid_modes}" + ) + if dataframe_type not in self._VALID_DATAFRAME_TYPES: + valid_types = ", ".join(self._VALID_DATAFRAME_TYPES) + raise DataSetError(f"`dataframe_type` must be one of {valid_types}") + if primary_key is None or len(primary_key) == 0: + if write_mode == "upsert": + raise DataSetError( + f"`primary_key` must be provided for" f"`write_mode` {write_mode}" + ) + self._table = Table( + database=database, + catalog=catalog, + table=table, + full_table_location=full_table_location, + write_mode=write_mode, + dataframe_type=dataframe_type, + primary_key=primary_key, + owner_group=owner_group, + partition_columns=partition_columns, + ) + + self._version_cache = Cache(maxsize=2) + self._version = version + + self._schema = None + if schema is not None: + self._schema = StructType.fromJson(schema) + + super().__init__( + filepath=None, + version=version, + exists_function=self._exists, + ) + + @cachedmethod(cache=attrgetter("_version_cache"), key=partial(hashkey, "load")) + def _fetch_latest_load_version(self) -> int: + # When load version is unpinned, fetch the most recent existing + # version from the given path. + latest_history = ( + self._get_spark() + .sql(f"DESCRIBE HISTORY {self._table.full_table_location} LIMIT 1") + .collect() + ) + if len(latest_history) != 1: + raise VersionNotFoundError( + f"Did not find any versions for {self._table.full_table_location}" + ) + return latest_history[0].version + + # 'key' is set to prevent cache key overlapping for load and save: + # https://cachetools.readthedocs.io/en/stable/#cachetools.cachedmethod + @cachedmethod(cache=attrgetter("_version_cache"), key=partial(hashkey, "save")) + def _fetch_latest_save_version(self) -> int: + """Generate and cache the current save version""" + return None + + @staticmethod + def _get_spark() -> SparkSession: + return SparkSession.builder.getOrCreate() + + def _load(self) -> Union[DataFrame, pd.DataFrame]: + """Loads the version of data in the format defined in the init + (spark|pandas dataframe) + + Raises: + VersionNotFoundError: if the version defined in + the init doesn't exist + + Returns: + Union[DataFrame, pd.DataFrame]: Returns a dataframe + in the format defined in the init + """ + if self._version and self._version.load >= 0: + try: + data = ( + self._get_spark() + .read.format("delta") + .option("versionAsOf", self._version.load) + .table(self._table.full_table_location) + ) + except Exception as exc: + raise VersionNotFoundError(self._version) from exc + else: + data = self._get_spark().table(self._table.full_table_location) + if self._table.dataframe_type == "pandas": + data = data.toPandas() + return data + + def _save_append(self, data: DataFrame) -> None: + """Saves the data to the table by appending it + to the location defined in the init + + Args: + data (DataFrame): the Spark dataframe to append to the table + """ + data.write.format("delta").mode("append").saveAsTable( + self._table.full_table_location + ) + + def _save_overwrite(self, data: DataFrame) -> None: + """Overwrites the data in the table with the data provided. + (this is the default save mode) + + Args: + data (DataFrame): the Spark dataframe to overwrite the table with. + """ + delta_table = data.write.format("delta") + if self._table.write_mode == "overwrite": + delta_table = delta_table.mode("overwrite").option( + "overwriteSchema", "true" + ) + delta_table.saveAsTable(self._table.full_table_location) + + def _save_upsert(self, update_data: DataFrame) -> None: + """Upserts the data by joining on primary_key columns or column. + If table doesn't exist at save, the data is inserted to a new table. + + Args: + update_data (DataFrame): the Spark dataframe to upsert + """ + if self._exists(): + base_data = self._get_spark().table(self._table.full_table_location) + base_columns = base_data.columns + update_columns = update_data.columns + + if set(update_columns) != set(base_columns): + raise DataSetError( + f"Upsert requires tables to have identical columns. " + f"Delta table {self._table.full_table_location} " + f"has columns: {base_columns}, whereas " + f"dataframe has columns {update_columns}" + ) + + where_expr = "" + if isinstance(self._table.primary_key, str): + where_expr = ( + f"base.{self._table.primary_key}=update.{self._table.primary_key}" + ) + elif isinstance(self._table.primary_key, list): + where_expr = " AND ".join( + f"base.{col}=update.{col}" for col in self._table.primary_key + ) + + update_data.createOrReplaceTempView("update") + self._get_spark().conf.set( + "fullTableAddress", self._table.full_table_location + ) + self._get_spark().conf.set("whereExpr", where_expr) + upsert_sql = """MERGE INTO ${fullTableAddress} base USING update ON ${whereExpr} + WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *""" + self._get_spark().sql(upsert_sql) + else: + self._save_append(update_data) + + def _save(self, data: Union[DataFrame, pd.DataFrame]) -> None: + """Saves the data based on the write_mode and dataframe_type in the init. + If write_mode is pandas, Spark dataframe is created first. + If schema is provided, data is matched to schema before saving + (columns will be sorted and truncated). + + Args: + data (Any): Spark or pandas dataframe to save to the table location + """ + # filter columns specified in schema and match their ordering + if self._schema: + cols = self._schema.fieldNames() + if self._table.dataframe_type == "pandas": + data = self._get_spark().createDataFrame( + data.loc[:, cols], schema=self._schema + ) + else: + data = data.select(*cols) + else: + if self._table.dataframe_type == "pandas": + data = self._get_spark().createDataFrame(data) + if self._table.write_mode == "overwrite": + self._save_overwrite(data) + elif self._table.write_mode == "upsert": + self._save_upsert(data) + elif self._table.write_mode == "append": + self._save_append(data) + + def _describe(self) -> Dict[str, str]: + """Returns a description of the instance of ManagedTableDataSet + + Returns: + Dict[str, str]: Dict with the details of the dataset + """ + return { + "catalog": self._table.catalog, + "database": self._table.database, + "table": self._table.table, + "write_mode": self._table.write_mode, + "dataframe_type": self._table.dataframe_type, + "primary_key": self._table.primary_key, + "version": self._version, + "owner_group": self._table.owner_group, + "partition_columns": self._table.partition_columns, + } + + def _exists(self) -> bool: + """Checks to see if the table exists + + Returns: + bool: boolean of whether the table defined + in the dataset instance exists in the Spark session + """ + if self._table.catalog: + try: + self._get_spark().sql(f"USE CATALOG {self._table.catalog}") + except (ParseException, AnalysisException) as exc: + logger.warning( + "catalog %s not found or unity not enabled. Error message: %s", + self._table.catalog, + exc, + ) + try: + return ( + self._get_spark() + .sql(f"SHOW TABLES IN `{self._table.database}`") + .filter(f"tableName = '{self._table.table}'") + .count() + > 0 + ) + except (ParseException, AnalysisException) as exc: + logger.warning("error occured while trying to find table: %s", exc) + return False diff --git a/kedro-datasets/kedro_datasets/databricks/unity/__init__.py b/kedro-datasets/kedro_datasets/databricks/unity/__init__.py deleted file mode 100644 index ab452e146..000000000 --- a/kedro-datasets/kedro_datasets/databricks/unity/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .managed_table_dataset import ManagedTableDataSet \ No newline at end of file diff --git a/kedro-datasets/kedro_datasets/databricks/unity/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/unity/managed_table_dataset.py deleted file mode 100644 index f0f04b7be..000000000 --- a/kedro-datasets/kedro_datasets/databricks/unity/managed_table_dataset.py +++ /dev/null @@ -1,198 +0,0 @@ -import logging -import pandas as pd - -from operator import attrgetter -from functools import partial -from cachetools.keys import hashkey -from typing import Any, Dict, List, Union -from cachetools import Cache, cachedmethod -from kedro.io.core import ( - AbstractVersionedDataSet, - DataSetError, - Version, - VersionNotFoundError, -) -from pyspark.sql import DataFrame, SparkSession -from pyspark.sql.types import StructType -from cachetools import Cache - -logger = logging.getLogger(__name__) - - -class ManagedTableDataSet(AbstractVersionedDataSet): - """``ManagedTableDataSet`` loads and saves data into managed delta tables.""" - - # this dataset cannot be used with ``ParallelRunner``, - # therefore it has the attribute ``_SINGLE_PROCESS = True`` - # for parallelism within a Spark pipeline please consider - # using ``ThreadRunner`` instead - _SINGLE_PROCESS = True - _VALID_WRITE_MODES = ["overwrite", "upsert", "append"] - _VALID_DATAFRAME_TYPES = ["spark", "pandas"] - - def __init__( - self, - table: str, - catalog: str = None, - database: str = "default", - write_mode: str = "overwrite", - dataframe_type: str = "spark", - primary_key: Union[str, List[str]] = None, - version: Version = None, - *, - # the following parameters are used by the hook to create or update unity - schema: Dict[str, Any] = None, # pylint: disable=unused-argument - partition_columns: List[str] = None, # pylint: disable=unused-argument - owner_group: str = None, - ) -> None: - """Creates a new instance of ``ManagedTableDataSet``.""" - - self._database = database - self._catalog = catalog - self._table = table - self._owner_group = owner_group - self._partition_columns = partition_columns - if catalog and database and table: - self._full_table_address = f"{catalog}.{database}.{table}" - elif table: - self._full_table_address = f"{database}.{table}" - - if write_mode not in self._VALID_WRITE_MODES: - valid_modes = ", ".join(self._VALID_WRITE_MODES) - raise DataSetError( - f"Invalid `write_mode` provided: {write_mode}. " - f"`write_mode` must be one of: {valid_modes}" - ) - self._write_mode = write_mode - - if dataframe_type not in self._VALID_DATAFRAME_TYPES: - valid_types = ", ".join(self._VALID_DATAFRAME_TYPES) - raise DataSetError(f"`dataframe_type` must be one of {valid_types}") - self._dataframe_type = dataframe_type - - if primary_key is None or len(primary_key) == 0: - if write_mode == "upsert": - raise DataSetError( - f"`primary_key` must be provided for" f"`write_mode` {write_mode}" - ) - - self._primary_key = primary_key - self._version_cache = Cache(maxsize=2) - self._version = version - - self._schema = None - if schema is not None: - self._schema = StructType.fromJson(schema) - - def _get_spark(self) -> SparkSession: - return ( - SparkSession.builder.getOrCreate() - ) - - def _load(self) -> Union[DataFrame, pd.DataFrame]: - if self._version and self._version.load >= 0: - try: - data = ( - self._get_spark() - .read.format("delta") - .option("versionAsOf", self._version.load) - .table(self._full_table_address) - ) - except: - raise VersionNotFoundError - else: - data = self._get_spark().table(self._full_table_address) - if self._dataframe_type == "pandas": - data = data.toPandas() - return data - - def _save_append(self, data: DataFrame) -> None: - data.write.format("delta").mode("append").saveAsTable(self._full_table_address) - - def _save_overwrite(self, data: DataFrame) -> None: - delta_table = data.write.format("delta") - if self._write_mode == "overwrite": - delta_table = delta_table.mode("overwrite").option( - "overwriteSchema", "true" - ) - delta_table.saveAsTable(self._full_table_address) - - def _save_upsert(self, update_data: DataFrame) -> None: - if self._exists(): - base_data = self._get_spark().table(self._full_table_address) - base_columns = base_data.columns - update_columns = update_data.columns - - if set(update_columns) != set(base_columns): - raise DataSetError( - f"Upsert requires tables to have identical columns. " - f"Delta table {self._full_table_address} " - f"has columns: {base_columns}, whereas " - f"dataframe has columns {update_columns}" - ) - - where_expr = "" - if isinstance(self._primary_key, str): - where_expr = f"base.{self._primary_key}=update.{self._primary_key}" - elif isinstance(self._primary_key, list): - where_expr = " AND ".join( - f"base.{col}=update.{col}" for col in self._primary_key - ) - - update_data.createOrReplaceTempView("update") - - upsert_sql = f"""MERGE INTO {self._full_table_address} base USING update - ON {where_expr} WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT * - """ - self._get_spark().sql(upsert_sql) - else: - self._save_append(update_data) - - def _save(self, data: Any) -> None: - # filter columns specified in schema and match their ordering - if self._schema: - cols = self._schema.fieldNames() - if self._dataframe_type == "pandas": - data = self._get_spark().createDataFrame( - data.loc[:, cols], schema=self._schema - ) - else: - data = data.select(*cols) - else: - if self._dataframe_type == "pandas": - data = self._get_spark().createDataFrame(data) - if self._write_mode == "overwrite": - self._save_overwrite(data) - elif self._write_mode == "upsert": - self._save_upsert(data) - elif self._write_mode == "append": - self._save_append(data) - - def _describe(self) -> Dict[str, str]: - return dict( - catalog=self._catalog, - database=self._database, - table=self._table, - write_mode=self._write_mode, - dataframe_type=self._dataframe_type, - primary_key=self._primary_key, - version=self._version, - owner_group=self._owner_group, - ) - - def _exists(self) -> bool: - if self._catalog: - try: - self._get_spark().sql(f"USE CATALOG {self._catalog}") - except: - logger.warn(f"catalog {self._catalog} not found") - try: - return ( - self._get_spark() - .sql(f"SHOW TABLES IN `{self._database}`") - .filter(f"tableName = '{self._table}'") - .count() - > 0 - ) - except: - return False diff --git a/kedro-datasets/kedro_datasets/pandas/generic_dataset.py b/kedro-datasets/kedro_datasets/pandas/generic_dataset.py index a2bb6b1be..91229edcf 100644 --- a/kedro-datasets/kedro_datasets/pandas/generic_dataset.py +++ b/kedro-datasets/kedro_datasets/pandas/generic_dataset.py @@ -181,7 +181,6 @@ def _ensure_file_system_target(self) -> None: ) def _load(self) -> pd.DataFrame: - self._ensure_file_system_target() load_path = get_filepath_str(self._get_load_path(), self._protocol) @@ -196,7 +195,6 @@ def _load(self) -> pd.DataFrame: ) def _save(self, data: pd.DataFrame) -> None: - self._ensure_file_system_target() save_path = get_filepath_str(self._get_save_path(), self._protocol) diff --git a/kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py index ca3c7643c..c90c5f958 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py @@ -126,7 +126,6 @@ def __init__( # Update properties in load_args and save_args with credentials. if credentials is not None: - # Check credentials for bad inputs. for cred_key, cred_value in credentials.items(): if cred_value is None: diff --git a/kedro-datasets/tests/databricks/test_unity.py b/kedro-datasets/tests/databricks/test_managed_table_dataset.py similarity index 94% rename from kedro-datasets/tests/databricks/test_unity.py rename to kedro-datasets/tests/databricks/test_managed_table_dataset.py index 0d54e29e4..f5bc494a1 100644 --- a/kedro-datasets/tests/databricks/test_unity.py +++ b/kedro-datasets/tests/databricks/test_managed_table_dataset.py @@ -1,8 +1,9 @@ +import pandas as pd import pytest -from kedro.io.core import DataSetError, VersionNotFoundError, Version -from pyspark.sql.types import IntegerType, StringType, StructField, StructType +from kedro.io.core import DataSetError, Version, VersionNotFoundError from pyspark.sql import DataFrame, SparkSession -import pandas as pd +from pyspark.sql.types import IntegerType, StringType, StructField, StructType + from kedro_datasets.databricks import ManagedTableDataSet @@ -171,19 +172,16 @@ def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession): class TestManagedTableDataSet: def test_full_table(self): unity_ds = ManagedTableDataSet(catalog="test", database="test", table="test") - assert unity_ds._full_table_address == "test.test.test" + assert unity_ds._table.full_table_location == "test.test.test" - def test_database_table(self): unity_ds = ManagedTableDataSet(database="test", table="test") - assert unity_ds._full_table_address == "test.test" + assert unity_ds._table.full_table_location == "test.test" - def test_table_only(self): unity_ds = ManagedTableDataSet(table="test") - assert unity_ds._full_table_address == "default.test" + assert unity_ds._table.full_table_location == "default.test" - def test_table_missing(self): with pytest.raises(TypeError): - ManagedTableDataSet() + ManagedTableDataSet() # pylint: disable=no-value-for-parameter def test_describe(self): unity_ds = ManagedTableDataSet(table="test") @@ -195,7 +193,8 @@ def test_describe(self): "dataframe_type": "spark", "primary_key": None, "version": None, - "owner_group": None + "owner_group": None, + "partition_columns": None, } def test_invalid_write_mode(self): @@ -240,7 +239,9 @@ def test_schema(self): assert unity_ds._schema == expected_schema def test_catalog_exists(self): - unity_ds = ManagedTableDataSet(catalog="test", database="invalid", table="test_not_there") + unity_ds = ManagedTableDataSet( + catalog="test", database="invalid", table="test_not_there" + ) assert not unity_ds._exists() def test_table_does_not_exist(self): @@ -251,7 +252,9 @@ def test_save_default(self, sample_spark_df: DataFrame): unity_ds = ManagedTableDataSet(database="test", table="test_save") unity_ds.save(sample_spark_df) saved_table = unity_ds.load() - assert unity_ds.exists() and sample_spark_df.exceptAll(saved_table).count() == 0 + assert ( + unity_ds._exists() and sample_spark_df.exceptAll(saved_table).count() == 0 + ) def test_save_schema_spark( self, subset_spark_df: DataFrame, subset_expected_df: DataFrame @@ -414,7 +417,7 @@ def test_load_spark_no_version(self, sample_spark_df: DataFrame): unity_ds.save(sample_spark_df) delta_ds = ManagedTableDataSet( - database="test", table="test_load_spark", version=Version(2,None) + database="test", table="test_load_spark", version=Version(2, None) ) with pytest.raises(VersionNotFoundError): _ = delta_ds.load() @@ -427,7 +430,7 @@ def test_load_version(self, sample_spark_df: DataFrame, append_spark_df: DataFra unity_ds.save(append_spark_df) loaded_ds = ManagedTableDataSet( - database="test", table="test_load_version", version=Version(0,None) + database="test", table="test_load_version", version=Version(0, None) ) loaded_df = loaded_ds.load() From 10932fb2ba2344be8d676b02628182bedab072d7 Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Mon, 13 Mar 2023 21:47:50 -0400 Subject: [PATCH 09/40] moved validation to dataclass Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- .../databricks/managed_table_dataset.py | 168 +++++++++++++----- .../databricks/test_managed_table_dataset.py | 8 +- 2 files changed, 129 insertions(+), 47 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 1b9e0c737..aeef0d1af 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -1,8 +1,9 @@ """``ManagedTableDataSet`` implementation to access managed delta tables in Databricks. """ -import dataclasses import logging +import re +from dataclasses import dataclass from functools import partial from operator import attrgetter from typing import Any, Dict, List, Union @@ -21,21 +22,127 @@ from pyspark.sql.utils import AnalysisException, ParseException logger = logging.getLogger(__name__) +NAMING_REGEX = r"\b[0-9a-zA-Z_]{1,32}\b" +_VALID_WRITE_MODES = ["overwrite", "upsert", "append"] +_VALID_DATAFRAME_TYPES = ["spark", "pandas"] -@dataclasses.dataclass -class Table: # pylint: disable=R0902 +@dataclass(frozen=True) +class ManagedTable: # pylint: disable=R0902 """Stores the definition of a managed table""" database: str catalog: str table: str - full_table_location: str write_mode: str dataframe_type: str primary_key: str owner_group: str partition_columns: str | List[str] + json_schema: StructType + + def __post_init__(self): + """Run validation methods if declared. + The validation method can be a simple check + that raises ValueError or a transformation to + the field value. + The validation is performed by calling a function named: + `validate_(self, value) -> raises DataSetError` + """ + for name, _ in self.__dataclass_fields__.items(): # pylint: disable=E1101 + if method := getattr(self, f"validate_{name}", None): + method() + + def validate_table(self): + """validates table name + + Raises: + DataSetError: + """ + if not re.fullmatch(NAMING_REGEX, self.table): + raise DataSetError( + "table does not conform to naming and is a required field" + ) + + def validate_database(self): + """validates database name + + Raises: + DataSetError: + """ + if self.database: + if not re.fullmatch(NAMING_REGEX, self.database): + raise DataSetError("database does not conform to naming") + + def validate_catalog(self): + """validates catalog name + + Raises: + DataSetError: + """ + if self.catalog: + if not re.fullmatch(NAMING_REGEX, self.catalog): + raise DataSetError("catalog does not conform to naming") + + def validate_write_mode(self): + """validates the write mode + + Raises: + DataSetError: + """ + if self.write_mode not in _VALID_WRITE_MODES: + valid_modes = ", ".join(_VALID_WRITE_MODES) + raise DataSetError( + f"Invalid `write_mode` provided: {self.write_mode}. " + f"`write_mode` must be one of: {valid_modes}" + ) + + def validate_dataframe_type(self): + """validates the dataframe type + + Raises: + DataSetError: + """ + if self.dataframe_type not in _VALID_DATAFRAME_TYPES: + valid_types = ", ".join(_VALID_DATAFRAME_TYPES) + raise DataSetError(f"`dataframe_type` must be one of {valid_types}") + + def validate_primary_key(self): + """validates the primary key of the table + + Raises: + DataSetError: + """ + if self.primary_key is None or len(self.primary_key) == 0: + if self.write_mode == "upsert": + raise DataSetError( + f"`primary_key` must be provided for" + f"`write_mode` {self.write_mode}" + ) + + def full_table_location(self) -> str: + """Returns the full table location + + Returns: + str: table location in the format catalog.database.table + """ + full_table_location = None + if self.catalog and self.database and self.table: + full_table_location = f"{self.catalog}.{self.database}.{self.table}" + elif self.table: + full_table_location = f"{self.database}.{self.table}" + return full_table_location + + def schema(self) -> StructType: + """Returns the Spark schema of the table if it exists + + Returns: + StructType: + """ + schema = None + if self.json_schema is not None: + schema = StructType.fromJson(self.json_schema) + return schema class ManagedTableDataSet(AbstractVersionedDataSet): @@ -82,8 +189,6 @@ class ManagedTableDataSet(AbstractVersionedDataSet): # for parallelism within a Spark pipeline please consider # using ``ThreadRunner`` instead _SINGLE_PROCESS = True - _VALID_WRITE_MODES = ["overwrite", "upsert", "append"] - _VALID_DATAFRAME_TYPES = ["spark", "pandas"] def __init__( # pylint: disable=R0913 self, @@ -103,44 +208,21 @@ def __init__( # pylint: disable=R0913 ) -> None: """Creates a new instance of ``ManagedTableDataSet``.""" - full_table_location = None - if catalog and database and table: - full_table_location = f"{catalog}.{database}.{table}" - elif table: - full_table_location = f"{database}.{table}" - if write_mode not in self._VALID_WRITE_MODES: - valid_modes = ", ".join(self._VALID_WRITE_MODES) - raise DataSetError( - f"Invalid `write_mode` provided: {write_mode}. " - f"`write_mode` must be one of: {valid_modes}" - ) - if dataframe_type not in self._VALID_DATAFRAME_TYPES: - valid_types = ", ".join(self._VALID_DATAFRAME_TYPES) - raise DataSetError(f"`dataframe_type` must be one of {valid_types}") - if primary_key is None or len(primary_key) == 0: - if write_mode == "upsert": - raise DataSetError( - f"`primary_key` must be provided for" f"`write_mode` {write_mode}" - ) - self._table = Table( + self._table = ManagedTable( database=database, catalog=catalog, table=table, - full_table_location=full_table_location, write_mode=write_mode, dataframe_type=dataframe_type, primary_key=primary_key, owner_group=owner_group, partition_columns=partition_columns, + json_schema=schema, ) self._version_cache = Cache(maxsize=2) self._version = version - self._schema = None - if schema is not None: - self._schema = StructType.fromJson(schema) - super().__init__( filepath=None, version=version, @@ -153,12 +235,12 @@ def _fetch_latest_load_version(self) -> int: # version from the given path. latest_history = ( self._get_spark() - .sql(f"DESCRIBE HISTORY {self._table.full_table_location} LIMIT 1") + .sql(f"DESCRIBE HISTORY {self._table.full_table_location()} LIMIT 1") .collect() ) if len(latest_history) != 1: raise VersionNotFoundError( - f"Did not find any versions for {self._table.full_table_location}" + f"Did not find any versions for {self._table.full_table_location()}" ) return latest_history[0].version @@ -191,12 +273,12 @@ def _load(self) -> Union[DataFrame, pd.DataFrame]: self._get_spark() .read.format("delta") .option("versionAsOf", self._version.load) - .table(self._table.full_table_location) + .table(self._table.full_table_location()) ) except Exception as exc: raise VersionNotFoundError(self._version) from exc else: - data = self._get_spark().table(self._table.full_table_location) + data = self._get_spark().table(self._table.full_table_location()) if self._table.dataframe_type == "pandas": data = data.toPandas() return data @@ -209,7 +291,7 @@ def _save_append(self, data: DataFrame) -> None: data (DataFrame): the Spark dataframe to append to the table """ data.write.format("delta").mode("append").saveAsTable( - self._table.full_table_location + self._table.full_table_location() ) def _save_overwrite(self, data: DataFrame) -> None: @@ -224,7 +306,7 @@ def _save_overwrite(self, data: DataFrame) -> None: delta_table = delta_table.mode("overwrite").option( "overwriteSchema", "true" ) - delta_table.saveAsTable(self._table.full_table_location) + delta_table.saveAsTable(self._table.full_table_location()) def _save_upsert(self, update_data: DataFrame) -> None: """Upserts the data by joining on primary_key columns or column. @@ -234,14 +316,14 @@ def _save_upsert(self, update_data: DataFrame) -> None: update_data (DataFrame): the Spark dataframe to upsert """ if self._exists(): - base_data = self._get_spark().table(self._table.full_table_location) + base_data = self._get_spark().table(self._table.full_table_location()) base_columns = base_data.columns update_columns = update_data.columns if set(update_columns) != set(base_columns): raise DataSetError( f"Upsert requires tables to have identical columns. " - f"Delta table {self._table.full_table_location} " + f"Delta table {self._table.full_table_location()} " f"has columns: {base_columns}, whereas " f"dataframe has columns {update_columns}" ) @@ -258,7 +340,7 @@ def _save_upsert(self, update_data: DataFrame) -> None: update_data.createOrReplaceTempView("update") self._get_spark().conf.set( - "fullTableAddress", self._table.full_table_location + "fullTableAddress", self._table.full_table_location() ) self._get_spark().conf.set("whereExpr", where_expr) upsert_sql = """MERGE INTO ${fullTableAddress} base USING update ON ${whereExpr} @@ -277,11 +359,11 @@ def _save(self, data: Union[DataFrame, pd.DataFrame]) -> None: data (Any): Spark or pandas dataframe to save to the table location """ # filter columns specified in schema and match their ordering - if self._schema: - cols = self._schema.fieldNames() + if self._table.schema(): + cols = self._table.schema().fieldNames() if self._table.dataframe_type == "pandas": data = self._get_spark().createDataFrame( - data.loc[:, cols], schema=self._schema + data.loc[:, cols], schema=self._table.schema() ) else: data = data.select(*cols) diff --git a/kedro-datasets/tests/databricks/test_managed_table_dataset.py b/kedro-datasets/tests/databricks/test_managed_table_dataset.py index f5bc494a1..4520042ab 100644 --- a/kedro-datasets/tests/databricks/test_managed_table_dataset.py +++ b/kedro-datasets/tests/databricks/test_managed_table_dataset.py @@ -172,13 +172,13 @@ def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession): class TestManagedTableDataSet: def test_full_table(self): unity_ds = ManagedTableDataSet(catalog="test", database="test", table="test") - assert unity_ds._table.full_table_location == "test.test.test" + assert unity_ds._table.full_table_location() == "test.test.test" unity_ds = ManagedTableDataSet(database="test", table="test") - assert unity_ds._table.full_table_location == "test.test" + assert unity_ds._table.full_table_location() == "test.test" unity_ds = ManagedTableDataSet(table="test") - assert unity_ds._table.full_table_location == "default.test" + assert unity_ds._table.full_table_location() == "default.test" with pytest.raises(TypeError): ManagedTableDataSet() # pylint: disable=no-value-for-parameter @@ -236,7 +236,7 @@ def test_schema(self): StructField("age", IntegerType(), True), ] ) - assert unity_ds._schema == expected_schema + assert unity_ds._table.schema() == expected_schema def test_catalog_exists(self): unity_ds = ManagedTableDataSet( From 74471a8cdf3c62dd0759e18e89157f185d54b8d2 Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Tue, 21 Mar 2023 12:58:31 -0400 Subject: [PATCH 10/40] bug fix in type of partition column and cleanup Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- .../databricks/managed_table_dataset.py | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index aeef0d1af..ee82253e3 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -22,15 +22,16 @@ from pyspark.sql.utils import AnalysisException, ParseException logger = logging.getLogger(__name__) -NAMING_REGEX = r"\b[0-9a-zA-Z_]{1,32}\b" -_VALID_WRITE_MODES = ["overwrite", "upsert", "append"] -_VALID_DATAFRAME_TYPES = ["spark", "pandas"] @dataclass(frozen=True) class ManagedTable: # pylint: disable=R0902 """Stores the definition of a managed table""" + # regex for tables, catalogs and schemas + _NAMING_REGEX = r"\b[0-9a-zA-Z_]{1,32}\b" + _VALID_WRITE_MODES = ["overwrite", "upsert", "append"] + _VALID_DATAFRAME_TYPES = ["spark", "pandas"] database: str catalog: str table: str @@ -38,14 +39,13 @@ class ManagedTable: # pylint: disable=R0902 dataframe_type: str primary_key: str owner_group: str - partition_columns: str | List[str] + partition_columns: Union[str, List[str]] json_schema: StructType def __post_init__(self): """Run validation methods if declared. The validation method can be a simple check - that raises ValueError or a transformation to - the field value. + that raises DataSetError. The validation is performed by calling a function named: `validate_(self, value) -> raises DataSetError` """ @@ -59,10 +59,8 @@ def validate_table(self): Raises: DataSetError: """ - if not re.fullmatch(NAMING_REGEX, self.table): - raise DataSetError( - "table does not conform to naming and is a required field" - ) + if not re.fullmatch(self._NAMING_REGEX, self.table): + raise DataSetError("table does not conform to naming") def validate_database(self): """validates database name @@ -71,7 +69,7 @@ def validate_database(self): DataSetError: """ if self.database: - if not re.fullmatch(NAMING_REGEX, self.database): + if not re.fullmatch(self._NAMING_REGEX, self.database): raise DataSetError("database does not conform to naming") def validate_catalog(self): @@ -81,7 +79,7 @@ def validate_catalog(self): DataSetError: """ if self.catalog: - if not re.fullmatch(NAMING_REGEX, self.catalog): + if not re.fullmatch(self._NAMING_REGEX, self.catalog): raise DataSetError("catalog does not conform to naming") def validate_write_mode(self): @@ -90,8 +88,8 @@ def validate_write_mode(self): Raises: DataSetError: """ - if self.write_mode not in _VALID_WRITE_MODES: - valid_modes = ", ".join(_VALID_WRITE_MODES) + if self.write_mode not in self._VALID_WRITE_MODES: + valid_modes = ", ".join(self._VALID_WRITE_MODES) raise DataSetError( f"Invalid `write_mode` provided: {self.write_mode}. " f"`write_mode` must be one of: {valid_modes}" @@ -103,8 +101,8 @@ def validate_dataframe_type(self): Raises: DataSetError: """ - if self.dataframe_type not in _VALID_DATAFRAME_TYPES: - valid_types = ", ".join(_VALID_DATAFRAME_TYPES) + if self.dataframe_type not in self._VALID_DATAFRAME_TYPES: + valid_types = ", ".join(self._VALID_DATAFRAME_TYPES) raise DataSetError(f"`dataframe_type` must be one of {valid_types}") def validate_primary_key(self): @@ -140,8 +138,11 @@ def schema(self) -> StructType: StructType: """ schema = None - if self.json_schema is not None: - schema = StructType.fromJson(self.json_schema) + try: + if self.json_schema is not None: + schema = StructType.fromJson(self.json_schema) + except ParseException as exc: + raise DataSetError(exc) from exc return schema From 4022f0d8ffd4cd9a91c47fe26846e64eec35196c Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Tue, 21 Mar 2023 14:53:51 -0400 Subject: [PATCH 11/40] updated docstring for ManagedTableDataSet Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- .../databricks/managed_table_dataset.py | 74 ++++++++++++++----- 1 file changed, 56 insertions(+), 18 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index ee82253e3..fd5cd5e03 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -148,10 +148,18 @@ def schema(self) -> StructType: class ManagedTableDataSet(AbstractVersionedDataSet): """``ManagedTableDataSet`` loads and saves data into managed delta tables on Databricks. + Load and save can be in Spark or Pandas dataframes, specified in dataframe_type. + When saving data, you can specify one of three modes: overwtire(default), append, + or upsert. Upsert requires you to specify the primary_column parameter which + will be used as part of the join condition. This dataset works best with + the databricks kedro starter. That starter comes with hooks that allow this + dataset to function properly. Follow the instructions in that starter to + setup your project for this dataset. Example usage for the `YAML API `_: + .. code-block:: yaml names_and_ages@spark: @@ -167,23 +175,24 @@ class ManagedTableDataSet(AbstractVersionedDataSet): `Python API `_: :: - Launch a pyspark session with the following configs: - % pyspark --packages io.delta:delta-core_2.12:1.2.1 - --conf "spark.sql.extensions=io.delta.sql.DeltaSparkSessionExtension" - --conf "spark.sql.catalog.spark_catalog=org.apache.spark.sql.delta.catalog.DeltaCatalog" - - >>> from pyspark.sql import SparkSession - >>> from pyspark.sql.types import (StructField, StringType, - IntegerType, StructType) - >>> from kedro_datasets.databricks import ManagedTableDataSet - >>> schema = StructType([StructField("name", StringType(), True), - StructField("age", IntegerType(), True)]) - >>> data = [('Alex', 31), ('Bob', 12), ('Clarke', 65), ('Dave', 29)] - >>> spark_df = SparkSession.builder.getOrCreate().createDataFrame(data, schema) - >>> data_set = ManagedTableDataSet(table="names_and_ages") - >>> data_set.save(spark_df) - >>> reloaded = data_set.load() - >>> reloaded.take(4)""" + + % pyspark --packages io.delta:delta-core_2.12:1.2.1 + --conf "spark.sql.extensions=io.delta.sql.DeltaSparkSessionExtension" + --conf "spark.sql.catalog.spark_catalog=org.apache.spark.sql.delta.catalog.DeltaCatalog" + + >>> from pyspark.sql import SparkSession + >>> from pyspark.sql.types import (StructField, StringType, + IntegerType, StructType) + >>> from kedro_datasets.databricks import ManagedTableDataSet + >>> schema = StructType([StructField("name", StringType(), True), + StructField("age", IntegerType(), True)]) + >>> data = [('Alex', 31), ('Bob', 12), ('Clarke', 65), ('Dave', 29)] + >>> spark_df = SparkSession.builder.getOrCreate().createDataFrame(data, schema) + >>> data_set = ManagedTableDataSet(table="names_and_ages") + >>> data_set.save(spark_df) + >>> reloaded = data_set.load() + >>> reloaded.take(4) + """ # this dataset cannot be used with ``ParallelRunner``, # therefore it has the attribute ``_SINGLE_PROCESS = True`` @@ -207,7 +216,36 @@ def __init__( # pylint: disable=R0913 partition_columns: List[str] = None, owner_group: str = None, ) -> None: - """Creates a new instance of ``ManagedTableDataSet``.""" + """Creates a new instance of ``ManagedTableDataSet`` + + Args: + table (str): the name of the table + catalog (str, optional): the name of the catalog in Unity. + Defaults to None. + database (str, optional): the name of the database + (also referred to as schema). Defaults to "default". + write_mode (str, optional): the mode to write the data into the table. + Options are:["overwrite", "append", "upsert"]. + "upsert" mode requires primary_key field to be populated. + Defaults to "overwrite". + dataframe_type (str, optional): "pandas" or "spark" dataframe. + Defaults to "spark". + primary_key (Union[str, List[str]], optional): the primary key of the table. + Can be in the form of a list. Defaults to None. + version (Version, optional): kedro.io.core.Version instance to load the data. + Defaults to None. + schema (Dict[str, Any], optional): the schema of the table in JSON form. + Dataframes will be truncated to match the schema if provided. + Used by the hooks to create the table if the schema is provided + Defaults to None. + partition_columns (List[str], optional): the columns to use for partitioning the table. + Used by the hooks. Defaults to None. + owner_group (str, optional): if table access control is enabled in your workspace, + specifying owner_group will transfer ownership of the table and database to + this owner. All databases should have the same owner_group. Defaults to None. + Raises: + DataSetError: Invalid configuration supplied (through ManagedTable validation) + """ self._table = ManagedTable( database=database, From f6531e12a2bfe06389c4ae30314651a20dc67a0f Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Wed, 5 Apr 2023 14:56:25 -0400 Subject: [PATCH 12/40] added backticks to catalog Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- .../kedro_datasets/databricks/managed_table_dataset.py | 4 ++-- .../tests/databricks/test_managed_table_dataset.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index fd5cd5e03..015daff93 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -126,9 +126,9 @@ def full_table_location(self) -> str: """ full_table_location = None if self.catalog and self.database and self.table: - full_table_location = f"{self.catalog}.{self.database}.{self.table}" + full_table_location = f"`{self.catalog}`.`{self.database}`.`{self.table}`" elif self.table: - full_table_location = f"{self.database}.{self.table}" + full_table_location = f"`{self.database}`.`{self.table}`" return full_table_location def schema(self) -> StructType: diff --git a/kedro-datasets/tests/databricks/test_managed_table_dataset.py b/kedro-datasets/tests/databricks/test_managed_table_dataset.py index 4520042ab..000aa8d6e 100644 --- a/kedro-datasets/tests/databricks/test_managed_table_dataset.py +++ b/kedro-datasets/tests/databricks/test_managed_table_dataset.py @@ -172,13 +172,13 @@ def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession): class TestManagedTableDataSet: def test_full_table(self): unity_ds = ManagedTableDataSet(catalog="test", database="test", table="test") - assert unity_ds._table.full_table_location() == "test.test.test" + assert unity_ds._table.full_table_location() == "`test`.`test`.`test`" unity_ds = ManagedTableDataSet(database="test", table="test") - assert unity_ds._table.full_table_location() == "test.test" + assert unity_ds._table.full_table_location() == "`test`.`test`" unity_ds = ManagedTableDataSet(table="test") - assert unity_ds._table.full_table_location() == "default.test" + assert unity_ds._table.full_table_location() == "`default`.`test`" with pytest.raises(TypeError): ManagedTableDataSet() # pylint: disable=no-value-for-parameter From 3ed18a18d9fd7c268e69fde4501f8ac49671a0a8 Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Tue, 11 Apr 2023 09:30:47 -0400 Subject: [PATCH 13/40] fixing regex to allow hyphens Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- .../kedro_datasets/databricks/managed_table_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 015daff93..41fd7d2e5 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -29,7 +29,7 @@ class ManagedTable: # pylint: disable=R0902 """Stores the definition of a managed table""" # regex for tables, catalogs and schemas - _NAMING_REGEX = r"\b[0-9a-zA-Z_]{1,32}\b" + _NAMING_REGEX = r"\b[0-9a-zA-Z_-]{1,32}\b" _VALID_WRITE_MODES = ["overwrite", "upsert", "append"] _VALID_DATAFRAME_TYPES = ["spark", "pandas"] database: str From a149b4dfca67636df40c914f6e00123c932fbc3b Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Wed, 3 May 2023 10:59:33 -0700 Subject: [PATCH 14/40] Update kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py Co-authored-by: Jannic <37243923+jmholzer@users.noreply.github.com> Signed-off-by: Jannic Holzer --- .../kedro_datasets/databricks/managed_table_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 41fd7d2e5..9fc2f3f30 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -66,7 +66,7 @@ def validate_database(self): """validates database name Raises: - DataSetError: + DataSetError: If the table name does not conform to naming constraints. """ if self.database: if not re.fullmatch(self._NAMING_REGEX, self.database): From c854a647347afda638f9c4adf94201ed2420f3db Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Wed, 3 May 2023 10:59:42 -0700 Subject: [PATCH 15/40] Update kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py Co-authored-by: Jannic <37243923+jmholzer@users.noreply.github.com> Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- .../kedro_datasets/databricks/managed_table_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 9fc2f3f30..6e895ba51 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -86,7 +86,7 @@ def validate_write_mode(self): """validates the write mode Raises: - DataSetError: + DataSetError: If an invalid `write_mode` is passed. """ if self.write_mode not in self._VALID_WRITE_MODES: valid_modes = ", ".join(self._VALID_WRITE_MODES) From c994c3f16a06fe825aa8e2081d9dc8f13ed55d44 Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Wed, 3 May 2023 10:59:56 -0700 Subject: [PATCH 16/40] Update kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py Co-authored-by: Jannic <37243923+jmholzer@users.noreply.github.com> Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- .../kedro_datasets/databricks/managed_table_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 6e895ba51..25fe150ba 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -109,7 +109,7 @@ def validate_primary_key(self): """validates the primary key of the table Raises: - DataSetError: + DataSetError: If no `primary_key` is specified. """ if self.primary_key is None or len(self.primary_key) == 0: if self.write_mode == "upsert": From 09bf84735d1b72c57d8944fc886e63c3584870e5 Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Wed, 3 May 2023 11:01:50 -0700 Subject: [PATCH 17/40] Update kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py Co-authored-by: Jannic <37243923+jmholzer@users.noreply.github.com> Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- .../kedro_datasets/databricks/managed_table_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 25fe150ba..5e8214a8f 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -76,7 +76,7 @@ def validate_catalog(self): """validates catalog name Raises: - DataSetError: + DataSetError: If the catalog name does not conform to naming constraints. """ if self.catalog: if not re.fullmatch(self._NAMING_REGEX, self.catalog): From b7e8cff5b90d1b1ad18ecf071b2d0b72c74a24e7 Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Wed, 3 May 2023 11:01:59 -0700 Subject: [PATCH 18/40] Update kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py Co-authored-by: Jannic <37243923+jmholzer@users.noreply.github.com> Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- .../kedro_datasets/databricks/managed_table_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 5e8214a8f..790df3def 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -99,7 +99,7 @@ def validate_dataframe_type(self): """validates the dataframe type Raises: - DataSetError: + DataSetError: If an invalid `dataframe_type` is passed """ if self.dataframe_type not in self._VALID_DATAFRAME_TYPES: valid_types = ", ".join(self._VALID_DATAFRAME_TYPES) From e7b8e4028ad70de2d7186eeec5939fe7333c10c5 Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Wed, 3 May 2023 11:02:09 -0700 Subject: [PATCH 19/40] Update kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py Co-authored-by: Jannic <37243923+jmholzer@users.noreply.github.com> Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- .../kedro_datasets/databricks/managed_table_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 790df3def..08f62f769 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -222,7 +222,7 @@ def __init__( # pylint: disable=R0913 table (str): the name of the table catalog (str, optional): the name of the catalog in Unity. Defaults to None. - database (str, optional): the name of the database + database (str, optional): the name of the database. (also referred to as schema). Defaults to "default". write_mode (str, optional): the mode to write the data into the table. Options are:["overwrite", "append", "upsert"]. From 31a0c73da9a3d9f2eb973351dc70a107533624a1 Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Wed, 3 May 2023 11:02:17 -0700 Subject: [PATCH 20/40] Update kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py Co-authored-by: Jannic <37243923+jmholzer@users.noreply.github.com> Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- .../kedro_datasets/databricks/managed_table_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 08f62f769..4e2da3cf6 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -149,7 +149,7 @@ def schema(self) -> StructType: class ManagedTableDataSet(AbstractVersionedDataSet): """``ManagedTableDataSet`` loads and saves data into managed delta tables on Databricks. Load and save can be in Spark or Pandas dataframes, specified in dataframe_type. - When saving data, you can specify one of three modes: overwtire(default), append, + When saving data, you can specify one of three modes: overwrite(default), append, or upsert. Upsert requires you to specify the primary_column parameter which will be used as part of the join condition. This dataset works best with the databricks kedro starter. That starter comes with hooks that allow this From 83704b4e2cb0869bd1d59a7c7a8023f9d31e7bd0 Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Wed, 3 May 2023 11:11:51 -0700 Subject: [PATCH 21/40] Update kedro-datasets/test_requirements.txt Co-authored-by: Jannic <37243923+jmholzer@users.noreply.github.com> Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- kedro-datasets/test_requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/kedro-datasets/test_requirements.txt b/kedro-datasets/test_requirements.txt index 90faa0b02..d2231136b 100644 --- a/kedro-datasets/test_requirements.txt +++ b/kedro-datasets/test_requirements.txt @@ -24,7 +24,8 @@ lxml~=4.6 matplotlib>=3.0.3, <3.4; python_version < '3.10' # 3.4.0 breaks holoviews matplotlib>=3.5, <3.6; python_version == '3.10' memory_profiler>=0.50.0, <1.0 -mlflow==2.2.1 +mlflow~=2.2.1; python_version>='3.8' +mlflow~=1.30.0; python_version=='3.7' moto==1.3.7; python_version < '3.10' moto==3.0.4; python_version == '3.10' networkx~=2.4 From 1bf1e29984ed34f87f2796bfbad9845f503196a7 Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Wed, 3 May 2023 11:12:03 -0700 Subject: [PATCH 22/40] Update kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py Co-authored-by: Jannic <37243923+jmholzer@users.noreply.github.com> Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- .../kedro_datasets/databricks/managed_table_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 4e2da3cf6..d26c0b189 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -207,7 +207,7 @@ def __init__( # pylint: disable=R0913 database: str = "default", write_mode: str = "overwrite", dataframe_type: str = "spark", - primary_key: Union[str, List[str]] = None, + primary_key: Optional[Union[str, List[str]]] = None, version: Version = None, *, # the following parameters are used by project hooks From 9e391ee87522c8b294be2ae239dd68315cb358a2 Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Wed, 3 May 2023 11:12:16 -0700 Subject: [PATCH 23/40] Update kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py Co-authored-by: Jannic <37243923+jmholzer@users.noreply.github.com> Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- .../kedro_datasets/databricks/managed_table_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index d26c0b189..e562650e9 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -429,7 +429,7 @@ def _describe(self) -> Dict[str, str]: "write_mode": self._table.write_mode, "dataframe_type": self._table.dataframe_type, "primary_key": self._table.primary_key, - "version": self._version, + "version": str(self._version), "owner_group": self._table.owner_group, "partition_columns": self._table.partition_columns, } From 651e37956b8bf54a78fe1c2282a23f686ffc0097 Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Wed, 3 May 2023 11:16:39 -0700 Subject: [PATCH 24/40] Update kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py Co-authored-by: Jannic <37243923+jmholzer@users.noreply.github.com> Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- .../databricks/managed_table_dataset.py | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index e562650e9..6266bfc35 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -174,24 +174,20 @@ class ManagedTableDataSet(AbstractVersionedDataSet): Example usage for the `Python API `_: - :: + .. code-block:: python - % pyspark --packages io.delta:delta-core_2.12:1.2.1 - --conf "spark.sql.extensions=io.delta.sql.DeltaSparkSessionExtension" - --conf "spark.sql.catalog.spark_catalog=org.apache.spark.sql.delta.catalog.DeltaCatalog" - - >>> from pyspark.sql import SparkSession - >>> from pyspark.sql.types import (StructField, StringType, + from pyspark.sql import SparkSession + from pyspark.sql.types import (StructField, StringType, IntegerType, StructType) - >>> from kedro_datasets.databricks import ManagedTableDataSet - >>> schema = StructType([StructField("name", StringType(), True), - StructField("age", IntegerType(), True)]) - >>> data = [('Alex', 31), ('Bob', 12), ('Clarke', 65), ('Dave', 29)] - >>> spark_df = SparkSession.builder.getOrCreate().createDataFrame(data, schema) - >>> data_set = ManagedTableDataSet(table="names_and_ages") - >>> data_set.save(spark_df) - >>> reloaded = data_set.load() - >>> reloaded.take(4) + from kedro_datasets.databricks import ManagedTableDataSet + schema = StructType([StructField("name", StringType(), True), + StructField("age", IntegerType(), True)]) + data = [('Alex', 31), ('Bob', 12), ('Clarke', 65), ('Dave', 29)] + spark_df = SparkSession.builder.getOrCreate().createDataFrame(data, schema) + data_set = ManagedTableDataSet(table="names_and_ages") + data_set.save(spark_df) + reloaded = data_set.load() + reloaded.take(4) """ # this dataset cannot be used with ``ParallelRunner``, From b2676161fe890f12d233b3279dc99e718703bf63 Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Wed, 3 May 2023 11:28:10 -0700 Subject: [PATCH 25/40] Update kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py Co-authored-by: Jannic <37243923+jmholzer@users.noreply.github.com> Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- .../kedro_datasets/databricks/managed_table_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 6266bfc35..7d0490d45 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -37,7 +37,7 @@ class ManagedTable: # pylint: disable=R0902 table: str write_mode: str dataframe_type: str - primary_key: str + primary_key: Optional[str] owner_group: str partition_columns: Union[str, List[str]] json_schema: StructType From f8f9786e73ad1b90f52a0ef5d32d0b31dde3b7a5 Mon Sep 17 00:00:00 2001 From: Danny Farah Date: Wed, 3 May 2023 11:29:37 -0700 Subject: [PATCH 26/40] adding backticks to catalog Signed-off-by: Danny Farah Signed-off-by: Jannic Holzer --- .../kedro_datasets/databricks/managed_table_dataset.py | 4 ++-- .../tests/databricks/test_managed_table_dataset.py | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 7d0490d45..5b9e83e57 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from functools import partial from operator import attrgetter -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import pandas as pd from cachetools import Cache, cachedmethod @@ -439,7 +439,7 @@ def _exists(self) -> bool: """ if self._table.catalog: try: - self._get_spark().sql(f"USE CATALOG {self._table.catalog}") + self._get_spark().sql(f"USE CATALOG `{self._table.catalog}`") except (ParseException, AnalysisException) as exc: logger.warning( "catalog %s not found or unity not enabled. Error message: %s", diff --git a/kedro-datasets/tests/databricks/test_managed_table_dataset.py b/kedro-datasets/tests/databricks/test_managed_table_dataset.py index 000aa8d6e..fbdbaaebc 100644 --- a/kedro-datasets/tests/databricks/test_managed_table_dataset.py +++ b/kedro-datasets/tests/databricks/test_managed_table_dataset.py @@ -174,6 +174,11 @@ def test_full_table(self): unity_ds = ManagedTableDataSet(catalog="test", database="test", table="test") assert unity_ds._table.full_table_location() == "`test`.`test`.`test`" + unity_ds = ManagedTableDataSet( + catalog="test-test", database="test", table="test" + ) + assert unity_ds._table.full_table_location() == "`test-test`.`test`.`test`" + unity_ds = ManagedTableDataSet(database="test", table="test") assert unity_ds._table.full_table_location() == "`test`.`test`" @@ -192,7 +197,7 @@ def test_describe(self): "write_mode": "overwrite", "dataframe_type": "spark", "primary_key": None, - "version": None, + "version": "None", "owner_group": None, "partition_columns": None, } From 57248ea7f7086b3a96cf95306be90f9e5d0cdcbe Mon Sep 17 00:00:00 2001 From: Jannic Holzer Date: Thu, 4 May 2023 14:49:19 +0100 Subject: [PATCH 27/40] Require pandas < 2.0 for compatibility with spark < 3.4 Signed-off-by: Jannic Holzer --- kedro-datasets/.gitignore | 2 +- kedro-datasets/test_requirements.txt | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/kedro-datasets/.gitignore b/kedro-datasets/.gitignore index 3725bd847..721e13f70 100644 --- a/kedro-datasets/.gitignore +++ b/kedro-datasets/.gitignore @@ -147,4 +147,4 @@ docs/tmp-build-artifacts docs/build spark-warehouse metastore_db/ -derby.log \ No newline at end of file +derby.log diff --git a/kedro-datasets/test_requirements.txt b/kedro-datasets/test_requirements.txt index d2231136b..a35bfdf44 100644 --- a/kedro-datasets/test_requirements.txt +++ b/kedro-datasets/test_requirements.txt @@ -24,15 +24,15 @@ lxml~=4.6 matplotlib>=3.0.3, <3.4; python_version < '3.10' # 3.4.0 breaks holoviews matplotlib>=3.5, <3.6; python_version == '3.10' memory_profiler>=0.50.0, <1.0 -mlflow~=2.2.1; python_version>='3.8' mlflow~=1.30.0; python_version=='3.7' +mlflow~=2.2.1; python_version>='3.8' moto==1.3.7; python_version < '3.10' moto==3.0.4; python_version == '3.10' networkx~=2.4 opencv-python~=4.5.5.64 openpyxl>=3.0.3, <4.0 pandas-gbq>=0.12.0, <0.18.0 -pandas>=1.3 # 1.3 for read_xml/to_xml +pandas>=1.3, <2 # 1.3 for read_xml/to_xml, <2 for compatibility with Spark < 3.4 Pillow~=9.0 plotly>=4.8.0, <6.0 polars~=0.15.13 From 944009af752335ce07364669637d616b02e01b44 Mon Sep 17 00:00:00 2001 From: Jannic Holzer Date: Thu, 4 May 2023 15:24:40 +0100 Subject: [PATCH 28/40] Replace use of walrus operator Signed-off-by: Jannic Holzer --- .../kedro_datasets/databricks/managed_table_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 5b9e83e57..68f14ebc9 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -50,7 +50,8 @@ def __post_init__(self): `validate_(self, value) -> raises DataSetError` """ for name, _ in self.__dataclass_fields__.items(): # pylint: disable=E1101 - if method := getattr(self, f"validate_{name}", None): + method = getattr(self, f"validate_{name}", None) + if method: method() def validate_table(self): From 25a293e824176400dc938a081dc5549b1c6aed58 Mon Sep 17 00:00:00 2001 From: Jannic Holzer Date: Fri, 5 May 2023 00:26:11 +0100 Subject: [PATCH 29/40] Add test coverage for validation methods Signed-off-by: Jannic Holzer --- .../tests/databricks/test_managed_table_dataset.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/kedro-datasets/tests/databricks/test_managed_table_dataset.py b/kedro-datasets/tests/databricks/test_managed_table_dataset.py index fbdbaaebc..8a71a4a0f 100644 --- a/kedro-datasets/tests/databricks/test_managed_table_dataset.py +++ b/kedro-datasets/tests/databricks/test_managed_table_dataset.py @@ -214,6 +214,18 @@ def test_missing_primary_key_upsert(self): with pytest.raises(DataSetError): ManagedTableDataSet(table="test", write_mode="upsert") + def test_invalid_table_name(self): + with pytest.raises(DataSetError): + ManagedTableDataSet(table="invalid!") + + def test_invalid_database(self): + with pytest.raises(DataSetError): + ManagedTableDataSet(table="test", database="invalid!") + + def test_invalid_catalog(self): + with pytest.raises(DataSetError): + ManagedTableDataSet(table="test", catalog="invalid!") + def test_schema(self): unity_ds = ManagedTableDataSet( table="test", From 3d6b682db7df90f39ce3c75f81e3f8fc224b4fed Mon Sep 17 00:00:00 2001 From: Jannic Holzer Date: Fri, 5 May 2023 00:27:27 +0100 Subject: [PATCH 30/40] Remove unused versioning functions Signed-off-by: Jannic Holzer --- .../databricks/managed_table_dataset.py | 29 +------------------ 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 68f14ebc9..8cd0d803b 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -4,13 +4,9 @@ import logging import re from dataclasses import dataclass -from functools import partial -from operator import attrgetter from typing import Any, Dict, List, Optional, Union import pandas as pd -from cachetools import Cache, cachedmethod -from cachetools.keys import hashkey from kedro.io.core import ( AbstractVersionedDataSet, DataSetError, @@ -256,7 +252,6 @@ def __init__( # pylint: disable=R0913 json_schema=schema, ) - self._version_cache = Cache(maxsize=2) self._version = version super().__init__( @@ -265,28 +260,6 @@ def __init__( # pylint: disable=R0913 exists_function=self._exists, ) - @cachedmethod(cache=attrgetter("_version_cache"), key=partial(hashkey, "load")) - def _fetch_latest_load_version(self) -> int: - # When load version is unpinned, fetch the most recent existing - # version from the given path. - latest_history = ( - self._get_spark() - .sql(f"DESCRIBE HISTORY {self._table.full_table_location()} LIMIT 1") - .collect() - ) - if len(latest_history) != 1: - raise VersionNotFoundError( - f"Did not find any versions for {self._table.full_table_location()}" - ) - return latest_history[0].version - - # 'key' is set to prevent cache key overlapping for load and save: - # https://cachetools.readthedocs.io/en/stable/#cachetools.cachedmethod - @cachedmethod(cache=attrgetter("_version_cache"), key=partial(hashkey, "save")) - def _fetch_latest_save_version(self) -> int: - """Generate and cache the current save version""" - return None - @staticmethod def _get_spark() -> SparkSession: return SparkSession.builder.getOrCreate() @@ -312,7 +285,7 @@ def _load(self) -> Union[DataFrame, pd.DataFrame]: .table(self._table.full_table_location()) ) except Exception as exc: - raise VersionNotFoundError(self._version) from exc + raise VersionNotFoundError(self._version.load) from exc else: data = self._get_spark().table(self._table.full_table_location()) if self._table.dataframe_type == "pandas": From b37a19837ff2e124bd2e4b9abe9119818b43d081 Mon Sep 17 00:00:00 2001 From: Jannic Holzer Date: Fri, 5 May 2023 01:13:43 +0100 Subject: [PATCH 31/40] Fix exception catching for invalid schema, add test for invalid schema Signed-off-by: Jannic Holzer --- .../databricks/managed_table_dataset.py | 2 +- .../tests/databricks/test_managed_table_dataset.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 8cd0d803b..1fb2729b7 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -138,7 +138,7 @@ def schema(self) -> StructType: try: if self.json_schema is not None: schema = StructType.fromJson(self.json_schema) - except ParseException as exc: + except (KeyError, ValueError) as exc: raise DataSetError(exc) from exc return schema diff --git a/kedro-datasets/tests/databricks/test_managed_table_dataset.py b/kedro-datasets/tests/databricks/test_managed_table_dataset.py index 8a71a4a0f..7f015c6a2 100644 --- a/kedro-datasets/tests/databricks/test_managed_table_dataset.py +++ b/kedro-datasets/tests/databricks/test_managed_table_dataset.py @@ -255,6 +255,20 @@ def test_schema(self): ) assert unity_ds._table.schema() == expected_schema + def test_invalid_schema(self): + with pytest.raises(DataSetError): + ManagedTableDataSet( + table="test", + schema={ + "fields": [ + { + "invalid": "schema", + } + ], + "type": "struct", + }, + )._table.schema() + def test_catalog_exists(self): unity_ds = ManagedTableDataSet( catalog="test", database="invalid", table="test_not_there" From 952cf3de0f3b0e098dfe910567659037b7cbab70 Mon Sep 17 00:00:00 2001 From: Jannic Holzer Date: Fri, 5 May 2023 01:29:37 +0100 Subject: [PATCH 32/40] Add pylint ignore Signed-off-by: Jannic Holzer --- kedro-datasets/tests/databricks/test_managed_table_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/kedro-datasets/tests/databricks/test_managed_table_dataset.py b/kedro-datasets/tests/databricks/test_managed_table_dataset.py index 7f015c6a2..9aae08707 100644 --- a/kedro-datasets/tests/databricks/test_managed_table_dataset.py +++ b/kedro-datasets/tests/databricks/test_managed_table_dataset.py @@ -169,6 +169,7 @@ def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession): return spark_session.createDataFrame(data, schema) +# pylint: disable=too-many-public-methods class TestManagedTableDataSet: def test_full_table(self): unity_ds = ManagedTableDataSet(catalog="test", database="test", table="test") From 743816ea818c6992d531c924f111f732e4053883 Mon Sep 17 00:00:00 2001 From: Jannic Holzer Date: Fri, 12 May 2023 14:10:09 +0100 Subject: [PATCH 33/40] Add tests/databricks to ignore for no-spark tests Signed-off-by: Jannic Holzer --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index be653ed59..4e0b4e640 100644 --- a/Makefile +++ b/Makefile @@ -52,10 +52,10 @@ sign-off: # kedro-datasets related only test-no-spark: - cd kedro-datasets && pytest tests --no-cov --ignore tests/spark --numprocesses 4 --dist loadfile + cd kedro-datasets && pytest tests --no-cov --ignore tests/spark --ignore tests/databricks --numprocesses 4 --dist loadfile test-no-spark-sequential: - cd kedro-datasets && pytest tests --no-cov --ignore tests/spark + cd kedro-datasets && pytest tests --no-cov --ignore tests/spark --ignore tests/databricks # kedro-datasets/snowflake tests skipped from default scope test-snowflake-only: From 0a160a5ce9fab152e0dfb405af39d6249e36453a Mon Sep 17 00:00:00 2001 From: Jannic <37243923+jmholzer@users.noreply.github.com> Date: Wed, 17 May 2023 17:47:57 +0100 Subject: [PATCH 34/40] Update kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py Co-authored-by: Nok Lam Chan --- .../kedro_datasets/databricks/managed_table_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 1fb2729b7..2dc1abe8b 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -21,7 +21,7 @@ @dataclass(frozen=True) -class ManagedTable: # pylint: disable=R0902 +class ManagedTable: # pylint: disable=too-many-instance-attributes """Stores the definition of a managed table""" # regex for tables, catalogs and schemas From daf5411578ce5791022e62227873b0610b76f651 Mon Sep 17 00:00:00 2001 From: Jannic <37243923+jmholzer@users.noreply.github.com> Date: Wed, 17 May 2023 17:48:10 +0100 Subject: [PATCH 35/40] Update kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py Co-authored-by: Nok Lam Chan --- .../kedro_datasets/databricks/managed_table_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 2dc1abe8b..c7d44b1e8 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -45,7 +45,7 @@ def __post_init__(self): The validation is performed by calling a function named: `validate_(self, value) -> raises DataSetError` """ - for name, _ in self.__dataclass_fields__.items(): # pylint: disable=E1101 + for name in self.__dataclass_fields__.keys(): # pylint: disable=no-member method = getattr(self, f"validate_{name}", None) if method: method() From c1e78cd8d0859601c46e2c1638d95657ba2c34bf Mon Sep 17 00:00:00 2001 From: Jannic Holzer Date: Thu, 18 May 2023 22:04:22 +0100 Subject: [PATCH 36/40] Remove spurious mlflow test dependency Signed-off-by: Jannic Holzer --- kedro-datasets/test_requirements.txt | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/kedro-datasets/test_requirements.txt b/kedro-datasets/test_requirements.txt index a35bfdf44..fe20fee5f 100644 --- a/kedro-datasets/test_requirements.txt +++ b/kedro-datasets/test_requirements.txt @@ -24,15 +24,13 @@ lxml~=4.6 matplotlib>=3.0.3, <3.4; python_version < '3.10' # 3.4.0 breaks holoviews matplotlib>=3.5, <3.6; python_version == '3.10' memory_profiler>=0.50.0, <1.0 -mlflow~=1.30.0; python_version=='3.7' -mlflow~=2.2.1; python_version>='3.8' moto==1.3.7; python_version < '3.10' moto==3.0.4; python_version == '3.10' networkx~=2.4 opencv-python~=4.5.5.64 openpyxl>=3.0.3, <4.0 pandas-gbq>=0.12.0, <0.18.0 -pandas>=1.3, <2 # 1.3 for read_xml/to_xml, <2 for compatibility with Spark < 3.4 +pandas>=1.3, <2 # 1.3 for read_xml/to_xml, <2 for compatibility with Spark < 3.4 Pillow~=9.0 plotly>=4.8.0, <6.0 polars~=0.15.13 From dbfd64100728fcf14a5bcfbde98d2fb636e28c93 Mon Sep 17 00:00:00 2001 From: Jannic Holzer Date: Mon, 22 May 2023 13:49:08 +0100 Subject: [PATCH 37/40] Add explicit check for database existence Signed-off-by: Jannic Holzer --- .../kedro_datasets/databricks/managed_table_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index c7d44b1e8..cca8639c7 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -124,7 +124,7 @@ def full_table_location(self) -> str: full_table_location = None if self.catalog and self.database and self.table: full_table_location = f"`{self.catalog}`.`{self.database}`.`{self.table}`" - elif self.table: + elif self.database and self.table: full_table_location = f"`{self.database}`.`{self.table}`" return full_table_location From 5ea0d66e5c93e38eddffd88aad30badc5458067e Mon Sep 17 00:00:00 2001 From: Jannic Holzer Date: Mon, 22 May 2023 13:51:05 +0100 Subject: [PATCH 38/40] Remove character limit for table names Signed-off-by: Jannic Holzer --- .../kedro_datasets/databricks/managed_table_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index cca8639c7..da6e00d8e 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -25,7 +25,7 @@ class ManagedTable: # pylint: disable=too-many-instance-attributes """Stores the definition of a managed table""" # regex for tables, catalogs and schemas - _NAMING_REGEX = r"\b[0-9a-zA-Z_-]{1,32}\b" + _NAMING_REGEX = r"\b[0-9a-zA-Z_-]{1,}\b" _VALID_WRITE_MODES = ["overwrite", "upsert", "append"] _VALID_DATAFRAME_TYPES = ["spark", "pandas"] database: str From c2fd478270f77052bfffa564582db0673efa989a Mon Sep 17 00:00:00 2001 From: Jannic Holzer Date: Mon, 22 May 2023 15:37:30 +0100 Subject: [PATCH 39/40] Refactor validation steps in ManagedTable Signed-off-by: Jannic Holzer --- .../databricks/managed_table_dataset.py | 43 +++++++++++-------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index da6e00d8e..a24303a70 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -29,7 +29,7 @@ class ManagedTable: # pylint: disable=too-many-instance-attributes _VALID_WRITE_MODES = ["overwrite", "upsert", "append"] _VALID_DATAFRAME_TYPES = ["spark", "pandas"] database: str - catalog: str + catalog: Optional[str] table: str write_mode: str dataframe_type: str @@ -46,31 +46,36 @@ def __post_init__(self): `validate_(self, value) -> raises DataSetError` """ for name in self.__dataclass_fields__.keys(): # pylint: disable=no-member - method = getattr(self, f"validate_{name}", None) + method = getattr(self, f"_validate_{name}", None) if method: method() - def validate_table(self): - """validates table name + def _validate_table(self): + """Validates table name Raises: - DataSetError: + DataSetError: If the table name does not conform to naming constraints. """ + if not self.table: + raise DataSetError("table name must be provided") + if not re.fullmatch(self._NAMING_REGEX, self.table): raise DataSetError("table does not conform to naming") - def validate_database(self): - """validates database name + def _validate_database(self): + """Validates database name Raises: - DataSetError: If the table name does not conform to naming constraints. + DataSetError: If the dataset name does not conform to naming constraints. """ - if self.database: - if not re.fullmatch(self._NAMING_REGEX, self.database): - raise DataSetError("database does not conform to naming") + if not self.database: + raise DataSetError("database name must be provided") + + if not re.fullmatch(self._NAMING_REGEX, self.database): + raise DataSetError("database does not conform to naming") - def validate_catalog(self): - """validates catalog name + def _validate_catalog(self): + """Validates catalog name Raises: DataSetError: If the catalog name does not conform to naming constraints. @@ -79,8 +84,8 @@ def validate_catalog(self): if not re.fullmatch(self._NAMING_REGEX, self.catalog): raise DataSetError("catalog does not conform to naming") - def validate_write_mode(self): - """validates the write mode + def _validate_write_mode(self): + """Validates the write mode Raises: DataSetError: If an invalid `write_mode` is passed. @@ -92,8 +97,8 @@ def validate_write_mode(self): f"`write_mode` must be one of: {valid_modes}" ) - def validate_dataframe_type(self): - """validates the dataframe type + def _validate_dataframe_type(self): + """Validates the dataframe type Raises: DataSetError: If an invalid `dataframe_type` is passed @@ -102,8 +107,8 @@ def validate_dataframe_type(self): valid_types = ", ".join(self._VALID_DATAFRAME_TYPES) raise DataSetError(f"`dataframe_type` must be one of {valid_types}") - def validate_primary_key(self): - """validates the primary key of the table + def _validate_primary_key(self): + """Validates the primary key of the table Raises: DataSetError: If no `primary_key` is specified. From 7e52e9c5be9433b44dd387c4a6885c96c8898425 Mon Sep 17 00:00:00 2001 From: Jannic Holzer Date: Mon, 22 May 2023 15:58:13 +0100 Subject: [PATCH 40/40] Remove spurious checks for table and schema name existence Signed-off-by: Jannic Holzer --- .../kedro_datasets/databricks/managed_table_dataset.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index a24303a70..01ec15a6f 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -56,9 +56,6 @@ def _validate_table(self): Raises: DataSetError: If the table name does not conform to naming constraints. """ - if not self.table: - raise DataSetError("table name must be provided") - if not re.fullmatch(self._NAMING_REGEX, self.table): raise DataSetError("table does not conform to naming") @@ -68,9 +65,6 @@ def _validate_database(self): Raises: DataSetError: If the dataset name does not conform to naming constraints. """ - if not self.database: - raise DataSetError("database name must be provided") - if not re.fullmatch(self._NAMING_REGEX, self.database): raise DataSetError("database does not conform to naming")