From ed4c0e2f367f6f7fd6237db2b1893225bf2b8e0e Mon Sep 17 00:00:00 2001 From: Constantin Dumitrascu Date: Wed, 17 Apr 2024 12:52:23 -0700 Subject: [PATCH 1/4] Added Python linter for table creation with implicit format. Starting with DBR v8.0, the default table format changed from 'parquet' to 'delta'. This linter yields advice for situations where a 'writeTo', 'table', 'insertInto', or 'saveAsTable' method is invoked, and no '.format(...)' invocation is found in the same chain of calls. For 'saveAsTable' in particular, the alternative of passing the 'format' as a direct argument is also supported. --- .../labs/ucx/source_code/languages.py | 7 +- .../labs/ucx/source_code/python_ast_util.py | 69 +++++++++++ .../labs/ucx/source_code/table_creation.py | 108 ++++++++++++++++++ .../unit/source_code/test_python_ast_util.py | 67 +++++++++++ tests/unit/source_code/test_table_creation.py | 97 ++++++++++++++++ 5 files changed, 347 insertions(+), 1 deletion(-) create mode 100644 src/databricks/labs/ucx/source_code/python_ast_util.py create mode 100644 src/databricks/labs/ucx/source_code/table_creation.py create mode 100644 tests/unit/source_code/test_python_ast_util.py create mode 100644 tests/unit/source_code/test_table_creation.py diff --git a/src/databricks/labs/ucx/source_code/languages.py b/src/databricks/labs/ucx/source_code/languages.py index 30eb129403..e3e50ae809 100644 --- a/src/databricks/labs/ucx/source_code/languages.py +++ b/src/databricks/labs/ucx/source_code/languages.py @@ -5,6 +5,7 @@ from databricks.labs.ucx.source_code.pyspark import SparkSql from databricks.labs.ucx.source_code.queries import FromTable from databricks.labs.ucx.source_code.dbfs import DBFSUsageLinter, FromDbfsFolder +from databricks.labs.ucx.source_code.table_creation import DBRv8_0Linter class Languages: @@ -13,7 +14,11 @@ def __init__(self, index: MigrationIndex): from_table = FromTable(index) dbfs_from_folder = FromDbfsFolder() self._linters = { - Language.PYTHON: SequentialLinter([SparkSql(from_table, index), DBFSUsageLinter()]), + Language.PYTHON: SequentialLinter([ + SparkSql(from_table, index), + DBFSUsageLinter(), + DBRv8_0Linter(dbr_version=None), + ]), Language.SQL: SequentialLinter([from_table, dbfs_from_folder]), } self._fixers: dict[Language, list[Fixer]] = { diff --git a/src/databricks/labs/ucx/source_code/python_ast_util.py b/src/databricks/labs/ucx/source_code/python_ast_util.py new file mode 100644 index 0000000000..99a2b42bd2 --- /dev/null +++ b/src/databricks/labs/ucx/source_code/python_ast_util.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import ast +from dataclasses import dataclass + + +@dataclass +class Span: + """ Represents a (possibly multiline) source code span. """ + start_line: int + start_col: int + end_line: int + end_col: int + + +class AstUtil: + @staticmethod + def extract_callchain(node: ast.AST) -> ast.Call | None: + """ If 'node' is an assignment or expression, extract its full call-chain (if it has one) """ + call = None + if isinstance(node, ast.Assign): + call = node.value + elif isinstance(node, ast.Expr): + call = node.value + if not isinstance(call, ast.Call): + call = None + return call + + @staticmethod + def extract_call_by_name(node: ast.Call, name: str) -> ast.Call | None: + """ Given a call-chain, extract its sub-call by method name (if it has one) """ + while True: + if not isinstance(node, ast.Call): + return None + + func = node.func + if not isinstance(func, ast.Attribute): + return None + if func.attr == name: + return node + node = func.value + + @staticmethod + def args_count(node: ast.Call) -> int: + """ Count the number of arguments (positionals + keywords) """ + return len(node.args) + len(node.keywords) + + @staticmethod + def get_arg( + node: ast.Call, + arg_index: int | None, + arg_name: str | None, + ) -> ast.expr | None: + """ Extract the call argument identified by an optional position or name (if it has one) """ + if arg_index is not None and len(node.args) > arg_index: + return node.args[arg_index] + if arg_name is not None: + arg = [kw.value for kw in node.keywords if kw.arg == arg_name] + if len(arg) == 1: + return arg[0] + return None + + @staticmethod + def is_none(node: ast.expr) -> bool: + """ Check if the given AST expression is the None constant """ + if not isinstance(node, ast.Constant): + return False + return node.value is None + diff --git a/src/databricks/labs/ucx/source_code/table_creation.py b/src/databricks/labs/ucx/source_code/table_creation.py new file mode 100644 index 0000000000..8024fb6f32 --- /dev/null +++ b/src/databricks/labs/ucx/source_code/table_creation.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import ast +from collections.abc import Iterable, Iterator +from dataclasses import dataclass + +from databricks.labs.ucx.source_code.python_ast_util import ( + AstUtil, + Span, +) +from databricks.labs.ucx.source_code.base import ( + Advice, + Linter, +) + + +@dataclass +class NoFormatPythonMatcher: + """ Matches Python AST nodes where tables are created with implicit format. + It performs only the matching, while linting / fixing are separated concerns. + """ + method_name: str + min_args: int + max_args: int + # some method_names accept 'format' as a direct (optional) argument: + format_arg_index: int = None + format_arg_name: str = None + + def matches(self, node: ast.AST) -> bool: + return self.get_advice_span(node) is not None + + def get_advice_span(self, node: ast.AST) -> Span | None: + # retrieve full callchain: + callchain = AstUtil.extract_callchain(node) + if callchain is None: + return None + + # check presence of the table-creating method call: + call = AstUtil.extract_call_by_name(callchain, self.method_name) + if call is None: + return None + call_args_count = AstUtil.args_count(call) + if call_args_count < self.min_args or call_args_count > self.max_args: + return None + + # check presence of the format specifier: + format_arg = AstUtil.get_arg(call, self.format_arg_index, self.format_arg_name) + if format_arg is not None and not AstUtil.is_none(format_arg): + return None + format_call = AstUtil.extract_call_by_name(callchain, "format") + if format_call is not None: + return None + + # matched need for issuing advice: + return Span( + call.lineno, + call.col_offset, + call.end_lineno or 0, + call.end_col_offset or 0, + ) + + +@dataclass +class NoFormatPythonLinter(Linter): + """ Python linting for table-creation with implicit format """ + _matchers: list[NoFormatPythonMatcher] + + def matches(self, node: ast.AST) -> bool: + return any(m.matches(node) for m in self._matchers) + + def lint(self, node: ast.AST) -> Iterator[Advice]: + for matcher in self._matchers: + span = matcher.get_advice_span(node) + if span is not None: + yield Advice( + code="table-migrate", + message="The default format changed in Databricks Runtime 8.0, from Parquet to Delta", + start_line=span.start_line, + start_col=span.start_col, + end_line=span.end_line, + end_col=span.end_col, + ) + + +class DBRv8_0Linter(Linter): + """ Performs Python linting for backwards incompatible changes in DBR version 8.0. + Specifically, it yields advice for table-creation with implicit format. + """ + # https://docs.databricks.com/en/archive/runtime-release-notes/8.0.html#delta-is-now-the-default-format-when-a-format-is-not-specified + + def __init__(self, dbr_version: tuple[int, int] | None): + version_cutoff = (8, 0) + self._skip_dbr = dbr_version is not None and dbr_version >= version_cutoff + + self._linter = NoFormatPythonLinter([ + NoFormatPythonMatcher("writeTo", 1, 1), + NoFormatPythonMatcher("table", 1, 1), + NoFormatPythonMatcher("insertInto", 1, 2), + NoFormatPythonMatcher("saveAsTable", 1, 4, 2, "format"), + ]) + + def lint(self, code: str) -> Iterable[Advice]: + if self._skip_dbr: + return + + tree = ast.parse(code) + for node in ast.walk(tree): + yield from self._linter.lint(node) diff --git a/tests/unit/source_code/test_python_ast_util.py b/tests/unit/source_code/test_python_ast_util.py new file mode 100644 index 0000000000..19540eb478 --- /dev/null +++ b/tests/unit/source_code/test_python_ast_util.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import pytest +import ast + +from databricks.labs.ucx.source_code.python_ast_util import AstUtil + + +def get_statement_node(stmt: str) -> ast.stmt: + node = ast.parse(stmt) + return node.body[0] + + +@pytest.mark.parametrize("stmt", ["o.m1().m2().m3()", "a = o.m1().m2().m3()"]) +def test_extract_callchain(migration_index, stmt): + node = get_statement_node(stmt) + act = AstUtil.extract_callchain(node) + assert isinstance(act, ast.Call) + assert isinstance(act.func, ast.Attribute) + assert "m3" == act.func.attr + + +@pytest.mark.parametrize("stmt", ["a = 3", "[x+1 for x in xs]"]) +def test_extract_callchain_none(migration_index, stmt): + node = get_statement_node(stmt) + act = AstUtil.extract_callchain(node) + assert act is None + + +def test_extract_call_by_name(migration_index): + callchain = get_statement_node("o.m1().m2().m3()").value + act = AstUtil.extract_call_by_name(callchain, "m2") + assert isinstance(act, ast.Call) + assert isinstance(act.func, ast.Attribute) + assert "m2" == act.func.attr + + +def test_extract_call_by_name_none(migration_index): + callchain = get_statement_node("o.m1().m2().m3()").value + act = AstUtil.extract_call_by_name(callchain, "m5000") + assert act is None + + +@pytest.mark.parametrize("param", [ + {"stmt": "o.m1()", "expected": 0}, + {"stmt": "o.m1(3)", "expected": 1}, + {"stmt": "o.m1(first=3)", "expected": 1}, + {"stmt": "o.m1(3, 3)", "expected": 2}, + {"stmt": "o.m1(first=3, second=3)", "expected": 2}, + {"stmt": "o.m1(3, second=3)", "expected": 2}, +]) +def test_args_count(migration_index, param): + call = get_statement_node(param["stmt"]).value + act = AstUtil.args_count(call) + assert param["expected"] == act + + +@pytest.mark.parametrize("param", [ + {"stmt": "a = x", "expected": False}, + {"stmt": "a = 3", "expected": False}, + {"stmt": "a = 'None'", "expected": False}, + {"stmt": "a = None", "expected": True}, +]) +def test_is_none(migration_index, param): + val = get_statement_node(param["stmt"]).value + act = AstUtil.is_none(val) + assert param["expected"] == act diff --git a/tests/unit/source_code/test_table_creation.py b/tests/unit/source_code/test_table_creation.py new file mode 100644 index 0000000000..f673e84209 --- /dev/null +++ b/tests/unit/source_code/test_table_creation.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import pytest + +from databricks.labs.ucx.source_code.base import Advice +from databricks.labs.ucx.source_code.table_creation import DBRv8_0Linter + + +METHOD_NAMES = [ + "writeTo", + "table", + "insertInto", + "saveAsTable", +] +ASSIGN = [True, False] +DBR_VERSIONS = [ # version as tuple of ints: (major, minor) + { "version": None, "suppress": False, }, + { "version": (7, 9), "suppress": False, }, + { "version": (8, 0), "suppress": True, }, + { "version": (9, 0), "suppress": True, }, +] + + +def get_code(assign: bool, stmt: str) -> str: + """ Return code snippet to be linted, with customizable statement """ + assign_str = 'df = ' if assign else '' + return f""" +spark.read.csv("s3://bucket/path") +for i in range(10): + {assign_str}{stmt} + do_stuff_with_df(df) +""" + + +def get_advice(assign: bool, method_name: str, args_len: int) -> Advice: + """ Repeated boilerplate Advice constructing """ + return Advice( + code="table-migrate", + message="The default format changed in Databricks Runtime 8.0, from Parquet to Delta", + start_line=4, + start_col=(9 if assign else 4), + end_line=4, + end_col=(29 if assign else 24) + len(method_name) + args_len, + ) + + +def lint( + code: str, + dbr_version: tuple[int, int] | None = (7, 9), +) -> list[Advice]: + """ Invoke linting for the given dbr version """ + return list(DBRv8_0Linter(dbr_version).lint(code)) + + +@pytest.mark.parametrize("method_name", METHOD_NAMES) +@pytest.mark.parametrize("assign", ASSIGN) +def test_has_format_call(migration_index, method_name, assign): + """ Tests that calling "format" doesn't yield advice """ + old_code = get_code(assign, f'spark.foo().format("delta").bar().{method_name}("catalog.db.table").baz()') + assert [] == lint(old_code) + + +@pytest.mark.parametrize("method_name", METHOD_NAMES) +@pytest.mark.parametrize("assign", ASSIGN) +def test_no_format(migration_index, method_name, assign): + """ Tests that not setting "format" yields advice (both in assignment or standalone callchain) """ + old_code = get_code(assign, f'spark.foo().bar().{method_name}("catalog.db.table").baz()') + assert [get_advice(assign, method_name, 18)] == lint(old_code) + + +@pytest.mark.parametrize("assign", ASSIGN) +def test_has_format_arg(migration_index, assign): + """ Tests that setting "format" positional arg doesn't yield advice """ + old_code = get_code(assign, f'spark.foo().format("delta").bar().saveAsTable("catalog.db.table", "csv").baz()') + assert [] == lint(old_code) + + +@pytest.mark.parametrize("assign", ASSIGN) +def test_has_format_kwarg(migration_index, assign): + """ Tests that setting "format" kwarg doesn't yield advice """ + old_code = get_code(assign, f'spark.foo().format("delta").bar().saveAsTable("catalog.db.table", format="csv").baz()') + assert [] == lint(old_code) + + +@pytest.mark.parametrize("assign", ASSIGN) +def test_has_format_arg_none(migration_index, assign): + """ Tests that explicitly setting "format" parameter to None yields advice """ + old_code = get_code(assign, f'spark.foo().bar().saveAsTable("catalog.db.table", format=None).baz()') + assert [get_advice(assign, "saveAsTable", 31)] == lint(old_code) + + +@pytest.mark.parametrize("dbr_version", DBR_VERSIONS) +def test_has_format_arg_none(migration_index, dbr_version): + """ Tests the DBR version cutoff filter """ + old_code = get_code(False, f'spark.foo().bar().table("catalog.db.table").baz()') + expected = [] if dbr_version["suppress"] else [get_advice(False, 'table', 18)] + assert expected == lint(old_code, dbr_version["version"]) From c1ac71c8934543e8400999667a228ecbab25d615 Mon Sep 17 00:00:00 2001 From: Constantin Dumitrascu Date: Thu, 18 Apr 2024 07:32:43 -0700 Subject: [PATCH 2/4] Fix formatting and boost coverage --- .../labs/ucx/source_code/languages.py | 14 ++-- .../labs/ucx/source_code/python_ast_util.py | 25 +++--- .../labs/ucx/source_code/table_creation.py | 37 +++++---- .../unit/source_code/test_python_ast_util.py | 67 ++++++++++++---- tests/unit/source_code/test_table_creation.py | 77 +++++++++++++------ 5 files changed, 140 insertions(+), 80 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/languages.py b/src/databricks/labs/ucx/source_code/languages.py index e3e50ae809..ee22222fc9 100644 --- a/src/databricks/labs/ucx/source_code/languages.py +++ b/src/databricks/labs/ucx/source_code/languages.py @@ -5,7 +5,7 @@ from databricks.labs.ucx.source_code.pyspark import SparkSql from databricks.labs.ucx.source_code.queries import FromTable from databricks.labs.ucx.source_code.dbfs import DBFSUsageLinter, FromDbfsFolder -from databricks.labs.ucx.source_code.table_creation import DBRv8_0Linter +from databricks.labs.ucx.source_code.table_creation import DBRv8d0Linter class Languages: @@ -14,11 +14,13 @@ def __init__(self, index: MigrationIndex): from_table = FromTable(index) dbfs_from_folder = FromDbfsFolder() self._linters = { - Language.PYTHON: SequentialLinter([ - SparkSql(from_table, index), - DBFSUsageLinter(), - DBRv8_0Linter(dbr_version=None), - ]), + Language.PYTHON: SequentialLinter( + [ + SparkSql(from_table, index), + DBFSUsageLinter(), + DBRv8d0Linter(dbr_version=None), + ] + ), Language.SQL: SequentialLinter([from_table, dbfs_from_folder]), } self._fixers: dict[Language, list[Fixer]] = { diff --git a/src/databricks/labs/ucx/source_code/python_ast_util.py b/src/databricks/labs/ucx/source_code/python_ast_util.py index 99a2b42bd2..af67b73d09 100644 --- a/src/databricks/labs/ucx/source_code/python_ast_util.py +++ b/src/databricks/labs/ucx/source_code/python_ast_util.py @@ -6,7 +6,8 @@ @dataclass class Span: - """ Represents a (possibly multiline) source code span. """ + """Represents a (possibly multiline) source code span.""" + start_line: int start_col: int end_line: int @@ -16,7 +17,7 @@ class Span: class AstUtil: @staticmethod def extract_callchain(node: ast.AST) -> ast.Call | None: - """ If 'node' is an assignment or expression, extract its full call-chain (if it has one) """ + """If 'node' is an assignment or expression, extract its full call-chain (if it has one)""" call = None if isinstance(node, ast.Assign): call = node.value @@ -28,30 +29,29 @@ def extract_callchain(node: ast.AST) -> ast.Call | None: @staticmethod def extract_call_by_name(node: ast.Call, name: str) -> ast.Call | None: - """ Given a call-chain, extract its sub-call by method name (if it has one) """ + """Given a call-chain, extract its sub-call by method name (if it has one)""" while True: - if not isinstance(node, ast.Call): - return None - func = node.func if not isinstance(func, ast.Attribute): return None if func.attr == name: return node + if not isinstance(func.value, ast.Call): + return None node = func.value @staticmethod def args_count(node: ast.Call) -> int: - """ Count the number of arguments (positionals + keywords) """ + """Count the number of arguments (positionals + keywords)""" return len(node.args) + len(node.keywords) @staticmethod def get_arg( - node: ast.Call, - arg_index: int | None, - arg_name: str | None, + node: ast.Call, + arg_index: int | None, + arg_name: str | None, ) -> ast.expr | None: - """ Extract the call argument identified by an optional position or name (if it has one) """ + """Extract the call argument identified by an optional position or name (if it has one)""" if arg_index is not None and len(node.args) > arg_index: return node.args[arg_index] if arg_name is not None: @@ -62,8 +62,7 @@ def get_arg( @staticmethod def is_none(node: ast.expr) -> bool: - """ Check if the given AST expression is the None constant """ + """Check if the given AST expression is the None constant""" if not isinstance(node, ast.Constant): return False return node.value is None - diff --git a/src/databricks/labs/ucx/source_code/table_creation.py b/src/databricks/labs/ucx/source_code/table_creation.py index 8024fb6f32..80b66ea2a7 100644 --- a/src/databricks/labs/ucx/source_code/table_creation.py +++ b/src/databricks/labs/ucx/source_code/table_creation.py @@ -16,18 +16,16 @@ @dataclass class NoFormatPythonMatcher: - """ Matches Python AST nodes where tables are created with implicit format. + """Matches Python AST nodes where tables are created with implicit format. It performs only the matching, while linting / fixing are separated concerns. """ + method_name: str min_args: int max_args: int # some method_names accept 'format' as a direct (optional) argument: - format_arg_index: int = None - format_arg_name: str = None - - def matches(self, node: ast.AST) -> bool: - return self.get_advice_span(node) is not None + format_arg_index: int | None = None + format_arg_name: str | None = None def get_advice_span(self, node: ast.AST) -> Span | None: # retrieve full callchain: @@ -61,12 +59,10 @@ def get_advice_span(self, node: ast.AST) -> Span | None: @dataclass -class NoFormatPythonLinter(Linter): - """ Python linting for table-creation with implicit format """ - _matchers: list[NoFormatPythonMatcher] +class NoFormatPythonLinter: + """Python linting for table-creation with implicit format""" - def matches(self, node: ast.AST) -> bool: - return any(m.matches(node) for m in self._matchers) + _matchers: list[NoFormatPythonMatcher] def lint(self, node: ast.AST) -> Iterator[Advice]: for matcher in self._matchers: @@ -82,22 +78,25 @@ def lint(self, node: ast.AST) -> Iterator[Advice]: ) -class DBRv8_0Linter(Linter): - """ Performs Python linting for backwards incompatible changes in DBR version 8.0. +class DBRv8d0Linter(Linter): + """Performs Python linting for backwards incompatible changes in DBR version 8.0. Specifically, it yields advice for table-creation with implicit format. """ + # https://docs.databricks.com/en/archive/runtime-release-notes/8.0.html#delta-is-now-the-default-format-when-a-format-is-not-specified def __init__(self, dbr_version: tuple[int, int] | None): version_cutoff = (8, 0) self._skip_dbr = dbr_version is not None and dbr_version >= version_cutoff - self._linter = NoFormatPythonLinter([ - NoFormatPythonMatcher("writeTo", 1, 1), - NoFormatPythonMatcher("table", 1, 1), - NoFormatPythonMatcher("insertInto", 1, 2), - NoFormatPythonMatcher("saveAsTable", 1, 4, 2, "format"), - ]) + self._linter = NoFormatPythonLinter( + [ + NoFormatPythonMatcher("writeTo", 1, 1), + NoFormatPythonMatcher("table", 1, 1), + NoFormatPythonMatcher("insertInto", 1, 2), + NoFormatPythonMatcher("saveAsTable", 1, 4, 2, "format"), + ] + ) def lint(self, code: str) -> Iterable[Advice]: if self._skip_dbr: diff --git a/tests/unit/source_code/test_python_ast_util.py b/tests/unit/source_code/test_python_ast_util.py index 19540eb478..d60108b901 100644 --- a/tests/unit/source_code/test_python_ast_util.py +++ b/tests/unit/source_code/test_python_ast_util.py @@ -1,7 +1,7 @@ from __future__ import annotations -import pytest import ast +import pytest from databricks.labs.ucx.source_code.python_ast_util import AstUtil @@ -17,7 +17,7 @@ def test_extract_callchain(migration_index, stmt): act = AstUtil.extract_callchain(node) assert isinstance(act, ast.Call) assert isinstance(act.func, ast.Attribute) - assert "m3" == act.func.attr + assert act.func.attr == "m3" @pytest.mark.parametrize("stmt", ["a = 3", "[x+1 for x in xs]"]) @@ -32,7 +32,7 @@ def test_extract_call_by_name(migration_index): act = AstUtil.extract_call_by_name(callchain, "m2") assert isinstance(act, ast.Call) assert isinstance(act.func, ast.Attribute) - assert "m2" == act.func.attr + assert act.func.attr == "m2" def test_extract_call_by_name_none(migration_index): @@ -41,26 +41,59 @@ def test_extract_call_by_name_none(migration_index): assert act is None -@pytest.mark.parametrize("param", [ - {"stmt": "o.m1()", "expected": 0}, - {"stmt": "o.m1(3)", "expected": 1}, - {"stmt": "o.m1(first=3)", "expected": 1}, - {"stmt": "o.m1(3, 3)", "expected": 2}, - {"stmt": "o.m1(first=3, second=3)", "expected": 2}, - {"stmt": "o.m1(3, second=3)", "expected": 2}, -]) +@pytest.mark.parametrize( + "param", + [ + {"stmt": "o.m1()", "arg_index": 1, "arg_name": "second", "expected": None}, + {"stmt": "o.m1(3)", "arg_index": 1, "arg_name": "second", "expected": None}, + {"stmt": "o.m1(first=3)", "arg_index": 1, "arg_name": "second", "expected": None}, + {"stmt": "o.m1(4, 3)", "arg_index": None, "arg_name": None, "expected": None}, + {"stmt": "o.m1(4, 3)", "arg_index": None, "arg_name": "second", "expected": None}, + {"stmt": "o.m1(4, 3)", "arg_index": 1, "arg_name": "second", "expected": 3}, + {"stmt": "o.m1(4, 3)", "arg_index": 1, "arg_name": None, "expected": 3}, + {"stmt": "o.m1(first=4, second=3)", "arg_index": 1, "arg_name": "second", "expected": 3}, + {"stmt": "o.m1(second=3, first=4)", "arg_index": 1, "arg_name": "second", "expected": 3}, + {"stmt": "o.m1(second=3, first=4)", "arg_index": None, "arg_name": "second", "expected": 3}, + {"stmt": "o.m1(second=3)", "arg_index": 1, "arg_name": "second", "expected": 3}, + {"stmt": "o.m1(4, 3, 2)", "arg_index": 1, "arg_name": "second", "expected": 3}, + ], +) +def test_get_arg(migration_index, param): + call = get_statement_node(param["stmt"]).value + act = AstUtil.get_arg(call, param["arg_index"], param["arg_name"]) + if param["expected"] is None: + assert param["expected"] is None + else: + assert isinstance(act, ast.Constant) + assert act.value == param["expected"] + + +@pytest.mark.parametrize( + "param", + [ + {"stmt": "o.m1()", "expected": 0}, + {"stmt": "o.m1(3)", "expected": 1}, + {"stmt": "o.m1(first=3)", "expected": 1}, + {"stmt": "o.m1(3, 3)", "expected": 2}, + {"stmt": "o.m1(first=3, second=3)", "expected": 2}, + {"stmt": "o.m1(3, second=3)", "expected": 2}, + ], +) def test_args_count(migration_index, param): call = get_statement_node(param["stmt"]).value act = AstUtil.args_count(call) assert param["expected"] == act -@pytest.mark.parametrize("param", [ - {"stmt": "a = x", "expected": False}, - {"stmt": "a = 3", "expected": False}, - {"stmt": "a = 'None'", "expected": False}, - {"stmt": "a = None", "expected": True}, -]) +@pytest.mark.parametrize( + "param", + [ + {"stmt": "a = x", "expected": False}, + {"stmt": "a = 3", "expected": False}, + {"stmt": "a = 'None'", "expected": False}, + {"stmt": "a = None", "expected": True}, + ], +) def test_is_none(migration_index, param): val = get_statement_node(param["stmt"]).value act = AstUtil.is_none(val) diff --git a/tests/unit/source_code/test_table_creation.py b/tests/unit/source_code/test_table_creation.py index f673e84209..6214c31db3 100644 --- a/tests/unit/source_code/test_table_creation.py +++ b/tests/unit/source_code/test_table_creation.py @@ -3,7 +3,7 @@ import pytest from databricks.labs.ucx.source_code.base import Advice -from databricks.labs.ucx.source_code.table_creation import DBRv8_0Linter +from databricks.labs.ucx.source_code.table_creation import DBRv8d0Linter METHOD_NAMES = [ @@ -14,15 +14,27 @@ ] ASSIGN = [True, False] DBR_VERSIONS = [ # version as tuple of ints: (major, minor) - { "version": None, "suppress": False, }, - { "version": (7, 9), "suppress": False, }, - { "version": (8, 0), "suppress": True, }, - { "version": (9, 0), "suppress": True, }, + { + "version": None, + "suppress": False, + }, + { + "version": (7, 9), + "suppress": False, + }, + { + "version": (8, 0), + "suppress": True, + }, + { + "version": (9, 0), + "suppress": True, + }, ] def get_code(assign: bool, stmt: str) -> str: - """ Return code snippet to be linted, with customizable statement """ + """Return code snippet to be linted, with customizable statement""" assign_str = 'df = ' if assign else '' return f""" spark.read.csv("s3://bucket/path") @@ -33,7 +45,7 @@ def get_code(assign: bool, stmt: str) -> str: def get_advice(assign: bool, method_name: str, args_len: int) -> Advice: - """ Repeated boilerplate Advice constructing """ + """Repeated boilerplate Advice constructing""" return Advice( code="table-migrate", message="The default format changed in Databricks Runtime 8.0, from Parquet to Delta", @@ -45,53 +57,68 @@ def get_advice(assign: bool, method_name: str, args_len: int) -> Advice: def lint( - code: str, - dbr_version: tuple[int, int] | None = (7, 9), + code: str, + dbr_version: tuple[int, int] | None = (7, 9), ) -> list[Advice]: - """ Invoke linting for the given dbr version """ - return list(DBRv8_0Linter(dbr_version).lint(code)) + """Invoke linting for the given dbr version""" + return list(DBRv8d0Linter(dbr_version).lint(code)) @pytest.mark.parametrize("method_name", METHOD_NAMES) @pytest.mark.parametrize("assign", ASSIGN) def test_has_format_call(migration_index, method_name, assign): - """ Tests that calling "format" doesn't yield advice """ + """Tests that calling "format" doesn't yield advice""" old_code = get_code(assign, f'spark.foo().format("delta").bar().{method_name}("catalog.db.table").baz()') - assert [] == lint(old_code) + assert not lint(old_code) @pytest.mark.parametrize("method_name", METHOD_NAMES) @pytest.mark.parametrize("assign", ASSIGN) def test_no_format(migration_index, method_name, assign): - """ Tests that not setting "format" yields advice (both in assignment or standalone callchain) """ + """Tests that not setting "format" yields advice (both in assignment or standalone callchain)""" old_code = get_code(assign, f'spark.foo().bar().{method_name}("catalog.db.table").baz()') assert [get_advice(assign, method_name, 18)] == lint(old_code) +@pytest.mark.parametrize( + "params", + [ + {"stmt": 'spark.foo().bar().table().baz()', "expected": False}, + {"stmt": 'spark.foo().bar().table("catalog.db.table").baz()', "expected": True}, + {"stmt": 'spark.foo().bar().table("catalog.db.table", "xyz").baz()', "expected": False}, + {"stmt": 'spark.foo().bar().table("catalog.db.table", fmt="xyz").baz()', "expected": False}, + ], +) +def test_no_format_args_count(migration_index, params): + """Tests that the number of arguments to table creation call is considered in matching""" + old_code = get_code(False, params["stmt"]) + assert (not params["expected"]) == (not lint(old_code)) + + @pytest.mark.parametrize("assign", ASSIGN) def test_has_format_arg(migration_index, assign): - """ Tests that setting "format" positional arg doesn't yield advice """ - old_code = get_code(assign, f'spark.foo().format("delta").bar().saveAsTable("catalog.db.table", "csv").baz()') - assert [] == lint(old_code) + """Tests that setting "format" positional arg doesn't yield advice""" + old_code = get_code(assign, 'spark.foo().format("delta").bar().saveAsTable("catalog.db.table", "csv").baz()') + assert not lint(old_code) @pytest.mark.parametrize("assign", ASSIGN) def test_has_format_kwarg(migration_index, assign): - """ Tests that setting "format" kwarg doesn't yield advice """ - old_code = get_code(assign, f'spark.foo().format("delta").bar().saveAsTable("catalog.db.table", format="csv").baz()') - assert [] == lint(old_code) + """Tests that setting "format" kwarg doesn't yield advice""" + old_code = get_code(assign, 'spark.foo().format("delta").bar().saveAsTable("catalog.db.table", format="csv").baz()') + assert not lint(old_code) @pytest.mark.parametrize("assign", ASSIGN) def test_has_format_arg_none(migration_index, assign): - """ Tests that explicitly setting "format" parameter to None yields advice """ - old_code = get_code(assign, f'spark.foo().bar().saveAsTable("catalog.db.table", format=None).baz()') + """Tests that explicitly setting "format" parameter to None yields advice""" + old_code = get_code(assign, 'spark.foo().bar().saveAsTable("catalog.db.table", format=None).baz()') assert [get_advice(assign, "saveAsTable", 31)] == lint(old_code) @pytest.mark.parametrize("dbr_version", DBR_VERSIONS) -def test_has_format_arg_none(migration_index, dbr_version): - """ Tests the DBR version cutoff filter """ - old_code = get_code(False, f'spark.foo().bar().table("catalog.db.table").baz()') +def test_dbr_version_filter(migration_index, dbr_version): + """Tests the DBR version cutoff filter""" + old_code = get_code(False, 'spark.foo().bar().table("catalog.db.table").baz()') expected = [] if dbr_version["suppress"] else [get_advice(False, 'table', 18)] assert expected == lint(old_code, dbr_version["version"]) From 35306932f57a1bab656463420069277d30af8df4 Mon Sep 17 00:00:00 2001 From: Constantin Dumitrascu Date: Thu, 18 Apr 2024 09:02:08 -0700 Subject: [PATCH 3/4] Feedback suggestions. --- .../labs/ucx/source_code/table_creation.py | 18 ++++++++++++------ tests/unit/source_code/test_python_ast_util.py | 3 ++- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/table_creation.py b/src/databricks/labs/ucx/source_code/table_creation.py index 80b66ea2a7..4e050bff0a 100644 --- a/src/databricks/labs/ucx/source_code/table_creation.py +++ b/src/databricks/labs/ucx/source_code/table_creation.py @@ -28,12 +28,12 @@ class NoFormatPythonMatcher: format_arg_name: str | None = None def get_advice_span(self, node: ast.AST) -> Span | None: - # retrieve full callchain: + # Check 1: retrieve full callchain: callchain = AstUtil.extract_callchain(node) if callchain is None: return None - # check presence of the table-creating method call: + # Check 2: check presence of the table-creating method call: call = AstUtil.extract_call_by_name(callchain, self.method_name) if call is None: return None @@ -41,15 +41,21 @@ def get_advice_span(self, node: ast.AST) -> Span | None: if call_args_count < self.min_args or call_args_count > self.max_args: return None - # check presence of the format specifier: + # Check 3: check presence of the format specifier: + # Option A: format specifier may be given as a direct parameter to the table-creating call + # example: df.saveToTable("c.db.table", format="csv") format_arg = AstUtil.get_arg(call, self.format_arg_index, self.format_arg_name) if format_arg is not None and not AstUtil.is_none(format_arg): + # i.e., found an explicit "format" argument, and its value is not None. return None + # Option B. format specifier may be a separate ".format(...)" call in this callchain + # example: df.format("csv").saveToTable("c.db.table") format_call = AstUtil.extract_call_by_name(callchain, "format") if format_call is not None: + # i.e., found an explicit ".format(...)" call in this chain. return None - # matched need for issuing advice: + # Finally: matched the need for advice, so return the corresponding source span: return Span( call.lineno, call.col_offset, @@ -58,11 +64,11 @@ def get_advice_span(self, node: ast.AST) -> Span | None: ) -@dataclass class NoFormatPythonLinter: """Python linting for table-creation with implicit format""" - _matchers: list[NoFormatPythonMatcher] + def __init__(self, matchers: list[NoFormatPythonMatcher]): + self._matchers = matchers def lint(self, node: ast.AST) -> Iterator[Advice]: for matcher in self._matchers: diff --git a/tests/unit/source_code/test_python_ast_util.py b/tests/unit/source_code/test_python_ast_util.py index d60108b901..eae838e254 100644 --- a/tests/unit/source_code/test_python_ast_util.py +++ b/tests/unit/source_code/test_python_ast_util.py @@ -62,7 +62,7 @@ def test_get_arg(migration_index, param): call = get_statement_node(param["stmt"]).value act = AstUtil.get_arg(call, param["arg_index"], param["arg_name"]) if param["expected"] is None: - assert param["expected"] is None + assert act is None else: assert isinstance(act, ast.Constant) assert act.value == param["expected"] @@ -77,6 +77,7 @@ def test_get_arg(migration_index, param): {"stmt": "o.m1(3, 3)", "expected": 2}, {"stmt": "o.m1(first=3, second=3)", "expected": 2}, {"stmt": "o.m1(3, second=3)", "expected": 2}, + {"stmt": "o.m1(3, *b, **c, second=3)", "expected": 4}, ], ) def test_args_count(migration_index, param): From 543b35e18f44adc0ad770533b7a0956c52be26cd Mon Sep 17 00:00:00 2001 From: Constantin Dumitrascu Date: Thu, 18 Apr 2024 15:21:03 -0700 Subject: [PATCH 4/4] PR suggestions --- .../labs/ucx/source_code/python_ast_util.py | 68 ------------ .../labs/ucx/source_code/python_linter.py | 52 +++++++++ .../labs/ucx/source_code/table_creation.py | 50 +++++---- .../unit/source_code/test_python_ast_util.py | 101 ------------------ tests/unit/source_code/test_python_linter.py | 100 +++++++++++++++++ 5 files changed, 181 insertions(+), 190 deletions(-) delete mode 100644 src/databricks/labs/ucx/source_code/python_ast_util.py delete mode 100644 tests/unit/source_code/test_python_ast_util.py diff --git a/src/databricks/labs/ucx/source_code/python_ast_util.py b/src/databricks/labs/ucx/source_code/python_ast_util.py deleted file mode 100644 index af67b73d09..0000000000 --- a/src/databricks/labs/ucx/source_code/python_ast_util.py +++ /dev/null @@ -1,68 +0,0 @@ -from __future__ import annotations - -import ast -from dataclasses import dataclass - - -@dataclass -class Span: - """Represents a (possibly multiline) source code span.""" - - start_line: int - start_col: int - end_line: int - end_col: int - - -class AstUtil: - @staticmethod - def extract_callchain(node: ast.AST) -> ast.Call | None: - """If 'node' is an assignment or expression, extract its full call-chain (if it has one)""" - call = None - if isinstance(node, ast.Assign): - call = node.value - elif isinstance(node, ast.Expr): - call = node.value - if not isinstance(call, ast.Call): - call = None - return call - - @staticmethod - def extract_call_by_name(node: ast.Call, name: str) -> ast.Call | None: - """Given a call-chain, extract its sub-call by method name (if it has one)""" - while True: - func = node.func - if not isinstance(func, ast.Attribute): - return None - if func.attr == name: - return node - if not isinstance(func.value, ast.Call): - return None - node = func.value - - @staticmethod - def args_count(node: ast.Call) -> int: - """Count the number of arguments (positionals + keywords)""" - return len(node.args) + len(node.keywords) - - @staticmethod - def get_arg( - node: ast.Call, - arg_index: int | None, - arg_name: str | None, - ) -> ast.expr | None: - """Extract the call argument identified by an optional position or name (if it has one)""" - if arg_index is not None and len(node.args) > arg_index: - return node.args[arg_index] - if arg_name is not None: - arg = [kw.value for kw in node.keywords if kw.arg == arg_name] - if len(arg) == 1: - return arg[0] - return None - - @staticmethod - def is_none(node: ast.expr) -> bool: - """Check if the given AST expression is the None constant""" - if not isinstance(node, ast.Constant): - return False - return node.value is None diff --git a/src/databricks/labs/ucx/source_code/python_linter.py b/src/databricks/labs/ucx/source_code/python_linter.py index 04061eb6dc..2e4001e6c4 100644 --- a/src/databricks/labs/ucx/source_code/python_linter.py +++ b/src/databricks/labs/ucx/source_code/python_linter.py @@ -177,6 +177,58 @@ def collect_appended_sys_paths(self): visitor.visit(self._root) return visitor.appended_paths + def extract_callchain(self) -> ast.Call | None: + """If 'node' is an assignment or expression, extract its full call-chain (if it has one)""" + call = None + if isinstance(self._root, ast.Assign): + call = self._root.value + elif isinstance(self._root, ast.Expr): + call = self._root.value + if not isinstance(call, ast.Call): + call = None + return call + + def extract_call_by_name(self, name: str) -> ast.Call | None: + """Given a call-chain, extract its sub-call by method name (if it has one)""" + assert isinstance(self._root, ast.Call) + node = self._root + while True: + func = node.func + if not isinstance(func, ast.Attribute): + return None + if func.attr == name: + return node + if not isinstance(func.value, ast.Call): + return None + node = func.value + + def args_count(self) -> int: + """Count the number of arguments (positionals + keywords)""" + assert isinstance(self._root, ast.Call) + return len(self._root.args) + len(self._root.keywords) + + def get_arg( + self, + arg_index: int | None, + arg_name: str | None, + ) -> ast.expr | None: + """Extract the call argument identified by an optional position or name (if it has one)""" + assert isinstance(self._root, ast.Call) + if arg_index is not None and len(self._root.args) > arg_index: + return self._root.args[arg_index] + if arg_name is not None: + arg = [kw.value for kw in self._root.keywords if kw.arg == arg_name] + if len(arg) == 1: + return arg[0] + return None + + def is_none(self) -> bool: + """Check if the given AST expression is the None constant""" + assert isinstance(self._root, ast.expr) + if not isinstance(self._root, ast.Constant): + return False + return self._root.value is None + class PythonLinter(Linter): diff --git a/src/databricks/labs/ucx/source_code/table_creation.py b/src/databricks/labs/ucx/source_code/table_creation.py index 4e050bff0a..41a57d6dfe 100644 --- a/src/databricks/labs/ucx/source_code/table_creation.py +++ b/src/databricks/labs/ucx/source_code/table_creation.py @@ -4,16 +4,26 @@ from collections.abc import Iterable, Iterator from dataclasses import dataclass -from databricks.labs.ucx.source_code.python_ast_util import ( - AstUtil, - Span, -) + +from databricks.labs.ucx.source_code.python_linter import ASTLinter from databricks.labs.ucx.source_code.base import ( Advice, Linter, ) +@dataclass +class Position: + line: int + character: int + + +@dataclass +class Range: + start: Position + end: Position + + @dataclass class NoFormatPythonMatcher: """Matches Python AST nodes where tables are created with implicit format. @@ -27,40 +37,38 @@ class NoFormatPythonMatcher: format_arg_index: int | None = None format_arg_name: str | None = None - def get_advice_span(self, node: ast.AST) -> Span | None: + def get_advice_span(self, node: ast.AST) -> Range | None: # Check 1: retrieve full callchain: - callchain = AstUtil.extract_callchain(node) + callchain = ASTLinter(node).extract_callchain() if callchain is None: return None # Check 2: check presence of the table-creating method call: - call = AstUtil.extract_call_by_name(callchain, self.method_name) + call = ASTLinter(callchain).extract_call_by_name(self.method_name) if call is None: return None - call_args_count = AstUtil.args_count(call) + call_args_count = ASTLinter(call).args_count() if call_args_count < self.min_args or call_args_count > self.max_args: return None # Check 3: check presence of the format specifier: # Option A: format specifier may be given as a direct parameter to the table-creating call # example: df.saveToTable("c.db.table", format="csv") - format_arg = AstUtil.get_arg(call, self.format_arg_index, self.format_arg_name) - if format_arg is not None and not AstUtil.is_none(format_arg): + format_arg = ASTLinter(call).get_arg(self.format_arg_index, self.format_arg_name) + if format_arg is not None and not ASTLinter(format_arg).is_none(): # i.e., found an explicit "format" argument, and its value is not None. return None # Option B. format specifier may be a separate ".format(...)" call in this callchain # example: df.format("csv").saveToTable("c.db.table") - format_call = AstUtil.extract_call_by_name(callchain, "format") + format_call = ASTLinter(callchain).extract_call_by_name("format") if format_call is not None: # i.e., found an explicit ".format(...)" call in this chain. return None - # Finally: matched the need for advice, so return the corresponding source span: - return Span( - call.lineno, - call.col_offset, - call.end_lineno or 0, - call.end_col_offset or 0, + # Finally: matched the need for advice, so return the corresponding source range: + return Range( + Position(call.lineno, call.col_offset), + Position(call.end_lineno or 0, call.end_col_offset or 0), ) @@ -77,10 +85,10 @@ def lint(self, node: ast.AST) -> Iterator[Advice]: yield Advice( code="table-migrate", message="The default format changed in Databricks Runtime 8.0, from Parquet to Delta", - start_line=span.start_line, - start_col=span.start_col, - end_line=span.end_line, - end_col=span.end_col, + start_line=span.start.line, + start_col=span.start.character, + end_line=span.end.line, + end_col=span.end.character, ) diff --git a/tests/unit/source_code/test_python_ast_util.py b/tests/unit/source_code/test_python_ast_util.py deleted file mode 100644 index eae838e254..0000000000 --- a/tests/unit/source_code/test_python_ast_util.py +++ /dev/null @@ -1,101 +0,0 @@ -from __future__ import annotations - -import ast -import pytest - -from databricks.labs.ucx.source_code.python_ast_util import AstUtil - - -def get_statement_node(stmt: str) -> ast.stmt: - node = ast.parse(stmt) - return node.body[0] - - -@pytest.mark.parametrize("stmt", ["o.m1().m2().m3()", "a = o.m1().m2().m3()"]) -def test_extract_callchain(migration_index, stmt): - node = get_statement_node(stmt) - act = AstUtil.extract_callchain(node) - assert isinstance(act, ast.Call) - assert isinstance(act.func, ast.Attribute) - assert act.func.attr == "m3" - - -@pytest.mark.parametrize("stmt", ["a = 3", "[x+1 for x in xs]"]) -def test_extract_callchain_none(migration_index, stmt): - node = get_statement_node(stmt) - act = AstUtil.extract_callchain(node) - assert act is None - - -def test_extract_call_by_name(migration_index): - callchain = get_statement_node("o.m1().m2().m3()").value - act = AstUtil.extract_call_by_name(callchain, "m2") - assert isinstance(act, ast.Call) - assert isinstance(act.func, ast.Attribute) - assert act.func.attr == "m2" - - -def test_extract_call_by_name_none(migration_index): - callchain = get_statement_node("o.m1().m2().m3()").value - act = AstUtil.extract_call_by_name(callchain, "m5000") - assert act is None - - -@pytest.mark.parametrize( - "param", - [ - {"stmt": "o.m1()", "arg_index": 1, "arg_name": "second", "expected": None}, - {"stmt": "o.m1(3)", "arg_index": 1, "arg_name": "second", "expected": None}, - {"stmt": "o.m1(first=3)", "arg_index": 1, "arg_name": "second", "expected": None}, - {"stmt": "o.m1(4, 3)", "arg_index": None, "arg_name": None, "expected": None}, - {"stmt": "o.m1(4, 3)", "arg_index": None, "arg_name": "second", "expected": None}, - {"stmt": "o.m1(4, 3)", "arg_index": 1, "arg_name": "second", "expected": 3}, - {"stmt": "o.m1(4, 3)", "arg_index": 1, "arg_name": None, "expected": 3}, - {"stmt": "o.m1(first=4, second=3)", "arg_index": 1, "arg_name": "second", "expected": 3}, - {"stmt": "o.m1(second=3, first=4)", "arg_index": 1, "arg_name": "second", "expected": 3}, - {"stmt": "o.m1(second=3, first=4)", "arg_index": None, "arg_name": "second", "expected": 3}, - {"stmt": "o.m1(second=3)", "arg_index": 1, "arg_name": "second", "expected": 3}, - {"stmt": "o.m1(4, 3, 2)", "arg_index": 1, "arg_name": "second", "expected": 3}, - ], -) -def test_get_arg(migration_index, param): - call = get_statement_node(param["stmt"]).value - act = AstUtil.get_arg(call, param["arg_index"], param["arg_name"]) - if param["expected"] is None: - assert act is None - else: - assert isinstance(act, ast.Constant) - assert act.value == param["expected"] - - -@pytest.mark.parametrize( - "param", - [ - {"stmt": "o.m1()", "expected": 0}, - {"stmt": "o.m1(3)", "expected": 1}, - {"stmt": "o.m1(first=3)", "expected": 1}, - {"stmt": "o.m1(3, 3)", "expected": 2}, - {"stmt": "o.m1(first=3, second=3)", "expected": 2}, - {"stmt": "o.m1(3, second=3)", "expected": 2}, - {"stmt": "o.m1(3, *b, **c, second=3)", "expected": 4}, - ], -) -def test_args_count(migration_index, param): - call = get_statement_node(param["stmt"]).value - act = AstUtil.args_count(call) - assert param["expected"] == act - - -@pytest.mark.parametrize( - "param", - [ - {"stmt": "a = x", "expected": False}, - {"stmt": "a = 3", "expected": False}, - {"stmt": "a = 'None'", "expected": False}, - {"stmt": "a = None", "expected": True}, - ], -) -def test_is_none(migration_index, param): - val = get_statement_node(param["stmt"]).value - act = AstUtil.is_none(val) - assert param["expected"] == act diff --git a/tests/unit/source_code/test_python_linter.py b/tests/unit/source_code/test_python_linter.py index 8c987517ea..ec527f1e13 100644 --- a/tests/unit/source_code/test_python_linter.py +++ b/tests/unit/source_code/test_python_linter.py @@ -1,3 +1,8 @@ +from __future__ import annotations + +import ast +import pytest + from databricks.labs.ucx.source_code.python_linter import ASTLinter, PythonLinter @@ -127,3 +132,98 @@ def test_linter_returns_appended_relative_paths_with_os_path_abspath_alias(): linter = ASTLinter.parse(code) appended = PythonLinter.list_appended_sys_paths(linter) assert "relative_path" in [p.path for p in appended] + + +def get_statement_node(stmt: str) -> ast.stmt: + node = ast.parse(stmt) + return node.body[0] + + +@pytest.mark.parametrize("stmt", ["o.m1().m2().m3()", "a = o.m1().m2().m3()"]) +def test_extract_callchain(migration_index, stmt): + node = get_statement_node(stmt) + act = ASTLinter(node).extract_callchain() + assert isinstance(act, ast.Call) + assert isinstance(act.func, ast.Attribute) + assert act.func.attr == "m3" + + +@pytest.mark.parametrize("stmt", ["a = 3", "[x+1 for x in xs]"]) +def test_extract_callchain_none(migration_index, stmt): + node = get_statement_node(stmt) + act = ASTLinter(node).extract_callchain() + assert act is None + + +def test_extract_call_by_name(migration_index): + callchain = get_statement_node("o.m1().m2().m3()").value + act = ASTLinter(callchain).extract_call_by_name("m2") + assert isinstance(act, ast.Call) + assert isinstance(act.func, ast.Attribute) + assert act.func.attr == "m2" + + +def test_extract_call_by_name_none(migration_index): + callchain = get_statement_node("o.m1().m2().m3()").value + act = ASTLinter(callchain).extract_call_by_name("m5000") + assert act is None + + +@pytest.mark.parametrize( + "param", + [ + {"stmt": "o.m1()", "arg_index": 1, "arg_name": "second", "expected": None}, + {"stmt": "o.m1(3)", "arg_index": 1, "arg_name": "second", "expected": None}, + {"stmt": "o.m1(first=3)", "arg_index": 1, "arg_name": "second", "expected": None}, + {"stmt": "o.m1(4, 3)", "arg_index": None, "arg_name": None, "expected": None}, + {"stmt": "o.m1(4, 3)", "arg_index": None, "arg_name": "second", "expected": None}, + {"stmt": "o.m1(4, 3)", "arg_index": 1, "arg_name": "second", "expected": 3}, + {"stmt": "o.m1(4, 3)", "arg_index": 1, "arg_name": None, "expected": 3}, + {"stmt": "o.m1(first=4, second=3)", "arg_index": 1, "arg_name": "second", "expected": 3}, + {"stmt": "o.m1(second=3, first=4)", "arg_index": 1, "arg_name": "second", "expected": 3}, + {"stmt": "o.m1(second=3, first=4)", "arg_index": None, "arg_name": "second", "expected": 3}, + {"stmt": "o.m1(second=3)", "arg_index": 1, "arg_name": "second", "expected": 3}, + {"stmt": "o.m1(4, 3, 2)", "arg_index": 1, "arg_name": "second", "expected": 3}, + ], +) +def test_get_arg(migration_index, param): + call = get_statement_node(param["stmt"]).value + act = ASTLinter(call).get_arg(param["arg_index"], param["arg_name"]) + if param["expected"] is None: + assert act is None + else: + assert isinstance(act, ast.Constant) + assert act.value == param["expected"] + + +@pytest.mark.parametrize( + "param", + [ + {"stmt": "o.m1()", "expected": 0}, + {"stmt": "o.m1(3)", "expected": 1}, + {"stmt": "o.m1(first=3)", "expected": 1}, + {"stmt": "o.m1(3, 3)", "expected": 2}, + {"stmt": "o.m1(first=3, second=3)", "expected": 2}, + {"stmt": "o.m1(3, second=3)", "expected": 2}, + {"stmt": "o.m1(3, *b, **c, second=3)", "expected": 4}, + ], +) +def test_args_count(migration_index, param): + call = get_statement_node(param["stmt"]).value + act = ASTLinter(call).args_count() + assert param["expected"] == act + + +@pytest.mark.parametrize( + "param", + [ + {"stmt": "a = x", "expected": False}, + {"stmt": "a = 3", "expected": False}, + {"stmt": "a = 'None'", "expected": False}, + {"stmt": "a = None", "expected": True}, + ], +) +def test_is_none(migration_index, param): + val = get_statement_node(param["stmt"]).value + act = ASTLinter(val).is_none() + assert param["expected"] == act