From 1a25f0ec5a4bb975dbbc153b2b05ba5d51e07907 Mon Sep 17 00:00:00 2001 From: Joshua Cannon Date: Thu, 22 Sep 2022 15:10:10 -0500 Subject: [PATCH] Support overriding param types for rule code (#16929) The main use-case is runtime-typed parameters (for use in "helper"/"inner" rules). ```python def make_rule(sometype: type[BaseType]): @rule(param_type_overrides={"request": sometype.var_type}) async def helper_rule(request: WhateverVarTypeShouldBe) -> str: ... return helper_rule ``` [ci skip-rust] [ci skip-build-wheels] --- src/python/pants/engine/rules.py | 25 +++++++++++++++++++++++-- src/python/pants/engine/rules_test.py | 22 ++++++++++++++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/src/python/pants/engine/rules.py b/src/python/pants/engine/rules.py index 17537c57305..1cd7c2543b5 100644 --- a/src/python/pants/engine/rules.py +++ b/src/python/pants/engine/rules.py @@ -163,6 +163,16 @@ def _ensure_type_annotation( PUBLIC_RULE_DECORATOR_ARGUMENTS = {"canonical_name", "desc", "level"} +# We aren't sure if these'll stick around or be removed at some point, so they are "private" +# and should only be used in Pants' codebase. +PRIVATE_RULE_DECORATOR_ARGUMENTS = { + # Allows callers to override the type Pants will use for the params listed. + # + # It is assumed (but not enforced) that the provided type is a subclass of the annotated type. + # (We assume but not enforce since this is likely to be used with unions, which has the same + # assumption between the union base and its members). + "_param_type_overrides", +} # We don't want @rule-writers to use 'rule_type' or 'cacheable' as kwargs directly, # but rather set them implicitly based on the rule annotation. # So we leave it out of PUBLIC_RULE_DECORATOR_ARGUMENTS. @@ -180,6 +190,7 @@ def rule_decorator(func, **kwargs) -> Callable: len( set(kwargs) - PUBLIC_RULE_DECORATOR_ARGUMENTS + - PRIVATE_RULE_DECORATOR_ARGUMENTS - IMPLICIT_PRIVATE_RULE_DECORATOR_ARGUMENTS ) != 0 @@ -190,6 +201,7 @@ def rule_decorator(func, **kwargs) -> Callable: rule_type: RuleType = kwargs["rule_type"] cacheable: bool = kwargs["cacheable"] + param_type_overrides: dict[str, type] = kwargs.get("_param_type_overrides", {}) func_id = f"@rule {func.__module__}:{func.__name__}" type_hints = get_type_hints(func) @@ -198,13 +210,22 @@ def rule_decorator(func, **kwargs) -> Callable: name=f"{func_id} return", raise_type=MissingReturnTypeAnnotation, ) + + func_params = inspect.signature(func).parameters + for parameter in param_type_overrides: + if parameter not in func_params: + raise ValueError( + f"Unknown parameter name in `param_type_overrides`: {parameter}." + + f" Parameter names: '{', '.join(func_params)}'" + ) + parameter_types = tuple( _ensure_type_annotation( - type_annotation=type_hints.get(parameter), + type_annotation=param_type_overrides.get(parameter, type_hints.get(parameter)), name=f"{func_id} parameter {parameter}", raise_type=MissingParameterTypeAnnotation, ) - for parameter in inspect.signature(func).parameters + for parameter in func_params ) is_goal_cls = issubclass(return_type, Goal) diff --git a/src/python/pants/engine/rules_test.py b/src/python/pants/engine/rules_test.py index f3ac76b1452..8ac54e21200 100644 --- a/src/python/pants/engine/rules_test.py +++ b/src/python/pants/engine/rules_test.py @@ -1013,6 +1013,28 @@ async def dup_a() -> B: # noqa: F811 return B() +def test_param_type_overrides() -> None: + type1 = int # use a runtime type + + @rule(_param_type_overrides={"param1": type1, "param2": dict}) + async def dont_injure_humans(param1: str, param2, param3: list) -> A: + return A() + + assert dont_injure_humans.rule.input_selectors == (int, dict, list) + + with pytest.raises(ValueError, match="paramX"): + + @rule(_param_type_overrides={"paramX": int}) + async def obey_human_orders() -> A: + return A() + + with pytest.raises(MissingParameterTypeAnnotation, match="must be a type"): + + @rule(_param_type_overrides={"param1": "A string"}) + async def protect_existence(param1) -> A: + return A() + + def test_invalid_rule_helper_name() -> None: with pytest.raises(ValueError, match="must be private"):