Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test(bzlmod): add python.toolchain unit tests #2204

Merged
merged 9 commits into from
Sep 11, 2024
89 changes: 58 additions & 31 deletions python/private/python.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,21 @@ load(":util.bzl", "IS_BAZEL_6_4_OR_HIGHER")
_MAX_NUM_TOOLCHAINS = 9999
_TOOLCHAIN_INDEX_PAD_LENGTH = len(str(_MAX_NUM_TOOLCHAINS))

def _python_register_toolchains(name, toolchain_attr, module, ignore_root_user_error):
"""Calls python_register_toolchains and returns a struct used to collect the toolchains.
def parse_modules(module_ctx):
"""Parse the modules and return a struct for registrations.

Args:
module_ctx: {type}`module_ctx` module context.

Returns:
A struct with the following attributes:
* `toolchains`: The list of toolchains to register. The last
element is special and is treated as the default toolchain.
* `defaults`: The default `kwargs` passed to
{bzl:obj}`python_register_toolchains`.
* `debug_info`: {type}`None | dict` extra information to be passed
to the debug repo.
"""
python_register_toolchains(
name = name,
python_version = toolchain_attr.python_version,
register_coverage_tool = toolchain_attr.configure_coverage_tool,
ignore_root_user_error = ignore_root_user_error,
)
return struct(
python_version = toolchain_attr.python_version,
name = name,
module = struct(name = module.name, is_root = module.is_root),
)

def _python_impl(module_ctx):
if module_ctx.os.environ.get("RULES_PYTHON_BZLMOD_DEBUG", "0") == "1":
debug_info = {
"toolchains_registered": [],
Expand All @@ -61,7 +60,7 @@ def _python_impl(module_ctx):
# This is a toolchain_info struct.
default_toolchain = None

# Map of string Major.Minor to the toolchain_info struct
# Map of version string to the toolchain_info struct
global_toolchain_versions = {}

ignore_root_user_error = None
Expand Down Expand Up @@ -139,11 +138,11 @@ def _python_impl(module_ctx):
)
toolchain_info = None
else:
toolchain_info = _python_register_toolchains(
toolchain_name,
toolchain_attr,
module = mod,
ignore_root_user_error = ignore_root_user_error,
toolchain_info = struct(
python_version = toolchain_attr.python_version,
name = toolchain_name,
register_coverage_tool = toolchain_attr.configure_coverage_tool,
module = struct(name = mod.name, is_root = mod.is_root),
)
global_toolchain_versions[toolchain_version] = toolchain_info
if debug_info:
Expand Down Expand Up @@ -184,39 +183,67 @@ def _python_impl(module_ctx):
if len(toolchains) > _MAX_NUM_TOOLCHAINS:
fail("more than {} python versions are not supported".format(_MAX_NUM_TOOLCHAINS))

return struct(
toolchains = [
struct(
python_version = t.python_version,
name = t.name,
register_coverage_tool = t.register_coverage_tool,
)
for t in toolchains
],
debug_info = debug_info,
default_python_version = toolchains[-1].python_version,
defaults = {
"ignore_root_user_error": ignore_root_user_error,
},
)

def _python_impl(module_ctx):
py = parse_modules(module_ctx)

for toolchain_info in py.toolchains:
python_register_toolchains(
name = toolchain_info.name,
python_version = toolchain_info.python_version,
register_coverage_tool = toolchain_info.register_coverage_tool,
**py.defaults
)

# Create the pythons_hub repo for the interpreter meta data and the
# the various toolchains.
hub_repo(
name = "pythons_hub",
default_python_version = default_toolchain.python_version,
# Last toolchain is default
default_python_version = py.default_python_version,
toolchain_prefixes = [
render.toolchain_prefix(index, toolchain.name, _TOOLCHAIN_INDEX_PAD_LENGTH)
for index, toolchain in enumerate(toolchains)
for index, toolchain in enumerate(py.toolchains)
],
toolchain_python_versions = [t.python_version for t in toolchains],
toolchain_python_versions = [t.python_version for t in py.toolchains],
# The last toolchain is the default; it can't have version constraints
# Despite the implication of the arg name, the values are strs, not bools
toolchain_set_python_version_constraints = [
"True" if i != len(toolchains) - 1 else "False"
for i in range(len(toolchains))
"True" if i != len(py.toolchains) - 1 else "False"
for i in range(len(py.toolchains))
],
toolchain_user_repository_names = [t.name for t in toolchains],
toolchain_user_repository_names = [t.name for t in py.toolchains],
)

# This is require in order to support multiple version py_test
# and py_binary
multi_toolchain_aliases(
name = "python_versions",
python_versions = {
version: toolchain.name
for version, toolchain in global_toolchain_versions.items()
toolchain.python_version: toolchain.name
for toolchain in py.toolchains
},
)

if debug_info != None:
if py.debug_info != None:
_debug_repo(
name = "rules_python_bzlmod_debug",
debug_info = json.encode_indent(debug_info),
debug_info = json.encode_indent(py.debug_info),
)

if bazel_features.external_deps.extension_metadata_has_reproducible:
Expand Down
3 changes: 3 additions & 0 deletions tests/python/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
load(":python_tests.bzl", "python_test_suite")
aignas marked this conversation as resolved.
Show resolved Hide resolved

python_test_suite(name = "python_tests")
253 changes: 253 additions & 0 deletions tests/python/python_tests.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
# Copyright 2024 The Bazel Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""

load("@rules_testing//lib:test_suite.bzl", "test_suite")
load("//python/private:python.bzl", _parse_modules = "parse_modules") # buildifier: disable=bzl-visibility

_tests = []

def parse_modules(*, mctx, **kwargs):
return _parse_modules(module_ctx = mctx, **kwargs)

def _mock_mctx(*modules, environ = {}):
return struct(
os = struct(environ = environ),
modules = [
struct(
name = modules[0].name,
tags = modules[0].tags,
is_root = modules[0].is_root,
),
] + [
struct(
name = mod.name,
tags = mod.tags,
is_root = False,
)
for mod in modules[1:]
],
)

def _mod(*, name, toolchain = [], rules_python_private_testing = [], is_root = True):
return struct(
name = name,
tags = struct(
toolchain = toolchain,
rules_python_private_testing = rules_python_private_testing,
),
is_root = is_root,
)

def _toolchain(python_version, *, is_default = False, **kwargs):
return struct(
is_default = is_default,
python_version = python_version,
**kwargs
)

def _test_default(env):
py = parse_modules(
mctx = _mock_mctx(
_mod(name = "rules_python", toolchain = [_toolchain("3.11")]),
),
)

env.expect.that_collection(py.defaults.keys()).contains_exactly([
"ignore_root_user_error",
])
env.expect.that_bool(py.defaults["ignore_root_user_error"]).equals(False)
env.expect.that_str(py.default_python_version).equals("3.11")

want_toolchain = struct(
name = "python_3_11",
python_version = "3.11",
register_coverage_tool = False,
)
env.expect.that_collection(py.toolchains).contains_exactly([want_toolchain])

_tests.append(_test_default)

def _test_default_some_module(env):
py = parse_modules(
mctx = _mock_mctx(
_mod(name = "rules_python", toolchain = [_toolchain("3.11")], is_root = False),
),
)

env.expect.that_collection(py.defaults.keys()).contains_exactly([
"ignore_root_user_error",
])
env.expect.that_str(py.default_python_version).equals("3.11")

want_toolchain = struct(
name = "python_3_11",
python_version = "3.11",
register_coverage_tool = False,
)
env.expect.that_collection(py.toolchains).contains_exactly([want_toolchain])

_tests.append(_test_default_some_module)

def _test_default_with_patch_version(env):
py = parse_modules(
mctx = _mock_mctx(
_mod(name = "rules_python", toolchain = [_toolchain("3.11.2")]),
),
)

env.expect.that_str(py.default_python_version).equals("3.11.2")

want_toolchain = struct(
name = "python_3_11_2",
python_version = "3.11.2",
register_coverage_tool = False,
)
env.expect.that_collection(py.toolchains).contains_exactly([want_toolchain])

_tests.append(_test_default_with_patch_version)

def _test_default_non_rules_python(env):
py = parse_modules(
mctx = _mock_mctx(
# NOTE @aignas 2024-09-06: the first item in the module_ctx.modules
# could be a non-root module, which is the case if the root module
# does not make any calls to the extension.
_mod(name = "rules_python", toolchain = [_toolchain("3.11")], is_root = False),
),
)

env.expect.that_str(py.default_python_version).equals("3.11")
rules_python_toolchain = struct(
name = "python_3_11",
python_version = "3.11",
register_coverage_tool = False,
)
env.expect.that_collection(py.toolchains).contains_exactly([rules_python_toolchain])

_tests.append(_test_default_non_rules_python)

def _test_default_non_rules_python_ignore_root_user_error(env):
py = parse_modules(
mctx = _mock_mctx(
_mod(
name = "my_module",
toolchain = [_toolchain("3.12", ignore_root_user_error = True)],
),
_mod(name = "rules_python", toolchain = [_toolchain("3.11")]),
),
)

env.expect.that_bool(py.defaults["ignore_root_user_error"]).equals(True)
env.expect.that_str(py.default_python_version).equals("3.12")

my_module_toolchain = struct(
name = "python_3_12",
python_version = "3.12",
register_coverage_tool = False,
)
rules_python_toolchain = struct(
name = "python_3_11",
python_version = "3.11",
register_coverage_tool = False,
)
env.expect.that_collection(py.toolchains).contains_exactly([
rules_python_toolchain,
my_module_toolchain,
]).in_order()

_tests.append(_test_default_non_rules_python_ignore_root_user_error)

def _test_default_non_rules_python_ignore_root_user_error_non_root_module(env):
py = parse_modules(
mctx = _mock_mctx(
_mod(name = "my_module", toolchain = [_toolchain("3.13")]),
_mod(name = "some_module", toolchain = [_toolchain("3.12", ignore_root_user_error = True)]),
_mod(name = "rules_python", toolchain = [_toolchain("3.11")]),
),
)

env.expect.that_str(py.default_python_version).equals("3.13")
env.expect.that_bool(py.defaults["ignore_root_user_error"]).equals(False)

my_module_toolchain = struct(
name = "python_3_13",
python_version = "3.13",
register_coverage_tool = False,
)
some_module_toolchain = struct(
name = "python_3_12",
python_version = "3.12",
register_coverage_tool = False,
)
rules_python_toolchain = struct(
name = "python_3_11",
python_version = "3.11",
register_coverage_tool = False,
)
env.expect.that_collection(py.toolchains).contains_exactly([
some_module_toolchain,
rules_python_toolchain,
my_module_toolchain,
aignas marked this conversation as resolved.
Show resolved Hide resolved
]).in_order()

_tests.append(_test_default_non_rules_python_ignore_root_user_error_non_root_module)

def _test_first_occurance_of_the_toolchain_wins(env):
py = parse_modules(
mctx = _mock_mctx(
_mod(name = "my_module", toolchain = [_toolchain("3.12")]),
_mod(name = "some_module", toolchain = [_toolchain("3.12", configure_coverage_tool = True)]),
_mod(name = "rules_python", toolchain = [_toolchain("3.11")]),
environ = {
"RULES_PYTHON_BZLMOD_DEBUG": "1",
},
),
)

env.expect.that_str(py.default_python_version).equals("3.12")

my_module_toolchain = struct(
name = "python_3_12",
python_version = "3.12",
# NOTE: coverage stays disabled even though `some_module` was
# configuring something else.
register_coverage_tool = False,
)
rules_python_toolchain = struct(
name = "python_3_11",
python_version = "3.11",
register_coverage_tool = False,
)
env.expect.that_collection(py.toolchains).contains_exactly([
rules_python_toolchain,
my_module_toolchain, # default toolchain is last
]).in_order()
env.expect.that_dict(py.debug_info).contains_exactly({
"toolchains_registered": [
{"ignore_root_user_error": False, "name": "python_3_12"},
{"ignore_root_user_error": False, "name": "python_3_11"},
],
})

_tests.append(_test_first_occurance_of_the_toolchain_wins)

def python_test_suite(name):
"""Create the test suite.

Args:
name: the name of the test suite
"""
test_suite(name = name, basic_tests = _tests)