Skip to content

Commit

Permalink
adding comments before intiial PR
Browse files Browse the repository at this point in the history
Signed-off-by: Danny Farah <danny_farah@mckinsey.com>
  • Loading branch information
dannyrfar committed Mar 21, 2023
1 parent 9389aa4 commit e6157a5
Show file tree
Hide file tree
Showing 10 changed files with 366 additions and 223 deletions.
7 changes: 6 additions & 1 deletion kedro-datasets/kedro_datasets/databricks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""Provides interface to Unity Catalog Tables."""

from .unity import ManagedTableDataSet
__all__ = ["ManagedTableDataSet"]

from contextlib import suppress

with suppress(ImportError):
from .managed_table_dataset import ManagedTableDataSet
342 changes: 342 additions & 0 deletions kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,342 @@
"""``ManagedTableDataSet`` implementation to access managed delta tables
in Databricks.
"""
import dataclasses
import logging
from functools import partial
from operator import attrgetter
from typing import Any, Dict, List, Union

import pandas as pd
from cachetools import Cache, cachedmethod
from cachetools.keys import hashkey
from kedro.io.core import (
AbstractVersionedDataSet,
DataSetError,
Version,
VersionNotFoundError,
)
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.types import StructType
from pyspark.sql.utils import AnalysisException, ParseException

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class Table: # pylint: disable=R0902
"""Stores the definition of a managed table"""

database: str
catalog: str
table: str
full_table_location: str
write_mode: str
dataframe_type: str
primary_key: str
owner_group: str
partition_columns: str | List[str]


class ManagedTableDataSet(AbstractVersionedDataSet):
"""``ManagedTableDataSet`` loads and saves data into managed delta tables on Databricks.
Example usage for the
`YAML API <https://kedro.readthedocs.io/en/stable/data/\
data_catalog.html#use-the-data-catalog-with-the-yaml-api>`_:
.. code-block:: yaml
names_and_ages@spark:
type: databricks.ManagedTableDataSet
table: names_and_ages
names_and_ages@pandas:
type: databricks.ManagedTableDataSet
table: names_and_ages
dataframe_type: pandas
Example usage for the
`Python API <https://kedro.readthedocs.io/en/stable/data/\
data_catalog.html#use-the-data-catalog-with-the-code-api>`_:
::
Launch a pyspark session with the following configs:
% pyspark --packages io.delta:delta-core_2.12:1.2.1
--conf "spark.sql.extensions=io.delta.sql.DeltaSparkSessionExtension"
--conf "spark.sql.catalog.spark_catalog=org.apache.spark.sql.delta.catalog.DeltaCatalog"
>>> from pyspark.sql import SparkSession
>>> from pyspark.sql.types import (StructField, StringType,
IntegerType, StructType)
>>> from kedro_datasets.databricks import ManagedTableDataSet
>>> schema = StructType([StructField("name", StringType(), True),
StructField("age", IntegerType(), True)])
>>> data = [('Alex', 31), ('Bob', 12), ('Clarke', 65), ('Dave', 29)]
>>> spark_df = SparkSession.builder.getOrCreate().createDataFrame(data, schema)
>>> data_set = ManagedTableDataSet(table="names_and_ages")
>>> data_set.save(spark_df)
>>> reloaded = data_set.load()
>>> reloaded.take(4)"""

# this dataset cannot be used with ``ParallelRunner``,
# therefore it has the attribute ``_SINGLE_PROCESS = True``
# for parallelism within a Spark pipeline please consider
# using ``ThreadRunner`` instead
_SINGLE_PROCESS = True
_VALID_WRITE_MODES = ["overwrite", "upsert", "append"]
_VALID_DATAFRAME_TYPES = ["spark", "pandas"]

def __init__( # pylint: disable=R0913
self,
table: str,
catalog: str = None,
database: str = "default",
write_mode: str = "overwrite",
dataframe_type: str = "spark",
primary_key: Union[str, List[str]] = None,
version: Version = None,
*,
# the following parameters are used by project hooks
# to create or update table properties
schema: Dict[str, Any] = None,
partition_columns: List[str] = None,
owner_group: str = None,
) -> None:
"""Creates a new instance of ``ManagedTableDataSet``."""

full_table_location = None
if catalog and database and table:
full_table_location = f"{catalog}.{database}.{table}"
elif table:
full_table_location = f"{database}.{table}"
if write_mode not in self._VALID_WRITE_MODES:
valid_modes = ", ".join(self._VALID_WRITE_MODES)
raise DataSetError(
f"Invalid `write_mode` provided: {write_mode}. "
f"`write_mode` must be one of: {valid_modes}"
)
if dataframe_type not in self._VALID_DATAFRAME_TYPES:
valid_types = ", ".join(self._VALID_DATAFRAME_TYPES)
raise DataSetError(f"`dataframe_type` must be one of {valid_types}")
if primary_key is None or len(primary_key) == 0:
if write_mode == "upsert":
raise DataSetError(
f"`primary_key` must be provided for" f"`write_mode` {write_mode}"
)
self._table = Table(
database=database,
catalog=catalog,
table=table,
full_table_location=full_table_location,
write_mode=write_mode,
dataframe_type=dataframe_type,
primary_key=primary_key,
owner_group=owner_group,
partition_columns=partition_columns,
)

self._version_cache = Cache(maxsize=2)
self._version = version

self._schema = None
if schema is not None:
self._schema = StructType.fromJson(schema)

super().__init__(
filepath=None,
version=version,
exists_function=self._exists,
)

@cachedmethod(cache=attrgetter("_version_cache"), key=partial(hashkey, "load"))
def _fetch_latest_load_version(self) -> int:
# When load version is unpinned, fetch the most recent existing
# version from the given path.
latest_history = (
self._get_spark()
.sql(f"DESCRIBE HISTORY {self._table.full_table_location} LIMIT 1")
.collect()
)
if len(latest_history) != 1:
raise VersionNotFoundError(
f"Did not find any versions for {self._table.full_table_location}"
)
return latest_history[0].version

# 'key' is set to prevent cache key overlapping for load and save:
# https://cachetools.readthedocs.io/en/stable/#cachetools.cachedmethod
@cachedmethod(cache=attrgetter("_version_cache"), key=partial(hashkey, "save"))
def _fetch_latest_save_version(self) -> int:
"""Generate and cache the current save version"""
return None

@staticmethod
def _get_spark() -> SparkSession:
return SparkSession.builder.getOrCreate()

def _load(self) -> Union[DataFrame, pd.DataFrame]:
"""Loads the version of data in the format defined in the init
(spark|pandas dataframe)
Raises:
VersionNotFoundError: if the version defined in
the init doesn't exist
Returns:
Union[DataFrame, pd.DataFrame]: Returns a dataframe
in the format defined in the init
"""
if self._version and self._version.load >= 0:
try:
data = (
self._get_spark()
.read.format("delta")
.option("versionAsOf", self._version.load)
.table(self._table.full_table_location)
)
except Exception as exc:
raise VersionNotFoundError(self._version) from exc
else:
data = self._get_spark().table(self._table.full_table_location)
if self._table.dataframe_type == "pandas":
data = data.toPandas()
return data

def _save_append(self, data: DataFrame) -> None:
"""Saves the data to the table by appending it
to the location defined in the init
Args:
data (DataFrame): the Spark dataframe to append to the table
"""
data.write.format("delta").mode("append").saveAsTable(
self._table.full_table_location
)

def _save_overwrite(self, data: DataFrame) -> None:
"""Overwrites the data in the table with the data provided.
(this is the default save mode)
Args:
data (DataFrame): the Spark dataframe to overwrite the table with.
"""
delta_table = data.write.format("delta")
if self._table.write_mode == "overwrite":
delta_table = delta_table.mode("overwrite").option(
"overwriteSchema", "true"
)
delta_table.saveAsTable(self._table.full_table_location)

def _save_upsert(self, update_data: DataFrame) -> None:
"""Upserts the data by joining on primary_key columns or column.
If table doesn't exist at save, the data is inserted to a new table.
Args:
update_data (DataFrame): the Spark dataframe to upsert
"""
if self._exists():
base_data = self._get_spark().table(self._table.full_table_location)
base_columns = base_data.columns
update_columns = update_data.columns

if set(update_columns) != set(base_columns):
raise DataSetError(
f"Upsert requires tables to have identical columns. "
f"Delta table {self._table.full_table_location} "
f"has columns: {base_columns}, whereas "
f"dataframe has columns {update_columns}"
)

where_expr = ""
if isinstance(self._table.primary_key, str):
where_expr = (
f"base.{self._table.primary_key}=update.{self._table.primary_key}"
)
elif isinstance(self._table.primary_key, list):
where_expr = " AND ".join(
f"base.{col}=update.{col}" for col in self._table.primary_key
)

update_data.createOrReplaceTempView("update")
self._get_spark().conf.set(
"fullTableAddress", self._table.full_table_location
)
self._get_spark().conf.set("whereExpr", where_expr)
upsert_sql = """MERGE INTO ${fullTableAddress} base USING update ON ${whereExpr}
WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *"""
self._get_spark().sql(upsert_sql)
else:
self._save_append(update_data)

def _save(self, data: Union[DataFrame, pd.DataFrame]) -> None:
"""Saves the data based on the write_mode and dataframe_type in the init.
If write_mode is pandas, Spark dataframe is created first.
If schema is provided, data is matched to schema before saving
(columns will be sorted and truncated).
Args:
data (Any): Spark or pandas dataframe to save to the table location
"""
# filter columns specified in schema and match their ordering
if self._schema:
cols = self._schema.fieldNames()
if self._table.dataframe_type == "pandas":
data = self._get_spark().createDataFrame(
data.loc[:, cols], schema=self._schema
)
else:
data = data.select(*cols)
else:
if self._table.dataframe_type == "pandas":
data = self._get_spark().createDataFrame(data)
if self._table.write_mode == "overwrite":
self._save_overwrite(data)
elif self._table.write_mode == "upsert":
self._save_upsert(data)
elif self._table.write_mode == "append":
self._save_append(data)

def _describe(self) -> Dict[str, str]:
"""Returns a description of the instance of ManagedTableDataSet
Returns:
Dict[str, str]: Dict with the details of the dataset
"""
return {
"catalog": self._table.catalog,
"database": self._table.database,
"table": self._table.table,
"write_mode": self._table.write_mode,
"dataframe_type": self._table.dataframe_type,
"primary_key": self._table.primary_key,
"version": self._version,
"owner_group": self._table.owner_group,
"partition_columns": self._table.partition_columns,
}

def _exists(self) -> bool:
"""Checks to see if the table exists
Returns:
bool: boolean of whether the table defined
in the dataset instance exists in the Spark session
"""
if self._table.catalog:
try:
self._get_spark().sql(f"USE CATALOG {self._table.catalog}")
except (ParseException, AnalysisException) as exc:
logger.warning(
"catalog %s not found or unity not enabled. Error message: %s",
self._table.catalog,
exc,
)
try:
return (
self._get_spark()
.sql(f"SHOW TABLES IN `{self._table.database}`")
.filter(f"tableName = '{self._table.table}'")
.count()
> 0
)
except (ParseException, AnalysisException) as exc:
logger.warning("error occured while trying to find table: %s", exc)
return False
1 change: 0 additions & 1 deletion kedro-datasets/kedro_datasets/databricks/unity/__init__.py

This file was deleted.

Loading

0 comments on commit e6157a5

Please sign in to comment.