Skip to content

Commit

Permalink
go: refactor tailor rule into rule helpers (#17082)
Browse files Browse the repository at this point in the history
Refactor the tailor rule for Go into several rule helpers.

[ci skip-build-wheels]

[ci skip-rust]
  • Loading branch information
Tom Dyas authored Oct 2, 2022
1 parent 585af30 commit fb64383
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 71 deletions.
161 changes: 92 additions & 69 deletions src/python/pants/backend/go/goals/tailor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
group_by_dir,
)
from pants.engine.fs import DigestContents, PathGlobs, Paths
from pants.engine.rules import Get, MultiGet, collect_rules, rule
from pants.engine.rules import Get, MultiGet, collect_rules, rule, rule_helper
from pants.engine.target import UnexpandedTargets
from pants.engine.unions import UnionRule
from pants.util.logging import LogLevel
Expand All @@ -48,6 +48,92 @@ def has_go_mod_ancestor(dirname: str, all_go_mod_dirs: set[str]) -> bool:
return any(dirname.startswith(go_mod_dir) for go_mod_dir in all_go_mod_dirs)


@rule_helper
async def _find_go_mod_targets(
all_go_mod_files: set[str], all_owned_sources: AllOwnedSources
) -> list[PutativeTarget]:
unowned_go_mod_files = all_go_mod_files - set(all_owned_sources)
return [
PutativeTarget.for_target_type(
GoModTarget,
path=dirname,
name=None,
triggering_sources=sorted(filenames),
)
for dirname, filenames in group_by_dir(unowned_go_mod_files).items()
]


@rule_helper
async def _find_go_package_targets(
request: PutativeGoTargetsRequest, all_go_mod_dirs: set[str], all_owned_sources: AllOwnedSources
) -> list[PutativeTarget]:
putative_targets = []
all_go_files = await Get(Paths, PathGlobs, request.path_globs("*.go"))
unowned_go_files = set(all_go_files.files) - set(all_owned_sources)
for dirname, filenames in group_by_dir(unowned_go_files).items():
# Ignore paths that have `testdata` or `vendor` in them.
# From `go help packages`: Note, however, that a directory named vendor that itself
# contains code is not a vendored package: cmd/vendor would be a command named vendor.
dirname_parts = PurePath(dirname).parts
if "testdata" in dirname_parts or "vendor" in dirname_parts[0:-1]:
continue
if not has_go_mod_ancestor(dirname, all_go_mod_dirs):
continue

putative_targets.append(
PutativeTarget.for_target_type(
GoPackageTarget,
path=dirname,
name=None,
triggering_sources=sorted(filenames),
)
)

return putative_targets


@rule_helper
async def _find_go_binary_targets(
request: PutativeGoTargetsRequest, all_go_mod_dirs: set[str]
) -> list[PutativeTarget]:
all_go_files_digest_contents = await Get(DigestContents, PathGlobs, request.path_globs("*.go"))

main_package_dirs = []
for file_content in all_go_files_digest_contents:
dirname = os.path.dirname(file_content.path)
if has_package_main(file_content.content) and has_go_mod_ancestor(dirname, all_go_mod_dirs):
main_package_dirs.append(dirname)

existing_targets = await Get(
UnexpandedTargets,
RawSpecs(
ancestor_globs=tuple(AncestorGlobSpec(d) for d in main_package_dirs),
description_of_origin="the `go_binary` tailor rule",
),
)
owned_main_packages = await MultiGet(
Get(GoBinaryMainPackage, GoBinaryMainPackageRequest(t[GoBinaryMainPackageField]))
for t in existing_targets
if t.has_field(GoBinaryMainPackageField)
)
unowned_main_package_dirs = set(main_package_dirs) - {
# NB: We assume the `go_package` lives in the directory it's defined, which we validate
# by e.g. banning `**` in its sources field.
pkg.address.spec_path
for pkg in owned_main_packages
}
return [
PutativeTarget.for_target_type(
GoBinaryTarget,
path=main_pkg_dir,
name="bin",
triggering_sources=tuple(),
)
for main_pkg_dir in unowned_main_package_dirs
]


@rule(level=LogLevel.DEBUG, desc="Determine candidate Go targets to create")
async def find_putative_go_targets(
request: PutativeGoTargetsRequest,
Expand All @@ -60,79 +146,16 @@ async def find_putative_go_targets(
all_go_mod_dirs = {os.path.dirname(fp) for fp in all_go_mod_files}

if golang_subsystem.tailor_go_mod_targets:
unowned_go_mod_files = all_go_mod_files - set(all_owned_sources)
for dirname, filenames in group_by_dir(unowned_go_mod_files).items():
putative_targets.append(
PutativeTarget.for_target_type(
GoModTarget,
path=dirname,
name=None,
triggering_sources=sorted(filenames),
)
)
putative_targets.extend(await _find_go_mod_targets(all_go_mod_files, all_owned_sources))

if golang_subsystem.tailor_package_targets:
all_go_files = await Get(Paths, PathGlobs, request.path_globs("*.go"))
unowned_go_files = set(all_go_files.files) - set(all_owned_sources)
for dirname, filenames in group_by_dir(unowned_go_files).items():
# Ignore paths that have `testdata` or `vendor` in them.
# From `go help packages`: Note, however, that a directory named vendor that itself
# contains code is not a vendored package: cmd/vendor would be a command named vendor.
dirname_parts = PurePath(dirname).parts
if "testdata" in dirname_parts or "vendor" in dirname_parts[0:-1]:
continue
if not has_go_mod_ancestor(dirname, all_go_mod_dirs):
continue
putative_targets.append(
PutativeTarget.for_target_type(
GoPackageTarget,
path=dirname,
name=None,
triggering_sources=sorted(filenames),
)
)

if golang_subsystem.tailor_binary_targets:
all_go_files_digest_contents = await Get(
DigestContents, PathGlobs, request.path_globs("*.go")
)

main_package_dirs = []
for file_content in all_go_files_digest_contents:
dirname = os.path.dirname(file_content.path)
if has_package_main(file_content.content) and has_go_mod_ancestor(
dirname, all_go_mod_dirs
):
main_package_dirs.append(dirname)

existing_targets = await Get(
UnexpandedTargets,
RawSpecs(
ancestor_globs=tuple(AncestorGlobSpec(d) for d in main_package_dirs),
description_of_origin="the `go_binary` tailor rule",
),
)
owned_main_packages = await MultiGet(
Get(GoBinaryMainPackage, GoBinaryMainPackageRequest(t[GoBinaryMainPackageField]))
for t in existing_targets
if t.has_field(GoBinaryMainPackageField)
)
unowned_main_package_dirs = set(main_package_dirs) - {
# NB: We assume the `go_package` lives in the directory it's defined, which we validate
# by e.g. banning `**` in its sources field.
pkg.address.spec_path
for pkg in owned_main_packages
}
putative_targets.extend(
PutativeTarget.for_target_type(
GoBinaryTarget,
path=main_pkg_dir,
name="bin",
triggering_sources=tuple(),
)
for main_pkg_dir in unowned_main_package_dirs
await _find_go_package_targets(request, all_go_mod_dirs, all_owned_sources)
)

if golang_subsystem.tailor_binary_targets:
putative_targets.extend(await _find_go_binary_targets(request, all_go_mod_dirs))

return PutativeTargets(putative_targets)


Expand Down
4 changes: 3 additions & 1 deletion src/python/pants/backend/go/target_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,12 @@ class GoModTarget(TargetGenerator):


class GoPackageSourcesField(MultipleSourcesField):
default = ("*.go", "*.s")
default = ("*.go",)
expected_file_extensions = (
".go",
".s",
".S",
".sx",
".c",
".h",
".hh",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_build_package_with_assembly(rule_runner: RuleRunner) -> None:
"BUILD": dedent(
"""\
go_mod(name="mod")
go_package(name="pkg")
go_package(name="pkg", sources=["*.go", "*.s"])
go_binary(name="bin")
"""
),
Expand Down

0 comments on commit fb64383

Please sign in to comment.