diff --git a/kedro-datasets/kedro_datasets/databricks/__init__.py b/kedro-datasets/kedro_datasets/databricks/__init__.py index cba69d17c..7819a2e06 100644 --- a/kedro-datasets/kedro_datasets/databricks/__init__.py +++ b/kedro-datasets/kedro_datasets/databricks/__init__.py @@ -1,4 +1,3 @@ """Provides interface to Unity Catalog Tables.""" from .unity import ManagedTableDataSet -from .mlflow import MLFlowModel, MLFlowArtifact, MLFlowDataSet, MLFlowMetrics, MLFlowModelMetadata, MLFlowTags diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/__init__.py b/kedro-datasets/kedro_datasets/databricks/mlflow/__init__.py deleted file mode 100644 index 1c3babc0f..000000000 --- a/kedro-datasets/kedro_datasets/databricks/mlflow/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .artifact import MLFlowArtifact -from .dataset import MLFlowDataSet -from .metrics import MLFlowMetrics -from .model_metadata import MLFlowModelMetadata -from .tags import MLFlowTags -from .model import MLFlowModel diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/artifact.py b/kedro-datasets/kedro_datasets/databricks/mlflow/artifact.py deleted file mode 100644 index 15691db43..000000000 --- a/kedro-datasets/kedro_datasets/databricks/mlflow/artifact.py +++ /dev/null @@ -1,133 +0,0 @@ -import logging -import os -from pathlib import Path -from tempfile import mkdtemp -from typing import Any, Dict - -import mlflow -from kedro.io.core import AbstractDataSet -from kedro.utils import load_obj as load_dataset -from mlflow.exceptions import MlflowException -from mlflow.tracking.artifact_utils import _download_artifact_from_uri - -from .common import MLFLOW_RUN_ID_ENV_VAR, ModelOpsException - -logger = logging.getLogger(__name__) - - -class MLFlowArtifact(AbstractDataSet): - def __init__( - self, - dataset_name: str, - dataset_type: str, - dataset_args: Dict[str, Any] = None, - *, - file_suffix: str, - run_id: str = None, - registered_model_name: str = None, - registered_model_version: str = None, - ): - """ - Log arbitrary Kedro datasets as mlflow artifacts - - Args: - dataset_name: dataset name as it should appear on mlflow run - dataset_type: full kedro dataset class name (incl. module) - dataset_args: kedro dataset args - file_suffix: file extension as it should appear on mlflow run - run_id: mlflow run-id, this should only be used when loading a - dataset saved from run which is different from active run - registered_model_name: mlflow registered model name, this should - only be used when loading an artifact linked to a model of - interest (i.e. back tracing atifacts from the run corresponding - to the model) - registered_model_version: mlflow registered model name, should be - used in combination with `registered_model_name` - - `run_id` and `registered_model_name` can't be specified together. - """ - if None in (registered_model_name, registered_model_version): - if registered_model_name or registered_model_version: - raise ModelOpsException( - "'registered_model_name' and " - "'registered_model_version' should be " - "set together" - ) - - if run_id and registered_model_name: - raise ModelOpsException( - "'run_id' cannot be passed when " "'registered_model_name' is set" - ) - - self._dataset_name = dataset_name - self._dataset_type = dataset_type - self._dataset_args = dataset_args or {} - self._file_suffix = file_suffix - self._run_id = run_id or os.environ.get(MLFLOW_RUN_ID_ENV_VAR) - self._registered_model_name = registered_model_name - self._registered_model_version = registered_model_version - - self._artifact_path = f"{dataset_name}{self._file_suffix}" - - self._filepath = Path(mkdtemp()) / self._artifact_path - - if registered_model_name: - self._version = f"{registered_model_name}/{registered_model_version}" - else: - self._version = run_id - - def _save(self, data: Any) -> None: - cls = load_dataset(self._dataset_type) - ds = cls(filepath=self._filepath.as_posix(), **self._dataset_args) - ds.save(data) - - filepath = self._filepath.as_posix() - if os.path.isdir(filepath): - mlflow.log_artifacts(self._filepath.as_posix(), self._artifact_path) - elif os.path.isfile(filepath): - mlflow.log_artifact(self._filepath.as_posix()) - else: - raise RuntimeError("cls.save() didn't work. Unexpected error.") - - run_id = mlflow.active_run().info.run_id - if self._version is not None: - logger.warning( - f"Ignoring version {self._version} set " - f"earlier, will use version='{run_id}' for loading" - ) - self._version = run_id - - def _load(self) -> Any: - if self._version is None: - msg = ( - "Could not determine the version to load. " - "Please specify either 'run_id' or 'registered_model_name' " - "along with 'registered_model_version' explicitly in " - "MLFlowArtifact constructor" - ) - raise MlflowException(msg) - - if "/" in self._version: - model_uri = f"models:/{self._version}" - model = mlflow.pyfunc.load_model(model_uri) - run_id = model._model_meta.run_id - else: - run_id = self._version - - local_path = _download_artifact_from_uri( - f"runs:/{run_id}/{self._artifact_path}" - ) - - cls = load_dataset(self._dataset_type) - ds = cls(filepath=local_path, **self._dataset_args) - return ds.load() - - def _describe(self) -> Dict[str, Any]: - return dict( - dataset_name=self._dataset_name, - dataset_type=self._dataset_type, - dataset_args=self._dataset_args, - file_suffix=self._file_suffix, - registered_model_name=self._registered_model_name, - registered_model_version=self._registered_model_version, - ) diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/common.py b/kedro-datasets/kedro_datasets/databricks/mlflow/common.py deleted file mode 100644 index af102d6b3..000000000 --- a/kedro-datasets/kedro_datasets/databricks/mlflow/common.py +++ /dev/null @@ -1,89 +0,0 @@ -import mlflow -from mlflow.tracking import MlflowClient - -MLFLOW_RUN_ID_ENV_VAR = "mlflow_run_id" - - -def parse_model_uri(model_uri): - parts = model_uri.split("/") - - if len(parts) < 2 or len(parts) > 3: - raise ValueError( - f"model uri should have the format " - f"'models:/' or " - f"'models://', got {model_uri}" - ) - - if parts[0] == "models:": - protocol = "models" - else: - raise ValueError("model uri should start with `models:/`, got %s", model_uri) - - name = parts[1] - - client = MlflowClient() - if len(parts) == 2: - results = client.search_model_versions(f"name='{name}'") - sorted_results = sorted( - results, - key=lambda modelversion: modelversion.creation_timestamp, - reverse=True, - ) - latest_version = sorted_results[0].version - version = latest_version - else: - version = parts[2] - if version in ["Production", "Staging", "Archived"]: - results = client.get_latest_versions(name, stages=[version]) - if len(results) > 0: - version = results[0].version - else: - version = None - - return protocol, name, version - - -def promote_model(model_name, model_version, stage): - import datetime - - now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - - client = MlflowClient() - - new_model_uri = f"models:/{model_name}/{model_version}" - _, _, new_model_version = parse_model_uri(new_model_uri) - new_model = mlflow.pyfunc.load_model(new_model_uri) - new_model_runid = new_model._model_meta.run_id - - msg = f"```Promoted version {model_version} to {stage}, at {now}```" - client.set_tag(new_model_runid, "mlflow.note.content", msg) - client.set_tag(new_model_runid, "Promoted at", now) - - results = client.get_latest_versions(model_name, stages=[stage]) - if len(results) > 0: - old_model_uri = f"models:/{model_name}/{stage}" - _, _, old_model_version = parse_model_uri(old_model_uri) - old_model = mlflow.pyfunc.load_model(old_model_uri) - old_model_runid = old_model._model_meta.run_id - - client.set_tag( - old_model._model_meta.run_id, - "mlflow.note.content", - f"```Replaced by version {new_model_version}, at {now}```", - ) - client.set_tag(old_model_runid, "Retired at", now) - client.set_tag(old_model_runid, "Replaced by", new_model_version) - - client.set_tag(new_model_runid, "Replaces", old_model_version) - - client.transition_model_version_stage( - name=model_name, version=old_model_version, stage="Archived" - ) - - client.transition_model_version_stage( - name=model_name, version=new_model_version, stage=stage - ) - - -class ModelOpsException(Exception): - pass diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/dataset.py b/kedro-datasets/kedro_datasets/databricks/mlflow/dataset.py deleted file mode 100644 index ee0a1e0ed..000000000 --- a/kedro-datasets/kedro_datasets/databricks/mlflow/dataset.py +++ /dev/null @@ -1,80 +0,0 @@ -import importlib -import logging -from typing import Any, Dict - -from kedro.io.core import AbstractDataSet - -from .common import ModelOpsException, parse_model_uri - -logger = logging.getLogger(__name__) - - -class MLFlowDataSet(AbstractDataSet): - def __init__( - self, - flavor: str, - dataset_name: str = None, - dataset_type: str = None, - dataset_args: Dict[str, Any] = None, - *, - file_suffix: str = None, - load_version: str = None, - ): - self._flavor = flavor - self._dataset_name = dataset_name - self._dataset_type = dataset_type - self._dataset_args = dataset_args or {} - self._file_suffix = file_suffix - self._load_version = load_version - - def _save(self, model: Any) -> None: - if self._load_version is not None: - msg = ( - f"Trying to save an MLFlowDataSet::{self._describe} which " - f"was initialized with load_version={self._load_version}. " - f"This can lead to inconsistency between saved and loaded " - f"versions, therefore disallowed. Please create separate " - f"catalog entries for saved and loaded datasets." - ) - raise ModelOpsException(msg) - - importlib.import_module(self._flavor).log_model( - model, - self._dataset_name, - registered_model_name=self._dataset_name, - dataset_type=self._dataset_type, - dataset_args=self._dataset_args, - file_suffix=self._file_suffix, - ) - - def _load(self) -> Any: - *_, latest_version = parse_model_uri(f"models:/{self._dataset_name}") - - dataset_version = self._load_version or latest_version - *_, dataset_version = parse_model_uri( - f"models:/{self._dataset_name}/{dataset_version}" - ) - - logger.info(f"Loading model '{self._dataset_name}' version '{dataset_version}'") - - if dataset_version != latest_version: - logger.warning(f"Newer version {latest_version} exists in repo") - - model = importlib.import_module(self._flavor).load_model( - f"models:/{self._dataset_name}/{dataset_version}", - dataset_type=self._dataset_type, - dataset_args=self._dataset_args, - file_suffix=self._file_suffix, - ) - - return model - - def _describe(self) -> Dict[str, Any]: - return dict( - flavor=self._flavor, - dataset_name=self._dataset_name, - dataset_type=self._dataset_type, - dataset_args=self._dataset_args, - file_suffix=self._file_suffix, - load_version=self._load_version, - ) diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/flavors/__init__.py b/kedro-datasets/kedro_datasets/databricks/mlflow/flavors/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/flavors/kedro_dataset_flavor.py b/kedro-datasets/kedro_datasets/databricks/mlflow/flavors/kedro_dataset_flavor.py deleted file mode 100644 index e0a43a1b0..000000000 --- a/kedro-datasets/kedro_datasets/databricks/mlflow/flavors/kedro_dataset_flavor.py +++ /dev/null @@ -1,154 +0,0 @@ -import os -import sys -from pathlib import Path -from typing import Any, Dict, Union - -import kedro -import yaml -from kedro.utils import load_obj as load_dataset -from mlflow import pyfunc -from mlflow.exceptions import MlflowException -from mlflow.models import Model -from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS -from mlflow.tracking.artifact_utils import _download_artifact_from_uri -from mlflow.utils.environment import _mlflow_conda_env -from mlflow.utils.model_utils import _get_flavor_configuration - -FLAVOR_NAME = "kedro_dataset" - - -DEFAULT_CONDA_ENV = _mlflow_conda_env( - additional_conda_deps=["kedro[all]={}".format(kedro.__version__)], - additional_pip_deps=None, - additional_conda_channels=None, -) - - -def save_model( - data: Any, - path: str, - conda_env: Union[str, Dict[str, Any]] = None, - mlflow_model: Model = Model(), - *, - dataset_type: str, - dataset_args: Dict[str, Any], - file_suffix: str, -): - if os.path.exists(path): - raise RuntimeError("Path '{}' already exists".format(path)) - os.makedirs(path) - - model_data_subpath = f"data.{file_suffix}" - model_data_path = os.path.join(path, model_data_subpath) - - cls = load_dataset(dataset_type) - ds = cls(filepath=model_data_path, **dataset_args) - ds.save(data) - - conda_env_subpath = "conda.yaml" - if conda_env is None: - conda_env = DEFAULT_CONDA_ENV - elif not isinstance(conda_env, dict): - with open(conda_env, "r") as f: - conda_env = yaml.safe_load(f) - with open(os.path.join(path, conda_env_subpath), "w") as f: - yaml.safe_dump(conda_env, stream=f, default_flow_style=False) - - pyfunc.add_to_model( - mlflow_model, - loader_module=__name__, - data=model_data_subpath, - env=conda_env_subpath, - ) - - mlflow_model.add_flavor( - FLAVOR_NAME, - data=model_data_subpath, - dataset_type=dataset_type, - dataset_args=dataset_args, - file_suffix=file_suffix, - ) - mlflow_model.save(os.path.join(path, "MLmodel")) - - -def log_model( - model: Any, - artifact_path: str, - conda_env: Dict[str, Any] = None, - registered_model_name: str = None, - await_registration_for: int = DEFAULT_AWAIT_MAX_SLEEP_SECONDS, - *, - dataset_type: str, - dataset_args: Dict[str, Any], - file_suffix: str, -): - return Model.log( - artifact_path=artifact_path, - flavor=sys.modules[__name__], - registered_model_name=registered_model_name, - await_registration_for=await_registration_for, - data=model, - conda_env=conda_env, - dataset_type=dataset_type, - dataset_args=dataset_args, - file_suffix=file_suffix, - ) - - -def _load_model_from_local_file( - local_path: str, - *, - dataset_type: str = None, - dataset_args: Dict[str, Any] = None, - file_suffix: str = None, -): - if dataset_type is not None: - model_data_subpath = f"data.{file_suffix}" - data_path = os.path.join(local_path, model_data_subpath) - else: - flavor_conf = _get_flavor_configuration( - model_path=local_path, flavor_name=FLAVOR_NAME - ) - data_path = os.path.join(local_path, flavor_conf["data"]) - dataset_type = flavor_conf["dataset_type"] - dataset_args = flavor_conf["dataset_args"] - - cls = load_dataset(dataset_type) - ds = cls(filepath=data_path, **dataset_args) - return ds.load() - - -def load_model( - model_uri: str, - *, - dataset_type: str = None, - dataset_args: Dict[str, Any] = None, - file_suffix: str = None, -): - if dataset_type is not None or dataset_args is not None or file_suffix is not None: - assert ( - dataset_type is not None - and dataset_args is not None - and file_suffix is not None - ), ("Please set 'dataset_type', " "'dataset_args' and 'file_suffix'") - - local_path = _download_artifact_from_uri(model_uri) - return _load_model_from_local_file( - local_path, - dataset_type=dataset_type, - dataset_args=dataset_args, - file_suffix=file_suffix, - ) - - -def _load_pyfunc(model_file: str): - local_path = Path(model_file).parent.absolute() - model = _load_model_from_local_file(local_path) - if not hasattr(model, "predict"): - try: - setattr(model, "predict", None) - except AttributeError: - raise MlflowException( - f"`pyfunc` flavor not supported, use " f"{__name__}.load instead" - ) - return model diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/metrics.py b/kedro-datasets/kedro_datasets/databricks/mlflow/metrics.py deleted file mode 100644 index 1c7760375..000000000 --- a/kedro-datasets/kedro_datasets/databricks/mlflow/metrics.py +++ /dev/null @@ -1,93 +0,0 @@ -import logging -from typing import Any, Dict, Union - -import mlflow -from kedro.io.core import AbstractDataSet -from mlflow.exceptions import MlflowException -from mlflow.tracking import MlflowClient - -from .common import ModelOpsException - -logger = logging.getLogger(__name__) - - -class MLFlowMetrics(AbstractDataSet): - def __init__( - self, - prefix: str = None, - run_id: str = None, - registered_model_name: str = None, - registered_model_version: str = None, - ): - if None in (registered_model_name, registered_model_version): - if registered_model_name or registered_model_version: - raise ModelOpsException( - "'registered_model_name' and " - "'registered_model_version' should be " - "set together" - ) - - if run_id and registered_model_name: - raise ModelOpsException( - "'run_id' cannot be passed when " "'registered_model_name' is set" - ) - - self._prefix = prefix - self._run_id = run_id - self._registered_model_name = registered_model_name - self._registered_model_version = registered_model_version - - if registered_model_name: - self._version = f"{registered_model_name}/{registered_model_version}" - else: - self._version = run_id - - def _save(self, metrics: Dict[str, Union[str, float, int]]) -> None: - if self._prefix is not None: - metrics = {f"{self._prefix}_{key}": value for key, value in metrics.items()} - mlflow.log_metrics(metrics) - - run_id = mlflow.active_run().info.run_id - if self._version is not None: - logger.warning( - f"Ignoring version {self._version.save} set " - f"earlier, will use version='{run_id}' for loading" - ) - self._version = run_id - - def _load(self) -> Any: - if self._version is None: - msg = ( - "Could not determine the version to load. " - "Please specify either 'run_id' or 'registered_model_name' " - "along with 'registered_model_version' explicitly in " - "MLFlowMetrics constructor" - ) - raise MlflowException(msg) - - client = MlflowClient() - - if "/" in self._version: - model_uri = f"models:/{self._version}" - model = mlflow.pyfunc.load_model(model_uri) - run_id = model._model_meta.run_id - else: - run_id = self._version - - run = client.get_run(run_id) - metrics = run.data.metrics - if self._prefix is not None: - metrics = { - key[len(self._prefix) + 1 :]: value - for key, value in metrics.items() - if key[: len(self._prefix)] == self._prefix - } - return metrics - - def _describe(self) -> Dict[str, Any]: - return dict( - prefix=self._prefix, - run_id=self._run_id, - registered_model_name=self._registered_model_name, - registered_model_version=self._registered_model_version, - ) diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/model.py b/kedro-datasets/kedro_datasets/databricks/mlflow/model.py deleted file mode 100644 index c5f2356a2..000000000 --- a/kedro-datasets/kedro_datasets/databricks/mlflow/model.py +++ /dev/null @@ -1,75 +0,0 @@ -import importlib -import logging -from typing import Any, Dict - -from kedro.io.core import AbstractDataSet -from mlflow.models.signature import ModelSignature - -from .common import ModelOpsException, parse_model_uri - -logger = logging.getLogger(__name__) - - -class MLFlowModel(AbstractDataSet): - def __init__( - self, - flavor: str, - model_name: str, - signature: Dict[str, Dict[str, str]] = None, - input_example: Dict[str, Any] = None, - load_version: str = None, - ): - self._flavor = flavor - self._model_name = model_name - - if signature: - self._signature = ModelSignature.from_dict(signature) - else: - self._signature = None - self._input_example = input_example - - self._load_version = load_version - - def _save(self, model: Any) -> None: - if self._load_version is not None: - msg = ( - f"Trying to save an MLFlowModel::{self._describe} which " - f"was initialized with load_version={self._load_version}. " - f"This can lead to inconsistency between saved and loaded " - f"versions, therefore disallowed. Please create separate " - f"catalog entries for saved and loaded datasets." - ) - raise ModelOpsException(msg) - - importlib.import_module(self._flavor).log_model( - model, - self._model_name, - registered_model_name=self._model_name, - signature=self._signature, - input_example=self._input_example, - ) - - def _load(self) -> Any: - *_, latest_version = parse_model_uri(f"models:/{self._model_name}") - - model_version = self._load_version or latest_version - - logger.info(f"Loading model '{self._model_name}' version '{model_version}'") - - if model_version != latest_version: - logger.warning(f"Newer version {latest_version} exists in repo") - - model = importlib.import_module(self._flavor).load_model( - f"models:/{self._model_name}/{model_version}" - ) - - return model - - def _describe(self) -> Dict[str, Any]: - return dict( - flavor=self._flavor, - model_name=self._model_name, - signature=self._signature, - input_example=self._input_example, - load_version=self._load_version, - ) diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/model_metadata.py b/kedro-datasets/kedro_datasets/databricks/mlflow/model_metadata.py deleted file mode 100644 index 3c160cec4..000000000 --- a/kedro-datasets/kedro_datasets/databricks/mlflow/model_metadata.py +++ /dev/null @@ -1,49 +0,0 @@ -import logging -from typing import Any, Dict, Union - -import mlflow -from kedro.io.core import AbstractDataSet - -from .common import ModelOpsException, parse_model_uri - -logger = logging.getLogger(__name__) - - -class MLFlowModelMetadata(AbstractDataSet): - def __init__( - self, registered_model_name: str, registered_model_version: str = None - ): - self._model_name = registered_model_name - self._model_version = registered_model_version - - def _save(self, tags: Dict[str, Union[str, float, int]]) -> None: - raise NotImplementedError() - - def _load(self) -> Any: - if self._model_version is None: - model_uri = f"models:/{self._model_name}" - else: - model_uri = f"models:/{self._model_name}/{self._model_version}" - _, _, load_version = parse_model_uri(model_uri) - - if load_version is None: - raise ModelOpsException( - f"No model with version " f"'{self._model_version}'" - ) - - pyfunc_model = mlflow.pyfunc.load_model( - f"models:/{self._model_name}/{load_version}" - ) - all_metadata = pyfunc_model._model_meta - model_metadata = { - "model_name": self._model_name, - "model_version": int(load_version), - "run_id": all_metadata.run_id, - } - return model_metadata - - def _describe(self) -> Dict[str, Any]: - return dict( - registered_model_name=self._model_name, - registered_model_version=self._model_version, - ) diff --git a/kedro-datasets/kedro_datasets/databricks/mlflow/tags.py b/kedro-datasets/kedro_datasets/databricks/mlflow/tags.py deleted file mode 100644 index 153810ae4..000000000 --- a/kedro-datasets/kedro_datasets/databricks/mlflow/tags.py +++ /dev/null @@ -1,94 +0,0 @@ -import logging -from typing import Any, Dict, Union - -import mlflow -from kedro.io.core import AbstractDataSet -from mlflow.exceptions import MlflowException -from mlflow.tracking import MlflowClient - -from .common import ModelOpsException - -logger = logging.getLogger(__name__) - - -class MLFlowTags(AbstractDataSet): - def __init__( - self, - prefix: str = None, - run_id: str = None, - registered_model_name: str = None, - registered_model_version: str = None, - ): - if None in (registered_model_name, registered_model_version): - if registered_model_name or registered_model_version: - raise ModelOpsException( - "'registered_model_name' and " - "'registered_model_version' should be " - "set together" - ) - - if run_id and registered_model_name: - raise ModelOpsException( - "'run_id' cannot be passed when " "'registered_model_name' is set" - ) - - self._prefix = prefix - self._run_id = run_id - self._registered_model_name = registered_model_name - self._registered_model_version = registered_model_version - - if registered_model_name: - self._version = f"{registered_model_name}/{registered_model_version}" - else: - self._version = run_id - - def _save(self, tags: Dict[str, Union[str, float, int]]) -> None: - if self._prefix is not None: - tags = {f"{self._prefix}_{key}": value for key, value in tags.items()} - - mlflow.set_tags(tags) - - run_id = mlflow.active_run().info.run_id - if self._version is not None: - logger.warning( - f"Ignoring version {self._version.save} set " - f"earlier, will use version='{run_id}' for loading" - ) - self._version = run_id - - def _load(self) -> Any: - if self._version is None: - msg = ( - "Could not determine the version to load. " - "Please specify either 'run_id' or 'registered_model_name' " - "along with 'registered_model_version' explicitly in " - "MLFlowTags constructor" - ) - raise MlflowException(msg) - - client = MlflowClient() - - if "/" in self._version: - model_uri = f"models:/{self._version}" - model = mlflow.pyfunc.load_model(model_uri) - run_id = model._model_meta.run_id - else: - run_id = self._version - - run = client.get_run(run_id) - tags = run.data.tags - if self._prefix is not None: - tags = { - key[len(self._prefix) + 1 :]: value - for key, value in tags.items() - if key[: len(self._prefix)] == self._prefix - } - return tags - - def _describe(self) -> Dict[str, Any]: - return dict( - prefix=self._prefix, - run_id=self._run_id, - registered_model_name=self._registered_model_name, - registered_model_version=self._registered_model_version, - ) diff --git a/kedro-datasets/kedro_datasets/databricks/unity/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/unity/managed_table_dataset.py index b46122197..f0f04b7be 100644 --- a/kedro-datasets/kedro_datasets/databricks/unity/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/unity/managed_table_dataset.py @@ -1,22 +1,26 @@ import logging -from typing import Any, Dict, List, Union import pandas as pd +from operator import attrgetter +from functools import partial +from cachetools.keys import hashkey +from typing import Any, Dict, List, Union +from cachetools import Cache, cachedmethod from kedro.io.core import ( AbstractVersionedDataSet, DataSetError, + Version, VersionNotFoundError, ) from pyspark.sql import DataFrame, SparkSession from pyspark.sql.types import StructType -from pyspark.sql.utils import AnalysisException from cachetools import Cache logger = logging.getLogger(__name__) class ManagedTableDataSet(AbstractVersionedDataSet): - """``ManagedTableDataSet`` loads data into Unity managed tables.""" + """``ManagedTableDataSet`` loads and saves data into managed delta tables.""" # this dataset cannot be used with ``ParallelRunner``, # therefore it has the attribute ``_SINGLE_PROCESS = True`` @@ -34,7 +38,7 @@ def __init__( write_mode: str = "overwrite", dataframe_type: str = "spark", primary_key: Union[str, List[str]] = None, - version: int = None, + version: Version = None, *, # the following parameters are used by the hook to create or update unity schema: Dict[str, Any] = None, # pylint: disable=unused-argument @@ -73,9 +77,8 @@ def __init__( ) self._primary_key = primary_key - - self._version = version self._version_cache = Cache(maxsize=2) + self._version = version self._schema = None if schema is not None: @@ -83,24 +86,16 @@ def __init__( def _get_spark(self) -> SparkSession: return ( - SparkSession.builder.config( - "spark.jars.packages", "io.delta:delta-core_2.12:1.2.1" - ) - .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") - .config( - "spark.sql.catalog.spark_catalog", - "org.apache.spark.sql.delta.catalog.DeltaCatalog", - ) - .getOrCreate() + SparkSession.builder.getOrCreate() ) def _load(self) -> Union[DataFrame, pd.DataFrame]: - if self._version is not None and self._version >= 0: + if self._version and self._version.load >= 0: try: data = ( self._get_spark() .read.format("delta") - .option("versionAsOf", self._version) + .option("versionAsOf", self._version.load) .table(self._full_table_address) ) except: diff --git a/kedro-datasets/tests/databricks/conftest.py b/kedro-datasets/tests/databricks/conftest.py index d360ffb68..26d63b056 100644 --- a/kedro-datasets/tests/databricks/conftest.py +++ b/kedro-datasets/tests/databricks/conftest.py @@ -6,7 +6,6 @@ """ import pytest from pyspark.sql import SparkSession -from delta.pip_utils import configure_spark_with_delta_pip @pytest.fixture(scope="class", autouse=True) diff --git a/kedro-datasets/tests/databricks/test_unity.py b/kedro-datasets/tests/databricks/test_unity.py index 471f81f57..0d54e29e4 100644 --- a/kedro-datasets/tests/databricks/test_unity.py +++ b/kedro-datasets/tests/databricks/test_unity.py @@ -1,5 +1,5 @@ import pytest -from kedro.io.core import DataSetError, VersionNotFoundError +from kedro.io.core import DataSetError, VersionNotFoundError, Version from pyspark.sql.types import IntegerType, StringType, StructField, StructType from pyspark.sql import DataFrame, SparkSession import pandas as pd @@ -195,6 +195,7 @@ def test_describe(self): "dataframe_type": "spark", "primary_key": None, "version": None, + "owner_group": None } def test_invalid_write_mode(self): @@ -413,7 +414,7 @@ def test_load_spark_no_version(self, sample_spark_df: DataFrame): unity_ds.save(sample_spark_df) delta_ds = ManagedTableDataSet( - database="test", table="test_load_spark", version=2 + database="test", table="test_load_spark", version=Version(2,None) ) with pytest.raises(VersionNotFoundError): _ = delta_ds.load() @@ -426,7 +427,7 @@ def test_load_version(self, sample_spark_df: DataFrame, append_spark_df: DataFra unity_ds.save(append_spark_df) loaded_ds = ManagedTableDataSet( - database="test", table="test_load_version", version=0 + database="test", table="test_load_version", version=Version(0,None) ) loaded_df = loaded_ds.load()