Skip to content

Commit

Permalink
Support overriding param types for rule code (#16929)
Browse files Browse the repository at this point in the history
The main use-case is runtime-typed parameters (for use in "helper"/"inner" rules).

```python

def make_rule(sometype: type[BaseType]):
    @rule(param_type_overrides={"request": sometype.var_type})
    async def helper_rule(request: WhateverVarTypeShouldBe) -> str:
        ...

    return helper_rule
```

[ci skip-rust]
[ci skip-build-wheels]
  • Loading branch information
thejcannon authored Sep 22, 2022
1 parent c64579c commit 1a25f0e
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
25 changes: 23 additions & 2 deletions src/python/pants/engine/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,16 @@ def _ensure_type_annotation(


PUBLIC_RULE_DECORATOR_ARGUMENTS = {"canonical_name", "desc", "level"}
# We aren't sure if these'll stick around or be removed at some point, so they are "private"
# and should only be used in Pants' codebase.
PRIVATE_RULE_DECORATOR_ARGUMENTS = {
# Allows callers to override the type Pants will use for the params listed.
#
# It is assumed (but not enforced) that the provided type is a subclass of the annotated type.
# (We assume but not enforce since this is likely to be used with unions, which has the same
# assumption between the union base and its members).
"_param_type_overrides",
}
# We don't want @rule-writers to use 'rule_type' or 'cacheable' as kwargs directly,
# but rather set them implicitly based on the rule annotation.
# So we leave it out of PUBLIC_RULE_DECORATOR_ARGUMENTS.
Expand All @@ -180,6 +190,7 @@ def rule_decorator(func, **kwargs) -> Callable:
len(
set(kwargs)
- PUBLIC_RULE_DECORATOR_ARGUMENTS
- PRIVATE_RULE_DECORATOR_ARGUMENTS
- IMPLICIT_PRIVATE_RULE_DECORATOR_ARGUMENTS
)
!= 0
Expand All @@ -190,6 +201,7 @@ def rule_decorator(func, **kwargs) -> Callable:

rule_type: RuleType = kwargs["rule_type"]
cacheable: bool = kwargs["cacheable"]
param_type_overrides: dict[str, type] = kwargs.get("_param_type_overrides", {})

func_id = f"@rule {func.__module__}:{func.__name__}"
type_hints = get_type_hints(func)
Expand All @@ -198,13 +210,22 @@ def rule_decorator(func, **kwargs) -> Callable:
name=f"{func_id} return",
raise_type=MissingReturnTypeAnnotation,
)

func_params = inspect.signature(func).parameters
for parameter in param_type_overrides:
if parameter not in func_params:
raise ValueError(
f"Unknown parameter name in `param_type_overrides`: {parameter}."
+ f" Parameter names: '{', '.join(func_params)}'"
)

parameter_types = tuple(
_ensure_type_annotation(
type_annotation=type_hints.get(parameter),
type_annotation=param_type_overrides.get(parameter, type_hints.get(parameter)),
name=f"{func_id} parameter {parameter}",
raise_type=MissingParameterTypeAnnotation,
)
for parameter in inspect.signature(func).parameters
for parameter in func_params
)
is_goal_cls = issubclass(return_type, Goal)

Expand Down
22 changes: 22 additions & 0 deletions src/python/pants/engine/rules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,28 @@ async def dup_a() -> B: # noqa: F811
return B()


def test_param_type_overrides() -> None:
type1 = int # use a runtime type

@rule(_param_type_overrides={"param1": type1, "param2": dict})
async def dont_injure_humans(param1: str, param2, param3: list) -> A:
return A()

assert dont_injure_humans.rule.input_selectors == (int, dict, list)

with pytest.raises(ValueError, match="paramX"):

@rule(_param_type_overrides={"paramX": int})
async def obey_human_orders() -> A:
return A()

with pytest.raises(MissingParameterTypeAnnotation, match="must be a type"):

@rule(_param_type_overrides={"param1": "A string"})
async def protect_existence(param1) -> A:
return A()


def test_invalid_rule_helper_name() -> None:
with pytest.raises(ValueError, match="must be private"):

Expand Down

0 comments on commit 1a25f0e

Please sign in to comment.