Skip to content

Commit

Permalink
Factor out a field-filtering rule helper
Browse files Browse the repository at this point in the history
(cherry picked from commit b8045313375093894515ca24becd0ed17aaa4f02)
  • Loading branch information
Christopher Neugebauer committed Oct 20, 2022
1 parent f004a8a commit c25b1ec
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 15 deletions.
15 changes: 10 additions & 5 deletions src/python/pants/backend/python/goals/lockfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
KnownUserResolveNames,
KnownUserResolveNamesRequest,
RequestedUserResolveNames,
TargetsWithSpecifiedFields,
TargetsWithSpecifiedFieldsRequest,
UserGenerateLockfiles,
WrappedGenerateLockfile,
)
Expand All @@ -41,7 +43,6 @@
from pants.engine.internals.target_adaptor import TargetAdaptor
from pants.engine.process import ProcessCacheScope, ProcessResult
from pants.engine.rules import Get, collect_rules, rule, rule_helper
from pants.engine.target import AllTargets
from pants.engine.unions import UnionRule
from pants.util.docutil import bin_name
from pants.util.logging import LogLevel
Expand Down Expand Up @@ -241,15 +242,19 @@ def determine_python_user_resolves(

@rule
async def setup_user_lockfile_requests(
requested: RequestedPythonUserResolveNames, all_targets: AllTargets, python_setup: PythonSetup
requested: RequestedPythonUserResolveNames, python_setup: PythonSetup
) -> UserGenerateLockfiles:
if not (python_setup.enable_resolves and python_setup.resolves_generate_lockfiles):
return UserGenerateLockfiles()

resolve_to_requirements_fields = defaultdict(set)
for tgt in all_targets:
if not tgt.has_fields((PythonRequirementResolveField, PythonRequirementsField)):
continue

filtered_targets = await Get(
TargetsWithSpecifiedFields,
TargetsWithSpecifiedFieldsRequest((PythonRequirementResolveField, PythonRequirementsField)),
)

for tgt in filtered_targets:
resolve = tgt[PythonRequirementResolveField].normalized_value(python_setup)
resolve_to_requirements_fields[resolve].add(tgt[PythonRequirementsField])

Expand Down
6 changes: 4 additions & 2 deletions src/python/pants/backend/python/goals/lockfile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from pants.backend.python.util_rules import pex
from pants.backend.python.util_rules.interpreter_constraints import InterpreterConstraints
from pants.core.goals.generate_lockfiles import GenerateLockfileResult, UserGenerateLockfiles
from pants.core.goals.generate_lockfiles import rules as generate_lockfiles_rules
from pants.engine.fs import DigestContents
from pants.engine.rules import SubsystemRule
from pants.testutil.rule_runner import PYTHON_BOOTSTRAP_ENV, QueryRule, RuleRunner
from pants.util.ordered_set import FrozenOrderedSet
from pants.util.strutil import strip_prefix
Expand All @@ -32,6 +32,7 @@ def rule_runner() -> RuleRunner:
rules=[
*lockfile_rules(),
*pex.rules(),
*generate_lockfiles_rules(),
QueryRule(GenerateLockfileResult, [GeneratePythonLockfile]),
]
)
Expand Down Expand Up @@ -197,7 +198,8 @@ def test_multiple_resolves() -> None:
rule_runner = RuleRunner(
rules=[
setup_user_lockfile_requests,
SubsystemRule(PythonSetup),
*PythonSetup.rules(), # type: ignore[call-arg]
*generate_lockfiles_rules(),
QueryRule(UserGenerateLockfiles, [RequestedPythonUserResolveNames]),
],
target_types=[PythonRequirementTarget],
Expand Down
2 changes: 2 additions & 0 deletions src/python/pants/backend/scala/resolve/lockfile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pants.backend.scala.resolve.lockfile import rules as scala_lockfile_rules
from pants.backend.scala.target_types import ScalaSourcesGeneratorTarget, ScalaSourceTarget
from pants.core.goals.generate_lockfiles import GenerateLockfileResult, UserGenerateLockfiles
from pants.core.goals.generate_lockfiles import rules as generate_lockfiles_rules
from pants.core.util_rules import external_tool, source_files, system_binaries
from pants.engine.internals import build_files, graph
from pants.jvm import jdk_rules
Expand All @@ -35,6 +36,7 @@ def rule_runner() -> RuleRunner:
*jdk_rules.rules(),
*coursier_fetch_rules(),
*coursier_jvm_tool_rules(),
*generate_lockfiles_rules(),
*lockfile.rules(),
*coursier_setup_rules(),
*external_tool.rules(),
Expand Down
31 changes: 28 additions & 3 deletions src/python/pants/core/goals/generate_lockfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
from typing import Callable, ClassVar, Iterable, Sequence
from typing import Callable, ClassVar, Iterable, Sequence, Type

from pants.engine.collection import Collection
from pants.engine.environment import EnvironmentName
from pants.engine.fs import Digest, MergeDigests, Workspace
from pants.engine.goal import Goal, GoalSubsystem
from pants.engine.internals.selectors import Get, MultiGet
from pants.engine.rules import collect_rules, goal_rule
from pants.engine.target import Target
from pants.engine.rules import collect_rules, goal_rule, rule
from pants.engine.target import AllTargets, Field, Target
from pants.engine.unions import UnionMembership, union
from pants.option.option_types import StrListOption, StrOption
from pants.util.docutil import bin_name, doc_url
Expand Down Expand Up @@ -407,6 +407,31 @@ async def generate_lockfiles_goal(
return GenerateLockfilesGoal(exit_code=0)


# -----------------------------------------------------------------------------------------------
# Rules used by most user lockfile-determining rules
# -----------------------------------------------------------------------------------------------


class TargetsWithSpecifiedFields(Collection[Target]):
"""A filtered set of all the `Target`s that have specified fields."""


@dataclass(frozen=True)
class TargetsWithSpecifiedFieldsRequest:
"""A request type to fetch all of the targets with specified field types."""

field_types: tuple[Type[Field], ...]


@rule
async def filter_targets_by_specified_fields(
request: TargetsWithSpecifiedFieldsRequest, all_targets: AllTargets
) -> TargetsWithSpecifiedFields:
return TargetsWithSpecifiedFields(
tgt for tgt in all_targets if tgt.has_fields(request.field_types)
)


# -----------------------------------------------------------------------------------------------
# Helpers for determining the resolve
# -----------------------------------------------------------------------------------------------
Expand Down
14 changes: 9 additions & 5 deletions src/python/pants/jvm/goals/lockfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
KnownUserResolveNames,
KnownUserResolveNamesRequest,
RequestedUserResolveNames,
TargetsWithSpecifiedFields,
TargetsWithSpecifiedFieldsRequest,
UserGenerateLockfiles,
WrappedGenerateLockfile,
)
from pants.engine.environment import EnvironmentName
from pants.engine.fs import CreateDigest, Digest, FileContent
from pants.engine.internals.selectors import MultiGet
from pants.engine.rules import Get, collect_rules, rule
from pants.engine.target import AllTargets
from pants.engine.unions import UnionMembership, UnionRule, union
from pants.jvm.resolve import coursier_fetch
from pants.jvm.resolve.common import ArtifactRequirement, ArtifactRequirements
Expand Down Expand Up @@ -132,13 +133,16 @@ async def validate_jvm_artifacts_for_resolve(
@rule
async def setup_user_lockfile_requests(
requested: RequestedJVMUserResolveNames,
all_targets: AllTargets,
jvm_subsystem: JvmSubsystem,
) -> UserGenerateLockfiles:
resolve_to_artifacts: Mapping[str, OrderedSet[ArtifactRequirement]] = defaultdict(OrderedSet)
for tgt in sorted(all_targets, key=lambda t: t.address):
if not tgt.has_field(JvmArtifactResolveField):
continue

filtered_targets = await Get(
TargetsWithSpecifiedFields,
TargetsWithSpecifiedFieldsRequest((JvmArtifactResolveField,)),
)

for tgt in sorted(filtered_targets, key=lambda t: t.address):
artifact = ArtifactRequirement.from_jvm_artifact_target(tgt)
resolve = tgt[JvmResolveField].normalized_value(jvm_subsystem)
resolve_to_artifacts[resolve].add(artifact)
Expand Down
2 changes: 2 additions & 0 deletions src/python/pants/jvm/goals/lockfile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest

from pants.core.goals.generate_lockfiles import GenerateLockfileResult, UserGenerateLockfiles
from pants.core.goals.generate_lockfiles import rules as generate_lockfiles_rules
from pants.core.util_rules import source_files
from pants.core.util_rules.external_tool import rules as external_tool_rules
from pants.engine.fs import DigestContents, FileDigest
Expand Down Expand Up @@ -39,6 +40,7 @@ def rule_runner() -> RuleRunner:
*lockfile.rules(),
*coursier_setup_rules(),
*external_tool_rules(),
*generate_lockfiles_rules(),
*source_files.rules(),
*util_rules(),
QueryRule(UserGenerateLockfiles, [RequestedJVMUserResolveNames]),
Expand Down

0 comments on commit c25b1ec

Please sign in to comment.