diff --git a/python/private/python.bzl b/python/private/python.bzl index 6a265d1395..e1d13b9f1c 100644 --- a/python/private/python.bzl +++ b/python/private/python.bzl @@ -28,22 +28,21 @@ load(":util.bzl", "IS_BAZEL_6_4_OR_HIGHER") _MAX_NUM_TOOLCHAINS = 9999 _TOOLCHAIN_INDEX_PAD_LENGTH = len(str(_MAX_NUM_TOOLCHAINS)) -def _python_register_toolchains(name, toolchain_attr, module, ignore_root_user_error): - """Calls python_register_toolchains and returns a struct used to collect the toolchains. +def parse_modules(module_ctx): + """Parse the modules and return a struct for registrations. + + Args: + module_ctx: {type}`module_ctx` module context. + + Returns: + A struct with the following attributes: + * `toolchains`: The list of toolchains to register. The last + element is special and is treated as the default toolchain. + * `defaults`: The default `kwargs` passed to + {bzl:obj}`python_register_toolchains`. + * `debug_info`: {type}`None | dict` extra information to be passed + to the debug repo. """ - python_register_toolchains( - name = name, - python_version = toolchain_attr.python_version, - register_coverage_tool = toolchain_attr.configure_coverage_tool, - ignore_root_user_error = ignore_root_user_error, - ) - return struct( - python_version = toolchain_attr.python_version, - name = name, - module = struct(name = module.name, is_root = module.is_root), - ) - -def _python_impl(module_ctx): if module_ctx.os.environ.get("RULES_PYTHON_BZLMOD_DEBUG", "0") == "1": debug_info = { "toolchains_registered": [], @@ -61,7 +60,7 @@ def _python_impl(module_ctx): # This is a toolchain_info struct. default_toolchain = None - # Map of string Major.Minor to the toolchain_info struct + # Map of version string to the toolchain_info struct global_toolchain_versions = {} ignore_root_user_error = None @@ -139,11 +138,11 @@ def _python_impl(module_ctx): ) toolchain_info = None else: - toolchain_info = _python_register_toolchains( - toolchain_name, - toolchain_attr, - module = mod, - ignore_root_user_error = ignore_root_user_error, + toolchain_info = struct( + python_version = toolchain_attr.python_version, + name = toolchain_name, + register_coverage_tool = toolchain_attr.configure_coverage_tool, + module = struct(name = mod.name, is_root = mod.is_root), ) global_toolchain_versions[toolchain_version] = toolchain_info if debug_info: @@ -184,23 +183,51 @@ def _python_impl(module_ctx): if len(toolchains) > _MAX_NUM_TOOLCHAINS: fail("more than {} python versions are not supported".format(_MAX_NUM_TOOLCHAINS)) + return struct( + toolchains = [ + struct( + python_version = t.python_version, + name = t.name, + register_coverage_tool = t.register_coverage_tool, + ) + for t in toolchains + ], + debug_info = debug_info, + default_python_version = toolchains[-1].python_version, + defaults = { + "ignore_root_user_error": ignore_root_user_error, + }, + ) + +def _python_impl(module_ctx): + py = parse_modules(module_ctx) + + for toolchain_info in py.toolchains: + python_register_toolchains( + name = toolchain_info.name, + python_version = toolchain_info.python_version, + register_coverage_tool = toolchain_info.register_coverage_tool, + **py.defaults + ) + # Create the pythons_hub repo for the interpreter meta data and the # the various toolchains. hub_repo( name = "pythons_hub", - default_python_version = default_toolchain.python_version, + # Last toolchain is default + default_python_version = py.default_python_version, toolchain_prefixes = [ render.toolchain_prefix(index, toolchain.name, _TOOLCHAIN_INDEX_PAD_LENGTH) - for index, toolchain in enumerate(toolchains) + for index, toolchain in enumerate(py.toolchains) ], - toolchain_python_versions = [t.python_version for t in toolchains], + toolchain_python_versions = [t.python_version for t in py.toolchains], # The last toolchain is the default; it can't have version constraints # Despite the implication of the arg name, the values are strs, not bools toolchain_set_python_version_constraints = [ - "True" if i != len(toolchains) - 1 else "False" - for i in range(len(toolchains)) + "True" if i != len(py.toolchains) - 1 else "False" + for i in range(len(py.toolchains)) ], - toolchain_user_repository_names = [t.name for t in toolchains], + toolchain_user_repository_names = [t.name for t in py.toolchains], ) # This is require in order to support multiple version py_test @@ -208,15 +235,15 @@ def _python_impl(module_ctx): multi_toolchain_aliases( name = "python_versions", python_versions = { - version: toolchain.name - for version, toolchain in global_toolchain_versions.items() + toolchain.python_version: toolchain.name + for toolchain in py.toolchains }, ) - if debug_info != None: + if py.debug_info != None: _debug_repo( name = "rules_python_bzlmod_debug", - debug_info = json.encode_indent(debug_info), + debug_info = json.encode_indent(py.debug_info), ) if bazel_features.external_deps.extension_metadata_has_reproducible: diff --git a/tests/python/BUILD.bazel b/tests/python/BUILD.bazel new file mode 100644 index 0000000000..2553536b63 --- /dev/null +++ b/tests/python/BUILD.bazel @@ -0,0 +1,17 @@ +# Copyright 2024 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load(":python_tests.bzl", "python_test_suite") + +python_test_suite(name = "python_tests") diff --git a/tests/python/python_tests.bzl b/tests/python/python_tests.bzl new file mode 100644 index 0000000000..acbd6676dc --- /dev/null +++ b/tests/python/python_tests.bzl @@ -0,0 +1,253 @@ +# Copyright 2024 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"" + +load("@rules_testing//lib:test_suite.bzl", "test_suite") +load("//python/private:python.bzl", _parse_modules = "parse_modules") # buildifier: disable=bzl-visibility + +_tests = [] + +def parse_modules(*, mctx, **kwargs): + return _parse_modules(module_ctx = mctx, **kwargs) + +def _mock_mctx(*modules, environ = {}): + return struct( + os = struct(environ = environ), + modules = [ + struct( + name = modules[0].name, + tags = modules[0].tags, + is_root = modules[0].is_root, + ), + ] + [ + struct( + name = mod.name, + tags = mod.tags, + is_root = False, + ) + for mod in modules[1:] + ], + ) + +def _mod(*, name, toolchain = [], rules_python_private_testing = [], is_root = True): + return struct( + name = name, + tags = struct( + toolchain = toolchain, + rules_python_private_testing = rules_python_private_testing, + ), + is_root = is_root, + ) + +def _toolchain(python_version, *, is_default = False, **kwargs): + return struct( + is_default = is_default, + python_version = python_version, + **kwargs + ) + +def _test_default(env): + py = parse_modules( + mctx = _mock_mctx( + _mod(name = "rules_python", toolchain = [_toolchain("3.11")]), + ), + ) + + env.expect.that_collection(py.defaults.keys()).contains_exactly([ + "ignore_root_user_error", + ]) + env.expect.that_bool(py.defaults["ignore_root_user_error"]).equals(False) + env.expect.that_str(py.default_python_version).equals("3.11") + + want_toolchain = struct( + name = "python_3_11", + python_version = "3.11", + register_coverage_tool = False, + ) + env.expect.that_collection(py.toolchains).contains_exactly([want_toolchain]) + +_tests.append(_test_default) + +def _test_default_some_module(env): + py = parse_modules( + mctx = _mock_mctx( + _mod(name = "rules_python", toolchain = [_toolchain("3.11")], is_root = False), + ), + ) + + env.expect.that_collection(py.defaults.keys()).contains_exactly([ + "ignore_root_user_error", + ]) + env.expect.that_str(py.default_python_version).equals("3.11") + + want_toolchain = struct( + name = "python_3_11", + python_version = "3.11", + register_coverage_tool = False, + ) + env.expect.that_collection(py.toolchains).contains_exactly([want_toolchain]) + +_tests.append(_test_default_some_module) + +def _test_default_with_patch_version(env): + py = parse_modules( + mctx = _mock_mctx( + _mod(name = "rules_python", toolchain = [_toolchain("3.11.2")]), + ), + ) + + env.expect.that_str(py.default_python_version).equals("3.11.2") + + want_toolchain = struct( + name = "python_3_11_2", + python_version = "3.11.2", + register_coverage_tool = False, + ) + env.expect.that_collection(py.toolchains).contains_exactly([want_toolchain]) + +_tests.append(_test_default_with_patch_version) + +def _test_default_non_rules_python(env): + py = parse_modules( + mctx = _mock_mctx( + # NOTE @aignas 2024-09-06: the first item in the module_ctx.modules + # could be a non-root module, which is the case if the root module + # does not make any calls to the extension. + _mod(name = "rules_python", toolchain = [_toolchain("3.11")], is_root = False), + ), + ) + + env.expect.that_str(py.default_python_version).equals("3.11") + rules_python_toolchain = struct( + name = "python_3_11", + python_version = "3.11", + register_coverage_tool = False, + ) + env.expect.that_collection(py.toolchains).contains_exactly([rules_python_toolchain]) + +_tests.append(_test_default_non_rules_python) + +def _test_default_non_rules_python_ignore_root_user_error(env): + py = parse_modules( + mctx = _mock_mctx( + _mod( + name = "my_module", + toolchain = [_toolchain("3.12", ignore_root_user_error = True)], + ), + _mod(name = "rules_python", toolchain = [_toolchain("3.11")]), + ), + ) + + env.expect.that_bool(py.defaults["ignore_root_user_error"]).equals(True) + env.expect.that_str(py.default_python_version).equals("3.12") + + my_module_toolchain = struct( + name = "python_3_12", + python_version = "3.12", + register_coverage_tool = False, + ) + rules_python_toolchain = struct( + name = "python_3_11", + python_version = "3.11", + register_coverage_tool = False, + ) + env.expect.that_collection(py.toolchains).contains_exactly([ + rules_python_toolchain, + my_module_toolchain, + ]).in_order() + +_tests.append(_test_default_non_rules_python_ignore_root_user_error) + +def _test_default_non_rules_python_ignore_root_user_error_non_root_module(env): + py = parse_modules( + mctx = _mock_mctx( + _mod(name = "my_module", toolchain = [_toolchain("3.13")]), + _mod(name = "some_module", toolchain = [_toolchain("3.12", ignore_root_user_error = True)]), + _mod(name = "rules_python", toolchain = [_toolchain("3.11")]), + ), + ) + + env.expect.that_str(py.default_python_version).equals("3.13") + env.expect.that_bool(py.defaults["ignore_root_user_error"]).equals(False) + + my_module_toolchain = struct( + name = "python_3_13", + python_version = "3.13", + register_coverage_tool = False, + ) + some_module_toolchain = struct( + name = "python_3_12", + python_version = "3.12", + register_coverage_tool = False, + ) + rules_python_toolchain = struct( + name = "python_3_11", + python_version = "3.11", + register_coverage_tool = False, + ) + env.expect.that_collection(py.toolchains).contains_exactly([ + some_module_toolchain, + rules_python_toolchain, + my_module_toolchain, # this was the only toolchain, default to that + ]).in_order() + +_tests.append(_test_default_non_rules_python_ignore_root_user_error_non_root_module) + +def _test_first_occurance_of_the_toolchain_wins(env): + py = parse_modules( + mctx = _mock_mctx( + _mod(name = "my_module", toolchain = [_toolchain("3.12")]), + _mod(name = "some_module", toolchain = [_toolchain("3.12", configure_coverage_tool = True)]), + _mod(name = "rules_python", toolchain = [_toolchain("3.11")]), + environ = { + "RULES_PYTHON_BZLMOD_DEBUG": "1", + }, + ), + ) + + env.expect.that_str(py.default_python_version).equals("3.12") + + my_module_toolchain = struct( + name = "python_3_12", + python_version = "3.12", + # NOTE: coverage stays disabled even though `some_module` was + # configuring something else. + register_coverage_tool = False, + ) + rules_python_toolchain = struct( + name = "python_3_11", + python_version = "3.11", + register_coverage_tool = False, + ) + env.expect.that_collection(py.toolchains).contains_exactly([ + rules_python_toolchain, + my_module_toolchain, # default toolchain is last + ]).in_order() + env.expect.that_dict(py.debug_info).contains_exactly({ + "toolchains_registered": [ + {"ignore_root_user_error": False, "name": "python_3_12"}, + {"ignore_root_user_error": False, "name": "python_3_11"}, + ], + }) + +_tests.append(_test_first_occurance_of_the_toolchain_wins) + +def python_test_suite(name): + """Create the test suite. + + Args: + name: the name of the test suite + """ + test_suite(name = name, basic_tests = _tests)