From f2601eb551b8f1011e91c02cc1797dee873241bf Mon Sep 17 00:00:00 2001 From: aignas <240938+aignas@users.noreply.github.com> Date: Tue, 16 Jul 2024 16:30:49 +0900 Subject: [PATCH] refactor(pypi): split out more utils and start passing 'abi_os_arch' around This is extra preparation needed for #2059. Summary: - Create `pypi_repo_utils` for more ergonomic handling of Python in repo context. - Split the resolution of requirements files to platforms into a separate function to make the testing easier. This also allows more validation that was realized that there is a need for in the WIP feature PR. - Make the code more robust about the assumption of the target platform label. Work towards #260, #1105, #1868. --- python/private/pypi/BUILD.bazel | 22 +- python/private/pypi/extension.bzl | 19 +- python/private/pypi/parse_requirements.bzl | 178 ++---------- python/private/pypi/pip_repository.bzl | 16 +- python/private/pypi/pypi_repo_utils.bzl | 94 ++++++ python/private/pypi/render_pkg_aliases.bzl | 26 +- .../pypi/requirements_files_by_platform.bzl | 258 +++++++++++++++++ python/private/pypi/whl_library.bzl | 80 +----- .../parse_requirements_tests.bzl | 271 ++---------------- .../render_pkg_aliases_test.bzl | 10 +- .../BUILD.bazel | 3 + .../requirements_files_by_platform_tests.bzl | 205 +++++++++++++ 12 files changed, 674 insertions(+), 508 deletions(-) create mode 100644 python/private/pypi/pypi_repo_utils.bzl create mode 100644 python/private/pypi/requirements_files_by_platform.bzl create mode 100644 tests/pypi/requirements_files_by_platform/BUILD.bazel create mode 100644 tests/pypi/requirements_files_by_platform/requirements_files_by_platform_tests.bzl diff --git a/python/private/pypi/BUILD.bazel b/python/private/pypi/BUILD.bazel index 08fb7259ec..00602b298c 100644 --- a/python/private/pypi/BUILD.bazel +++ b/python/private/pypi/BUILD.bazel @@ -161,6 +161,8 @@ bzl_library( deps = [ ":index_sources_bzl", ":parse_requirements_txt_bzl", + ":pypi_repo_utils_bzl", + ":requirements_files_by_platform_bzl", ":whl_target_platforms_bzl", "//python/private:normalize_name_bzl", ], @@ -227,6 +229,15 @@ bzl_library( srcs = ["pip_repository_attrs.bzl"], ) +bzl_library( + name = "pypi_repo_utils_bzl", + srcs = ["pypi_repo_utils.bzl"], + deps = [ + "//python:versions_bzl", + "//python/private:toolchains_repo_bzl", + ], +) + bzl_library( name = "render_pkg_aliases_bzl", srcs = ["render_pkg_aliases.bzl"], @@ -240,6 +251,14 @@ bzl_library( ], ) +bzl_library( + name = "requirements_files_by_platform_bzl", + srcs = ["requirements_files_by_platform.bzl"], + deps = [ + ":whl_target_platforms_bzl", + ], +) + bzl_library( name = "simpleapi_download_bzl", srcs = ["simpleapi_download.bzl"], @@ -270,13 +289,12 @@ bzl_library( ":generate_whl_library_build_bazel_bzl", ":parse_whl_name_bzl", ":patch_whl_bzl", + ":pypi_repo_utils_bzl", ":whl_target_platforms_bzl", "//python:repositories_bzl", - "//python:versions_bzl", "//python/private:auth_bzl", "//python/private:envsubst_bzl", "//python/private:repo_utils_bzl", - "//python/private:toolchains_repo_bzl", ], ) diff --git a/python/private/pypi/extension.bzl b/python/private/pypi/extension.bzl index 6aafc71831..d837d8d50a 100644 --- a/python/private/pypi/extension.bzl +++ b/python/private/pypi/extension.bzl @@ -26,6 +26,7 @@ load(":parse_requirements.bzl", "host_platform", "parse_requirements", "select_r load(":parse_whl_name.bzl", "parse_whl_name") load(":pip_repository_attrs.bzl", "ATTRS") load(":render_pkg_aliases.bzl", "whl_alias") +load(":requirements_files_by_platform.bzl", "requirements_files_by_platform") load(":simpleapi_download.bzl", "simpleapi_download") load(":whl_library.bzl", "whl_library") load(":whl_repo_name.bzl", "whl_repo_name") @@ -183,12 +184,16 @@ def _create_whl_repos(module_ctx, pip_attr, whl_map, whl_overrides, group_map, s requirements_by_platform = parse_requirements( module_ctx, - requirements_by_platform = pip_attr.requirements_by_platform, - requirements_linux = pip_attr.requirements_linux, - requirements_lock = pip_attr.requirements_lock, - requirements_osx = pip_attr.requirements_darwin, - requirements_windows = pip_attr.requirements_windows, - extra_pip_args = pip_attr.extra_pip_args, + requirements_by_platform = requirements_files_by_platform( + requirements_by_platform = pip_attr.requirements_by_platform, + requirements_linux = pip_attr.requirements_linux, + requirements_lock = pip_attr.requirements_lock, + requirements_osx = pip_attr.requirements_darwin, + requirements_windows = pip_attr.requirements_windows, + extra_pip_args = pip_attr.extra_pip_args, + python_version = major_minor, + logger = logger, + ), get_index_urls = get_index_urls, python_version = major_minor, logger = logger, @@ -298,7 +303,7 @@ def _create_whl_repos(module_ctx, pip_attr, whl_map, whl_overrides, group_map, s requirement = select_requirement( requirements, - platform = repository_platform, + platform = None if pip_attr.download_only else repository_platform, ) if not requirement: # Sometimes the package is not present for host platform if there diff --git a/python/private/pypi/parse_requirements.bzl b/python/private/pypi/parse_requirements.bzl index d52180c009..5258153a84 100644 --- a/python/private/pypi/parse_requirements.bzl +++ b/python/private/pypi/parse_requirements.bzl @@ -29,7 +29,7 @@ behavior. load("//python/private:normalize_name.bzl", "normalize_name") load(":index_sources.bzl", "index_sources") load(":parse_requirements_txt.bzl", "parse_requirements_txt") -load(":whl_target_platforms.bzl", "select_whls", "whl_target_platforms") +load(":whl_target_platforms.bzl", "select_whls") # This includes the vendored _translate_cpu and _translate_os from # @platforms//host:extension.bzl at version 0.0.9 so that we don't @@ -80,72 +80,10 @@ DEFAULT_PLATFORMS = [ "windows_x86_64", ] -def _default_platforms(*, filter): - if not filter: - fail("Must specific a filter string, got: {}".format(filter)) - - if filter.startswith("cp3"): - # TODO @aignas 2024-05-23: properly handle python versions in the filter. - # For now we are just dropping it to ensure that we don't fail. - _, _, filter = filter.partition("_") - - sanitized = filter.replace("*", "").replace("_", "") - if sanitized and not sanitized.isalnum(): - fail("The platform filter can only contain '*', '_' and alphanumerics") - - if "*" in filter: - prefix = filter.rstrip("*") - if "*" in prefix: - fail("The filter can only contain '*' at the end of it") - - if not prefix: - return DEFAULT_PLATFORMS - - return [p for p in DEFAULT_PLATFORMS if p.startswith(prefix)] - else: - return [p for p in DEFAULT_PLATFORMS if filter in p] - -def _platforms_from_args(extra_pip_args): - platform_values = [] - - for arg in extra_pip_args: - if platform_values and platform_values[-1] == "": - platform_values[-1] = arg - continue - - if arg == "--platform": - platform_values.append("") - continue - - if not arg.startswith("--platform"): - continue - - _, _, plat = arg.partition("=") - if not plat: - _, _, plat = arg.partition(" ") - if plat: - platform_values.append(plat) - else: - platform_values.append("") - - if not platform_values: - return [] - - platforms = { - p.target_platform: None - for arg in platform_values - for p in whl_target_platforms(arg) - } - return list(platforms.keys()) - def parse_requirements( ctx, *, requirements_by_platform = {}, - requirements_osx = None, - requirements_linux = None, - requirements_lock = None, - requirements_windows = None, extra_pip_args = [], get_index_urls = None, python_version = None, @@ -158,10 +96,6 @@ def parse_requirements( requirements_by_platform (label_keyed_string_dict): a way to have different package versions (or different packages) for different os, arch combinations. - requirements_osx (label): The requirements file for the osx OS. - requirements_linux (label): The requirements file for the linux OS. - requirements_lock (label): The requirements file for all OSes, or used as a fallback. - requirements_windows (label): The requirements file for windows OS. extra_pip_args (string list): Extra pip arguments to perform extra validations and to be joined with args fined in files. get_index_urls: Callable[[ctx, list[str]], dict], a callable to get all @@ -186,91 +120,11 @@ def parse_requirements( The second element is extra_pip_args should be passed to `whl_library`. """ - if not ( - requirements_lock or - requirements_linux or - requirements_osx or - requirements_windows or - requirements_by_platform - ): - fail_fn( - "A 'requirements_lock' attribute must be specified, a platform-specific lockfiles " + - "via 'requirements_by_platform' or an os-specific lockfiles must be specified " + - "via 'requirements_*' attributes", - ) - return None - - platforms = _platforms_from_args(extra_pip_args) - - if platforms: - lock_files = [ - f - for f in [ - requirements_lock, - requirements_linux, - requirements_osx, - requirements_windows, - ] + list(requirements_by_platform.keys()) - if f - ] - - if len(lock_files) > 1: - # If the --platform argument is used, check that we are using - # a single `requirements_lock` file instead of the OS specific ones as that is - # the only correct way to use the API. - fail_fn("only a single 'requirements_lock' file can be used when using '--platform' pip argument, consider specifying it via 'requirements_lock' attribute") - return None - - files_by_platform = [ - (lock_files[0], platforms), - ] - else: - files_by_platform = { - file: [ - platform - for filter_or_platform in specifier.split(",") - for platform in (_default_platforms(filter = filter_or_platform) if filter_or_platform.endswith("*") else [filter_or_platform]) - ] - for file, specifier in requirements_by_platform.items() - }.items() - - for f in [ - # If the users need a greater span of the platforms, they should consider - # using the 'requirements_by_platform' attribute. - (requirements_linux, _default_platforms(filter = "linux_*")), - (requirements_osx, _default_platforms(filter = "osx_*")), - (requirements_windows, _default_platforms(filter = "windows_*")), - (requirements_lock, None), - ]: - if f[0]: - files_by_platform.append(f) - - configured_platforms = {} - options = {} requirements = {} - for file, plats in files_by_platform: - if plats: - for p in plats: - if p in configured_platforms: - fail_fn( - "Expected the platform '{}' to be map only to a single requirements file, but got multiple: '{}', '{}'".format( - p, - configured_platforms[p], - file, - ), - ) - return None - configured_platforms[p] = file - else: - plats = [ - p - for p in DEFAULT_PLATFORMS - if p not in configured_platforms - ] - for p in plats: - configured_platforms[p] = file - + for file, plats in requirements_by_platform.items(): + if logger: + logger.debug(lambda: "Using {} for {}".format(file, plats)) contents = ctx.read(file) # Parse the requirements file directly in starlark to get the information @@ -303,9 +157,9 @@ def parse_requirements( tokenized_options.append(p) pip_args = tokenized_options + extra_pip_args - for p in plats: - requirements[p] = requirements_dict - options[p] = pip_args + for plat in plats: + requirements[plat] = requirements_dict + options[plat] = pip_args requirements_by_platform = {} for target_platform, reqs_ in requirements.items(): @@ -325,7 +179,6 @@ def parse_requirements( requirement_line = requirement_line, target_platforms = [], extra_pip_args = extra_pip_args, - download = len(platforms) > 0, ), ) for_req.target_platforms.append(target_platform) @@ -353,12 +206,12 @@ def parse_requirements( for p in r.target_platforms: requirement_target_platforms[p] = None - is_exposed = len(requirement_target_platforms) == len(configured_platforms) + is_exposed = len(requirement_target_platforms) == len(requirements) if not is_exposed and logger: - logger.debug(lambda: "Package {} will not be exposed because it is only present on a subset of platforms: {} out of {}".format( + logger.debug(lambda: "Package '{}' will not be exposed because it is only present on a subset of platforms: {} out of {}".format( whl_name, sorted(requirement_target_platforms), - sorted(configured_platforms), + sorted(requirements), )) for r in sorted(reqs.values(), key = lambda r: r.requirement_line): @@ -376,13 +229,15 @@ def parse_requirements( requirement_line = r.requirement_line, target_platforms = sorted(r.target_platforms), extra_pip_args = r.extra_pip_args, - download = r.download, whls = whls, sdist = sdist, is_exposed = is_exposed, ), ) + if logger: + logger.debug(lambda: "Will configure whl repos: {}".format(ret.keys())) + return ret def select_requirement(requirements, *, platform): @@ -391,8 +246,9 @@ def select_requirement(requirements, *, platform): Args: requirements (list[struct]): The list of requirements as returned by the `parse_requirements` function above. - platform (str): The host platform. Usually an output of the - `host_platform` function. + platform (str or None): The host platform. Usually an output of the + `host_platform` function. If None, then this function will return + the first requirement it finds. Returns: None if not found or a struct returned as one of the values in the @@ -402,7 +258,7 @@ def select_requirement(requirements, *, platform): maybe_requirement = [ req for req in requirements - if platform in req.target_platforms or req.download + if not platform or [p for p in req.target_platforms if p.endswith(platform)] ] if not maybe_requirement: # Sometimes the package is not present for host platform if there diff --git a/python/private/pypi/pip_repository.bzl b/python/private/pypi/pip_repository.bzl index a22f4d9d2c..42622c3c73 100644 --- a/python/private/pypi/pip_repository.bzl +++ b/python/private/pypi/pip_repository.bzl @@ -21,6 +21,7 @@ load("//python/private:text_util.bzl", "render") load(":parse_requirements.bzl", "host_platform", "parse_requirements", "select_requirement") load(":pip_repository_attrs.bzl", "ATTRS") load(":render_pkg_aliases.bzl", "render_pkg_aliases", "whl_alias") +load(":requirements_files_by_platform.bzl", "requirements_files_by_platform") def _get_python_interpreter_attr(rctx): """A helper function for getting the `python_interpreter` attribute or it's default @@ -71,11 +72,14 @@ exports_files(["requirements.bzl"]) def _pip_repository_impl(rctx): requirements_by_platform = parse_requirements( rctx, - requirements_by_platform = rctx.attr.requirements_by_platform, - requirements_linux = rctx.attr.requirements_linux, - requirements_lock = rctx.attr.requirements_lock, - requirements_osx = rctx.attr.requirements_darwin, - requirements_windows = rctx.attr.requirements_windows, + requirements_by_platform = requirements_files_by_platform( + requirements_by_platform = rctx.attr.requirements_by_platform, + requirements_linux = rctx.attr.requirements_linux, + requirements_lock = rctx.attr.requirements_lock, + requirements_osx = rctx.attr.requirements_darwin, + requirements_windows = rctx.attr.requirements_windows, + extra_pip_args = rctx.attr.extra_pip_args, + ), extra_pip_args = rctx.attr.extra_pip_args, ) selected_requirements = {} @@ -84,7 +88,7 @@ def _pip_repository_impl(rctx): for name, requirements in requirements_by_platform.items(): r = select_requirement( requirements, - platform = repository_platform, + platform = None if rctx.attr.download_only else repository_platform, ) if not r: continue diff --git a/python/private/pypi/pypi_repo_utils.bzl b/python/private/pypi/pypi_repo_utils.bzl new file mode 100644 index 0000000000..6e5d93b160 --- /dev/null +++ b/python/private/pypi/pypi_repo_utils.bzl @@ -0,0 +1,94 @@ +# Copyright 2024 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"" + +load("//python:versions.bzl", "WINDOWS_NAME") +load("//python/private:toolchains_repo.bzl", "get_host_os_arch") + +def _get_python_interpreter_attr(ctx, *, python_interpreter = None): + """A helper function for getting the `python_interpreter` attribute or it's default + + Args: + ctx (repository_ctx): Handle to the rule repository context. + python_interpreter (str): The python interpreter override. + + Returns: + str: The attribute value or it's default + """ + if python_interpreter: + return python_interpreter + + if "win" in ctx.os.name: + return "python.exe" + else: + return "python3" + +def _resolve_python_interpreter(ctx, *, python_interpreter = None, python_interpreter_target = None): + """Helper function to find the python interpreter from the common attributes + + Args: + ctx: Handle to the rule repository context. + python_interpreter: The python interpreter to use. + python_interpreter_target: The python interpreter to use after downloading the label. + + Returns: + `path` object, for the resolved path to the Python interpreter. + """ + python_interpreter = _get_python_interpreter_attr(ctx, python_interpreter = python_interpreter) + + if python_interpreter_target != None: + python_interpreter = ctx.path(python_interpreter_target) + + (os, _) = get_host_os_arch(ctx) + + # On Windows, the symlink doesn't work because Windows attempts to find + # Python DLLs where the symlink is, not where the symlink points. + if os == WINDOWS_NAME: + python_interpreter = python_interpreter.realpath + elif "/" not in python_interpreter: + # It's a plain command, e.g. "python3", to look up in the environment. + found_python_interpreter = ctx.which(python_interpreter) + if not found_python_interpreter: + fail("python interpreter `{}` not found in PATH".format(python_interpreter)) + python_interpreter = found_python_interpreter + else: + python_interpreter = ctx.path(python_interpreter) + return python_interpreter + +def _construct_pypath(rctx, *, entries): + """Helper function to construct a PYTHONPATH. + + Contains entries for code in this repo as well as packages downloaded from //python/pip_install:repositories.bzl. + This allows us to run python code inside repository rule implementations. + + Args: + rctx: Handle to the repository_context. + entries: The list of entries to add to PYTHONPATH. + + Returns: String of the PYTHONPATH. + """ + + separator = ":" if not "windows" in rctx.os.name.lower() else ";" + pypath = separator.join([ + str(rctx.path(entry).dirname) + # Use a dict as a way to remove duplicates and then sort it. + for entry in sorted({x: None for x in entries}) + ]) + return pypath + +pypi_repo_utils = struct( + resolve_python_interpreter = _resolve_python_interpreter, + construct_pythonpath = _construct_pypath, +) diff --git a/python/private/pypi/render_pkg_aliases.bzl b/python/private/pypi/render_pkg_aliases.bzl index eb907fee0f..9e5158f8f0 100644 --- a/python/private/pypi/render_pkg_aliases.bzl +++ b/python/private/pypi/render_pkg_aliases.bzl @@ -265,6 +265,11 @@ def whl_alias(*, repo, version = None, config_setting = None, filename = None, t config_setting = config_setting or ("//_config:is_python_" + version) config_setting = str(config_setting) + if target_platforms: + for p in target_platforms: + if not p.startswith("cp"): + fail("target_platform should start with 'cp' denoting the python version, got: " + p) + return struct( repo = repo, version = version, @@ -448,7 +453,7 @@ def get_whl_flag_versions(aliases): parsed = parse_whl_name(a.filename) else: for plat in a.target_platforms or []: - target_platforms[plat] = None + target_platforms[_non_versioned_platform(plat)] = None continue for platform_tag in parsed.platform_tag.split("."): @@ -486,6 +491,19 @@ def get_whl_flag_versions(aliases): if v } +def _non_versioned_platform(p, *, strict = False): + """A small utility function that converts 'cp311_linux_x86_64' to 'linux_x86_64'. + + This is so that we can tighten the code structure later by using strict = True. + """ + has_abi = p.startswith("cp") + if has_abi: + return p.partition("_")[-1] + elif not strict: + return p + else: + fail("Expected to always have a platform in the form '{{abi}}_{{os}}_{{arch}}', got: {}".format(p)) + def get_filename_config_settings( *, filename, @@ -499,7 +517,7 @@ def get_filename_config_settings( Args: filename: the distribution filename (can be a whl or an sdist). - target_platforms: list[str], target platforms in "{os}_{cpu}" format. + target_platforms: list[str], target platforms in "{abi}_{os}_{cpu}" format. glibc_versions: list[tuple[int, int]], list of versions. muslc_versions: list[tuple[int, int]], list of versions. osx_versions: list[tuple[int, int]], list of versions. @@ -541,7 +559,7 @@ def get_filename_config_settings( if parsed.platform_tag == "any": prefixes = ["{}_{}_any".format(py, abi)] - suffixes = target_platforms + suffixes = [_non_versioned_platform(p) for p in target_platforms or []] else: prefixes = ["{}_{}".format(py, abi)] suffixes = _whl_config_setting_suffixes( @@ -553,7 +571,7 @@ def get_filename_config_settings( ) else: prefixes = ["sdist"] - suffixes = target_platforms + suffixes = [_non_versioned_platform(p) for p in target_platforms or []] if python_default and python_version: prefixes += ["cp{}_{}".format(python_version, p) for p in prefixes] diff --git a/python/private/pypi/requirements_files_by_platform.bzl b/python/private/pypi/requirements_files_by_platform.bzl new file mode 100644 index 0000000000..e3aafc083f --- /dev/null +++ b/python/private/pypi/requirements_files_by_platform.bzl @@ -0,0 +1,258 @@ +# Copyright 2024 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Get the requirement files by platform.""" + +load(":whl_target_platforms.bzl", "whl_target_platforms") + +# TODO @aignas 2024-05-13: consider using the same platform tags as are used in +# the //python:versions.bzl +DEFAULT_PLATFORMS = [ + "linux_aarch64", + "linux_arm", + "linux_ppc", + "linux_s390x", + "linux_x86_64", + "osx_aarch64", + "osx_x86_64", + "windows_x86_64", +] + +def _default_platforms(*, filter): + if not filter: + fail("Must specific a filter string, got: {}".format(filter)) + + if filter.startswith("cp3"): + # TODO @aignas 2024-05-23: properly handle python versions in the filter. + # For now we are just dropping it to ensure that we don't fail. + _, _, filter = filter.partition("_") + + sanitized = filter.replace("*", "").replace("_", "") + if sanitized and not sanitized.isalnum(): + fail("The platform filter can only contain '*', '_' and alphanumerics") + + if "*" in filter: + prefix = filter.rstrip("*") + if "*" in prefix: + fail("The filter can only contain '*' at the end of it") + + if not prefix: + return DEFAULT_PLATFORMS + + return [p for p in DEFAULT_PLATFORMS if p.startswith(prefix)] + else: + return [p for p in DEFAULT_PLATFORMS if filter in p] + +def _platforms_from_args(extra_pip_args): + platform_values = [] + + if not extra_pip_args: + return platform_values + + for arg in extra_pip_args: + if platform_values and platform_values[-1] == "": + platform_values[-1] = arg + continue + + if arg == "--platform": + platform_values.append("") + continue + + if not arg.startswith("--platform"): + continue + + _, _, plat = arg.partition("=") + if not plat: + _, _, plat = arg.partition(" ") + if plat: + platform_values.append(plat) + else: + platform_values.append("") + + if not platform_values: + return [] + + platforms = { + p.target_platform: None + for arg in platform_values + for p in whl_target_platforms(arg) + } + return list(platforms.keys()) + +def _platform(platform_string, python_version = None): + if not python_version or platform_string.startswith("cp3"): + return platform_string + + _, _, tail = python_version.partition(".") + minor, _, _ = tail.partition(".") + + return "cp3{}_{}".format(minor, platform_string) + +def requirements_files_by_platform( + *, + requirements_by_platform = {}, + requirements_osx = None, + requirements_linux = None, + requirements_lock = None, + requirements_windows = None, + extra_pip_args = None, + python_version = None, + logger = None, + fail_fn = fail): + """Resolve the requirement files by target platform. + + Args: + requirements_by_platform (label_keyed_string_dict): a way to have + different package versions (or different packages) for different + os, arch combinations. + requirements_osx (label): The requirements file for the osx OS. + requirements_linux (label): The requirements file for the linux OS. + requirements_lock (label): The requirements file for all OSes, or used as a fallback. + requirements_windows (label): The requirements file for windows OS. + extra_pip_args (string list): Extra pip arguments to perform extra validations and to + be joined with args fined in files. + python_version: str or None. This is needed when the get_index_urls is + specified. It should be of the form "3.x.x", + logger: repo_utils.logger or None, a simple struct to log diagnostic messages. + fail_fn (Callable[[str], None]): A failure function used in testing failure cases. + + Returns: + A dict with keys as the labels to the files and values as lists of + platforms that the files support. + """ + if not ( + requirements_lock or + requirements_linux or + requirements_osx or + requirements_windows or + requirements_by_platform + ): + fail_fn( + "A 'requirements_lock' attribute must be specified, a platform-specific lockfiles " + + "via 'requirements_by_platform' or an os-specific lockfiles must be specified " + + "via 'requirements_*' attributes", + ) + return None + + platforms = _platforms_from_args(extra_pip_args) + if logger: + logger.debug(lambda: "Platforms from pip args: {}".format(platforms)) + + if platforms: + lock_files = [ + f + for f in [ + requirements_lock, + requirements_linux, + requirements_osx, + requirements_windows, + ] + list(requirements_by_platform.keys()) + if f + ] + + if len(lock_files) > 1: + # If the --platform argument is used, check that we are using + # a single `requirements_lock` file instead of the OS specific ones as that is + # the only correct way to use the API. + fail_fn("only a single 'requirements_lock' file can be used when using '--platform' pip argument, consider specifying it via 'requirements_lock' attribute") + return None + + files_by_platform = [ + (lock_files[0], platforms), + ] + if logger: + logger.debug(lambda: "Files by platform with the platform set in the args: {}".format(files_by_platform)) + else: + files_by_platform = { + file: [ + platform + for filter_or_platform in specifier.split(",") + for platform in (_default_platforms(filter = filter_or_platform) if filter_or_platform.endswith("*") else [filter_or_platform]) + ] + for file, specifier in requirements_by_platform.items() + }.items() + + if logger: + logger.debug(lambda: "Files by platform with the platform set in the attrs: {}".format(files_by_platform)) + + for f in [ + # If the users need a greater span of the platforms, they should consider + # using the 'requirements_by_platform' attribute. + (requirements_linux, _default_platforms(filter = "linux_*")), + (requirements_osx, _default_platforms(filter = "osx_*")), + (requirements_windows, _default_platforms(filter = "windows_*")), + (requirements_lock, None), + ]: + if f[0]: + if logger: + logger.debug(lambda: "Adding an extra item to files_by_platform: {}".format(f)) + files_by_platform.append(f) + + configured_platforms = {} + requirements = {} + for file, plats in files_by_platform: + if plats: + plats = [_platform(p, python_version) for p in plats] + for p in plats: + if p in configured_platforms: + fail_fn( + "Expected the platform '{}' to be map only to a single requirements file, but got multiple: '{}', '{}'".format( + p, + configured_platforms[p], + file, + ), + ) + return None + + configured_platforms[p] = file + else: + default_platforms = [_platform(p, python_version) for p in DEFAULT_PLATFORMS] + plats = [ + p + for p in default_platforms + if p not in configured_platforms + ] + if logger: + logger.debug(lambda: "File {} will be used for the remaining platforms {} that are not in configured_platforms: {}".format( + file, + plats, + default_platforms, + )) + for p in plats: + configured_platforms[p] = file + + if logger: + logger.debug(lambda: "Configured platforms for file {} are {}".format(file, plats)) + + for p in plats: + if p in requirements: + # This should never happen because in the code above we should + # have unambiguous selection of the requirements files. + fail_fn("Attempting to override a requirements file '{}' with '{}' for platform '{}'".format( + requirements[p], + file, + p, + )) + return None + requirements[p] = file + + # Now return a dict that is similar to requirements_by_platform - where we + # have labels/files as keys in the dict to minimize the number of times we + # may parse the same file. + + ret = {} + for plat, file in requirements.items(): + ret.setdefault(file, []).append(plat) + + return ret diff --git a/python/private/pypi/whl_library.bzl b/python/private/pypi/whl_library.bzl index a3fa1d8e36..f453f92ccf 100644 --- a/python/private/pypi/whl_library.bzl +++ b/python/private/pypi/whl_library.bzl @@ -15,88 +15,21 @@ "" load("//python:repositories.bzl", "is_standalone_interpreter") -load("//python:versions.bzl", "WINDOWS_NAME") load("//python/private:auth.bzl", "AUTH_ATTRS", "get_auth") load("//python/private:envsubst.bzl", "envsubst") load("//python/private:repo_utils.bzl", "REPO_DEBUG_ENV_VAR", "repo_utils") -load("//python/private:toolchains_repo.bzl", "get_host_os_arch") load(":attrs.bzl", "ATTRS", "use_isolated") load(":deps.bzl", "all_repo_names") load(":generate_whl_library_build_bazel.bzl", "generate_whl_library_build_bazel") load(":parse_whl_name.bzl", "parse_whl_name") load(":patch_whl.bzl", "patch_whl") +load(":pypi_repo_utils.bzl", "pypi_repo_utils") load(":whl_target_platforms.bzl", "whl_target_platforms") _CPPFLAGS = "CPPFLAGS" _COMMAND_LINE_TOOLS_PATH_SLUG = "commandlinetools" _WHEEL_ENTRY_POINT_PREFIX = "rules_python_wheel_entry_point" -def _construct_pypath(rctx): - """Helper function to construct a PYTHONPATH. - - Contains entries for code in this repo as well as packages downloaded from //python/pip_install:repositories.bzl. - This allows us to run python code inside repository rule implementations. - - Args: - rctx: Handle to the repository_context. - - Returns: String of the PYTHONPATH. - """ - - separator = ":" if not "windows" in rctx.os.name.lower() else ";" - pypath = separator.join([ - str(rctx.path(entry).dirname) - for entry in rctx.attr._python_path_entries - ]) - return pypath - -def _get_python_interpreter_attr(rctx): - """A helper function for getting the `python_interpreter` attribute or it's default - - Args: - rctx (repository_ctx): Handle to the rule repository context. - - Returns: - str: The attribute value or it's default - """ - if rctx.attr.python_interpreter: - return rctx.attr.python_interpreter - - if "win" in rctx.os.name: - return "python.exe" - else: - return "python3" - -def _resolve_python_interpreter(rctx): - """Helper function to find the python interpreter from the common attributes - - Args: - rctx: Handle to the rule repository context. - - Returns: - `path` object, for the resolved path to the Python interpreter. - """ - python_interpreter = _get_python_interpreter_attr(rctx) - - if rctx.attr.python_interpreter_target != None: - python_interpreter = rctx.path(rctx.attr.python_interpreter_target) - - (os, _) = get_host_os_arch(rctx) - - # On Windows, the symlink doesn't work because Windows attempts to find - # Python DLLs where the symlink is, not where the symlink points. - if os == WINDOWS_NAME: - python_interpreter = python_interpreter.realpath - elif "/" not in python_interpreter: - # It's a plain command, e.g. "python3", to look up in the environment. - found_python_interpreter = rctx.which(python_interpreter) - if not found_python_interpreter: - fail("python interpreter `{}` not found in PATH".format(python_interpreter)) - python_interpreter = found_python_interpreter - else: - python_interpreter = rctx.path(python_interpreter) - return python_interpreter - def _get_xcode_location_cflags(rctx): """Query the xcode sdk location to update cflags @@ -230,14 +163,21 @@ def _create_repository_execution_environment(rctx, python_interpreter): cppflags.extend(_get_toolchain_unix_cflags(rctx, python_interpreter)) env = { - "PYTHONPATH": _construct_pypath(rctx), + "PYTHONPATH": pypi_repo_utils.construct_pythonpath( + rctx, + entries = rctx.attr._python_path_entries, + ), _CPPFLAGS: " ".join(cppflags), } return env def _whl_library_impl(rctx): - python_interpreter = _resolve_python_interpreter(rctx) + python_interpreter = pypi_repo_utils.resolve_python_interpreter( + rctx, + python_interpreter = rctx.attr.python_interpreter, + python_interpreter_target = rctx.attr.python_interpreter_target, + ) args = [ python_interpreter, "-m", diff --git a/tests/pypi/parse_requirements/parse_requirements_tests.bzl b/tests/pypi/parse_requirements/parse_requirements_tests.bzl index 5c33dd83b2..1a7143b747 100644 --- a/tests/pypi/parse_requirements/parse_requirements_tests.bzl +++ b/tests/pypi/parse_requirements/parse_requirements_tests.bzl @@ -52,33 +52,17 @@ bar==0.0.1 --hash=sha256:deadb00f _tests = [] -def _test_fail_no_requirements(env): - errors = [] - parse_requirements( - ctx = _mock_ctx(), - fail_fn = errors.append, - ) - env.expect.that_str(errors[0]).equals("""\ -A 'requirements_lock' attribute must be specified, a platform-specific lockfiles via 'requirements_by_platform' or an os-specific lockfiles must be specified via 'requirements_*' attributes""") - -_tests.append(_test_fail_no_requirements) - def _test_simple(env): got = parse_requirements( - ctx = _mock_ctx(), - requirements_lock = "requirements_lock", - ) - got_alternative = parse_requirements( ctx = _mock_ctx(), requirements_by_platform = { - "requirements_lock": "*", + "requirements_lock": ["linux_x86_64", "windows_x86_64"], }, ) env.expect.that_dict(got).contains_exactly({ "foo": [ struct( distribution = "foo", - download = False, extra_pip_args = [], requirement_line = "foo[extra]==0.0.1 --hash=sha256:deadbeef", srcs = struct( @@ -87,13 +71,7 @@ def _test_simple(env): version = "0.0.1", ), target_platforms = [ - "linux_aarch64", - "linux_arm", - "linux_ppc", - "linux_s390x", "linux_x86_64", - "osx_aarch64", - "osx_x86_64", "windows_x86_64", ], whls = [], @@ -102,68 +80,26 @@ def _test_simple(env): ), ], }) - env.expect.that_dict(got).contains_exactly(got_alternative) env.expect.that_str( select_requirement( got["foo"], - platform = "linux_ppc", + platform = "linux_x86_64", ).srcs.version, ).equals("0.0.1") _tests.append(_test_simple) -def _test_platform_markers_with_python_version(env): +def _test_dupe_requirements(env): got = parse_requirements( ctx = _mock_ctx(), requirements_by_platform = { - "requirements_lock": "cp39_linux_*", - }, - ) - got_alternative = parse_requirements( - ctx = _mock_ctx(), - requirements_by_platform = { - "requirements_lock": "linux_*", + "requirements_lock_dupe": ["linux_x86_64"], }, ) env.expect.that_dict(got).contains_exactly({ "foo": [ struct( distribution = "foo", - download = False, - extra_pip_args = [], - requirement_line = "foo[extra]==0.0.1 --hash=sha256:deadbeef", - srcs = struct( - requirement = "foo[extra]==0.0.1", - shas = ["deadbeef"], - version = "0.0.1", - ), - target_platforms = [ - "linux_aarch64", - "linux_arm", - "linux_ppc", - "linux_s390x", - "linux_x86_64", - ], - whls = [], - sdist = None, - is_exposed = True, - ), - ], - }) - env.expect.that_dict(got).contains_exactly(got_alternative) - -_tests.append(_test_platform_markers_with_python_version) - -def _test_dupe_requirements(env): - got = parse_requirements( - ctx = _mock_ctx(), - requirements_lock = "requirements_lock_dupe", - ) - env.expect.that_dict(got).contains_exactly({ - "foo": [ - struct( - distribution = "foo", - download = False, extra_pip_args = [], requirement_line = "foo[extra,extra_2]==0.0.1 --hash=sha256:deadbeef", srcs = struct( @@ -171,16 +107,7 @@ def _test_dupe_requirements(env): shas = ["deadbeef"], version = "0.0.1", ), - target_platforms = [ - "linux_aarch64", - "linux_arm", - "linux_ppc", - "linux_s390x", - "linux_x86_64", - "osx_aarch64", - "osx_x86_64", - "windows_x86_64", - ], + target_platforms = ["linux_x86_64"], whls = [], sdist = None, is_exposed = True, @@ -192,19 +119,10 @@ _tests.append(_test_dupe_requirements) def _test_multi_os(env): got = parse_requirements( - ctx = _mock_ctx(), - requirements_linux = "requirements_linux", - requirements_osx = "requirements_osx", - requirements_windows = "requirements_windows", - ) - - # This is an alternative way to express the same intent - got_alternative = parse_requirements( ctx = _mock_ctx(), requirements_by_platform = { - "requirements_linux": "linux_*", - "requirements_osx": "osx_*", - "requirements_windows": "windows_*", + "requirements_linux": ["linux_x86_64"], + "requirements_windows": ["windows_x86_64"], }, ) @@ -212,7 +130,6 @@ def _test_multi_os(env): "bar": [ struct( distribution = "bar", - download = False, extra_pip_args = [], requirement_line = "bar==0.0.1 --hash=sha256:deadb00f", srcs = struct( @@ -229,7 +146,6 @@ def _test_multi_os(env): "foo": [ struct( distribution = "foo", - download = False, extra_pip_args = [], requirement_line = "foo==0.0.3 --hash=sha256:deadbaaf", srcs = struct( @@ -237,22 +153,13 @@ def _test_multi_os(env): shas = ["deadbaaf"], version = "0.0.3", ), - target_platforms = [ - "linux_aarch64", - "linux_arm", - "linux_ppc", - "linux_s390x", - "linux_x86_64", - "osx_aarch64", - "osx_x86_64", - ], + target_platforms = ["linux_x86_64"], whls = [], sdist = None, is_exposed = True, ), struct( distribution = "foo", - download = False, extra_pip_args = [], requirement_line = "foo[extra]==0.0.2 --hash=sha256:deadbeef", srcs = struct( @@ -267,7 +174,6 @@ def _test_multi_os(env): ), ], }) - env.expect.that_dict(got).contains_exactly(got_alternative) env.expect.that_str( select_requirement( got["foo"], @@ -277,168 +183,27 @@ def _test_multi_os(env): _tests.append(_test_multi_os) -def _test_fail_duplicate_platforms(env): - errors = [] - parse_requirements( - ctx = _mock_ctx(), - requirements_by_platform = { - "requirements_linux": "linux_x86_64", - "requirements_lock": "*", - }, - fail_fn = errors.append, - ) - env.expect.that_collection(errors).has_size(1) - env.expect.that_str(",".join(errors)).equals("Expected the platform 'linux_x86_64' to be map only to a single requirements file, but got multiple: 'requirements_linux', 'requirements_lock'") - -_tests.append(_test_fail_duplicate_platforms) - -def _test_multi_os_download_only_platform(env): - got = parse_requirements( - ctx = _mock_ctx(), - requirements_lock = "requirements_linux", - extra_pip_args = [ - "--platform", - "manylinux_2_27_x86_64", - "--platform=manylinux_2_12_x86_64", - "--platform manylinux_2_5_x86_64", - ], - ) - env.expect.that_dict(got).contains_exactly({ - "foo": [ +def _test_select_requirement_none_platform(env): + got = select_requirement( + [ struct( - distribution = "foo", - download = True, - extra_pip_args = [ - "--platform", - "manylinux_2_27_x86_64", - "--platform=manylinux_2_12_x86_64", - "--platform manylinux_2_5_x86_64", - ], - requirement_line = "foo==0.0.3 --hash=sha256:deadbaaf", - srcs = struct( - requirement = "foo==0.0.3", - shas = ["deadbaaf"], - version = "0.0.3", - ), + some_attr = "foo", target_platforms = ["linux_x86_64"], - whls = [], - sdist = None, - is_exposed = True, ), ], - }) - env.expect.that_str( - select_requirement( - got["foo"], - platform = "windows_x86_64", - ).srcs.version, - ).equals("0.0.3") - -_tests.append(_test_multi_os_download_only_platform) - -def _test_fail_download_only_bad_attr(env): - errors = [] - parse_requirements( - ctx = _mock_ctx(), - requirements_linux = "requirements_linux", - requirements_osx = "requirements_osx", - extra_pip_args = [ - "--platform", - "manylinux_2_27_x86_64", - "--platform=manylinux_2_12_x86_64", - "--platform manylinux_2_5_x86_64", - ], - fail_fn = errors.append, + platform = None, ) - env.expect.that_str(errors[0]).equals("only a single 'requirements_lock' file can be used when using '--platform' pip argument, consider specifying it via 'requirements_lock' attribute") + env.expect.that_str(got.some_attr).equals("foo") -_tests.append(_test_fail_download_only_bad_attr) - -def _test_os_arch_requirements_with_default(env): - got = parse_requirements( - ctx = _mock_ctx(), - requirements_by_platform = { - "requirements_direct": "linux_super_exotic", - "requirements_linux": "linux_x86_64,linux_aarch64", - }, - requirements_lock = "requirements_lock", - ) - env.expect.that_dict(got).contains_exactly({ - "foo": [ - struct( - distribution = "foo", - download = False, - extra_pip_args = [], - requirement_line = "foo==0.0.3 --hash=sha256:deadbaaf", - srcs = struct( - requirement = "foo==0.0.3", - shas = ["deadbaaf"], - version = "0.0.3", - ), - target_platforms = ["linux_aarch64", "linux_x86_64"], - whls = [], - sdist = None, - is_exposed = True, - ), - struct( - distribution = "foo", - download = False, - extra_pip_args = [], - requirement_line = "foo[extra] @ https://some-url", - srcs = struct( - requirement = "foo[extra] @ https://some-url", - shas = [], - version = "", - ), - target_platforms = ["linux_super_exotic"], - whls = [], - sdist = None, - is_exposed = True, - ), - struct( - distribution = "foo", - download = False, - extra_pip_args = [], - requirement_line = "foo[extra]==0.0.1 --hash=sha256:deadbeef", - srcs = struct( - requirement = "foo[extra]==0.0.1", - shas = ["deadbeef"], - version = "0.0.1", - ), - target_platforms = [ - "linux_arm", - "linux_ppc", - "linux_s390x", - "osx_aarch64", - "osx_x86_64", - "windows_x86_64", - ], - whls = [], - sdist = None, - is_exposed = True, - ), - ], - }) - env.expect.that_str( - select_requirement( - got["foo"], - platform = "windows_x86_64", - ).srcs.version, - ).equals("0.0.1") - env.expect.that_str( - select_requirement( - got["foo"], - platform = "linux_x86_64", - ).srcs.version, - ).equals("0.0.3") - -_tests.append(_test_os_arch_requirements_with_default) +_tests.append(_test_select_requirement_none_platform) def _test_fail_no_python_version(env): errors = [] parse_requirements( ctx = _mock_ctx(), - requirements_lock = "requirements_lock", + requirements_by_platform = { + "requirements_lock": [""], + }, get_index_urls = lambda _, __: {}, fail_fn = errors.append, ) diff --git a/tests/pypi/render_pkg_aliases/render_pkg_aliases_test.bzl b/tests/pypi/render_pkg_aliases/render_pkg_aliases_test.bzl index 0d4c75e3c2..09a06311fc 100644 --- a/tests/pypi/render_pkg_aliases/render_pkg_aliases_test.bzl +++ b/tests/pypi/render_pkg_aliases/render_pkg_aliases_test.bzl @@ -517,7 +517,7 @@ def _test_get_python_versions_from_filenames(env): _tests.append(_test_get_python_versions_from_filenames) -def _test_target_platforms_from_alias_target_platforms(env): +def _test_get_flag_versions_from_alias_target_platforms(env): got = get_whl_flag_versions( aliases = [ whl_alias( @@ -534,7 +534,7 @@ def _test_target_platforms_from_alias_target_platforms(env): version = "3.3", filename = "foo-0.0.0-py3-none-any.whl", target_platforms = [ - "linux_x86_64", + "cp33_linux_x86_64", ], ), ], @@ -548,7 +548,7 @@ def _test_target_platforms_from_alias_target_platforms(env): } env.expect.that_dict(got).contains_exactly(want) -_tests.append(_test_target_platforms_from_alias_target_platforms) +_tests.append(_test_get_flag_versions_from_alias_target_platforms) def _test_config_settings( env, @@ -820,8 +820,8 @@ def _test_multiplatform_whl_aliases_filename(env): filename = "foo-0.0.2-py3-none-any.whl", version = "3.1", target_platforms = [ - "linux_x86_64", - "linux_aarch64", + "cp31_linux_x86_64", + "cp31_linux_aarch64", ], ), ] diff --git a/tests/pypi/requirements_files_by_platform/BUILD.bazel b/tests/pypi/requirements_files_by_platform/BUILD.bazel new file mode 100644 index 0000000000..d78d459f59 --- /dev/null +++ b/tests/pypi/requirements_files_by_platform/BUILD.bazel @@ -0,0 +1,3 @@ +load(":requirements_files_by_platform_tests.bzl", "requirements_files_by_platform_test_suite") + +requirements_files_by_platform_test_suite(name = "requirements_files_by_platform_tests") diff --git a/tests/pypi/requirements_files_by_platform/requirements_files_by_platform_tests.bzl b/tests/pypi/requirements_files_by_platform/requirements_files_by_platform_tests.bzl new file mode 100644 index 0000000000..b729b0eaf0 --- /dev/null +++ b/tests/pypi/requirements_files_by_platform/requirements_files_by_platform_tests.bzl @@ -0,0 +1,205 @@ +# Copyright 2024 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"" + +load("@rules_testing//lib:test_suite.bzl", "test_suite") +load("//python/private/pypi:requirements_files_by_platform.bzl", "requirements_files_by_platform") # buildifier: disable=bzl-visibility + +_tests = [] + +def _test_fail_no_requirements(env): + errors = [] + requirements_files_by_platform( + fail_fn = errors.append, + ) + env.expect.that_str(errors[0]).equals("""\ +A 'requirements_lock' attribute must be specified, a platform-specific lockfiles via 'requirements_by_platform' or an os-specific lockfiles must be specified via 'requirements_*' attributes""") + +_tests.append(_test_fail_no_requirements) + +def _test_fail_duplicate_platforms(env): + errors = [] + requirements_files_by_platform( + requirements_by_platform = { + "requirements_linux": "linux_x86_64", + "requirements_lock": "*", + }, + fail_fn = errors.append, + ) + env.expect.that_collection(errors).has_size(1) + env.expect.that_str(",".join(errors)).equals("Expected the platform 'linux_x86_64' to be map only to a single requirements file, but got multiple: 'requirements_linux', 'requirements_lock'") + +_tests.append(_test_fail_duplicate_platforms) + +def _test_fail_download_only_bad_attr(env): + errors = [] + requirements_files_by_platform( + requirements_linux = "requirements_linux", + requirements_osx = "requirements_osx", + extra_pip_args = [ + "--platform", + "manylinux_2_27_x86_64", + "--platform=manylinux_2_12_x86_64", + "--platform manylinux_2_5_x86_64", + ], + fail_fn = errors.append, + ) + env.expect.that_str(errors[0]).equals("only a single 'requirements_lock' file can be used when using '--platform' pip argument, consider specifying it via 'requirements_lock' attribute") + +_tests.append(_test_fail_download_only_bad_attr) + +def _test_simple(env): + for got in [ + requirements_files_by_platform( + requirements_lock = "requirements_lock", + ), + requirements_files_by_platform( + requirements_by_platform = { + "requirements_lock": "*", + }, + ), + ]: + env.expect.that_dict(got).contains_exactly({ + "requirements_lock": [ + "linux_aarch64", + "linux_arm", + "linux_ppc", + "linux_s390x", + "linux_x86_64", + "osx_aarch64", + "osx_x86_64", + "windows_x86_64", + ], + }) + +_tests.append(_test_simple) + +def _test_simple_with_python_version(env): + for got in [ + requirements_files_by_platform( + requirements_lock = "requirements_lock", + python_version = "3.11", + ), + requirements_files_by_platform( + requirements_by_platform = { + "requirements_lock": "*", + }, + python_version = "3.11", + ), + # TODO @aignas 2024-07-15: consider supporting this way of specifying + # the requirements without the need of the `python_version` attribute + # setting. However, this might need more tweaks, hence only leaving a + # comment in the test. + # requirements_files_by_platform( + # requirements_by_platform = { + # "requirements_lock": "cp311_*", + # }, + # ), + ]: + env.expect.that_dict(got).contains_exactly({ + "requirements_lock": [ + "cp311_linux_aarch64", + "cp311_linux_arm", + "cp311_linux_ppc", + "cp311_linux_s390x", + "cp311_linux_x86_64", + "cp311_osx_aarch64", + "cp311_osx_x86_64", + "cp311_windows_x86_64", + ], + }) + +_tests.append(_test_simple_with_python_version) + +def _test_multi_os(env): + for got in [ + requirements_files_by_platform( + requirements_linux = "requirements_linux", + requirements_osx = "requirements_osx", + requirements_windows = "requirements_windows", + ), + requirements_files_by_platform( + requirements_by_platform = { + "requirements_linux": "linux_*", + "requirements_osx": "osx_*", + "requirements_windows": "windows_*", + }, + ), + ]: + env.expect.that_dict(got).contains_exactly({ + "requirements_linux": [ + "linux_aarch64", + "linux_arm", + "linux_ppc", + "linux_s390x", + "linux_x86_64", + ], + "requirements_osx": [ + "osx_aarch64", + "osx_x86_64", + ], + "requirements_windows": [ + "windows_x86_64", + ], + }) + +_tests.append(_test_multi_os) + +def _test_multi_os_download_only_platform(env): + got = requirements_files_by_platform( + requirements_lock = "requirements_linux", + extra_pip_args = [ + "--platform", + "manylinux_2_27_x86_64", + "--platform=manylinux_2_12_x86_64", + "--platform manylinux_2_5_x86_64", + ], + ) + env.expect.that_dict(got).contains_exactly({ + "requirements_linux": ["linux_x86_64"], + }) + +_tests.append(_test_multi_os_download_only_platform) + +def _test_os_arch_requirements_with_default(env): + got = requirements_files_by_platform( + requirements_by_platform = { + "requirements_exotic": "linux_super_exotic", + "requirements_linux": "linux_x86_64,linux_aarch64", + }, + requirements_lock = "requirements_lock", + ) + env.expect.that_dict(got).contains_exactly({ + "requirements_exotic": ["linux_super_exotic"], + "requirements_linux": ["linux_x86_64", "linux_aarch64"], + "requirements_lock": [ + "linux_arm", + "linux_ppc", + "linux_s390x", + "osx_aarch64", + "osx_x86_64", + "windows_x86_64", + ], + }) + +_tests.append(_test_os_arch_requirements_with_default) + +def requirements_files_by_platform_test_suite(name): + """Create the test suite. + + Args: + name: the name of the test suite + """ + test_suite(name = name, basic_tests = _tests)