diff --git a/src/main/java/com/google/devtools/build/lib/rules/cpp/CcCompilationContext.java b/src/main/java/com/google/devtools/build/lib/rules/cpp/CcCompilationContext.java index df8715d371175a..66b573035da29e 100644 --- a/src/main/java/com/google/devtools/build/lib/rules/cpp/CcCompilationContext.java +++ b/src/main/java/com/google/devtools/build/lib/rules/cpp/CcCompilationContext.java @@ -531,6 +531,13 @@ public NestedSet getTransitiveModules(boolean usePic) { return usePic ? transitivePicModules : transitiveModules; } + @Override + public Depset getStarlarkTransitiveModules(boolean usePic, StarlarkThread thread) + throws EvalException { + CcModule.checkPrivateStarlarkificationAllowlist(thread); + return Depset.of(Artifact.TYPE, getTransitiveModules(usePic)); + } + /** * Returns the immutable set of additional transitive inputs needed for compilation, like C++ * module map artifacts. @@ -541,6 +548,12 @@ public NestedSet getAdditionalInputs() { return builder.build(); } + @Override + public Depset getStarlarkAdditionalInputs(StarlarkThread thread) throws EvalException { + CcModule.checkPrivateStarlarkificationAllowlist(thread); + return Depset.of(Artifact.TYPE, getAdditionalInputs()); + } + /** Adds additional transitive inputs needed for compilation to builder. */ void addAdditionalInputs(NestedSetBuilder builder) { builder.addAll(directModuleMaps); diff --git a/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/cpp/CcCompilationContextApi.java b/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/cpp/CcCompilationContextApi.java index 9a642b78a1f088..cdad44e7240100 100644 --- a/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/cpp/CcCompilationContextApi.java +++ b/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/cpp/CcCompilationContextApi.java @@ -17,6 +17,8 @@ import com.google.devtools.build.docgen.annot.DocCategory; import com.google.devtools.build.lib.collect.nestedset.Depset; import com.google.devtools.build.lib.starlarkbuildapi.FileApi; +import net.starlark.java.annot.Param; +import net.starlark.java.annot.ParamType; import net.starlark.java.annot.StarlarkBuiltin; import net.starlark.java.annot.StarlarkMethod; import net.starlark.java.eval.EvalException; @@ -133,4 +135,20 @@ public interface CcCompilationContextApi extends Starlark doc = "Returns the set of validation artifacts.", structField = true) Depset getStarlarkValidationArtifacts(); + + @StarlarkMethod(name = "additional_inputs", documented = false, useStarlarkThread = true) + Depset getStarlarkAdditionalInputs(StarlarkThread thread) throws EvalException; + + @StarlarkMethod( + name = "transitive_modules", + documented = false, + useStarlarkThread = true, + parameters = { + @Param( + name = "use_pic", + positional = false, + named = true, + allowedTypes = {@ParamType(type = Boolean.class)}) + }) + Depset getStarlarkTransitiveModules(boolean usePic, StarlarkThread thread) throws EvalException; } diff --git a/src/main/starlark/builtins_bzl/common/cc/cc_binary.bzl b/src/main/starlark/builtins_bzl/common/cc/cc_binary.bzl index 5d4e47691a2507..5bf82474fdc13c 100644 --- a/src/main/starlark/builtins_bzl/common/cc/cc_binary.bzl +++ b/src/main/starlark/builtins_bzl/common/cc/cc_binary.bzl @@ -660,7 +660,7 @@ def cc_binary_impl(ctx, additional_linkopts): Returns: Appropriate providers for cc_binary/cc_test. """ - cc_helper.check_srcs_extensions(ctx, ALLOWED_SRC_FILES, "cc_binary") + cc_helper.check_srcs_extensions(ctx, ALLOWED_SRC_FILES, "cc_binary", True) common = cc_internal.create_common(ctx = ctx) semantics.validate_deps(ctx) diff --git a/src/main/starlark/builtins_bzl/common/cc/cc_helper.bzl b/src/main/starlark/builtins_bzl/common/cc/cc_helper.bzl index c9c7966a5063f4..4182c572010395 100644 --- a/src/main/starlark/builtins_bzl/common/cc/cc_helper.bzl +++ b/src/main/starlark/builtins_bzl/common/cc/cc_helper.bzl @@ -41,23 +41,23 @@ artifact_category = struct( CLIF_OUTPUT_PROTO = "CLIF_OUTPUT_PROTO", ) -def _check_src_extension(file, allowed_src_files): +def _check_src_extension(file, allowed_src_files, allow_versioned_shared_libraries): extension = "." + file.extension - if _matches_extension(extension, allowed_src_files) or _is_shared_library_extension_valid(file.path): + if _matches_extension(extension, allowed_src_files) or (allow_versioned_shared_libraries and _is_versioned_shared_library_extension_valid(file.path)): return True return False -def _check_srcs_extensions(ctx, allowed_src_files, rule_name): +def _check_srcs_extensions(ctx, allowed_src_files, rule_name, allow_versioned_shared_libraries): for src in ctx.attr.srcs: if DefaultInfo in src: files = src[DefaultInfo].files.to_list() if len(files) == 1 and files[0].is_source: - if not _check_src_extension(files[0], allowed_src_files) and not files[0].is_directory: + if not _check_src_extension(files[0], allowed_src_files, allow_versioned_shared_libraries) and not files[0].is_directory: fail("in srcs attribute of {} rule {}: source file '{}' is misplaced here".format(rule_name, ctx.label, str(src.label))) else: at_least_one_good = False for file in files: - if _check_src_extension(file, allowed_src_files) or file.is_directory: + if _check_src_extension(file, allowed_src_files, allow_versioned_shared_libraries) or file.is_directory: at_least_one_good = True break if not at_least_one_good: @@ -134,6 +134,24 @@ def _find_cpp_toolchain(ctx): # We didn't find anything. fail("In order to use find_cpp_toolchain, you must define the '_cc_toolchain' attribute on your rule or aspect.") +def _collect_compilation_prerequisites(ctx, compilation_context): + direct = [] + transitive = [] + if hasattr(ctx.attr, "srcs"): + for src in ctx.attr.srcs: + if DefaultInfo in src: + files = src[DefaultInfo].files.to_list() + for file in files: + if _check_src_extension(file, extensions.CC_AND_OBJC, False): + direct.append(file) + + transitive.append(compilation_context.headers) + transitive.append(compilation_context.additional_inputs()) + transitive.append(compilation_context.transitive_modules(use_pic = True)) + transitive.append(compilation_context.transitive_modules(use_pic = False)) + + return depset(direct = direct, transitive = transitive) + def _build_output_groups_for_emitting_compile_providers( compilation_outputs, compilation_context, @@ -151,7 +169,7 @@ def _build_output_groups_for_emitting_compile_providers( use_pic = use_pic, ) output_groups_builder["compilation_outputs"] = files_to_compile - output_groups_builder["compilation_prerequisites_INTERNAL_"] = cc_internal.collect_compilation_prerequisites(ctx = ctx, compilation_context = compilation_context) + output_groups_builder["compilation_prerequisites_INTERNAL_"] = _collect_compilation_prerequisites(ctx = ctx, compilation_context = compilation_context) if generate_hidden_top_level_group: output_groups_builder["_hidden_top_level_INTERNAL_"] = _collect_library_hidden_top_level_artifacts( @@ -219,14 +237,6 @@ OBJC_SOURCE = [".m"] OBJCPP_SOURCE = [".mm"] CLIF_INPUT_PROTO = [".ipb"] CLIF_OUTPUT_PROTO = [".opb"] -CC_AND_OBJC = [] -CC_AND_OBJC.extend(CC_SOURCE) -CC_AND_OBJC.extend(C_SOURCE) -CC_AND_OBJC.extend(OBJC_SOURCE) -CC_AND_OBJC.extend(OBJCPP_SOURCE) -CC_AND_OBJC.extend(CLIF_INPUT_PROTO) -CC_AND_OBJC.extend(CLIF_OUTPUT_PROTO) - CC_HEADER = [".h", ".hh", ".hpp", ".ipp", ".hxx", ".h++", ".inc", ".inl", ".tlh", ".tli", ".H", ".tcc"] ASSESMBLER_WITH_C_PREPROCESSOR = [".S"] ASSEMBLER = [".s", ".asm"] @@ -238,6 +248,15 @@ SHARED_LIBRARY = [".so", ".dylib", ".dll"] OBJECT_FILE = [".o", ".obj"] PIC_OBJECT_FILE = [".pic.o"] +CC_AND_OBJC = [] +CC_AND_OBJC.extend(CC_SOURCE) +CC_AND_OBJC.extend(C_SOURCE) +CC_AND_OBJC.extend(OBJC_SOURCE) +CC_AND_OBJC.extend(OBJCPP_SOURCE) +CC_AND_OBJC.extend(CC_HEADER) +CC_AND_OBJC.extend(ASSEMBLER) +CC_AND_OBJC.extend(ASSESMBLER_WITH_C_PREPROCESSOR) + extensions = struct( CC_SOURCE = CC_SOURCE, C_SOURCE = C_SOURCE, @@ -392,12 +411,7 @@ def _build_precompiled_files(ctx): shared_libraries, ) -def _is_shared_library_extension_valid(shared_library_name): - if (shared_library_name.endswith(".so") or - shared_library_name.endswith(".dll") or - shared_library_name.endswith(".dylib")): - return True - +def _is_versioned_shared_library_extension_valid(shared_library_name): # validate agains the regex "^.+\\.((so)|(dylib))(\\.\\d\\w*)+$", # must match VERSIONED_SHARED_LIBRARY. for ext in (".so.", ".dylib."): @@ -411,9 +425,16 @@ def _is_shared_library_extension_valid(shared_library_name): if not (c.isalnum() or c == "_"): return False return True - return False +def _is_shared_library_extension_valid(shared_library_name): + if (shared_library_name.endswith(".so") or + shared_library_name.endswith(".dll") or + shared_library_name.endswith(".dylib")): + return True + + return _is_versioned_shared_library_extension_valid(shared_library_name) + def _get_providers(deps, provider): providers = [] for dep in deps: @@ -538,5 +559,6 @@ cc_helper = struct( generate_def_file = _generate_def_file, stringify_linker_input = _stringify_linker_input, get_linked_artifact = _get_linked_artifact, + collect_compilation_prerequisites = _collect_compilation_prerequisites, collect_native_cc_libraries = _collect_native_cc_libraries, ) diff --git a/src/main/starlark/builtins_bzl/common/cc/cc_library.bzl b/src/main/starlark/builtins_bzl/common/cc/cc_library.bzl index 185f9080bcaea7..a4e9bf6134f80d 100755 --- a/src/main/starlark/builtins_bzl/common/cc/cc_library.bzl +++ b/src/main/starlark/builtins_bzl/common/cc/cc_library.bzl @@ -22,7 +22,7 @@ cc_common = _builtins.toplevel.cc_common cc_internal = _builtins.internal.cc_internal def _cc_library_impl(ctx): - cc_helper.check_srcs_extensions(ctx, ALLOWED_SRC_FILES, "cc_library") + cc_helper.check_srcs_extensions(ctx, ALLOWED_SRC_FILES, "cc_library", True) common = cc_internal.create_common(ctx = ctx) common.report_invalid_options(ctx = ctx) diff --git a/src/test/java/com/google/devtools/build/lib/rules/cpp/StarlarkCcCommonTest.java b/src/test/java/com/google/devtools/build/lib/rules/cpp/StarlarkCcCommonTest.java index 5c38525506cca5..fea4a944bee94c 100755 --- a/src/test/java/com/google/devtools/build/lib/rules/cpp/StarlarkCcCommonTest.java +++ b/src/test/java/com/google/devtools/build/lib/rules/cpp/StarlarkCcCommonTest.java @@ -7080,12 +7080,19 @@ public void testExpandedCompileApiBlocked() throws Exception { public void testExpandedCcCompilationContextApiBlocked() throws Exception { scratch.file( "b/BUILD", - "load('//my_rules:rule.bzl', 'method_rule', 'param_2_rule')", + "load('//my_rules:rule.bzl', 'method_rule', 'param_2_rule', 'additional_inputs_rule'," + + " 'transitive_modules_rule')", "param_2_rule(", " name = 'p2',", ")", "method_rule(", " name = 'm',", + ")", + "additional_inputs_rule(", + " name = 'ai',", + ")", + "transitive_modules_rule(", + " name = 'tm',", ")"); scratch.file("my_rules/BUILD"); scratch.file( @@ -7097,16 +7104,30 @@ public void testExpandedCcCompilationContextApiBlocked() throws Exception { "def _p2_impl(ctx):", " comp_context = cc_common.create_compilation_context(purpose = 'testing')", " return [CcInfo(compilation_context = comp_context)]", + "def _additional_inputs_impl(ctx):", + " comp_context = cc_common.create_compilation_context()", + " comp_context.additional_inputs()", + "def _transitive_modules_impl(ctx):", + " comp_context = cc_common.create_compilation_context()", + " comp_context.transitive_modules(use_pic = True)", "method_rule = rule(", " implementation = _m_impl,", ")", "param_2_rule = rule(", - " implementation = _p2_impl,", + " implementation = _p2_impl)", + "additional_inputs_rule = rule(", + " implementation = _additional_inputs_impl)", + "transitive_modules_rule = rule(", + " implementation = _transitive_modules_impl", ")"); AssertionError e = assertThrows(AssertionError.class, () -> getConfiguredTarget("//b:m")); assertThat(e).hasMessageThat().contains("Rule in 'my_rules' cannot use private API"); e = assertThrows(AssertionError.class, () -> getConfiguredTarget("//b:p2")); assertThat(e).hasMessageThat().contains("Rule in 'my_rules' cannot use private API"); + e = assertThrows(AssertionError.class, () -> getConfiguredTarget("//b:ai")); + assertThat(e).hasMessageThat().contains("Rule in 'my_rules' cannot use private API"); + e = assertThrows(AssertionError.class, () -> getConfiguredTarget("//b:tm")); + assertThat(e).hasMessageThat().contains("Rule in 'my_rules' cannot use private API"); } @Test