Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for attribute lookup in rule parsing #16948

Merged
merged 1 commit into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions src/python/pants/core/goals/lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,12 +405,6 @@ def _get_error_code(results: Sequence[LintResult]) -> int:
return 0


# TODO(16868): Rule parser requires arguments to be values in module scope
_LintTargetsPartitionRequest = LintTargetsRequest.PartitionRequest
_LintFilesPartitionRequest = LintFilesRequest.PartitionRequest
_LintSubPartition = LintRequest.SubPartition


@goal_rule
async def lint(
console: Console,
Expand Down Expand Up @@ -512,7 +506,7 @@ def partition_request_get(
lint_targets_request_type = cast("type[LintTargetsRequest]", request_type)
return Get(
Partitions,
_LintTargetsPartitionRequest,
LintTargetsRequest.PartitionRequest,
lint_targets_request_type.PartitionRequest(
tuple(
lint_targets_request_type.field_set_type.create(target)
Expand All @@ -525,7 +519,7 @@ def partition_request_get(
assert partition_request_type in file_partitioners
return Get(
Partitions,
_LintFilesPartitionRequest,
LintFilesRequest.PartitionRequest,
cast("type[LintFilesRequest]", request_type).PartitionRequest(specs_paths.files),
)

Expand Down Expand Up @@ -589,7 +583,7 @@ def partition_request_get(
)

all_requests = [
*(Get(LintResult, _LintSubPartition, request) for request in lint_batches),
*(Get(LintResult, LintRequest.SubPartition, request) for request in lint_batches),
*(Get(LintResult, FmtTargetsRequest, request) for request in fmt_target_requests),
*(Get(LintResult, _FmtBuildFilesRequest, request) for request in fmt_build_requests),
]
Expand Down
132 changes: 48 additions & 84 deletions src/python/pants/engine/internals/rule_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
import itertools
import logging
import sys
import types
from functools import partial
from typing import Callable, List, Sequence, cast
from typing import Any, Callable, List

from pants.engine.internals.selectors import AwaitableConstraints, GetParseError
from pants.util.memo import memoized
Expand All @@ -29,18 +28,6 @@ def _get_starting_indent(source: str) -> int:
return 0


def _get_lookup_names(attr: ast.expr) -> list[str]:
names = []
while isinstance(attr, ast.Attribute):
names.append(attr.attr)
attr = attr.value
# NB: attr could be a constant, like `",".join()`
id = getattr(attr, "id", None)
if id is not None:
names.append(id)
return names


class _AwaitableCollector(ast.NodeVisitor):
def __init__(self, func: Callable):
self.func = func
Expand All @@ -55,41 +42,47 @@ def __init__(self, func: Callable):
self.awaitables: List[AwaitableConstraints] = []
self.visit(ast.parse(source))

def _expect_dict_value_is_name(
self, call_node_name: ast.Name, call_node_args: Sequence[ast.expr], value_expr: ast.expr
) -> ast.Name:
lineno = value_expr.lineno + self.func.__code__.co_firstlineno - 1
if not isinstance(value_expr, ast.Name):
raise GetParseError(
f"All values of the input dict should be literal type names, but got "
f"{value_expr} (type `{type(value_expr).__name__}`)`) "
f"in {self.source_file}:{lineno}",
get_args=call_node_args,
source_file_name=(self.source_file or "<none>"),
)
return value_expr

def _resolve_constraint_arg_type(self, name: str, lineno: int) -> type:
lineno += self.func.__code__.co_firstlineno - 1
resolved = (
getattr(self.owning_module, name, None)
or self.owning_module.__builtins__.get(name, None)
) # fmt: skip
def _lookup(self, attr: ast.expr) -> Any:
names = []
while isinstance(attr, ast.Attribute):
names.append(attr.attr)
attr = attr.value
# NB: attr could be a constant, like `",".join()`
id = getattr(attr, "id", None)
if id is not None:
names.append(id)

if not names:
return attr

name = names.pop()
result = (
getattr(self.owning_module, name)
if hasattr(self.owning_module, name)
else self.owning_module.__builtins__.get(name, None)
)
while result is not None and names:
result = getattr(result, names.pop(), None)

return result

def _check_constraint_arg_type(self, resolved: Any, node: ast.AST) -> type:
lineno = node.lineno + self.func.__code__.co_firstlineno - 1
if resolved is None:
raise ValueError(
f"Could not resolve type `{name}` in top level of module "
f"Could not resolve type `{node}` in top level of module "
f"{self.owning_module.__name__} defined in {self.source_file}:{lineno}"
)
elif not isinstance(resolved, type):
raise ValueError(
f"Expected a `type` constructor for `{name}`, but got: {resolved} (type "
f"`{type(resolved).__name__}`) in {self.source_file}:{lineno}"
f"Expected a `type`, but got: {resolved}"
+ f" (type `{type(resolved).__name__}`) in {self.source_file}:{lineno}"
)
return resolved

def _get_awaitable(self, call_node: ast.Call) -> AwaitableConstraints:
assert isinstance(call_node.func, ast.Name)
is_effect = call_node.func.id == "Effect"
func = self._lookup(call_node.func)
is_effect = func.__name__ == "Effect"
get_args = call_node.args
parse_error = partial(GetParseError, get_args=get_args, source_file_name=self.source_file)

Expand All @@ -98,53 +91,29 @@ def _get_awaitable(self, call_node: ast.Call) -> AwaitableConstraints:
f"Expected either two or three arguments, but got {len(get_args)} arguments."
)

output_expr = get_args[0]
if not isinstance(output_expr, ast.Name):
raise parse_error(
"The first argument should be the output type, like `Digest` or `ProcessResult`."
)
output_type = output_expr
output_node = get_args[0]
output_type = self._lookup(output_node)

input_args = get_args[1:]
input_types: List[ast.Name]
if len(input_args) == 1:
input_constructor = input_args[0]
input_nodes = get_args[1:]
input_types: List[Any]
if len(input_nodes) == 1:
input_constructor = input_nodes[0]
if isinstance(input_constructor, ast.Call):
if not isinstance(input_constructor.func, ast.Name):
raise parse_error(
f"Because you are using the shorthand form {call_node.func.id}(OutputType, "
"InputType(constructor args)), the second argument should be a top-level "
"constructor function call, like `MergeDigest(...)` or `Process(...)`, rather "
"than a method call."
Comment on lines -114 to -118
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These error messages get worse, but 🤷‍♂️ It'll still list the incorrect value and file/line.

)
input_types = [input_constructor.func]
input_nodes = [input_constructor.func]
input_types = [self._lookup(input_constructor.func)]
elif isinstance(input_constructor, ast.Dict):
input_types = [
self._expect_dict_value_is_name(call_node.func, get_args, v)
for v in input_constructor.values
]
input_nodes = input_constructor.values
input_types = [self._lookup(v) for v in input_constructor.values]
else:
raise parse_error(
f"Because you are using the two-argument form {call_node.func.id}(OutputType, "
"$input), the $input argument should either be a "
"constructor call, like `MergeDigest(...)` or `Process(...)`, or a dict "
"literal mapping inputs to their declared types, like "
"`{merge_digest: MergeDigest}`."
)
input_types = [self._lookup(n) for n in input_nodes]
else:
if not isinstance(input_args[0], ast.Name):
raise parse_error(
f"Because you are using the longhand form {call_node.func.id}(OutputType, "
"InputType, input), the second argument should be a type, like `MergeDigests` or "
"`Process`."
)
input_types = [input_args[0]]
input_types = [self._lookup(input_nodes[0])]

return AwaitableConstraints(
self._resolve_constraint_arg_type(output_type.id, output_type.lineno),
self._check_constraint_arg_type(output_type, output_node),
tuple(
self._resolve_constraint_arg_type(input_type.id, input_type.lineno)
for input_type in input_types
self._check_constraint_arg_type(input_type, input_node)
for input_type, input_node in zip(input_types, input_nodes)
),
is_effect,
)
Expand All @@ -153,12 +122,7 @@ def visit_Call(self, call_node: ast.Call) -> None:
if _is_awaitable_constraint(call_node):
self.awaitables.append(self._get_awaitable(call_node))
else:
func_node = call_node.func
lookup_names = _get_lookup_names(func_node)
attr = cast(types.FunctionType, self.func).__globals__.get(lookup_names.pop(), None)
while attr is not None and lookup_names:
attr = getattr(attr, lookup_names.pop(), None)

attr = self._lookup(call_node.func)
if hasattr(attr, "rule_helper"):
self.awaitables.extend(collect_awaitables(attr))

Expand Down
51 changes: 43 additions & 8 deletions src/python/pants/engine/internals/rule_visitor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ async def _static_helper():
container_instance = HelperContainer()


class InnerScope:
STR = str
INT = int
BOOL = bool

HelperContainer = HelperContainer
container_instance = container_instance


def assert_awaitables(func, awaitable_types: Iterable[tuple[type | list[type], type]]):
gets = collect_awaitables(func)
actual_types = tuple((list(get.input_types), get.output_type) for get in gets)
Expand Down Expand Up @@ -90,6 +99,14 @@ async def rule():
assert_awaitables(rule, [(int, str), (str, int)])


def test_attribute_lookup() -> None:
async def rule1():
await Get(InnerScope.STR, InnerScope.INT, 42)
await Get(InnerScope.STR, InnerScope.INT(42))

assert_awaitables(rule1, [(int, str), (int, str)])


def test_get_no_index_call_no_subject_call_allowed() -> None:
async def rule():
get_type: type = Get # noqa: F841
Expand All @@ -108,18 +125,36 @@ def test_rule_helpers_class_methods() -> None:
async def rule1():
HelperContainer()._static_helper(1)

# Rule helpers must be called via module-scoped attributes
assert_awaitables(rule1, [])
async def rule1_inner():
InnerScope.HelperContainer()._static_helper(1)

async def rule2():
container_instance._static_helper(1)
HelperContainer._static_helper(1)

assert_awaitables(rule2, [(int, str), (str, int)])
async def rule2_inner():
InnerScope.HelperContainer._static_helper(1)

async def rule3():
container_instance._static_helper(1)

async def rule3_inner():
InnerScope.container_instance._static_helper(1)

async def rule4():
container_instance._method_helper(1)

assert_awaitables(rule3, [(int, str)])
async def rule4_inner():
InnerScope.container_instance._method_helper(1)

# Rule helpers must be called via module-scoped attribute lookup
assert_awaitables(rule1, [])
assert_awaitables(rule1_inner, [])
assert_awaitables(rule2, [(int, str), (str, int)])
assert_awaitables(rule2_inner, [(int, str), (str, int)])
assert_awaitables(rule3, [(int, str), (str, int)])
assert_awaitables(rule3_inner, [(int, str), (str, int)])
assert_awaitables(rule4, [(int, str)])
assert_awaitables(rule4_inner, [(int, str)])


def test_valid_get_unresolvable_product_type() -> None:
Expand Down Expand Up @@ -160,21 +195,21 @@ def test_invalid_get_invalid_subject_arg_no_constructor_call() -> None:
async def rule():
Get(STR, "bob")

with pytest.raises(GetParseError):
with pytest.raises(ValueError):
collect_awaitables(rule)


def test_invalid_get_invalid_product_type_not_a_type_name() -> None:
async def rule():
Get(call(), STR("bob")) # noqa: F821

with pytest.raises(GetParseError):
with pytest.raises(ValueError):
collect_awaitables(rule)


def test_invalid_get_dict_value_not_type() -> None:
async def rule():
Get(int, {"str": "not a type"})

with pytest.raises(GetParseError):
with pytest.raises(ValueError):
collect_awaitables(rule)