Skip to content

Commit

Permalink
Ensure proper sequencing of view migrations (#1157)
Browse files Browse the repository at this point in the history
## Changes
add views_migrator and corresponding test cases

### Linked issues
Resolves #1132

### Functionality 

- [ ] added relevant user documentation
- [ ] added new CLI command
- [ ] modified existing command: `databricks labs ucx ...`
- [ ] added a new workflow
- [ ] modified existing workflow: `...`
- [ ] added a new table
- [ ] modified existing table: `...`

### Tests

- [ ] manually tested
- [x] added unit tests
- [ ] added integration tests
- [ ] verified on staging environment (screenshot attached)

---------

Co-authored-by: Ganesh Girase <ganeshgirase@gmail.com>
  • Loading branch information
ericvergnaud and ganeshgirase authored Mar 29, 2024
1 parent f775a99 commit 8a76764
Show file tree
Hide file tree
Showing 4 changed files with 358 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/databricks/labs/ucx/hive_metastore/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def is_delta(self) -> bool:
def key(self) -> str:
return f"{self.catalog}.{self.database}.{self.name}".lower()

def __hash__(self):
return hash(self.key)

@property
def kind(self) -> str:
return "VIEW" if self.view_text is not None else "TABLE"
Expand Down
123 changes: 123 additions & 0 deletions src/databricks/labs/ucx/hive_metastore/views_migrator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import sqlglot
from sqlglot import ParseError
from sqlglot.expressions import Expression as SqlExpression
from sqlglot.expressions import Table as SqlTable

from databricks.labs.ucx.hive_metastore import TablesCrawler
from databricks.labs.ucx.hive_metastore.tables import Table


class ViewToMigrate:

_view: Table
_table_dependencies: list[Table] | None
_view_dependencies: list[Table] | None

def __init__(self, table: Table):
if table.view_text is None:
raise RuntimeError("Should never get there! A view must have 'view_text'!")
self._view = table
self._table_dependencies = None
self._view_dependencies = None

@property
def view(self):
return self._view

def view_dependencies(self, all_tables: dict[str, Table]) -> list[Table]:
if self._table_dependencies is None or self._view_dependencies is None:
self._compute_dependencies(all_tables)
assert self._view_dependencies is not None
return self._view_dependencies

def _compute_dependencies(self, all_tables: dict[str, Table]):
table_dependencies = set()
view_dependencies = set()
statement = self._parse_view_text()
for sql_table in statement.find_all(SqlTable):
catalog = self._catalog(sql_table)
if catalog != 'hive_metastore':
continue
table_with_key = Table(catalog, sql_table.db, sql_table.name, "type", "")
table = all_tables.get(table_with_key.key)
if table is None:
raise ValueError(
f"Unknown schema object: {table_with_key.key} in view SQL: {self._view.view_text} of table {self._view.key}"
)
if table.view_text is None:
table_dependencies.add(table)
else:
view_dependencies.add(table)
self._table_dependencies = list(table_dependencies)
self._view_dependencies = list(view_dependencies)

def _parse_view_text(self) -> SqlExpression:
try:
# below can never happen but avoids a pylint error
assert self._view.view_text is not None
statements = sqlglot.parse(self._view.view_text)
if len(statements) != 1 or statements[0] is None:
raise ValueError(f"Could not analyze view SQL: {self._view.view_text} of table {self._view.key}")
return statements[0]
except ParseError as e:
raise ValueError(f"Could not analyze view SQL: {self._view.view_text} of table {self._view.key}") from e

# duplicated from FromTable._catalog, not sure if it's worth factorizing
@staticmethod
def _catalog(table):
if table.catalog:
return table.catalog
return 'hive_metastore'

def __hash__(self):
return hash(self._view)


class ViewsMigrator:

def __init__(self, crawler: TablesCrawler):
self._crawler = crawler
self._result_view_list: list[ViewToMigrate] = []
self._result_tables_set: set[Table] = set()

def sequence(self) -> list[Table]:
# sequencing is achieved using a very simple algorithm:
# for each view, we register dependencies (extracted from view_text)
# then given the remaining set of views to process,
# and the growing set of views already processed
# we check if each remaining view refers to not yet processed views
# if none, then it's safe to add that view to the next batch of views
# the complexity for a given set of views v and a dependency depth d looks like Ov^d
# this seems enormous but in practice d remains small and v decreases rapidly
table_list = self._crawler.snapshot()
all_tables = {}
views = set()
for table in table_list:
all_tables[table.key] = table
if table.view_text is None:
continue
views.add(ViewToMigrate(table))
while len(views) > 0:
next_batch = self._next_batch(views, all_tables)
self._result_view_list.extend(next_batch)
self._result_tables_set.update([v.view for v in next_batch])
views.difference_update(next_batch)
return [v.view for v in self._result_view_list]

def _next_batch(self, views: set[ViewToMigrate], all_tables: dict[str, Table]) -> set[ViewToMigrate]:
# we can't (slightly) optimize by checking len(views) == 0 or 1,
# because we'd lose the opportunity to check the SQL
result: set[ViewToMigrate] = set()
for view in views:
view_deps = view.view_dependencies(all_tables)
if len(view_deps) == 0:
result.add(view)
else:
# does the view have at least one view dependency that is not yet processed ?
not_processed_yet = next((t for t in view_deps if t not in self._result_tables_set), None)
if not_processed_yet is None:
result.add(view)
# prevent infinite loop
if len(result) == 0 and len(views) > 0:
raise ValueError(f"Circular view references are preventing migration: {views}")
return result
79 changes: 79 additions & 0 deletions tests/unit/hive_metastore/tables/tables_and_views.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
[
{
"db": "db1",
"table": "t1"
},
{
"db": "db1",
"table": "t2"
},
{
"db": "db1",
"table": "t3"
},
{
"db": "db2",
"table": "t1"
},
{
"db": "db2",
"table": "t3"
},
{
"db": "db1",
"table": "v1",
"view_text": "select * from db1.t1"
},
{
"db": "db1",
"table": "v2",
"view_text": "select * from db1.t2 where db1.t2.c1 = 32"
},
{
"db": "db1",
"table": "v3",
"view_text": "select * from db1.t1, db1.t2 where db1.t1.c1 = db1.t2.c1"
},
{
"db": "db1",
"table": "v4",
"view_text": "select * from db1.v1"
},
{
"db": "db1",
"table": "v5",
"view_text": "select * from db1.v7, db1.v6"
},
{
"db": "db1",
"table": "v6",
"view_text": "select * from db1.v7"
},
{
"db": "db1",
"table": "v7",
"view_text": "select * from db1.v4"
},
{
"db": "db1",
"table": "v8",
"view_text": "123 invalid SQL"
},
{
"db": "db1",
"table": "v9",
"view_text": "select * from db15.t32"
},
{
"db": "db1",
"table": "v10",
"view_text": "select * from db1.v11"
},
{
"db": "db1",
"table": "v11",
"view_text": "select * from db1.v10"
}


]
153 changes: 153 additions & 0 deletions tests/unit/hive_metastore/test_views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import json
from pathlib import Path

import pytest
from databricks.labs.lsql.backends import MockBackend, SqlBackend

from databricks.labs.ucx.hive_metastore import TablesCrawler
from databricks.labs.ucx.hive_metastore.views_migrator import ViewsMigrator

SCHEMA_NAME = "schema"


def test_migrate_no_view_returns_empty_sequence():
samples = Samples.load("db1.t1", "db2.t1")
sql_backend = mock_backend(samples, "db1", "db2")
crawler = TablesCrawler(sql_backend, SCHEMA_NAME, ["db1", "db2"])
migrator = ViewsMigrator(crawler)
sequence = migrator.sequence()
assert len(sequence) == 0


def test_migrate_direct_view_returns_singleton_sequence() -> None:
samples = Samples.load("db1.t1", "db1.v1")
sql_backend = mock_backend(samples, "db1")
crawler = TablesCrawler(sql_backend, SCHEMA_NAME, ["db1"])
migrator = ViewsMigrator(crawler)
sequence = migrator.sequence()
assert len(sequence) == 1
table = sequence[0]
assert table.key == "hive_metastore.db1.v1"


def test_migrate_direct_views_returns_sequence() -> None:
samples = Samples.load("db1.t1", "db1.v1", "db1.t2", "db1.v2")
sql_backend = mock_backend(samples, "db1")
crawler = TablesCrawler(sql_backend, SCHEMA_NAME, ["db1"])
migrator = ViewsMigrator(crawler)
sequence = migrator.sequence()
assert len(sequence) == 2
expected = {"hive_metastore.db1.v1", "hive_metastore.db1.v2"}
actual = {t.key for t in sequence}
assert expected == actual


def test_migrate_indirect_views_returns_correct_sequence() -> None:
samples = Samples.load("db1.t1", "db1.v1", "db1.v4")
sql_backend = mock_backend(samples, "db1")
crawler = TablesCrawler(sql_backend, SCHEMA_NAME, ["db1"])
migrator = ViewsMigrator(crawler)
sequence = migrator.sequence()
assert len(sequence) == 2
expected = ["hive_metastore.db1.v1", "hive_metastore.db1.v4"]
actual = [t.key for t in sequence]
assert expected == actual


def test_migrate_deep_indirect_views_returns_correct_sequence() -> None:
samples = Samples.load("db1.t1", "db1.v1", "db1.v4", "db1.v5", "db1.v6", "db1.v7")
sql_backend = mock_backend(samples, "db1")
crawler = TablesCrawler(sql_backend, SCHEMA_NAME, ["db1"])
migrator = ViewsMigrator(crawler)
sequence = migrator.sequence()
assert len(sequence) == 5
expected = [
"hive_metastore.db1.v1",
"hive_metastore.db1.v4",
"hive_metastore.db1.v7",
"hive_metastore.db1.v6",
"hive_metastore.db1.v5",
]
actual = [t.key for t in sequence]
assert expected == actual


def test_migrate_invalid_sql_raises_value_error() -> None:
with pytest.raises(ValueError) as error:
samples = Samples.load("db1.v8")
sql_backend = mock_backend(samples, "db1")
crawler = TablesCrawler(sql_backend, SCHEMA_NAME, ["db1"])
migrator = ViewsMigrator(crawler)
sequence = migrator.sequence()
assert sequence is None # should never get there
assert "Could not analyze view SQL:" in str(error)


def test_migrate_invalid_sql_tables_raises_value_error() -> None:
with pytest.raises(ValueError) as error:
samples = Samples.load("db1.v9")
sql_backend = mock_backend(samples, "db1")
crawler = TablesCrawler(sql_backend, SCHEMA_NAME, ["db1"])
migrator = ViewsMigrator(crawler)
sequence = migrator.sequence()
assert sequence is None # should never get there
assert "Unknown schema object:" in str(error)


def test_migrate_circular_vues_raises_value_error() -> None:
with pytest.raises(ValueError) as error:
samples = Samples.load("db1.v10", "db1.v11")
sql_backend = mock_backend(samples, "db1")
crawler = TablesCrawler(sql_backend, SCHEMA_NAME, ["db1"])
migrator = ViewsMigrator(crawler)
sequence = migrator.sequence()
assert sequence is None # should never get there
assert "Circular view references are preventing migration:" in str(error)


def mock_backend(samples: list[dict], *dbnames: str) -> SqlBackend:
db_rows: dict[str, list[tuple]] = {}
select_query = 'SELECT \\* FROM hive_metastore.schema.tables'
for dbname in dbnames:
# pylint warning W0640 is a pylint bug (verified manually), see https://github.com/pylint-dev/pylint/issues/5263
# pylint: disable=cell-var-from-loop
valid_samples = list(filter(lambda s: s["db"] == dbname, samples))
show_tuples = [(s["db"], s["table"], "true") for s in valid_samples]
db_rows[f'SHOW TABLES FROM hive_metastore.{dbname}'] = show_tuples
# catalog, database, table, object_type, table_format, location, view_text
select_tuples = [
(
"hive_metastore",
s["db"],
s["table"],
"type",
"DELTA" if s.get("view_text", None) is None else "VIEW",
None,
s.get("view_text", None),
)
for s in valid_samples
]
db_rows[select_query] = select_tuples
return MockBackend(rows=db_rows)


class Samples:

samples: dict = {}

@classmethod
def load(cls, *names: str):
cls._preload_all()
valid_keys = set(names)
return [cls.samples[key] for key in filter(lambda key: key in valid_keys, cls.samples.keys())]

@classmethod
def _preload_all(cls):
if len(cls.samples) == 0:
path = Path(Path(__file__).parent, "tables", "tables_and_views.json")
with open(path, encoding="utf-8") as file:
samples = json.load(file)
cls.samples = {}
for sample in samples:
key = sample["db"] + "." + sample["table"]
cls.samples[key] = sample

0 comments on commit 8a76764

Please sign in to comment.