Skip to content

Commit

Permalink
Added "what" property for migration to scope down table migrations (#856
Browse files Browse the repository at this point in the history
)

## Changes
<!-- Summary of your changes that are easy to understand. Add
screenshots when necessary -->

### Linked issues
related to  #333 

Resolves #..

### Functionality 

- [ ] added relevant user documentation
- [ ] added new CLI command
- [ ] modified existing command: `databricks labs ucx ...`
- [ ] added a new workflow
- [ ] modified existing workflow: `...`
- [ ] added a new table
- [ ] modified existing table: `...`

### Tests
<!-- How is this tested? Please see the checklist below and also
describe any other relevant tests -->

- [ ] manually tested
- [ ] added unit tests
- [ ] added integration tests
- [ ] verified on staging environment (screenshot attached)
  • Loading branch information
FastLee committed Jan 30, 2024
1 parent ee67586 commit fedf569
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 56 deletions.
14 changes: 8 additions & 6 deletions src/databricks/labs/ucx/hive_metastore/table_migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from databricks.labs.ucx.framework.crawlers import SqlBackend
from databricks.labs.ucx.hive_metastore import TablesCrawler
from databricks.labs.ucx.hive_metastore.mapping import Rule, TableMapping
from databricks.labs.ucx.hive_metastore.tables import MigrationCount, Table
from databricks.labs.ucx.hive_metastore.tables import MigrationCount, Table, What

logger = logging.getLogger(__name__)


class TablesMigrate:

def __init__(
self,
tc: TablesCrawler,
Expand All @@ -34,23 +35,24 @@ def __init__(
self._tm = tm
self._seen_tables: dict[str, str] = {}

def migrate_tables(self):
def migrate_tables(self, *, what: What | None = None):
self._init_seen_tables()
tables_to_migrate = self._tm.get_tables_to_migrate(self._tc)
tasks = []
for table in tables_to_migrate:
tasks.append(partial(self._migrate_table, table.src, table.rule))
if not what or table.src.what == what:
tasks.append(partial(self._migrate_table, table.src, table.rule))
Threads.strict("migrate tables", tasks)

def _migrate_table(self, src_table: Table, rule: Rule):
if self._table_already_upgraded(rule.as_uc_table_key):
logger.info(f"Table {src_table.key} already upgraded to {rule.as_uc_table_key}")
return True
if src_table.kind == "TABLE" and src_table.table_format == "DELTA" and src_table.is_dbfs_root:
if src_table.what == What.DBFS_ROOT_DELTA:
return self._migrate_dbfs_root_table(src_table, rule)
if src_table.kind == "TABLE" and src_table.is_format_supported_for_sync:
if src_table.what == What.EXTERNAL_SYNC:
return self._migrate_external_table(src_table, rule)
if src_table.kind == "VIEW":
if src_table.what == What.VIEW:
return self._migrate_view(src_table, rule)
logger.info(f"Table {src_table.key} is not supported for migration")
return True
Expand Down
27 changes: 27 additions & 0 deletions src/databricks/labs/ucx/hive_metastore/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import typing
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial

from databricks.labs.blueprint.parallel import Threads
Expand All @@ -14,6 +15,16 @@
logger = logging.getLogger(__name__)


class What(Enum):
EXTERNAL_SYNC = auto()
EXTERNAL_NO_SYNC = auto()
DBFS_ROOT_DELTA = auto()
DBFS_ROOT_NON_DELTA = auto()
VIEW = auto()
DB_DATASET = auto()
UNKNOWN = auto()


@dataclass
class Table:
catalog: str
Expand Down Expand Up @@ -96,6 +107,22 @@ def is_databricks_dataset(self) -> bool:
return True
return False

@property
def what(self) -> What:
if self.is_databricks_dataset:
return What.DB_DATASET
if self.is_dbfs_root and self.table_format == "DELTA":
return What.DBFS_ROOT_DELTA
if self.is_dbfs_root:
return What.DBFS_ROOT_NON_DELTA
if self.kind == "TABLE" and self.is_format_supported_for_sync:
return What.EXTERNAL_SYNC
if self.kind == "TABLE":
return What.EXTERNAL_NO_SYNC
if self.kind == "VIEW":
return What.VIEW
return What.UNKNOWN

def sql_migrate_external(self, target_table_key):
return f"SYNC TABLE {escape_sql_identifier(target_table_key)} FROM {escape_sql_identifier(self.key)};"

Expand Down
6 changes: 6 additions & 0 deletions tests/integration/hive_metastore/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from databricks.sdk.retries import retried

from databricks.labs.ucx.hive_metastore import TablesCrawler
from databricks.labs.ucx.hive_metastore.tables import What

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -38,8 +39,13 @@ def test_describe_all_tables_in_databases(ws, sql_backend, inventory_schema, mak

assert len(all_tables) >= 5
assert all_tables[non_delta.full_name].table_format == "JSON"
assert all_tables[non_delta.full_name].what == What.DB_DATASET
assert all_tables[managed_table.full_name].object_type == "MANAGED"
assert all_tables[managed_table.full_name].what == What.DBFS_ROOT_DELTA
assert all_tables[tmp_table.full_name].object_type == "MANAGED"
assert all_tables[tmp_table.full_name].what == What.DBFS_ROOT_DELTA
assert all_tables[external_table.full_name].object_type == "EXTERNAL"
assert all_tables[external_table.full_name].what == What.EXTERNAL_NO_SYNC
assert all_tables[view.full_name].object_type == "VIEW"
assert all_tables[view.full_name].view_text == "SELECT 2+2 AS four"
assert all_tables[view.full_name].what == What.VIEW
72 changes: 72 additions & 0 deletions tests/unit/hive_metastore/test_table_migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
MigrationCount,
Table,
TablesCrawler,
What,
)

from ..framework.mocks import MockBackend
Expand Down Expand Up @@ -66,6 +67,25 @@ def test_migrate_dbfs_root_tables_should_produce_proper_queries():
)


def test_migrate_dbfs_root_tables_should_be_skipped_when_upgrading_external():
errors = {}
rows = {}
backend = MockBackend(fails_on_first=errors, rows=rows)
table_crawler = TablesCrawler(backend, "inventory_database")
client = MagicMock()
table_mapping = create_autospec(TableMapping)
table_mapping.get_tables_to_migrate.return_value = [
TableToMigrate(
Table("hive_metastore", "db1_src", "managed_dbfs", "MANAGED", "DELTA", "dbfs:/some_location"),
Rule("workspace", "ucx_default", "db1_src", "db1_dst", "managed_dbfs", "managed_dbfs"),
),
]
table_migrate = TablesMigrate(table_crawler, client, backend, table_mapping)
table_migrate.migrate_tables(what=What.EXTERNAL_SYNC)

assert len(backend.queries) == 0


def test_migrate_external_tables_should_produce_proper_queries():
errors = {}
rows = {}
Expand All @@ -87,6 +107,58 @@ def test_migrate_external_tables_should_produce_proper_queries():
]


def test_migrate_already_upgraded_table_should_produce_no_queries():
errors = {}
rows = {}
backend = MockBackend(fails_on_first=errors, rows=rows)
table_crawler = TablesCrawler(backend, "inventory_database")
client = create_autospec(WorkspaceClient)
client.catalogs.list.return_value = [CatalogInfo(name="cat1")]
client.schemas.list.return_value = [
SchemaInfo(catalog_name="cat1", name="test_schema1"),
]
client.tables.list.return_value = [
TableInfo(
catalog_name="cat1",
schema_name="schema1",
name="dest1",
full_name="cat1.schema1.dest1",
properties={"upgraded_from": "hive_metastore.db1_src.external_src"},
),
]

table_mapping = create_autospec(TableMapping)
table_mapping.get_tables_to_migrate.return_value = [
TableToMigrate(
Table("hive_metastore", "db1_src", "external_src", "EXTERNAL", "DELTA"),
Rule("workspace", "cat1", "db1_src", "schema1", "external_src", "dest1"),
)
]
table_migrate = TablesMigrate(table_crawler, client, backend, table_mapping)
table_migrate.migrate_tables()

assert len(backend.queries) == 0


def test_migrate_unsupported_format_table_should_produce_no_queries():
errors = {}
rows = {}
backend = MockBackend(fails_on_first=errors, rows=rows)
table_crawler = TablesCrawler(backend, "inventory_database")
client = create_autospec(WorkspaceClient)
table_mapping = create_autospec(TableMapping)
table_mapping.get_tables_to_migrate.return_value = [
TableToMigrate(
Table("hive_metastore", "db1_src", "external_src", "EXTERNAL", "UNSUPPORTED_FORMAT"),
Rule("workspace", "cat1", "db1_src", "schema1", "external_src", "dest1"),
)
]
table_migrate = TablesMigrate(table_crawler, client, backend, table_mapping)
table_migrate.migrate_tables()

assert len(backend.queries) == 0


def test_migrate_view_should_produce_proper_queries():
errors = {}
rows = {}
Expand Down
142 changes: 92 additions & 50 deletions tests/unit/hive_metastore/test_tables.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from databricks.labs.ucx.hive_metastore.tables import Table, TablesCrawler
from databricks.labs.ucx.hive_metastore.tables import Table, TablesCrawler, What

from ..framework.mocks import MockBackend

Expand Down Expand Up @@ -136,52 +136,94 @@ def test_tables_returning_error_when_describing():
assert len(results) == 1


def test_is_dbfs_root():
assert Table("a", "b", "c", "MANAGED", "DELTA", location="dbfs:/somelocation/tablename").is_dbfs_root
assert Table("a", "b", "c", "MANAGED", "DELTA", location="/dbfs/somelocation/tablename").is_dbfs_root
assert not Table("a", "b", "c", "MANAGED", "DELTA", location="dbfs:/mnt/somelocation/tablename").is_dbfs_root
assert not Table("a", "b", "c", "MANAGED", "DELTA", location="/dbfs/mnt/somelocation/tablename").is_dbfs_root
assert not Table(
"a", "b", "c", "MANAGED", "DELTA", location="dbfs:/databricks-datasets/somelocation/tablename"
).is_dbfs_root
assert not Table(
"a", "b", "c", "MANAGED", "DELTA", location="/dbfs/databricks-datasets/somelocation/tablename"
).is_dbfs_root
assert not Table("a", "b", "c", "MANAGED", "DELTA", location="s3:/somelocation/tablename").is_dbfs_root
assert not Table("a", "b", "c", "MANAGED", "DELTA", location="adls:/somelocation/tablename").is_dbfs_root


def test_is_db_dataset():
assert not Table("a", "b", "c", "MANAGED", "DELTA", location="dbfs:/somelocation/tablename").is_databricks_dataset
assert not Table("a", "b", "c", "MANAGED", "DELTA", location="/dbfs/somelocation/tablename").is_databricks_dataset
assert not Table(
"a", "b", "c", "MANAGED", "DELTA", location="dbfs:/mnt/somelocation/tablename"
).is_databricks_dataset
assert not Table(
"a", "b", "c", "MANAGED", "DELTA", location="/dbfs/mnt/somelocation/tablename"
).is_databricks_dataset
assert Table(
"a", "b", "c", "MANAGED", "DELTA", location="dbfs:/databricks-datasets/somelocation/tablename"
).is_databricks_dataset
assert Table(
"a", "b", "c", "MANAGED", "DELTA", location="/dbfs/databricks-datasets/somelocation/tablename"
).is_databricks_dataset
assert not Table("a", "b", "c", "MANAGED", "DELTA", location="s3:/somelocation/tablename").is_databricks_dataset
assert not Table("a", "b", "c", "MANAGED", "DELTA", location="adls:/somelocation/tablename").is_databricks_dataset


def test_is_supported_for_sync():
assert Table(
"a", "b", "c", "EXTERNAL", "DELTA", location="dbfs:/somelocation/tablename"
).is_format_supported_for_sync
assert Table("a", "b", "c", "EXTERNAL", "CSV", location="dbfs:/somelocation/tablename").is_format_supported_for_sync
assert Table(
"a", "b", "c", "EXTERNAL", "TEXT", location="dbfs:/somelocation/tablename"
).is_format_supported_for_sync
assert Table("a", "b", "c", "EXTERNAL", "ORC", location="dbfs:/somelocation/tablename").is_format_supported_for_sync
assert Table(
"a", "b", "c", "EXTERNAL", "JSON", location="dbfs:/somelocation/tablename"
).is_format_supported_for_sync
assert not (
Table("a", "b", "c", "EXTERNAL", "AVRO", location="dbfs:/somelocation/tablename").is_format_supported_for_sync
)
@pytest.mark.parametrize(
'table,dbfs_root,what',
[
(Table("a", "b", "c", "MANAGED", "DELTA", location="dbfs:/somelocation/tablename"), True, What.DBFS_ROOT_DELTA),
(
Table("a", "b", "c", "MANAGED", "PARQUET", location="dbfs:/somelocation/tablename"),
True,
What.DBFS_ROOT_NON_DELTA,
),
(Table("a", "b", "c", "MANAGED", "DELTA", location="/dbfs/somelocation/tablename"), True, What.DBFS_ROOT_DELTA),
(
Table("a", "b", "c", "MANAGED", "DELTA", location="dbfs:/mnt/somelocation/tablename"),
False,
What.EXTERNAL_SYNC,
),
(
Table("a", "b", "c", "MANAGED", "DELTA", location="/dbfs/mnt/somelocation/tablename"),
False,
What.EXTERNAL_SYNC,
),
(
Table("a", "b", "c", "MANAGED", "DELTA", location="dbfs:/databricks-datasets/somelocation/tablename"),
False,
What.DB_DATASET,
),
(
Table("a", "b", "c", "MANAGED", "DELTA", location="/dbfs/databricks-datasets/somelocation/tablename"),
False,
What.DB_DATASET,
),
(Table("a", "b", "c", "MANAGED", "DELTA", location="s3:/somelocation/tablename"), False, What.EXTERNAL_SYNC),
(Table("a", "b", "c", "MANAGED", "DELTA", location="adls:/somelocation/tablename"), False, What.EXTERNAL_SYNC),
],
)
def test_is_dbfs_root(table, dbfs_root, what):
assert table.is_dbfs_root == dbfs_root
assert table.what == what


@pytest.mark.parametrize(
'table,db_dataset',
[
(Table("a", "b", "c", "MANAGED", "DELTA", location="dbfs:/somelocation/tablename"), False),
(Table("a", "b", "c", "MANAGED", "DELTA", location="/dbfs/somelocation/tablename"), False),
(Table("a", "b", "c", "MANAGED", "DELTA", location="dbfs:/mnt/somelocation/tablename"), False),
(Table("a", "b", "c", "MANAGED", "DELTA", location="/dbfs/mnt/somelocation/tablename"), False),
(Table("a", "b", "c", "MANAGED", "DELTA", location="dbfs:/databricks-datasets/somelocation/tablename"), True),
(Table("a", "b", "c", "MANAGED", "DELTA", location="/dbfs/databricks-datasets/somelocation/tablename"), True),
(Table("a", "b", "c", "MANAGED", "DELTA", location="s3:/somelocation/tablename"), False),
(Table("a", "b", "c", "MANAGED", "DELTA", location="adls:/somelocation/tablename"), False),
],
)
def test_is_db_dataset(table, db_dataset):
assert table.is_databricks_dataset == db_dataset
assert (table.what == What.DB_DATASET) == db_dataset


@pytest.mark.parametrize(
'table,supported',
[
(Table("a", "b", "c", "EXTERNAL", "DELTA", location="dbfs:/somelocation/tablename"), True),
(Table("a", "b", "c", "EXTERNAL", "CSV", location="dbfs:/somelocation/tablename"), True),
(Table("a", "b", "c", "EXTERNAL", "TEXT", location="dbfs:/somelocation/tablename"), True),
(Table("a", "b", "c", "EXTERNAL", "ORC", location="dbfs:/somelocation/tablename"), True),
(Table("a", "b", "c", "EXTERNAL", "JSON", location="dbfs:/somelocation/tablename"), True),
(Table("a", "b", "c", "EXTERNAL", "AVRO", location="dbfs:/somelocation/tablename"), False),
],
)
def test_is_supported_for_sync(table, supported):
assert table.is_format_supported_for_sync == supported


@pytest.mark.parametrize(
'table,what',
[
(Table("a", "b", "c", "EXTERNAL", "DELTA", location="s3://external_location/table"), What.EXTERNAL_SYNC),
(
Table("a", "b", "c", "EXTERNAL", "UNSUPPORTED_FORMAT", location="s3://external_location/table"),
What.EXTERNAL_NO_SYNC,
),
(Table("a", "b", "c", "MANAGED", "DELTA", location="dbfs:/somelocation/tablename"), What.DBFS_ROOT_DELTA),
(Table("a", "b", "c", "MANAGED", "PARQUET", location="dbfs:/somelocation/tablename"), What.DBFS_ROOT_NON_DELTA),
(Table("a", "b", "c", "VIEW", "VIEW", view_text="select * from some_table"), What.VIEW),
(
Table("a", "b", "c", "MANAGED", "DELTA", location="dbfs:/databricks-datasets/somelocation/tablename"),
What.DB_DATASET,
),
],
)
def test_table_what(table, what):
assert table.what == what

0 comments on commit fedf569

Please sign in to comment.