diff --git a/src/python/pants/core/goals/lint.py b/src/python/pants/core/goals/lint.py index 008c7f77647..f1ddb347fee 100644 --- a/src/python/pants/core/goals/lint.py +++ b/src/python/pants/core/goals/lint.py @@ -405,12 +405,6 @@ def _get_error_code(results: Sequence[LintResult]) -> int: return 0 -# TODO(16868): Rule parser requires arguments to be values in module scope -_LintTargetsPartitionRequest = LintTargetsRequest.PartitionRequest -_LintFilesPartitionRequest = LintFilesRequest.PartitionRequest -_LintSubPartition = LintRequest.SubPartition - - @goal_rule async def lint( console: Console, @@ -512,7 +506,7 @@ def partition_request_get( lint_targets_request_type = cast("type[LintTargetsRequest]", request_type) return Get( Partitions, - _LintTargetsPartitionRequest, + LintTargetsRequest.PartitionRequest, lint_targets_request_type.PartitionRequest( tuple( lint_targets_request_type.field_set_type.create(target) @@ -525,7 +519,7 @@ def partition_request_get( assert partition_request_type in file_partitioners return Get( Partitions, - _LintFilesPartitionRequest, + LintFilesRequest.PartitionRequest, cast("type[LintFilesRequest]", request_type).PartitionRequest(specs_paths.files), ) @@ -589,7 +583,7 @@ def partition_request_get( ) all_requests = [ - *(Get(LintResult, _LintSubPartition, request) for request in lint_batches), + *(Get(LintResult, LintRequest.SubPartition, request) for request in lint_batches), *(Get(LintResult, FmtTargetsRequest, request) for request in fmt_target_requests), *(Get(LintResult, _FmtBuildFilesRequest, request) for request in fmt_build_requests), ] diff --git a/src/python/pants/engine/internals/rule_visitor.py b/src/python/pants/engine/internals/rule_visitor.py index 2149653ebdf..1498a37e678 100644 --- a/src/python/pants/engine/internals/rule_visitor.py +++ b/src/python/pants/engine/internals/rule_visitor.py @@ -7,9 +7,8 @@ import itertools import logging import sys -import types from functools import partial -from typing import Callable, List, Sequence, cast +from typing import Any, Callable, List from pants.engine.internals.selectors import AwaitableConstraints, GetParseError from pants.util.memo import memoized @@ -29,18 +28,6 @@ def _get_starting_indent(source: str) -> int: return 0 -def _get_lookup_names(attr: ast.expr) -> list[str]: - names = [] - while isinstance(attr, ast.Attribute): - names.append(attr.attr) - attr = attr.value - # NB: attr could be a constant, like `",".join()` - id = getattr(attr, "id", None) - if id is not None: - names.append(id) - return names - - class _AwaitableCollector(ast.NodeVisitor): def __init__(self, func: Callable): self.func = func @@ -55,41 +42,47 @@ def __init__(self, func: Callable): self.awaitables: List[AwaitableConstraints] = [] self.visit(ast.parse(source)) - def _expect_dict_value_is_name( - self, call_node_name: ast.Name, call_node_args: Sequence[ast.expr], value_expr: ast.expr - ) -> ast.Name: - lineno = value_expr.lineno + self.func.__code__.co_firstlineno - 1 - if not isinstance(value_expr, ast.Name): - raise GetParseError( - f"All values of the input dict should be literal type names, but got " - f"{value_expr} (type `{type(value_expr).__name__}`)`) " - f"in {self.source_file}:{lineno}", - get_args=call_node_args, - source_file_name=(self.source_file or ""), - ) - return value_expr - - def _resolve_constraint_arg_type(self, name: str, lineno: int) -> type: - lineno += self.func.__code__.co_firstlineno - 1 - resolved = ( - getattr(self.owning_module, name, None) - or self.owning_module.__builtins__.get(name, None) - ) # fmt: skip + def _lookup(self, attr: ast.expr) -> Any: + names = [] + while isinstance(attr, ast.Attribute): + names.append(attr.attr) + attr = attr.value + # NB: attr could be a constant, like `",".join()` + id = getattr(attr, "id", None) + if id is not None: + names.append(id) + + if not names: + return attr + + name = names.pop() + result = ( + getattr(self.owning_module, name) + if hasattr(self.owning_module, name) + else self.owning_module.__builtins__.get(name, None) + ) + while result is not None and names: + result = getattr(result, names.pop(), None) + + return result + + def _check_constraint_arg_type(self, resolved: Any, node: ast.AST) -> type: + lineno = node.lineno + self.func.__code__.co_firstlineno - 1 if resolved is None: raise ValueError( - f"Could not resolve type `{name}` in top level of module " + f"Could not resolve type `{node}` in top level of module " f"{self.owning_module.__name__} defined in {self.source_file}:{lineno}" ) elif not isinstance(resolved, type): raise ValueError( - f"Expected a `type` constructor for `{name}`, but got: {resolved} (type " - f"`{type(resolved).__name__}`) in {self.source_file}:{lineno}" + f"Expected a `type`, but got: {resolved}" + + f" (type `{type(resolved).__name__}`) in {self.source_file}:{lineno}" ) return resolved def _get_awaitable(self, call_node: ast.Call) -> AwaitableConstraints: - assert isinstance(call_node.func, ast.Name) - is_effect = call_node.func.id == "Effect" + func = self._lookup(call_node.func) + is_effect = func.__name__ == "Effect" get_args = call_node.args parse_error = partial(GetParseError, get_args=get_args, source_file_name=self.source_file) @@ -98,53 +91,29 @@ def _get_awaitable(self, call_node: ast.Call) -> AwaitableConstraints: f"Expected either two or three arguments, but got {len(get_args)} arguments." ) - output_expr = get_args[0] - if not isinstance(output_expr, ast.Name): - raise parse_error( - "The first argument should be the output type, like `Digest` or `ProcessResult`." - ) - output_type = output_expr + output_node = get_args[0] + output_type = self._lookup(output_node) - input_args = get_args[1:] - input_types: List[ast.Name] - if len(input_args) == 1: - input_constructor = input_args[0] + input_nodes = get_args[1:] + input_types: List[Any] + if len(input_nodes) == 1: + input_constructor = input_nodes[0] if isinstance(input_constructor, ast.Call): - if not isinstance(input_constructor.func, ast.Name): - raise parse_error( - f"Because you are using the shorthand form {call_node.func.id}(OutputType, " - "InputType(constructor args)), the second argument should be a top-level " - "constructor function call, like `MergeDigest(...)` or `Process(...)`, rather " - "than a method call." - ) - input_types = [input_constructor.func] + input_nodes = [input_constructor.func] + input_types = [self._lookup(input_constructor.func)] elif isinstance(input_constructor, ast.Dict): - input_types = [ - self._expect_dict_value_is_name(call_node.func, get_args, v) - for v in input_constructor.values - ] + input_nodes = input_constructor.values + input_types = [self._lookup(v) for v in input_constructor.values] else: - raise parse_error( - f"Because you are using the two-argument form {call_node.func.id}(OutputType, " - "$input), the $input argument should either be a " - "constructor call, like `MergeDigest(...)` or `Process(...)`, or a dict " - "literal mapping inputs to their declared types, like " - "`{merge_digest: MergeDigest}`." - ) + input_types = [self._lookup(n) for n in input_nodes] else: - if not isinstance(input_args[0], ast.Name): - raise parse_error( - f"Because you are using the longhand form {call_node.func.id}(OutputType, " - "InputType, input), the second argument should be a type, like `MergeDigests` or " - "`Process`." - ) - input_types = [input_args[0]] + input_types = [self._lookup(input_nodes[0])] return AwaitableConstraints( - self._resolve_constraint_arg_type(output_type.id, output_type.lineno), + self._check_constraint_arg_type(output_type, output_node), tuple( - self._resolve_constraint_arg_type(input_type.id, input_type.lineno) - for input_type in input_types + self._check_constraint_arg_type(input_type, input_node) + for input_type, input_node in zip(input_types, input_nodes) ), is_effect, ) @@ -153,12 +122,7 @@ def visit_Call(self, call_node: ast.Call) -> None: if _is_awaitable_constraint(call_node): self.awaitables.append(self._get_awaitable(call_node)) else: - func_node = call_node.func - lookup_names = _get_lookup_names(func_node) - attr = cast(types.FunctionType, self.func).__globals__.get(lookup_names.pop(), None) - while attr is not None and lookup_names: - attr = getattr(attr, lookup_names.pop(), None) - + attr = self._lookup(call_node.func) if hasattr(attr, "rule_helper"): self.awaitables.extend(collect_awaitables(attr)) diff --git a/src/python/pants/engine/internals/rule_visitor_test.py b/src/python/pants/engine/internals/rule_visitor_test.py index 53e3ce3b90e..b1c53a6787b 100644 --- a/src/python/pants/engine/internals/rule_visitor_test.py +++ b/src/python/pants/engine/internals/rule_visitor_test.py @@ -43,6 +43,15 @@ async def _static_helper(): container_instance = HelperContainer() +class InnerScope: + STR = str + INT = int + BOOL = bool + + HelperContainer = HelperContainer + container_instance = container_instance + + def assert_awaitables(func, awaitable_types: Iterable[tuple[type | list[type], type]]): gets = collect_awaitables(func) actual_types = tuple((list(get.input_types), get.output_type) for get in gets) @@ -90,6 +99,14 @@ async def rule(): assert_awaitables(rule, [(int, str), (str, int)]) +def test_attribute_lookup() -> None: + async def rule1(): + await Get(InnerScope.STR, InnerScope.INT, 42) + await Get(InnerScope.STR, InnerScope.INT(42)) + + assert_awaitables(rule1, [(int, str), (int, str)]) + + def test_get_no_index_call_no_subject_call_allowed() -> None: async def rule(): get_type: type = Get # noqa: F841 @@ -108,18 +125,36 @@ def test_rule_helpers_class_methods() -> None: async def rule1(): HelperContainer()._static_helper(1) - # Rule helpers must be called via module-scoped attributes - assert_awaitables(rule1, []) + async def rule1_inner(): + InnerScope.HelperContainer()._static_helper(1) async def rule2(): - container_instance._static_helper(1) + HelperContainer._static_helper(1) - assert_awaitables(rule2, [(int, str), (str, int)]) + async def rule2_inner(): + InnerScope.HelperContainer._static_helper(1) async def rule3(): + container_instance._static_helper(1) + + async def rule3_inner(): + InnerScope.container_instance._static_helper(1) + + async def rule4(): container_instance._method_helper(1) - assert_awaitables(rule3, [(int, str)]) + async def rule4_inner(): + InnerScope.container_instance._method_helper(1) + + # Rule helpers must be called via module-scoped attribute lookup + assert_awaitables(rule1, []) + assert_awaitables(rule1_inner, []) + assert_awaitables(rule2, [(int, str), (str, int)]) + assert_awaitables(rule2_inner, [(int, str), (str, int)]) + assert_awaitables(rule3, [(int, str), (str, int)]) + assert_awaitables(rule3_inner, [(int, str), (str, int)]) + assert_awaitables(rule4, [(int, str)]) + assert_awaitables(rule4_inner, [(int, str)]) def test_valid_get_unresolvable_product_type() -> None: @@ -160,7 +195,7 @@ def test_invalid_get_invalid_subject_arg_no_constructor_call() -> None: async def rule(): Get(STR, "bob") - with pytest.raises(GetParseError): + with pytest.raises(ValueError): collect_awaitables(rule) @@ -168,7 +203,7 @@ def test_invalid_get_invalid_product_type_not_a_type_name() -> None: async def rule(): Get(call(), STR("bob")) # noqa: F821 - with pytest.raises(GetParseError): + with pytest.raises(ValueError): collect_awaitables(rule) @@ -176,5 +211,5 @@ def test_invalid_get_dict_value_not_type() -> None: async def rule(): Get(int, {"str": "not a type"}) - with pytest.raises(GetParseError): + with pytest.raises(ValueError): collect_awaitables(rule)