Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure that USE statements are recognized and apply to table references without a qualifying schema in SQL and pyspark #1433

Merged
merged 12 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/databricks/labs/ucx/hive_metastore/view_migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex, TableView
from databricks.labs.ucx.hive_metastore.mapping import TableToMigrate
from databricks.labs.ucx.source_code.base import CurrentSessionState
from databricks.labs.ucx.source_code.queries import FromTable

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -41,7 +42,7 @@ def _view_dependencies(self):
yield TableView("hive_metastore", src_db, old_table.name)

def sql_migrate_view(self, index: MigrationIndex) -> str:
from_table = FromTable(index, use_schema=self.src.database)
from_table = FromTable(index, CurrentSessionState(self.src.database))
assert self.src.view_text is not None, 'Expected a view text'
migrated_select = from_table.apply(self.src.view_text)
statements = sqlglot.parse(migrated_select, read='databricks')
Expand Down
19 changes: 19 additions & 0 deletions src/databricks/labs/ucx/source_code/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,25 @@ def name(self) -> str: ...
def apply(self, code: str) -> str: ...
jimidle marked this conversation as resolved.
Show resolved Hide resolved


# The default schema to use when the schema is not specified in a table reference
# See: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-qry-select-usedb.html
DEFAULT_SCHEMA = 'default'


@dataclass
class CurrentSessionState:
"""
A data class that represents the current state of a session.

This class can be used to track various aspects of a session, such as the current schema.

Attributes:
schema (str): The current schema of the session. If not provided, it defaults to 'DEFAULT_SCHEMA'.
"""

schema: str = DEFAULT_SCHEMA


class SequentialLinter(Linter):
def __init__(self, linters: list[Linter]):
self._linters = linters
Expand Down
2 changes: 1 addition & 1 deletion src/databricks/labs/ucx/source_code/dbfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def name() -> str:
return 'dbfs-query'

def lint(self, code: str) -> Iterable[Advice]:
for statement in sqlglot.parse(code, dialect='databricks'):
for statement in sqlglot.parse(code, read='databricks'):
if not statement:
continue
for table in statement.find_all(Table):
Expand Down
5 changes: 3 additions & 2 deletions src/databricks/labs/ucx/source_code/languages.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from databricks.sdk.service.workspace import Language

from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex
from databricks.labs.ucx.source_code.base import Fixer, Linter, SequentialLinter
from databricks.labs.ucx.source_code.base import Fixer, Linter, SequentialLinter, CurrentSessionState
from databricks.labs.ucx.source_code.pyspark import SparkSql
from databricks.labs.ucx.source_code.queries import FromTable
from databricks.labs.ucx.source_code.dbfs import DBFSUsageLinter, FromDbfsFolder
Expand All @@ -11,7 +11,8 @@
class Languages:
def __init__(self, index: MigrationIndex):
self._index = index
from_table = FromTable(index)
session_state = CurrentSessionState()
from_table = FromTable(index, session_state=session_state)
dbfs_from_folder = FromDbfsFolder()
self._linters = {
Language.PYTHON: SequentialLinter(
Expand Down
2 changes: 2 additions & 0 deletions src/databricks/labs/ucx/source_code/notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ def requires_isolated_pi(self) -> str:

@classmethod
def of_language(cls, language: Language) -> CellLanguage:
# TODO: Should this not raise a ValueError if the language is not found?
# It also causes a GeneratorExit exception to be raised. Maybe an explicit loop is better.
return next((cl for cl in CellLanguage if cl.language == language))

@classmethod
Expand Down
4 changes: 3 additions & 1 deletion src/databricks/labs/ucx/source_code/notebook_linter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Iterable

from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex
from databricks.labs.ucx.source_code.base import Advice
from databricks.labs.ucx.source_code.notebook import Notebook
from databricks.labs.ucx.source_code.languages import Languages, Language
Expand All @@ -16,7 +17,8 @@ def __init__(self, langs: Languages, notebook: Notebook):
self._notebook: Notebook = notebook

@classmethod
def from_source(cls, langs: Languages, source: str, default_language: Language) -> 'NotebookLinter':
def from_source(cls, index: MigrationIndex, source: str, default_language: Language) -> 'NotebookLinter':
langs = Languages(index)
notebook = Notebook.parse("", source, default_language)
assert notebook is not None
return cls(langs, notebook)
Expand Down
9 changes: 6 additions & 3 deletions src/databricks/labs/ucx/source_code/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> Iterator[Advice]:
table_arg = self._get_table_arg(node)
if isinstance(table_arg, ast.Constant):
dst = self._find_dest(index, table_arg.value)
dst = self._find_dest(index, table_arg.value, from_table.schema)
if dst is not None:
yield Deprecation(
code='table-migrate',
Expand All @@ -104,13 +104,16 @@
def apply(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> None:
table_arg = self._get_table_arg(node)
assert isinstance(table_arg, ast.Constant)
dst = self._find_dest(index, table_arg.value)
dst = self._find_dest(index, table_arg.value, from_table.schema)

Check warning on line 107 in src/databricks/labs/ucx/source_code/pyspark.py

View check run for this annotation

Codecov / codecov/patch

src/databricks/labs/ucx/source_code/pyspark.py#L107

Added line #L107 was not covered by tests
if dst is not None:
table_arg.value = dst.destination()

@staticmethod
def _find_dest(index: MigrationIndex, value: str):
def _find_dest(index: MigrationIndex, value: str, schema: str):
parts = value.split(".")
# Ensure that unqualified table references use the current schema
if len(parts) == 1:
return index.get(schema, parts[0])
return None if len(parts) != 2 else index.get(parts[0], parts[1])


Expand Down
1 change: 0 additions & 1 deletion src/databricks/labs/ucx/source_code/python_linter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from databricks.labs.ucx.source_code.base import Linter, Advice, Advisory


logger = logging.getLogger(__name__)


Expand Down
67 changes: 52 additions & 15 deletions src/databricks/labs/ucx/source_code/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,71 @@

import logging
import sqlglot
from sqlglot.expressions import Table, Expression
from sqlglot.expressions import Table, Expression, Use
from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex
from databricks.labs.ucx.source_code.base import Advice, Deprecation, Fixer, Linter
from databricks.labs.ucx.source_code.base import Advice, Deprecation, Fixer, Linter, CurrentSessionState

logger = logging.getLogger(__name__)


class FromTable(Linter, Fixer):
def __init__(self, index: MigrationIndex, *, use_schema: str | None = None):
self._index = index
self._use_schema = use_schema
"""Linter and Fixer for table migrations in SQL queries.

This class is responsible for identifying and fixing table migrations in
SQL queries.
"""

def __init__(self, index: MigrationIndex, session_state: CurrentSessionState):
"""
Initializes the FromTable class.

Args:
index: The migration index, which is a mapping of source tables to destination tables.
session_state: The current session state, which will be used to track the current schema.

We need to be careful with the nomenclature here. For instance when parsing a table reference,
sqlglot uses `db` instead of `schema` to refer to the schema. The following table references
show how sqlglot represents them:::

catalog.schema.table -> Table(catalog='catalog', db='schema', this='table')
jimidle marked this conversation as resolved.
Show resolved Hide resolved
schema.table -> Table(catalog='', db='schema', this='table')
table -> Table(catalog='', db='', this='table')
"""
self._index: MigrationIndex = index
self._session_state: CurrentSessionState = session_state if session_state else CurrentSessionState()

def name(self) -> str:
return 'table-migrate'

@property
def schema(self):
return self._session_state.schema

def lint(self, code: str) -> Iterable[Advice]:
for statement in sqlglot.parse(code, dialect='databricks'):
for statement in sqlglot.parse(code, read='databricks'):
if not statement:
continue
for table in statement.find_all(Table):
catalog = self._catalog(table)
if catalog != 'hive_metastore':
if isinstance(statement, Use):
# Sqlglot captures the database name in the Use statement as a Table, with
# the schema as the table name.
self._session_state.schema = table.name
continue
src_db = table.db if table.db else self._use_schema
if not src_db:

# we only migrate tables in the hive_metastore catalog
if self._catalog(table) != 'hive_metastore':
continue
# Sqlglot uses db instead of schema, watch out for that
src_schema = table.db if table.db else self._session_state.schema
if not src_schema:
logger.error(f"Could not determine schema for table {table.name}")
continue
dst = self._index.get(src_db, table.name)
dst = self._index.get(src_schema, table.name)
if not dst:
continue
yield Deprecation(
code='table-migrate',
message=f"Table {table.db}.{table.name} is migrated to {dst.destination()} in Unity Catalog",
message=f"Table {src_schema}.{table.name} is migrated to {dst.destination()} in Unity Catalog",
# SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159
start_line=0,
start_col=0,
Expand All @@ -53,12 +85,17 @@ def apply(self, code: str) -> str:
for statement in sqlglot.parse(code, read='databricks'):
if not statement:
continue
if isinstance(statement, Use):
jimidle marked this conversation as resolved.
Show resolved Hide resolved
table = statement.this
self._session_state.schema = table.name
new_statements.append(statement.sql('databricks'))
continue
for old_table in self._dependent_tables(statement):
src_db = old_table.db if old_table.db else self._use_schema
if not src_db:
src_schema = old_table.db if old_table.db else self._session_state.schema
if not src_schema:
logger.error(f"Could not determine schema for table {old_table.name}")
continue
dst = self._index.get(src_db, old_table.name)
dst = self._index.get(src_schema, old_table.name)
if not dst:
continue
new_table = Table(catalog=dst.dst_catalog, db=dst.dst_schema, this=dst.dst_table)
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/source_code/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,24 @@ def migration_index():
MigrationStatus('other', 'matters', dst_catalog='some', dst_schema='certain', dst_table='issues'),
]
)


@pytest.fixture
def extended_test_index():
return MigrationIndex(
[
MigrationStatus('old', 'things', dst_catalog='brand', dst_schema='new', dst_table='stuff'),
MigrationStatus('other', 'matters', dst_catalog='some', dst_schema='certain', dst_table='issues'),
MigrationStatus('old', 'stuff', dst_catalog='brand', dst_schema='new', dst_table='things'),
MigrationStatus('other', 'issues', dst_catalog='some', dst_schema='certain', dst_table='matters'),
MigrationStatus('default', 'testtable', dst_catalog='cata', dst_schema='nondefault', dst_table='table'),
MigrationStatus('different_db', 'testtable', dst_catalog='cata2', dst_schema='newspace', dst_table='table'),
MigrationStatus('old', 'testtable', dst_catalog='cata3', dst_schema='newspace', dst_table='table'),
MigrationStatus('default', 'people', dst_catalog='cata4', dst_schema='nondefault', dst_table='newpeople'),
MigrationStatus(
'something', 'persons', dst_catalog='cata4', dst_schema='newsomething', dst_table='persons'
),
MigrationStatus('whatever', 'kittens', dst_catalog='cata4', dst_schema='felines', dst_table='toms'),
MigrationStatus('whatever', 'numbers', dst_catalog='cata4', dst_schema='counting', dst_table='numbers'),
]
)
3 changes: 2 additions & 1 deletion tests/unit/source_code/test_notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
SQL_NOTEBOOK_SAMPLE = (
"chf-pqi-scoring.sql.txt",
Language.SQL,
['md', 'sql', 'sql', 'md', 'sql', 'python', 'sql', 'sql', 'sql', 'md', 'sql', 'sql', 'md', 'sql', 'sql', 'md', 'sql'],
['md', 'sql', 'sql', 'md', 'sql', 'python', 'sql', 'sql', 'sql', 'md', 'sql',
'sql', 'md', 'sql', 'sql', 'md', 'sql'],
)
SHELL_NOTEBOOK_SAMPLE = (
"notebook-with-shell-cell.py.txt",
Expand Down
Loading
Loading