diff --git a/ci_cd/tasks/update_deps.py b/ci_cd/tasks/update_deps.py index a7152e08..0337bd33 100644 --- a/ci_cd/tasks/update_deps.py +++ b/ci_cd/tasks/update_deps.py @@ -9,15 +9,15 @@ import operator import re import sys -from collections import namedtuple from pathlib import Path from typing import TYPE_CHECKING import tomlkit from invoke import task +from pip._vendor.packaging.requirements import Requirement from ci_cd.exceptions import CICDException, InputError, InputParserError -from ci_cd.utils import Emoji, SemanticVersion, update_file +from ci_cd.utils import Emoji, SemanticVersion, SemanticVersionRange, update_file if TYPE_CHECKING: # pragma: no cover from typing import Literal @@ -74,19 +74,6 @@ def update_deps( # pylint: disable=too-many-branches,too-many-locals,too-many-s if not ignore: ignore: list[str] = [] # type: ignore[no-redef] - VersionSpec = namedtuple( - "VersionSpec", - [ - "full_dependency", - "package", - "url_version", - "operator", - "version", - "extra_operator_version", - "environment_marker", - ], - ) - try: ignore_rules = parse_ignore_entries(ignore, ignore_separator) except InputError as exc: @@ -122,67 +109,46 @@ def update_deps( # pylint: disable=too-many-branches,too-many-locals,too-many-s already_handled_packages = set() updated_packages = {} - dependencies = pyproject.get("project", {}).get("dependencies", []) + dependencies: list[str] = pyproject.get("project", {}).get("dependencies", []) for optional_deps in ( pyproject.get("project", {}).get("optional-dependencies", {}).values() ): dependencies.extend(optional_deps) error = False - for line in dependencies: - match = re.match( - r"^(?P(?P[a-zA-Z0-9_.-]+)(?:\s*\[.*\])?)\s*" - r"(?:" - r"(?P@\s*\S+)|" - r"(?P>|<|<=|>=|==|!=|~=)\s*" - r"(?P[0-9]+(?:\.[0-9]+){0,2})" - r")?\s*" - r"(?P(?:,(?:>|<|<=|>=|==|!=|~=)\s*" - r"[0-9]+(?:\.[0-9]+){0,2}\s*)+)*" - r"(?P;.+)*$", - line, - ) - if match is None: - msg = ( - f"Could not parse package and version specification for line:\n {line}" - ) - LOGGER.warning(msg) - if fail_fast: - sys.exit(f"{Emoji.CROSS_MARK.value} {msg}") - print(msg) - error = True - continue - - version_spec = VersionSpec(**match.groupdict()) - LOGGER.debug("version_spec: %s", version_spec) + for dependency in dependencies: + parsed_requirement = Requirement(dependency) + LOGGER.debug("parsed_requirement: %r", parsed_requirement) # Skip package if already handled - if version_spec.package in already_handled_packages: + if parsed_requirement.name in already_handled_packages: continue # Skip URL versioned dependencies - if version_spec.url_version: + if parsed_requirement.url: msg = ( - f"Dependency {version_spec.full_dependency!r} is pinned to a URL and " + f"Dependency {parsed_requirement.name!r} is pinned to a URL and " "will be skipped." ) LOGGER.info(msg) - print(msg) + print(msg, flush=True) + already_handled_packages.add(parsed_requirement.name) continue # Skip and warn if package is not version-restricted - if not version_spec.operator and not version_spec.url_version: + if not parsed_requirement.specifier: msg = ( - f"Dependency {version_spec.full_dependency!r} is not version " + f"Dependency {parsed_requirement.name!r} is not version " "restricted and will be skipped. Consider adding version restrictions." ) LOGGER.warning(msg) - print(msg) + print(msg, flush=True) + already_handled_packages.add(parsed_requirement.name) continue # Check version from PyPI's online package index out: "Result" = context.run( - f"pip index versions --python-version {py_version} {version_spec.package}", + f"pip index versions --python-version {py_version} {parsed_requirement.name}", hide=True, ) package_latest_version_line = out.stdout.split(sep="\n", maxsplit=1)[0] @@ -198,37 +164,64 @@ def update_deps( # pylint: disable=too-many-branches,too-many-locals,too-many-s LOGGER.warning(msg) if fail_fast: sys.exit(f"{Emoji.CROSS_MARK.value} {msg}") - print(msg) - already_handled_packages.add(version_spec.package) + print(msg, flush=True) + already_handled_packages.add(parsed_requirement.name) error = True continue # Sanity check - if version_spec.package != match.group("package"): + if parsed_requirement.name != match.group("package"): msg = ( - f"Package name parsed from pyproject.toml ({version_spec.package!r}) " - "does not match the name returned from 'pip index versions': " - f"{match.group('package')!r}" + "Package name parsed from pyproject.toml " + f"({parsed_requirement.name!r}) does not match the name returned from " + f"'pip index versions': {match.group('package')!r}" ) LOGGER.warning(msg) if fail_fast: sys.exit(f"{Emoji.CROSS_MARK.value} {msg}") - print(msg) - already_handled_packages.add(version_spec.package) + print(msg, flush=True) + already_handled_packages.add(parsed_requirement.name) error = True continue # Check whether pyproject.toml already uses the latest version - latest_version = match.group("version").split(".") - for index, version_part in enumerate(version_spec.version.split(".")): - if version_part != latest_version[index]: - break - else: - already_handled_packages.add(version_spec.package) + try: + current_version_range = SemanticVersionRange(parsed_requirement) + except ValueError as exc: + msg = ( + f"Could not parse version range specifier set for package " + f"{parsed_requirement.name!r} from pyproject.toml: {exc}" + ) + LOGGER.warning(msg) + if fail_fast: + sys.exit(f"{Emoji.CROSS_MARK.value} {msg}") + print(msg, flush=True) + already_handled_packages.add(parsed_requirement.name) + error = True continue + latest_version = match.group("version") + if latest_version in current_version_range: + if ">=" in current_version_range.operators: + # Not updating version, since minimum version is already satisfied + already_handled_packages.add(parsed_requirement.name) + continue + if "~=" in current_version_range.operators: + # Maybe update version, since the ~= specifier is used + # NOTE: This may be wrong if there are multiple ~= specifiers, + # expecting only one, though + version = current_version_range.version_from_operator("~=") + for index, version_part in enumerate(version.split(".")): + if version_part != latest_version.split(".")[index]: + break + else: + # Not updating version, since the minimum version range matches the + # latest version + already_handled_packages.add(parsed_requirement.name) + continue + # Apply ignore rules - if version_spec.package in ignore_rules or "*" in ignore_rules: + if parsed_requirement.name in ignore_rules or "*" in ignore_rules: versions: "list[dict[Literal['operator', 'version'], str]]" = [] update_types: "dict[Literal['version-update'], list[Literal['major', 'minor', 'patch']]]" = ( # pylint: disable=line-too-long {} @@ -237,8 +230,8 @@ def update_deps( # pylint: disable=too-many-branches,too-many-locals,too-many-s if "*" in ignore_rules: versions, update_types = parse_ignore_rules(ignore_rules["*"]) - if version_spec.package in ignore_rules: - parsed_rules = parse_ignore_rules(ignore_rules[version_spec.package]) + if parsed_requirement.name in ignore_rules: + parsed_rules = parse_ignore_rules(ignore_rules[parsed_requirement.name]) versions.extend(parsed_rules[0]) update_types.update(parsed_rules[1]) @@ -248,37 +241,184 @@ def update_deps( # pylint: disable=too-many-branches,too-many-locals,too-many-s ) if ignore_version( - current=version_spec.version.split("."), - latest=latest_version, + # Cast to str() to ensure a split on a padded version + current=str(current_version_range.lower).split("."), + latest=latest_version.split("."), version_rules=versions, semver_rules=update_types, ): - already_handled_packages.add(version_spec.package) + already_handled_packages.add(parsed_requirement.name) continue + # Update specifiers + updated_version_range = None + if latest_version > current_version_range: + spec_operator = "<=" + if spec_operator in current_version_range.operators: + # Update to include latest version + n_version_parts = current_version_range.version_from_operator( + spec_operator + ).number_of_original_core_version_parts + updated_specifier_version = ".".join( + latest_version.split(".")[:n_version_parts] + ) + updated_specifier = f"{spec_operator}{updated_specifier_version}" + updated_specifier_set = [ + str(_) for _ in current_version_range if _.operator != spec_operator + ] + updated_specifier_set.append(updated_specifier) + updated_version_range = SemanticVersionRange( + ",".join(updated_specifier_set) + ) + else: + spec_operator = "~=" + if spec_operator in current_version_range.operators: + # Expand and change ~= to >= and < operators if the latest version + # changes major version. + # Otherwise, update to include latest version as the minimum version + current_version = current_version_range.version_from_operator( + spec_operator + ) + parsed_latest_version = SemanticVersion(latest_version) + + if parsed_latest_version.major > current_version.major: + # Expand and change ~= to >= and < operators + # >= current_version + specifier_set_updates = [f">={current_version}"] + # < next major version up from latest_version + specifier_set_updates.append( + f"<{str(parsed_latest_version.next_version('major').major)}" + ) + else: + # Keep the ~= operator, but update to include the latest + # version as the minimum version + updated_specifier_version = ".".join( + latest_version.split(".")[ + : current_version.number_of_original_core_version_parts + ] + ) + specifier_set_updates = [ + f"{spec_operator}{updated_specifier_version}" + ] + + updated_specifier_set = [ + str(_) + for _ in current_version_range + if _.operator != spec_operator + ] + updated_specifier_set.extend(specifier_set_updates) + updated_version_range = SemanticVersionRange( + ",".join(updated_specifier_set) + ) + else: + spec_operator = "<" + if spec_operator in current_version_range.operators: + # Update to include latest version by upping to the next + # version up from the latest version + n_version_parts = current_version_range.version_from_operator( + spec_operator + ).number_of_original_core_version_parts + parsed_latest_version = SemanticVersion(latest_version) + if n_version_parts == 1: + updated_specifier_version = str( + parsed_latest_version.next_version("major").major + ) + elif n_version_parts == 2: + updated_specifier_version = ".".join( + parsed_latest_version.next_version("minor").split(".")[ + :2 + ] + ) + elif n_version_parts == 3: + updated_specifier_version = ( + parsed_latest_version.next_version("patch") + ) + else: + raise CICDException( + f"Invalid number of version parts: {n_version_parts}" + ) + updated_specifier = ( + f"{spec_operator}{updated_specifier_version}" + ) + updated_specifier_set = [ + str(_) + for _ in current_version_range + if _.operator != spec_operator + ] + updated_specifier_set.append(updated_specifier) + updated_version_range = SemanticVersionRange( + ",".join(updated_specifier_set) + ) + elif latest_version in current_version_range: + for spec_operator in ["~=", "=="]: + if spec_operator in current_version_range.operators: + n_version_parts = current_version_range.version_from_operator( + spec_operator + ).number_of_original_core_version_parts + updated_version = ".".join( + latest_version.split(".")[:n_version_parts] + ) + updated_specifier = f"{spec_operator}{updated_version}" + updated_specifier_set = [ + str(_) + for _ in current_version_range + if _.operator != spec_operator + ] + updated_specifier_set.append(updated_specifier) + updated_version_range = SemanticVersionRange( + ",".join(updated_specifier_set) + ) + break + else: + msg = ( + "Could not determine how to update to the latest version using the " + f"version range specifier set: {current_version_range}. " + f"Package: {parsed_requirement.name}. Latest version: {latest_version}" + ) + LOGGER.warning(msg) + if fail_fast: + sys.exit(f"{Emoji.CROSS_MARK.value} {msg}") + print(msg, flush=True) + already_handled_packages.add(parsed_requirement.name) + error = True + continue + + LOGGER.debug("updated_version_range: %s", updated_version_range) + if updated_version_range is None: + msg = ( + "Could not determine how to update to the latest version using the " + f"version range specifier set: {current_version_range}. " + f"Package: {parsed_requirement.name}. Latest version: {latest_version}" + ) + LOGGER.warning(msg) + if fail_fast: + sys.exit(f"{Emoji.CROSS_MARK.value} {msg}") + print(msg, flush=True) + already_handled_packages.add(parsed_requirement.name) + error = True + continue + if not error: # Update pyproject.toml - updated_version = ".".join( - latest_version[: len(version_spec.version.split("."))] - ) - escaped_full_dependency_name = version_spec.full_dependency.replace( - "[", r"\[" - ).replace("]", r"\]") - update_file( - pyproject_path, - ( - rf'"{escaped_full_dependency_name} {version_spec.operator}.*"', - f'"{version_spec.full_dependency} ' - f"{version_spec.operator}{updated_version}" - f'{version_spec.extra_operator_version if version_spec.extra_operator_version else ""}' # pylint: disable=line-too-long - f'{version_spec.environment_marker if version_spec.environment_marker else ""}"', # pylint: disable=line-too-long - ), + updated_dependency = parsed_requirement.name + if parsed_requirement.extras: + formatted_extras = ",".join(sorted(parsed_requirement.extras)) + updated_dependency += f"[{formatted_extras}]" + match = re.search( + rf"{parsed_requirement.name}(?:\[.*\])?(?P\s)+", dependency ) - already_handled_packages.add(version_spec.package) - updated_packages[version_spec.full_dependency] = ( - f"{version_spec.operator}{updated_version}" - f"{version_spec.extra_operator_version if version_spec.extra_operator_version else ''}" # pylint: disable=line-too-long - f"{' ' + version_spec.environment_marker if version_spec.environment_marker else ''}" # pylint: disable=line-too-long + if match: + updated_dependency += match.group("space") + updated_dependency += ",".join( + str(_) for _ in sorted(updated_version_range) + ) # Specifier set + if parsed_requirement.marker: + updated_dependency += f"; {parsed_requirement.marker}" + + update_file(pyproject_path, (re.escape(dependency), updated_dependency)) + already_handled_packages.add(parsed_requirement.name) + updated_packages[parsed_requirement.name] = str(updated_version_range) + ( + f"; {parsed_requirement.marker}" if parsed_requirement.marker else "" ) if error: @@ -294,10 +434,11 @@ def update_deps( # pylint: disable=too-many-branches,too-many-locals,too-many-s f" {package} ({version})" for package, version in updated_packages.items() ) - + "\n" + + "\n", + flush=True, ) else: - print(f"{Emoji.CHECK_MARK.value} No dependency updates available.") + print(f"{Emoji.CHECK_MARK.value} No dependency updates available.", flush=True) def parse_ignore_entries( @@ -535,9 +676,9 @@ def ignore_version( Parameters: current: The current version as a list of version parts. It's expected, but not - required, the version is a semantic version. + required, to be a semantic version. latest: The latest version as a list of version parts. It's expected, but not - required, the version is a semantic version. + required, to be a semantic version. version_rules: Version ignore rules. semver_rules: Semantic version ignore rules. diff --git a/ci_cd/utils.py b/ci_cd/utils.py index b88d5982..e8117014 100644 --- a/ci_cd/utils.py +++ b/ci_cd/utils.py @@ -8,8 +8,20 @@ from pathlib import Path from typing import TYPE_CHECKING, no_type_check +from pip._vendor.packaging.requirements import InvalidRequirement, Requirement +from pip._vendor.packaging.specifiers import ( + InvalidSpecifier, + Specifier, + SpecifierSet, + _IndividualSpecifier, +) + if TYPE_CHECKING: # pragma: no cover - from typing import Any, Optional, Tuple, Union + from typing import Any, Iterator, Optional, Union + + from pip._vendor.packaging.specifiers import LegacySpecifier + + ParsedSpecifier = Union[Specifier, LegacySpecifier] LOGGER = logging.getLogger(__file__) @@ -124,6 +136,10 @@ def __init__( self._pre_release = pre_release if pre_release else None self._build = build if build else None + self._original_version = self._build_version( + major, minor, patch, pre_release, build + ) + @classmethod def _build_version( cls, @@ -161,6 +177,21 @@ def _build_version( version += f"+{build}" return version + @property + def original_version(self) -> str: + """The original version string used to create the instance.""" + return self._original_version + + @property + def number_of_original_core_version_parts(self) -> int: + """The original semantic version parts used to create the instance, + ignoring pre-relase and build.""" + match = re.match(self._REGEX, self.original_version) + if match is None: + raise AssertionError + major, minor, patch, _, _ = match.groups() + return len([part for part in (major, minor, patch) if part]) + @property def major(self) -> int: """The major version.""" @@ -204,19 +235,20 @@ def __repr__(self) -> str: """Return the string representation of the object.""" return repr(self.__str__()) - def _validate_other_type(self, other: "Any") -> "SemanticVersion": + @classmethod + def _validate_other_type(cls, other: "Any") -> "SemanticVersion": """Initial check/validation of `other` before rich comparisons.""" not_implemented_exc = NotImplementedError( - f"Rich comparison not implemented between {self.__class__.__name__} and " + f"Rich comparison not implemented between {cls.__name__} and " f"{type(other)}" ) - if isinstance(other, self.__class__): + if isinstance(other, cls): return other if isinstance(other, str): try: - return self.__class__(other) + return cls(other) except (TypeError, ValueError) as exc: raise not_implemented_exc from exc @@ -224,6 +256,9 @@ def _validate_other_type(self, other: "Any") -> "SemanticVersion": def __lt__(self, other: "Any") -> bool: """Less than (`<`) rich comparison.""" + if isinstance(other, SemanticVersionRange): + return self < other.lower + other_semver = self._validate_other_type(other) if self.major < other_semver.major: @@ -248,6 +283,9 @@ def __le__(self, other: "Any") -> bool: def __eq__(self, other: "Any") -> bool: """Equal to (`==`) rich comparison.""" + if isinstance(other, SemanticVersionRange): + return self in other + other_semver = self._validate_other_type(other) return ( @@ -269,7 +307,7 @@ def __gt__(self, other: "Any") -> bool: """Greater than (`>`) rich comparison.""" return not self.__le__(other) - def next_version(self, version_part: str) -> "SemanticVersion": + def next_version(self, version_part: "Optional[str]" = None) -> "SemanticVersion": """Return the next version for the specified version part. Parameters: @@ -282,22 +320,439 @@ def next_version(self, version_part: str) -> "SemanticVersion": ValueError: If the version part is not one of `major`, `minor`, or `patch`. """ - if version_part not in ("major", "minor", "patch"): + # Deduce "original_version" for new instance, disregard pre-release and build + match = re.match(self._REGEX, self.original_version) + if match is None: + raise AssertionError + _, minor, patch, _, _ = match.groups() + + if not version_part: + # Determine what version part to increment based on the original version + version_part = "patch" if patch else "minor" if minor else "major" + + if version_part == "major": + incremented_version = str(self.major + 1) + if minor: + incremented_version += ".0" + if patch: + incremented_version += ".0" + elif version_part == "minor": + incremented_version = f"{self.major}.{self.minor + 1}" + if patch: + incremented_version += ".0" + elif version_part == "patch": + incremented_version = f"{self.major}.{self.minor}.{self.patch + 1}" + else: raise ValueError( "version_part must be one of 'major', 'minor', or 'patch', not " f"{version_part!r}" ) - if version_part == "major": - return self.__class__(f"{self.major + 1}.0.0") - if version_part == "minor": - return self.__class__(f"{self.major}.{self.minor + 1}.0") + return self.__class__(incremented_version) + + +class SortableSpecifier(Specifier): + """A sortable specifier.""" + + _sorted_operators = ["===", "==", "~=", ">=", ">", "<", "<=", "!="] + + def __lt__(self, other: "Any") -> bool: + """Less than (`<`) rich comparison.""" + if isinstance(other, _IndividualSpecifier): + return self._sorted_operators.index( + self.operator + ) < self._sorted_operators.index(other.operator) + if isinstance(other, str): + try: + other = self.__class__(other) + except InvalidSpecifier as exc: + raise NotImplementedError from exc + return self._sorted_operators.index( + self.operator + ) < self._sorted_operators.index(other.operator) + + raise NotImplementedError + + +class SemanticVersionRange: + """A range of semantic versions. + + The implementation relies on the pip package `packaging` for parsing the version + requirements and checking if a version is in the range. + """ + + _arbritrary_upper_limit = "9" * 3 + + def __init__(self, specifier: "Union[SpecifierSet, Requirement, str]") -> None: + if isinstance(specifier, str): + try: + specifier = Requirement(specifier).specifier + except InvalidRequirement: + try: + specifier = SpecifierSet(specifier) # type: ignore[arg-type] + except InvalidSpecifier as exc: + raise ValueError( + f"specifier ({specifier}) cannot be parsed as a requirement " + "or specifier set" + ) from exc + elif isinstance(specifier, Requirement): + specifier = specifier.specifier + if not isinstance(specifier, SpecifierSet): + raise TypeError( + f"specifier must be of type {SpecifierSet.__name__}, " + f"{Requirement.__name__}, or str, not {type(specifier).__name__}" + ) + + self._specifier = self._sanitize_specifier_set(specifier) + self._lower = self._determine_lower() + self._upper = self._determine_upper() + + def _sanitize_specifier_set(self, specifier_set: "SpecifierSet") -> "SpecifierSet": + """Sanitize the specifier set.""" + if not specifier_set: + return specifier_set + + # Check all version specifiers are semantic + for specifier in specifier_set: + try: + SemanticVersion(specifier.version) + except ValueError as exc: + raise ValueError( + f"Specifier {specifier} is not a semantic version specifier" + ) from exc + + # Check single-use operators are used only once + if ( + len( + [ + specifier + for specifier in specifier_set + if specifier.operator in ("<", "<=") + ] + ) + > 1 + ): + raise ValueError( + "Multiple upper bound specifiers ('<', '<=') found in specifier set " + f"{specifier_set}. Instead, consider using the != operator. For " + "example, to avoid a complete minor range: !=1.*" + ) + if ( + len( + [ + specifier + for specifier in specifier_set + if specifier.operator in (">", ">=", "~=") + ] + ) + > 1 + ): + raise ValueError( + "Multiple lower bound specifiers ('>', '>=', '~=') found in specifier " + f"set {specifier_set}. Instead, consider using the != operator. For " + "example, to avoid a complete minor range: !=1.*" + ) + if ( + len( + [specifier for specifier in specifier_set if specifier.operator == "=="] + ) + > 1 + ): + raise ValueError( + "Single-use specifier '==' found multiple times in specifier set " + f"{specifier_set}." + ) + + # Check no other specifiers are given if equals is given + if ( + any(specifier.operator == "==" for specifier in specifier_set) + and len(specifier_set) > 1 + ): + raise ValueError( + "Specifier set erroneously contains specifiers alongside a '==' " + "specifier" + ) + + # Check that the version range is not "flipped" + apparent_range = [None, None] + for specifier in specifier_set: + if specifier.operator in (">", ">=", "~="): + apparent_range[0] = SemanticVersion(specifier.version) + elif specifier.operator in ("<", "<="): + apparent_range[1] = SemanticVersion(specifier.version) + else: + continue + if ( + apparent_range[0] + and apparent_range[1] + and apparent_range[0] > apparent_range[1] + ): + raise ValueError( + f"Version range from specifier set {specifier_set} is 'flipped'. I.e.," + " the apparent lower bound is larger than the apparent upper bound." + ) + + return specifier_set + + def __contains__(self, version: "str") -> bool: + """Check if the given version is in the range.""" + return version in self._specifier + + def __str__(self) -> str: + """Return the string representation of the object.""" + return str(self._specifier) + + def __repr__(self) -> str: + """Return the string representation of the object.""" + return f"<{self.__class__.__name__}({str(self)!r})>" + + @property + def lower(self) -> SemanticVersion: + """The lower bound of the range.""" + return self._lower + + @lower.setter + def lower(self, value: "Union[SemanticVersion, str]") -> None: + """Set the lower bound of the range.""" + if isinstance(value, str): + value = SemanticVersion(value) + + if not isinstance(value, SemanticVersion): + raise TypeError( + f"lower must be of type {SemanticVersion.__name__} or str, not " + f"{type(value).__name__}" + ) + + if value > self._upper: + raise ValueError( + f"lower ({value}) cannot be greater than upper ({self._upper})" + ) + if value not in self and value != SemanticVersion("0"): + raise ValueError( + f"lower ({value}) is not in the range ({self}) and is not '0'" + ) + + self._lower = value + + @property + def upper(self) -> SemanticVersion: + """The upper bound of the range.""" + return self._upper + + @upper.setter + def upper(self, value: "Union[SemanticVersion, str]") -> None: + """Set the upper bound of the range.""" + if isinstance(value, str): + value = SemanticVersion(value) + + if not isinstance(value, SemanticVersion): + raise TypeError( + f"upper must be of type {SemanticVersion.__name__} or str, not " + f"{type(value).__name__}" + ) + + if value < self._lower: + raise ValueError( + f"upper ({value}) cannot be less than lower ({self._lower})" + ) + if value not in self: + raise ValueError(f"upper ({value}) is not in the range ({self})") + + self._upper = value + + @property + def operators(self) -> tuple[str, ...]: + """The operators used in the specifier.""" + return tuple(specifier.operator for specifier in self._specifier) + + def version_from_operator(self, operator: str) -> SemanticVersion: + """Return the raw string version for the given operator.""" + for specifier in self._specifier: + if specifier.operator == operator: + return SemanticVersion(specifier.version) + raise ValueError( + f"Operator {operator} not found in specifier set {self._specifier}" + ) - return self.__class__(f"{self.major}.{self.minor}.{self.patch + 1}") + def _determine_lower(self) -> SemanticVersion: + """Determine the lower version range limit based on the specifier.""" + + def __next_version(version: SemanticVersion) -> SemanticVersion: + """Return the next version of the given version.""" + if version.patch < int(self._arbritrary_upper_limit): + return version.next_version("patch") + if version.minor < int(self._arbritrary_upper_limit): + return version.next_version("minor") + if version.major >= int(self._arbritrary_upper_limit): + raise ValueError( + f"Major version for {version} exceeds {self.__class__.__name__}'s " + "internal upper limit" + ) + return version.next_version("major") + + if not self._specifier: + return SemanticVersion("0") + + lower = None + for specifier in self._specifier: + if specifier.operator == ">=": + lower = SemanticVersion(specifier.version) + break + if specifier.operator == "==": + lower = SemanticVersion(specifier.version) + break + if specifier.operator == "~=": + lower = SemanticVersion(specifier.version) + break + + if specifier.operator == ">": + next_version = __next_version(SemanticVersion(specifier.version)) + while next_version not in self and next_version <= SemanticVersion( + self._arbritrary_upper_limit + ): + next_version = __next_version(next_version) + + lower = ( + min(SemanticVersion(specifier.version).next_version("patch"), lower) + if lower and lower != SemanticVersion("0") + else SemanticVersion(specifier.version).next_version("patch") + ) + elif specifier.operator in ("<", "<=", "!="): + lower = lower or SemanticVersion("0") + else: + # The arbitrary operator (===) is not supported + raise NotImplementedError( + f"Specifier operator {specifier.operator} not implemented" + ) + if lower is None: + raise ValueError(f"Could not determine lower bound for {self}") + return lower + + def _determine_upper(self) -> SemanticVersion: + """Determine the upper version range limit based on the specifier.""" + + def __previous_version(version: SemanticVersion) -> SemanticVersion: + """Return the previous version of the given version.""" + if version.patch > 0: + return SemanticVersion( + major=version.major, minor=version.minor, patch=version.patch - 1 + ) + if version.minor > 0: + return SemanticVersion( + major=version.major, + minor=version.minor - 1, + patch=self._arbritrary_upper_limit, + ) + if version.major <= 0: + raise ValueError(f"Cannot determine previous version for {version}") + return SemanticVersion( + major=version.major - 1, + minor=self._arbritrary_upper_limit, + patch=self._arbritrary_upper_limit, + ) + + if not self._specifier: + # An arbitrary upper limit major version + return SemanticVersion(self._arbritrary_upper_limit) + + upper = None + for specifier in self._specifier: + if specifier.operator == "<=": + upper = SemanticVersion(specifier.version) + break + if specifier.operator == "==": + upper = SemanticVersion(specifier.version) + break + + if specifier.operator == "<": + previous_version = __previous_version( + SemanticVersion(specifier.version) + ) + while ( + previous_version not in self + and previous_version != SemanticVersion("0") + ): + previous_version = __previous_version(previous_version) + upper = ( + max(previous_version, upper) + if ( + upper and upper != SemanticVersion(self._arbritrary_upper_limit) + ) + else previous_version + ) + elif specifier.operator in (">", ">=", "!=", "~="): + upper = upper or SemanticVersion(self._arbritrary_upper_limit) + else: + # The arbitrary operator (===) is not supported + raise NotImplementedError( + f"Specifier operator {specifier.operator} not implemented" + ) + if upper is None: + raise ValueError(f"Could not determine upper bound for {self}") + return upper + + def __gt__(self, other: "Any") -> bool: + """Greater than (`>`) rich comparison.""" + if isinstance(other, self.__class__): + other_semver = SemanticVersion(other.upper) + else: + other_semver = SemanticVersion._validate_other_type(other) + + return self.upper > other_semver + + def __eq__(self, other: "Any") -> bool: + """Equal to (`==`) rich comparison.""" + if isinstance(other, self.__class__): + return self._specifier == other._specifier + + other_semver = SemanticVersion._validate_other_type(other) + + return other_semver in self + + def __ne__(self, other: "Any") -> bool: + """Not equal to (`!=`) rich comparison.""" + return not self.__eq__(other) + + def __lt__(self, other: "Any") -> bool: + """Less than (`<`) rich comparison.""" + if isinstance(other, self.__class__): + other_semver = SemanticVersion(other.lower) + else: + other_semver = SemanticVersion._validate_other_type(other) + + return self.lower < other_semver + + def __ge__(self, other: "Any") -> bool: + """Greater than or equal to (`>=`) rich comparison.""" + return not self.__lt__(other) + + def __le__(self, other: "Any") -> bool: + """Less than or equal to (`<=`) rich comparison.""" + return self.__lt__(other) or self.__eq__(other) + + def __and__(self, other: "Any") -> "SemanticVersionRange": + """Intersection (`&`) of two version ranges.""" + if isinstance(other, self.__class__): + return self.__class__(str(self) + "," + str(other)) + if isinstance(other, SemanticVersion): + return self.__class__(str(self) + "," + str(other)) + if isinstance(other, str): + return self.__class__(str(self) + "," + other) + raise NotImplementedError( + f"Intersection between {self.__class__.__name__} and {type(other).__name__} " + "not implemented" + ) + + def __iter__(self) -> "Iterator[SortableSpecifier]": + """Iterate over the range, or rather, the underlying SpecifierSet.""" + return iter( + SortableSpecifier(spec=str(_), prereleases=_.prereleases or None) + for _ in self._specifier + ) def update_file( - filename: Path, sub_line: "Tuple[str, str]", strip: "Optional[str]" = None + filename: Path, sub_line: tuple[str, str], strip: "Optional[str]" = None ) -> None: """Utility function for tasks to read, update, and write files""" if strip is None and filename.suffix == ".md": diff --git a/tests/tasks/test_update_deps.py b/tests/tasks/test_update_deps.py index 09ecc04b..3c618dc8 100644 --- a/tests/tasks/test_update_deps.py +++ b/tests/tasks/test_update_deps.py @@ -65,7 +65,8 @@ def test_update_deps(tmp_path: "Path", caplog: pytest.LogCaptureFixture) -> None "name", "name1<=1", "name2>=3", - "name3>=3,<2", + # Non-contiguous version ranges are currently not allowed/supported. + # "name3>=3,<2", "name4@http://foo.com", "name5 [fred,bar] @ http://foo.com ; python_version=='2.7'", "name6[quux, strange];python_version<'2.7' and platform_version=='2'", @@ -88,7 +89,7 @@ def test_update_deps(tmp_path: "Path", caplog: pytest.LogCaptureFixture) -> None **{ re.compile(r".*invoke$"): "invoke (1.7.1)\n", re.compile(r".*tomlkit$"): "tomlkit (1.0.0)", - re.compile(r".*mike$"): "mike (1.0.1)", + re.compile(r".*mike$"): "mike (1.1.1)", re.compile(r".*pytest$"): "pytest (7.1.0)", re.compile(r".*pytest-cov$"): "pytest-cov (3.1.0)", re.compile(r".*pre-commit$"): "pre-commit (2.20.0)", @@ -98,7 +99,7 @@ def test_update_deps(tmp_path: "Path", caplog: pytest.LogCaptureFixture) -> None re.compile(r".*aa$"): "aa (1.2.3)", re.compile(r".*name$"): "name (1.2.3)", }, - **{re.compile(rf".*name{i}$"): f"name{i} (1.2.3)" for i in range(1, 12)}, + **{re.compile(rf".*name{i}$"): f"name{i} (3.2.1)" for i in range(1, 12)}, } ) @@ -123,10 +124,11 @@ def test_update_deps(tmp_path: "Path", caplog: pytest.LogCaptureFixture) -> None package_name = line.split(maxsplit=1)[0] assert line == f"{package_name} ~={original_dependencies[package_name]}" elif "tomlkit" in line: - # Should be three version digits, since the original dependency had three. - assert line == "tomlkit[test,docs] ~=1.0.0" + # Should expand the ~= operator to >=,< + # Also, alphabetically sorts the extras + assert line == f"tomlkit[docs,test] >={original_dependencies['tomlkit']},<2" elif "mike" in line: - assert line == "mike >=1.0,<3" + assert line == "mike >=1.1,<3" elif "pytest-cov" in line: assert line == "pytest-cov ~=3.1" elif "pylint" in line: @@ -137,7 +139,9 @@ def test_update_deps(tmp_path: "Path", caplog: pytest.LogCaptureFixture) -> None line.startswith(package_name) for package_name in [f"name{i}" for i in range(6, 12)] ): - package_name = line.split(";", maxsplit=1)[0].strip() + package_name = ( + line.split(";", maxsplit=1)[0].split("[", maxsplit=1)[0].strip() + ) assert ( f"{package_name!r} is not version restricted and will be skipped." in caplog.text @@ -147,20 +151,18 @@ def test_update_deps(tmp_path: "Path", caplog: pytest.LogCaptureFixture) -> None "python_version" in line and "platform_version" in line ) elif "name1" in line: - assert line == "name1<=1" + # Update, since the latest version is 3.2.1 + assert line == "name1<=3" elif "name2" in line: assert line == "name2>=3" - elif "name3" in line: - assert line == "name3>=3,<2" + # elif "name3" in line: + # assert line == "name3>=3,<2" elif "name4" in line: assert line == "name4@http://foo.com" assert "'name4' is pinned to a URL and will be skipped" in caplog.text elif "name5" in line: assert line == "name5 [fred,bar] @ http://foo.com ; python_version=='2.7'" - assert ( - "'name5 [fred,bar]' is pinned to a URL and will be skipped" - in caplog.text - ) + assert "'name5' is pinned to a URL and will be skipped" in caplog.text else: pytest.fail(f"Unknown package in line: {line}") @@ -973,7 +975,7 @@ def test_ignore_version_fails() -> None: ["dependency-name=*...update-types=version-update:semver-major"], { "invoke": "invoke ~=1.7", - "tomlkit[test,docs]": "tomlkit[test,docs] ~=0.11.4", + "tomlkit[docs,test]": "tomlkit[docs,test] ~=0.11.4", "mike": "mike >=1.0,<3", "pytest": "pytest ~=7.2", "pytest-cov": "pytest-cov ~=3.1", @@ -986,13 +988,13 @@ def test_ignore_version_fails() -> None: ["dependency-name=invoke...versions=>=2"], { "invoke": "invoke ~=1.7", - "tomlkit[test,docs]": "tomlkit[test,docs] ~=1.0.0", + "tomlkit[docs,test]": "tomlkit[docs,test] >=0.11.4,<2", "mike": "mike >=1.0,<3", "pytest": "pytest ~=7.2", "pytest-cov": "pytest-cov ~=3.1", "pre-commit": "pre-commit ~=2.20", "pylint": "pylint ~=2.14", - "Sphinx": "Sphinx >=6.1.3,<6", + "Sphinx": "Sphinx >=4.5.0,<7", }, ), ( @@ -1002,59 +1004,59 @@ def test_ignore_version_fails() -> None: ], { "invoke": "invoke ~=1.7", - "tomlkit[test,docs]": "tomlkit[test,docs] ~=1.0.0", + "tomlkit[docs,test]": "tomlkit[docs,test] >=0.11.4,<2", "mike": "mike >=1.0,<3", "pytest": "pytest ~=7.2", "pytest-cov": "pytest-cov ~=3.1", "pre-commit": "pre-commit ~=2.20", "pylint": "pylint ~=2.14", - "Sphinx": "Sphinx >=6.1.3,<6", + "Sphinx": "Sphinx >=4.5.0,<7", }, ), ( ["dependency-name=pylint...versions=~=2.14"], { "invoke": "invoke ~=1.7", - "tomlkit[test,docs]": "tomlkit[test,docs] ~=1.0.0", + "tomlkit[docs,test]": "tomlkit[docs,test] >=0.11.4,<2", "mike": "mike >=1.0,<3", "pytest": "pytest ~=7.2", "pytest-cov": "pytest-cov ~=3.1", "pre-commit": "pre-commit ~=2.20", "pylint": "pylint ~=2.13", - "Sphinx": "Sphinx >=6.1.3,<6", + "Sphinx": "Sphinx >=4.5.0,<7", }, ), ( ["dependency-name=pytest"], { "invoke": "invoke ~=1.7", - "tomlkit[test,docs]": "tomlkit[test,docs] ~=1.0.0", + "tomlkit[docs,test]": "tomlkit[docs,test] >=0.11.4,<2", "mike": "mike >=1.0,<3", "pytest": "pytest ~=7.1", "pytest-cov": "pytest-cov ~=3.1", "pre-commit": "pre-commit ~=2.20", "pylint": "pylint ~=2.14", - "Sphinx": "Sphinx >=6.1.3,<6", + "Sphinx": "Sphinx >=4.5.0,<7", }, ), ( ["dependency-name=pytest-cov...update-types=version-update:semver-minor"], { "invoke": "invoke ~=1.7", - "tomlkit[test,docs]": "tomlkit[test,docs] ~=1.0.0", + "tomlkit[docs,test]": "tomlkit[docs,test] >=0.11.4,<2", "mike": "mike >=1.0,<3", "pytest": "pytest ~=7.2", "pytest-cov": "pytest-cov ~=3.0", "pre-commit": "pre-commit ~=2.20", "pylint": "pylint ~=2.14", - "Sphinx": "Sphinx >=6.1.3,<6", # This should be fixed! + "Sphinx": "Sphinx >=4.5.0,<7", }, ), ( ["dependency-name=Sphinx...versions=>=4.5.0"], { "invoke": "invoke ~=1.7", - "tomlkit[test,docs]": "tomlkit[test,docs] ~=1.0.0", + "tomlkit[docs,test]": "tomlkit[docs,test] >=0.11.4,<2", "mike": "mike >=1.0,<3", "pytest": "pytest ~=7.2", "pytest-cov": "pytest-cov ~=3.1", @@ -1088,7 +1090,7 @@ def test_ignore_rules_logic( original_dependencies = { "invoke": "1.7", "tomlkit": "0.11.4", - "mike": "1.1", + "mike": "1.0", "pytest": "7.1", "pytest-cov": "3.0", "pre-commit": "2.20", @@ -1104,7 +1106,7 @@ def test_ignore_rules_logic( dependencies = [ "invoke ~={original_dependencies['invoke']}", - "tomlkit[test,docs] ~={original_dependencies['tomlkit']}", + "tomlkit[docs,test] ~={original_dependencies['tomlkit']}", ] [project.optional-dependencies]