Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Python linter for table creation with implicit format. #1435

Merged
merged 6 commits into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/databricks/labs/ucx/source_code/languages.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from databricks.labs.ucx.source_code.pyspark import SparkSql
from databricks.labs.ucx.source_code.queries import FromTable
from databricks.labs.ucx.source_code.dbfs import DBFSUsageLinter, FromDbfsFolder
from databricks.labs.ucx.source_code.table_creation import DBRv8d0Linter


class Languages:
Expand All @@ -13,7 +14,13 @@ def __init__(self, index: MigrationIndex):
from_table = FromTable(index)
dbfs_from_folder = FromDbfsFolder()
self._linters = {
Language.PYTHON: SequentialLinter([SparkSql(from_table, index), DBFSUsageLinter()]),
Language.PYTHON: SequentialLinter(
[
SparkSql(from_table, index),
DBFSUsageLinter(),
DBRv8d0Linter(dbr_version=None),
]
),
Language.SQL: SequentialLinter([from_table, dbfs_from_folder]),
}
self._fixers: dict[Language, list[Fixer]] = {
Expand Down
52 changes: 52 additions & 0 deletions src/databricks/labs/ucx/source_code/python_linter.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,58 @@ def collect_appended_sys_paths(self):
visitor.visit(self._root)
return visitor.appended_paths

def extract_callchain(self) -> ast.Call | None:
"""If 'node' is an assignment or expression, extract its full call-chain (if it has one)"""
call = None
if isinstance(self._root, ast.Assign):
call = self._root.value
elif isinstance(self._root, ast.Expr):
call = self._root.value
if not isinstance(call, ast.Call):
call = None
return call

def extract_call_by_name(self, name: str) -> ast.Call | None:
"""Given a call-chain, extract its sub-call by method name (if it has one)"""
assert isinstance(self._root, ast.Call)
node = self._root
while True:
func = node.func
if not isinstance(func, ast.Attribute):
return None
if func.attr == name:
return node
if not isinstance(func.value, ast.Call):
return None
node = func.value

def args_count(self) -> int:
"""Count the number of arguments (positionals + keywords)"""
assert isinstance(self._root, ast.Call)
return len(self._root.args) + len(self._root.keywords)

def get_arg(
self,
arg_index: int | None,
arg_name: str | None,
) -> ast.expr | None:
"""Extract the call argument identified by an optional position or name (if it has one)"""
assert isinstance(self._root, ast.Call)
if arg_index is not None and len(self._root.args) > arg_index:
return self._root.args[arg_index]
if arg_name is not None:
arg = [kw.value for kw in self._root.keywords if kw.arg == arg_name]
if len(arg) == 1:
return arg[0]
return None

def is_none(self) -> bool:
"""Check if the given AST expression is the None constant"""
assert isinstance(self._root, ast.expr)
if not isinstance(self._root, ast.Constant):
return False
return self._root.value is None


class PythonLinter(Linter):

Expand Down
121 changes: 121 additions & 0 deletions src/databricks/labs/ucx/source_code/table_creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from __future__ import annotations

import ast
from collections.abc import Iterable, Iterator
from dataclasses import dataclass


from databricks.labs.ucx.source_code.python_linter import ASTLinter
from databricks.labs.ucx.source_code.base import (
Advice,
Linter,
)


@dataclass
class Position:
line: int
character: int


@dataclass
class Range:
start: Position
end: Position


@dataclass
class NoFormatPythonMatcher:
"""Matches Python AST nodes where tables are created with implicit format.
It performs only the matching, while linting / fixing are separated concerns.
"""

method_name: str
min_args: int
max_args: int
# some method_names accept 'format' as a direct (optional) argument:
format_arg_index: int | None = None
format_arg_name: str | None = None

def get_advice_span(self, node: ast.AST) -> Range | None:
# Check 1: retrieve full callchain:
callchain = ASTLinter(node).extract_callchain()
if callchain is None:
return None

# Check 2: check presence of the table-creating method call:
call = ASTLinter(callchain).extract_call_by_name(self.method_name)
if call is None:
return None
call_args_count = ASTLinter(call).args_count()
if call_args_count < self.min_args or call_args_count > self.max_args:
return None

# Check 3: check presence of the format specifier:
# Option A: format specifier may be given as a direct parameter to the table-creating call
# example: df.saveToTable("c.db.table", format="csv")
format_arg = ASTLinter(call).get_arg(self.format_arg_index, self.format_arg_name)
if format_arg is not None and not ASTLinter(format_arg).is_none():
# i.e., found an explicit "format" argument, and its value is not None.
return None
# Option B. format specifier may be a separate ".format(...)" call in this callchain
# example: df.format("csv").saveToTable("c.db.table")
format_call = ASTLinter(callchain).extract_call_by_name("format")
if format_call is not None:
JCZuurmond marked this conversation as resolved.
Show resolved Hide resolved
# i.e., found an explicit ".format(...)" call in this chain.
return None

# Finally: matched the need for advice, so return the corresponding source range:
return Range(
Position(call.lineno, call.col_offset),
Position(call.end_lineno or 0, call.end_col_offset or 0),
)


class NoFormatPythonLinter:
JCZuurmond marked this conversation as resolved.
Show resolved Hide resolved
"""Python linting for table-creation with implicit format"""

def __init__(self, matchers: list[NoFormatPythonMatcher]):
self._matchers = matchers

def lint(self, node: ast.AST) -> Iterator[Advice]:
for matcher in self._matchers:
span = matcher.get_advice_span(node)
if span is not None:
yield Advice(
code="table-migrate",
message="The default format changed in Databricks Runtime 8.0, from Parquet to Delta",
start_line=span.start.line,
start_col=span.start.character,
end_line=span.end.line,
end_col=span.end.character,
)


class DBRv8d0Linter(Linter):
"""Performs Python linting for backwards incompatible changes in DBR version 8.0.
Specifically, it yields advice for table-creation with implicit format.
"""

# https://docs.databricks.com/en/archive/runtime-release-notes/8.0.html#delta-is-now-the-default-format-when-a-format-is-not-specified

def __init__(self, dbr_version: tuple[int, int] | None):
version_cutoff = (8, 0)
self._skip_dbr = dbr_version is not None and dbr_version >= version_cutoff
JCZuurmond marked this conversation as resolved.
Show resolved Hide resolved

self._linter = NoFormatPythonLinter(
[
NoFormatPythonMatcher("writeTo", 1, 1),
NoFormatPythonMatcher("table", 1, 1),
NoFormatPythonMatcher("insertInto", 1, 2),
NoFormatPythonMatcher("saveAsTable", 1, 4, 2, "format"),
]
)

def lint(self, code: str) -> Iterable[Advice]:
if self._skip_dbr:
return

tree = ast.parse(code)
for node in ast.walk(tree):
yield from self._linter.lint(node)
100 changes: 100 additions & 0 deletions tests/unit/source_code/test_python_linter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from __future__ import annotations

import ast
import pytest

from databricks.labs.ucx.source_code.python_linter import ASTLinter, PythonLinter


Expand Down Expand Up @@ -127,3 +132,98 @@ def test_linter_returns_appended_relative_paths_with_os_path_abspath_alias():
linter = ASTLinter.parse(code)
appended = PythonLinter.list_appended_sys_paths(linter)
assert "relative_path" in [p.path for p in appended]


def get_statement_node(stmt: str) -> ast.stmt:
node = ast.parse(stmt)
return node.body[0]


@pytest.mark.parametrize("stmt", ["o.m1().m2().m3()", "a = o.m1().m2().m3()"])
def test_extract_callchain(migration_index, stmt):
node = get_statement_node(stmt)
act = ASTLinter(node).extract_callchain()
assert isinstance(act, ast.Call)
assert isinstance(act.func, ast.Attribute)
assert act.func.attr == "m3"


@pytest.mark.parametrize("stmt", ["a = 3", "[x+1 for x in xs]"])
def test_extract_callchain_none(migration_index, stmt):
node = get_statement_node(stmt)
act = ASTLinter(node).extract_callchain()
assert act is None


def test_extract_call_by_name(migration_index):
callchain = get_statement_node("o.m1().m2().m3()").value
act = ASTLinter(callchain).extract_call_by_name("m2")
assert isinstance(act, ast.Call)
assert isinstance(act.func, ast.Attribute)
assert act.func.attr == "m2"


def test_extract_call_by_name_none(migration_index):
callchain = get_statement_node("o.m1().m2().m3()").value
act = ASTLinter(callchain).extract_call_by_name("m5000")
assert act is None


@pytest.mark.parametrize(
"param",
[
{"stmt": "o.m1()", "arg_index": 1, "arg_name": "second", "expected": None},
{"stmt": "o.m1(3)", "arg_index": 1, "arg_name": "second", "expected": None},
{"stmt": "o.m1(first=3)", "arg_index": 1, "arg_name": "second", "expected": None},
{"stmt": "o.m1(4, 3)", "arg_index": None, "arg_name": None, "expected": None},
{"stmt": "o.m1(4, 3)", "arg_index": None, "arg_name": "second", "expected": None},
{"stmt": "o.m1(4, 3)", "arg_index": 1, "arg_name": "second", "expected": 3},
{"stmt": "o.m1(4, 3)", "arg_index": 1, "arg_name": None, "expected": 3},
{"stmt": "o.m1(first=4, second=3)", "arg_index": 1, "arg_name": "second", "expected": 3},
{"stmt": "o.m1(second=3, first=4)", "arg_index": 1, "arg_name": "second", "expected": 3},
{"stmt": "o.m1(second=3, first=4)", "arg_index": None, "arg_name": "second", "expected": 3},
{"stmt": "o.m1(second=3)", "arg_index": 1, "arg_name": "second", "expected": 3},
{"stmt": "o.m1(4, 3, 2)", "arg_index": 1, "arg_name": "second", "expected": 3},
],
)
def test_get_arg(migration_index, param):
call = get_statement_node(param["stmt"]).value
act = ASTLinter(call).get_arg(param["arg_index"], param["arg_name"])
if param["expected"] is None:
assert act is None
else:
assert isinstance(act, ast.Constant)
assert act.value == param["expected"]


@pytest.mark.parametrize(
"param",
[
{"stmt": "o.m1()", "expected": 0},
{"stmt": "o.m1(3)", "expected": 1},
{"stmt": "o.m1(first=3)", "expected": 1},
{"stmt": "o.m1(3, 3)", "expected": 2},
{"stmt": "o.m1(first=3, second=3)", "expected": 2},
{"stmt": "o.m1(3, second=3)", "expected": 2},
{"stmt": "o.m1(3, *b, **c, second=3)", "expected": 4},
],
)
def test_args_count(migration_index, param):
call = get_statement_node(param["stmt"]).value
act = ASTLinter(call).args_count()
assert param["expected"] == act


@pytest.mark.parametrize(
"param",
[
{"stmt": "a = x", "expected": False},
{"stmt": "a = 3", "expected": False},
{"stmt": "a = 'None'", "expected": False},
{"stmt": "a = None", "expected": True},
],
)
def test_is_none(migration_index, param):
val = get_statement_node(param["stmt"]).value
act = ASTLinter(val).is_none()
assert param["expected"] == act
Loading
Loading