Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mask the environment name in graph and specs_rules #17179

Merged
Merged
2 changes: 1 addition & 1 deletion src/python/pants/build_graph/build_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def register_rules(self, plugin_or_backend: str, rules: Iterable[Rule | UnionRul

# "Index" the rules to normalize them and expand their dependencies.
rule_index = RuleIndex.create(rules)
rules_and_queries = (*rule_index.rules, *rule_index.queries)
rules_and_queries: tuple[Rule, ...] = (*rule_index.rules, *rule_index.queries)
for rule in rules_and_queries:
self._rule_to_providers[rule].append(plugin_or_backend)
for union_rule in rule_index.union_rules:
Expand Down
3 changes: 2 additions & 1 deletion src/python/pants/engine/internals/native_engine.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ def tasks_task_begin(
tasks: PyTasks,
func: Any,
return_type: type,
arg_types: Sequence[type],
masked_types: Sequence[type],
side_effecting: bool,
engine_aware_return_type: bool,
cacheable: bool,
Expand All @@ -299,7 +301,6 @@ def tasks_add_get(tasks: PyTasks, output: type, inputs: Sequence[type]) -> None:
def tasks_add_get_union(
tasks: PyTasks, output_type: type, input_types: Sequence[type], in_scope_types: Sequence[type]
) -> None: ...
def tasks_add_select(tasks: PyTasks, selector: type) -> None: ...
def tasks_add_query(tasks: PyTasks, output_type: type, input_types: Sequence[type]) -> None: ...
def execution_add_root_select(
scheduler: PyScheduler,
Expand Down
5 changes: 2 additions & 3 deletions src/python/pants/engine/internals/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,8 @@ def register_task(rule: TaskRule) -> None:
tasks,
rule.func,
rule.output_type,
rule.input_selectors,
rule.masked_types,
side_effecting=any(issubclass(t, SideEffecting) for t in rule.input_selectors),
engine_aware_return_type=issubclass(rule.output_type, EngineAwareReturnType),
cacheable=rule.cacheable,
Expand All @@ -651,9 +653,6 @@ def register_task(rule: TaskRule) -> None:
level=rule.level.level,
)

for selector in rule.input_selectors:
native_engine.tasks_add_select(tasks, selector)

for the_get in rule.input_gets:
unions = [t for t in the_get.input_types if is_union(t)]
if len(unions) == 1:
Expand Down
56 changes: 19 additions & 37 deletions src/python/pants/engine/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import inspect
import sys
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from types import FrameType, ModuleType
Expand All @@ -23,7 +22,7 @@
overload,
)

from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, Protocol

from pants.base.deprecated import warn_or_error
from pants.engine.engine_aware import SideEffecting
Expand Down Expand Up @@ -72,6 +71,7 @@ def _make_rule(
rule_type: RuleType,
return_type: Type,
parameter_types: Iterable[Type],
masked_types: Iterable[Type],
*,
cacheable: bool,
canonical_name: str,
Expand Down Expand Up @@ -110,8 +110,9 @@ def wrapper(func):
func.rule = TaskRule(
return_type,
parameter_types,
awaitables,
masked_types,
func,
input_gets=awaitables,
canonical_name=canonical_name,
desc=desc,
level=level,
Expand Down Expand Up @@ -172,6 +173,10 @@ def _ensure_type_annotation(
# (We assume but not enforce since this is likely to be used with unions, which has the same
# assumption between the union base and its members).
"_param_type_overrides",
# Allows callers to prevent the given list of types from being included in the identity of
# a @rule. Although the type may be in scope for callers, it will not be consumable in the
# `@rule` which declares the type masked.
"_masked_types",
}
# We don't want @rule-writers to use 'rule_type' or 'cacheable' as kwargs directly,
# but rather set them implicitly based on the rule annotation.
Expand Down Expand Up @@ -201,6 +206,7 @@ def rule_decorator(func, **kwargs) -> Callable:

rule_type: RuleType = kwargs["rule_type"]
cacheable: bool = kwargs["cacheable"]
masked_types: tuple[type, ...] = tuple(kwargs.get("_masked_types", ()))
param_type_overrides: dict[str, type] = kwargs.get("_param_type_overrides", {})

func_id = f"@rule {func.__module__}:{func.__name__}"
Expand Down Expand Up @@ -273,6 +279,7 @@ def rule_decorator(func, **kwargs) -> Callable:
rule_type,
return_type,
parameter_types,
masked_types,
cacheable=cacheable,
canonical_name=effective_name,
desc=effective_desc,
Expand Down Expand Up @@ -417,15 +424,14 @@ def wrapper(func: Callable[P, R]) -> Callable[P, R]:
return _rule_helper_decorator(func, **kwargs)


class Rule(ABC):
class Rule(Protocol):
"""Rules declare how to produce products for the product graph.

A rule describes what dependencies must be provided to produce a particular product. They also
act as factories for constructing the nodes within the graph.
"""

@property
@abstractmethod
def output_type(self):
"""An output `type` for the rule."""

Expand Down Expand Up @@ -465,43 +471,23 @@ def iter_rules():
return list(iter_rules())


@frozen_after_init
@dataclass(unsafe_hash=True)
class TaskRule(Rule):
@dataclass(frozen=True)
class TaskRule:
"""A Rule that runs a task function when all of its input selectors are satisfied.

NB: This API is not meant for direct consumption. To create a `TaskRule` you should always
prefer the `@rule` constructor.
"""

_output_type: Type
output_type: Type
input_selectors: Tuple[Type, ...]
input_gets: Tuple[AwaitableConstraints, ...]
masked_types: Tuple[Type, ...]
func: Callable
cacheable: bool
canonical_name: str
desc: Optional[str]
level: LogLevel

def __init__(
self,
output_type: Type,
input_selectors: Iterable[Type],
func: Callable,
input_gets: Iterable[AwaitableConstraints],
canonical_name: str,
desc: Optional[str] = None,
level: LogLevel = LogLevel.TRACE,
cacheable: bool = True,
) -> None:
self._output_type = output_type
self.input_selectors = tuple(input_selectors)
self.input_gets = tuple(input_gets)
self.func = func
self.cacheable = cacheable
self.canonical_name = canonical_name
self.desc = desc
self.level = level
desc: Optional[str] = None
level: LogLevel = LogLevel.TRACE
cacheable: bool = True

def __str__(self):
return "(name={}, {}, {!r}, {}, gets={})".format(
Expand All @@ -512,14 +498,10 @@ def __str__(self):
self.input_gets,
)

@property
def output_type(self):
return self._output_type


@frozen_after_init
@dataclass(unsafe_hash=True)
class QueryRule(Rule):
class QueryRule:
"""A QueryRule declares that a given set of Params will be used to request an output type.

Every callsite to `Scheduler.product_request` should have a corresponding QueryRule to ensure
Expand Down
6 changes: 4 additions & 2 deletions src/python/pants/option/subsystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,13 @@ def _construct_subsystem_rule(cls) -> Rule:
return TaskRule(
output_type=cls,
input_selectors=(),
func=partial_construct_subsystem,
input_gets=(
AwaitableConstraints(
output_type=ScopedOptions, input_types=(Scope,), is_effect=False
),
),
masked_types=(),
func=partial_construct_subsystem,
canonical_name=name,
)

Expand All @@ -228,14 +229,15 @@ async def inner(*a, **k):
return TaskRule(
output_type=cls.EnvironmentAware,
input_selectors=(cls, EnvironmentTarget),
func=inner,
input_gets=(
AwaitableConstraints(
output_type=EnvironmentVars,
input_types=(EnvironmentVarsRequest,),
is_effect=False,
),
),
masked_types=(),
func=inner,
canonical_name=name,
)

Expand Down
52 changes: 40 additions & 12 deletions src/rust/engine/rule_graph/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,39 @@ impl<R: Rule> Builder<R> {
for edge_to_delete in edges_to_delete {
graph[edge_to_delete].mark_deleted(EdgePrunedReason::SmallerParamSetAvailable);
}

// Validate masked params.
if let Node::Rule(rule) = &graph[node_id].0.node {
for masked_param in rule.masked_params() {
if graph[node_id].0.in_set.contains(&masked_param) {
let in_set = params_str(&graph[node_id].0.in_set);
let dependencies = graph
.edges_directed(node_id, Direction::Outgoing)
.filter(|edge_ref| {
!edge_ref.weight().is_deleted()
&& !edge_ref.weight().0.provides(&masked_param)
&& graph[edge_ref.target()].0.in_set.contains(&masked_param)
})
.map(|edge_ref| {
let dep_id = edge_ref.target();
format!(
"{} for {}",
graph[dep_id].0.node,
params_str(&graph[dep_id].0.in_set)
)
})
.collect::<Vec<_>>()
.join("\n ");
errored
.entry(node_id)
.or_insert_with(Vec::new)
.push(format!(
"Rule `{rule} (for {in_set})` masked the parameter type `{masked_param}`, but \
it was required by some dependencies:\n {dependencies}"
));
}
}
}
}

if errored.is_empty() {
Expand Down Expand Up @@ -1507,22 +1540,17 @@ impl<R: Rule> Builder<R> {
continue;
}

// Compute the out_set for this combination: any Params that are consumed here are removed
// Compute the out_set for this combination. Any Params that are consumed here are removed
// from the out_set that Rule dependencies will be allowed to consume. Params that weren't
// present in the out_set were already filtered near the top of this method.
let out_set = {
let consumed_by_params = combination
.iter()
.filter_map(|(_, dependency_id, _)| match graph[*dependency_id].0.node {
Node::Param(p) => Some(p),
_ => None,
})
.collect::<ParamTypes<_>>();

let mut out_set = out_set.clone();
for (_, dependency_id, _) in &combination {
if let Node::Param(p) = graph[*dependency_id].0.node {
out_set.remove(&p);
}
}
out_set
.difference(&consumed_by_params)
.cloned()
.collect::<ParamTypes<R::TypeId>>()
};

// We can eliminate this candidate if any dependencies have minimal in_sets which contain
Expand Down
5 changes: 5 additions & 0 deletions src/rust/engine/rule_graph/src/rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ pub trait Rule:
///
fn dependency_keys(&self) -> Vec<&DependencyKey<Self::TypeId>>;

///
/// Returns types which this rule is not allowed to consume from the calling scope.
///
fn masked_params(&self) -> Vec<Self::TypeId>;

///
/// True if this rule implementation should be required to be reachable in the RuleGraph.
///
Expand Down
44 changes: 44 additions & 0 deletions src/rust/engine/rule_graph/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,39 @@ fn ambiguity() {
.contains("Encountered 1 rule graph error:\n Too many"));
}

#[test]
fn masked_params() {
Comment on lines +89 to +90
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be good to test when masking works

// The middle rule masks "e", so even though the query provides it, construction should fail.
let rules = indexset![
Rule::new(
"a",
"a_from_b",
vec![DependencyKey::new("b").provided_params(vec!["c"])]
),
Rule::new(
"b",
"b_from_c",
vec![
DependencyKey::new("c"),
DependencyKey::new("c").provided_params(vec!["d"])
],
)
.masked_params(vec!["e"]),
Rule::new(
"c",
"c_from_d",
vec![DependencyKey::new("d"), DependencyKey::new("e")],
),
];
let queries = indexset![Query::new("a", vec!["e"])];

let res = RuleGraph::new(rules, queries).err().unwrap();
assert!(res.contains(
"Encountered 1 rule graph error:\n \
Rule `b_from_c(2) -> b (for c+e)` masked the parameter type `e`, but it"
));
}

#[test]
fn nested_single() {
let rules = indexset![
Expand Down Expand Up @@ -943,6 +976,7 @@ struct Rule {
product: &'static str,
name: &'static str,
dependency_keys: Vec<DependencyKey<&'static str>>,
masked_params: Vec<&'static str>,
}

impl Rule {
Expand All @@ -955,8 +989,14 @@ impl Rule {
product,
name,
dependency_keys,
masked_params: vec![],
}
}

fn masked_params(mut self, masked_params: Vec<&'static str>) -> Self {
self.masked_params = masked_params;
self
}
}

impl super::Rule for Rule {
Expand All @@ -970,6 +1010,10 @@ impl super::Rule for Rule {
self.dependency_keys.iter().collect()
}

fn masked_params(&self) -> Vec<Self::TypeId> {
self.masked_params.clone()
}

fn require_reachable(&self) -> bool {
!self.name.ends_with("_unreachable")
}
Expand Down
Loading