diff --git a/src/test/py/bazel/py_test.py b/src/test/py/bazel/py_test.py index c75923b3c5c29e..f02506bcd09af6 100644 --- a/src/test/py/bazel/py_test.py +++ b/src/test/py/bazel/py_test.py @@ -207,6 +207,116 @@ def testPyTestWithStdlibCollisionRunsRemotely(self): self.AssertExitCode(exit_code, 0, stderr, stdout) self.assertIn('Test ran', stdout) +class PyRunfilesLibraryTest(test_base.TestBase): + def testPyRunfilesLibraryCurrentRepository(self): + self.CreateWorkspaceWithDefaultRepos('WORKSPACE', [ + 'local_repository(', + ' name = "other_repo",', + ' path = "other_repo_path",', + ')' + ]) + + self.ScratchFile('pkg/BUILD.bazel', [ + 'py_library(', + ' name = "library",', + ' srcs = ["library.py"],', + ' visibility = ["//visibility:public"],', + ' deps = ["@bazel_tools//tools/python/runfiles"],', + ')', + '', + 'py_binary(', + ' name = "binary",', + ' srcs = ["binary.py"],', + ' deps = [', + ' ":library",', + ' "@bazel_tools//tools/python/runfiles",' + ' ],', + ')', + '', + 'py_test(', + ' name = "test",', + ' srcs = ["test.py"],', + ' deps = [', + ' ":library",', + ' "@bazel_tools//tools/python/runfiles",', + ' ],', + ')', + ]) + self.ScratchFile('pkg/library.py', [ + 'from bazel_tools.tools.python.runfiles import runfiles', + 'def print_repo_name():', + ' print("in pkg/library.py: \'%s\'" % runfiles.Create().CurrentRepository())', + ]) + self.ScratchFile('pkg/binary.py', [ + 'from bazel_tools.tools.python.runfiles import runfiles', + 'from pkg import library', + 'library.print_repo_name()', + 'print("in pkg/binary.py: \'%s\'" % runfiles.Create().CurrentRepository())', + ]) + self.ScratchFile('pkg/test.py', [ + 'from bazel_tools.tools.python.runfiles import runfiles', + 'from pkg import library', + 'library.print_repo_name()', + 'print("in pkg/test.py: \'%s\'" % runfiles.Create().CurrentRepository())', + ]) + + self.ScratchFile('other_repo_path/WORKSPACE') + self.ScratchFile('other_repo_path/pkg/BUILD.bazel', [ + 'py_binary(', + ' name = "binary",', + ' srcs = ["binary.py"],', + ' deps = [', + ' "@//pkg:library",', + ' "@bazel_tools//tools/python/runfiles",' + ' ],', + ')', + '', + 'py_test(', + ' name = "test",', + ' srcs = ["test.py"],', + ' deps = [', + ' "@//pkg:library",', + ' "@bazel_tools//tools/python/runfiles",', + ' ],', + ')', + ]) + self.ScratchFile('other_repo_path/pkg/binary.py', [ + 'from bazel_tools.tools.python.runfiles import runfiles', + 'from pkg import library', + 'library.print_repo_name()', + 'print("in external/other_repo/pkg/binary.py: \'%s\'" % runfiles.Create().CurrentRepository())', + ]) + self.ScratchFile('other_repo_path/pkg/test.py', [ + 'from bazel_tools.tools.python.runfiles import runfiles', + 'from pkg import library', + 'library.print_repo_name()', + 'print("in external/other_repo/pkg/test.py: \'%s\'" % runfiles.Create().CurrentRepository())', + ]) + + exit_code, stdout, stderr = self.RunBazel( + ['run', '//pkg:binary']) + self.AssertExitCode(exit_code, 0, stderr, stdout) + self.assertIn('in pkg/binary.py: \'\'', stdout) + self.assertIn('in pkg/library.py: \'\'', stdout) + + exit_code, stdout, stderr = self.RunBazel( + ['test', '//pkg:test', '--test_output=streamed']) + self.AssertExitCode(exit_code, 0, stderr, stdout) + self.assertIn('in pkg/test.py: \'\'', stdout) + self.assertIn('in pkg/library.py: \'\'', stdout) + + exit_code, stdout, stderr = self.RunBazel( + ['run', '@other_repo//pkg:binary']) + self.AssertExitCode(exit_code, 0, stderr, stdout) + self.assertIn('in external/other_repo/pkg/binary.py: \'other_repo\'', stdout) + self.assertIn('in pkg/library.py: \'\'', stdout) + + exit_code, stdout, stderr = self.RunBazel( + ['test', '@other_repo//pkg:test', '--test_output=streamed']) + self.AssertExitCode(exit_code, 0, stderr, stdout) + self.assertIn('in external/other_repo/pkg/test.py: \'other_repo\'', stdout) + self.assertIn('in pkg/library.py: \'\'', stdout) + if __name__ == '__main__': unittest.main() diff --git a/tools/python/gen_runfiles_constants.bzl b/tools/python/gen_runfiles_constants.bzl new file mode 100644 index 00000000000000..cfd2fee07098fe --- /dev/null +++ b/tools/python/gen_runfiles_constants.bzl @@ -0,0 +1,32 @@ +# Copyright 2022 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +_RUNFILES_CONSTANTS_TEMPLATE = """# Generated by gen_runfiles_constants.bzl +# Internal-only; do no use. +# The name of the runfiles directory corresponding to the main repository. +MAIN_REPOSITORY_RUNFILES_DIRECTORY = '%s' +""" + +def _gen_runfiles_constants_impl(ctx): + out = ctx.actions.declare_file(ctx.attr.name + ".py") + ctx.actions.write(out, _RUNFILES_CONSTANTS_TEMPLATE % ctx.workspace_name) + + return DefaultInfo( + files = depset([out]), + runfiles = ctx.runfiles([out]), + ) + +gen_runfiles_constants = rule( + implementation = _gen_runfiles_constants_impl, +) diff --git a/tools/python/runfiles/BUILD b/tools/python/runfiles/BUILD index 21a88a9174c320..2afdd88f282ce2 100644 --- a/tools/python/runfiles/BUILD +++ b/tools/python/runfiles/BUILD @@ -1,3 +1,4 @@ +load("//tools/python:gen_runfiles_constants.bzl", "gen_runfiles_constants") load("//tools/python:private/defs.bzl", "py_library", "py_test") package(default_visibility = ["//visibility:private"]) @@ -20,7 +21,14 @@ filegroup( py_library( name = "runfiles", testonly = 1, - srcs = ["runfiles.py"], + srcs = [ + "runfiles.py", + ":_runfiles_constants", + ], +) + +gen_runfiles_constants( + name = "_runfiles_constants", ) py_test( diff --git a/tools/python/runfiles/BUILD.tools b/tools/python/runfiles/BUILD.tools index 3bfe889f34fc3f..abd6acac406e1d 100644 --- a/tools/python/runfiles/BUILD.tools +++ b/tools/python/runfiles/BUILD.tools @@ -1,7 +1,15 @@ +load("//tools/python:gen_runfiles_constants.bzl", "gen_runfiles_constants") load("//tools/python:private/defs.bzl", "py_library") py_library( name = "runfiles", - srcs = ["runfiles.py"], + srcs = [ + "runfiles.py", + ":_runfiles_constants", + ], visibility = ["//visibility:public"], ) + +gen_runfiles_constants( + name = "_runfiles_constants", +) diff --git a/tools/python/runfiles/runfiles.py b/tools/python/runfiles/runfiles.py index 03cff1c27ee386..ae2e0b5517c08c 100644 --- a/tools/python/runfiles/runfiles.py +++ b/tools/python/runfiles/runfiles.py @@ -58,9 +58,12 @@ p = subprocess.Popen([r.Rlocation("path/to/binary")], env, ...) """ +import inspect import os import posixpath +import sys +from ._runfiles_constants import MAIN_REPOSITORY_RUNFILES_DIRECTORY def CreateManifestBased(manifest_path): return _Runfiles(_ManifestBased(manifest_path)) @@ -114,6 +117,7 @@ class _Runfiles(object): def __init__(self, strategy): self._strategy = strategy + self._python_runfiles_root = _FindPythonRunfilesRoot() def Rlocation(self, path): """Returns the runtime path of a runfile. @@ -161,6 +165,61 @@ def EnvVars(self): """ return self._strategy.EnvVars() + def CurrentRepository(self, frame = 1): + """Returns the canonical name of the caller's Bazel repository. + + For example, this function returns '' (the empty string) when called from + the main repository and a string of the form 'rules_python~0.13.0` when + called from code in the repository corresponding to the rules_python Bazel + module. + + More information about the difference between canonical repository names and + the `@repo` part of labels is available at: + https://bazel.build/build/bzlmod#repository-names + + NOTE: This function inspects the callstack to determine where in the + runfiles the caller is located to determine which repository it came from. + This may fail or produce incorrect results depending on who the caller is, + for example if it is not represented by a Python source file. Use the + `frame` argument to control the stack lookup. + + Args: + frame: int; the stack frame to return the repository name for. Defaults to + 1, the caller of the CurrentRepository function. + Returns: + The canonical name of the Bazel repository containing the file containing + the frame-th caller of this function + Raises: + ValueError: if the caller cannot be determined or the caller's file path + is not contained in the Python runfiles tree + """ + try: + caller_path = inspect.getfile(sys._getframe(frame)) + except (TypeError, ValueError): + raise ValueError("failed to determine caller's file path") + caller_runfiles_path = os.path.relpath(caller_path, self._python_runfiles_root) + if caller_runfiles_path.startswith(".." + os.path.sep): + raise ValueError('{} does not lie under the runfiles root {}'.format(caller_path, self._python_runfiles_root)) + + caller_runfiles_directory = caller_runfiles_path[:caller_runfiles_path.find(os.path.sep)] + if caller_runfiles_directory == MAIN_REPOSITORY_RUNFILES_DIRECTORY: + # The canonical name of the main repository (also known as the workspace) + # is the empty string. + return '' + # For all other repositories, the name of the runfiles directory is the + # canonical name. + return caller_runfiles_directory + +def _FindPythonRunfilesRoot(): + root = __file__ + # Walk up our own runfiles path to the root of the runfiles tree from which + # the current file is being run. This path coincides with what the Bazel + # Python stub sets up as sys.path[0]. Since that entry can be changed at + # runtime, we rederive it here. + for _ in range("bazel_tools/tools/python/runfiles/runfiles.py".count("/") + 1): + root = os.path.dirname(root) + return root + class _ManifestBased(object): """`Runfiles` strategy that parses a runfiles-manifest to look up runfiles.""" diff --git a/tools/python/runfiles/runfiles_test.py b/tools/python/runfiles/runfiles_test.py index 70168cbb6107e5..b35da055a9549f 100644 --- a/tools/python/runfiles/runfiles_test.py +++ b/tools/python/runfiles/runfiles_test.py @@ -262,6 +262,9 @@ def testPathsFromEnvvars(self): self.assertEqual(mf, "") self.assertEqual(dr, "") + def testCurrentRepository(self): + self.assertEqual(runfiles.Create().CurrentRepository(), "") + @staticmethod def IsWindows(): return os.name == "nt"