diff --git a/mypy.ini b/mypy.ini index a63ccf787c..3581693038 100644 --- a/mypy.ini +++ b/mypy.ini @@ -17,6 +17,8 @@ exclude = (?x)( | ^setuptools/config/_validate_pyproject/ # Auto-generated | ^setuptools/tests/bdist_wheel_testdata/ # Duplicate module name ) +# Too many false-positives +disable_error_code = overload-overlap # Ignoring attr-defined because setuptools wraps a lot of distutils classes, adding new attributes, # w/o updating all the attributes and return types from the base classes for type-checkers to understand diff --git a/pkg_resources/__init__.py b/pkg_resources/__init__.py index 04c6c68cb8..c4ace5aa77 100644 --- a/pkg_resources/__init__.py +++ b/pkg_resources/__init__.py @@ -34,6 +34,7 @@ import types from typing import ( Any, + Literal, Dict, Iterator, Mapping, @@ -98,8 +99,8 @@ from pkg_resources.extern.platformdirs import user_cache_dir as _user_cache_dir if TYPE_CHECKING: + from _typeshed import BytesPath, StrPath, StrOrBytesPath from typing_extensions import Self - from _typeshed import StrPath, StrOrBytesPath, BytesPath warnings.warn( "pkg_resources is deprecated as an API. " @@ -110,15 +111,20 @@ _T = TypeVar("_T") +_DistributionT = TypeVar("_DistributionT", bound="Distribution") # Type aliases _NestedStr = Union[str, Iterable[Union[str, Iterable["_NestedStr"]]]] +_InstallerTypeT = Callable[["Requirement"], "_DistributionT"] _InstallerType = Callable[["Requirement"], Union["Distribution", None]] _PkgReqType = Union[str, "Requirement"] _EPDistType = Union["Distribution", _PkgReqType] _MetadataType = Union["IResourceProvider", None] +_ResolvedEntryPoint = Any # Can be any attribute in the module +_ResourceStream = Any # TODO / Incomplete: A readable file-like object # Any object works, but let's indicate we expect something like a module (optionally has __loader__ or __file__) _ModuleLike = Union[object, types.ModuleType] -_ProviderFactoryType = Callable[[_ModuleLike], "IResourceProvider"] +# Any: Should be _ModuleLike but we end up with issues where _ModuleLike doesn't have _ZipLoaderModule's __loader__ +_ProviderFactoryType = Callable[[Any], "IResourceProvider"] _DistFinderType = Callable[[_T, str, bool], Iterable["Distribution"]] _NSHandlerType = Callable[[_T, str, str, types.ModuleType], Union[str, None]] _AdapterT = TypeVar( @@ -131,6 +137,10 @@ class _LoaderProtocol(Protocol): def load_module(self, fullname: str, /) -> types.ModuleType: ... +class _ZipLoaderModule(Protocol): + __loader__: zipimport.zipimporter + + _PEP440_FALLBACK = re.compile(r"^v?(?P(?:[0-9]+!)?[0-9]+(?:\.[0-9]+)*)", re.I) @@ -404,7 +414,11 @@ def register_loader_type( _provider_factories[loader_type] = provider_factory -def get_provider(moduleOrReq: str | Requirement): +@overload +def get_provider(moduleOrReq: str) -> IResourceProvider: ... +@overload +def get_provider(moduleOrReq: Requirement) -> Distribution: ... +def get_provider(moduleOrReq: str | Requirement) -> IResourceProvider | Distribution: """Return an IResourceProvider for the named module or requirement""" if isinstance(moduleOrReq, Requirement): return working_set.find(moduleOrReq) or require(str(moduleOrReq))[0] @@ -515,22 +529,33 @@ def compatible_platforms(provided: str | None, required: str | None): return False -def get_distribution(dist: _EPDistType): +@overload +def get_distribution(dist: _DistributionT) -> _DistributionT: ... +@overload +def get_distribution(dist: _PkgReqType) -> Distribution: ... +def get_distribution(dist: Distribution | _PkgReqType) -> Distribution: """Return a current distribution object for a Requirement or string""" if isinstance(dist, str): dist = Requirement.parse(dist) if isinstance(dist, Requirement): - dist = get_provider(dist) + # Bad type narrowing, dist has to be a Requirement here, so get_provider has to return Distribution + dist = get_provider(dist) # type: ignore[assignment] if not isinstance(dist, Distribution): - raise TypeError("Expected string, Requirement, or Distribution", dist) + raise TypeError("Expected str, Requirement, or Distribution", dist) return dist -def load_entry_point(dist: _EPDistType, group: str, name: str): +def load_entry_point(dist: _EPDistType, group: str, name: str) -> _ResolvedEntryPoint: """Return `name` entry point of `group` for `dist` or raise ImportError""" return get_distribution(dist).load_entry_point(group, name) +@overload +def get_entry_map( + dist: _EPDistType, group: None = None +) -> dict[str, dict[str, EntryPoint]]: ... +@overload +def get_entry_map(dist: _EPDistType, group: str) -> dict[str, EntryPoint]: ... def get_entry_map(dist: _EPDistType, group: str | None = None): """Return the entry point map for `group`, or the full entry map""" return get_distribution(dist).get_entry_map(group) @@ -545,10 +570,10 @@ class IMetadataProvider(Protocol): def has_metadata(self, name: str) -> bool: """Does the package's distribution contain the named metadata?""" - def get_metadata(self, name: str): + def get_metadata(self, name: str) -> str: """The named metadata resource as a string""" - def get_metadata_lines(self, name: str): + def get_metadata_lines(self, name: str) -> Iterator[str]: """Yield named metadata resource as list of non-blank non-comment lines Leading and trailing whitespace is stripped from each line, and lines @@ -557,22 +582,26 @@ def get_metadata_lines(self, name: str): def metadata_isdir(self, name: str) -> bool: """Is the named metadata a directory? (like ``os.path.isdir()``)""" - def metadata_listdir(self, name: str): + def metadata_listdir(self, name: str) -> list[str]: """List of metadata names in the directory (like ``os.listdir()``)""" - def run_script(self, script_name: str, namespace: dict[str, Any]): + def run_script(self, script_name: str, namespace: dict[str, Any]) -> None: """Execute the named script in the supplied namespace dictionary""" class IResourceProvider(IMetadataProvider, Protocol): """An object that provides access to package resources""" - def get_resource_filename(self, manager: ResourceManager, resource_name: str): + def get_resource_filename( + self, manager: ResourceManager, resource_name: str + ) -> str: """Return a true filesystem path for `resource_name` `manager` must be a ``ResourceManager``""" - def get_resource_stream(self, manager: ResourceManager, resource_name: str): + def get_resource_stream( + self, manager: ResourceManager, resource_name: str + ) -> _ResourceStream: """Return a readable file-like object for `resource_name` `manager` must be a ``ResourceManager``""" @@ -584,13 +613,13 @@ def get_resource_string( `manager` must be a ``ResourceManager``""" - def has_resource(self, resource_name: str): + def has_resource(self, resource_name: str) -> bool: """Does the package contain the named resource?""" - def resource_isdir(self, resource_name: str): + def resource_isdir(self, resource_name: str) -> bool: """Is the named resource a directory? (like ``os.path.isdir()``)""" - def resource_listdir(self, resource_name: str): + def resource_listdir(self, resource_name: str) -> list[str]: """List of resource names in the directory (like ``os.listdir()``)""" @@ -773,6 +802,26 @@ def add( keys2.append(dist.key) self._added_new(dist) + @overload + def resolve( + self, + requirements: Iterable[Requirement], + env: Environment | None, + installer: _InstallerTypeT[_DistributionT], + replace_conflicting: bool = False, + extras: tuple[str, ...] | None = None, + ) -> list[_DistributionT]: ... + @overload + def resolve( + self, + requirements: Iterable[Requirement], + env: Environment | None = None, + *, + installer: _InstallerTypeT[_DistributionT], + replace_conflicting: bool = False, + extras: tuple[str, ...] | None = None, + ) -> list[_DistributionT]: ... + @overload def resolve( self, requirements: Iterable[Requirement], @@ -780,7 +829,15 @@ def resolve( installer: _InstallerType | None = None, replace_conflicting: bool = False, extras: tuple[str, ...] | None = None, - ): + ) -> list[Distribution]: ... + def resolve( + self, + requirements: Iterable[Requirement], + env: Environment | None = None, + installer: _InstallerType | None | _InstallerTypeT[_DistributionT] = None, + replace_conflicting: bool = False, + extras: tuple[str, ...] | None = None, + ) -> list[Distribution] | list[_DistributionT]: """List all distributions needed to (recursively) meet `requirements` `requirements` must be a sequence of ``Requirement`` objects. `env`, @@ -878,13 +935,41 @@ def _resolve_dist( raise VersionConflict(dist, req).with_context(dependent_req) return dist + @overload + def find_plugins( + self, + plugin_env: Environment, + full_env: Environment | None, + installer: _InstallerTypeT[_DistributionT], + fallback: bool = True, + ) -> tuple[list[_DistributionT], dict[Distribution, Exception]]: ... + @overload + def find_plugins( + self, + plugin_env: Environment, + full_env: Environment | None = None, + *, + installer: _InstallerTypeT[_DistributionT], + fallback: bool = True, + ) -> tuple[list[_DistributionT], dict[Distribution, Exception]]: ... + @overload def find_plugins( self, plugin_env: Environment, full_env: Environment | None = None, installer: _InstallerType | None = None, fallback: bool = True, - ): + ) -> tuple[list[Distribution], dict[Distribution, Exception]]: ... + def find_plugins( + self, + plugin_env: Environment, + full_env: Environment | None = None, + installer: _InstallerType | None | _InstallerTypeT[_DistributionT] = None, + fallback: bool = True, + ) -> tuple[ + list[Distribution] | list[_DistributionT], + dict[Distribution, Exception], + ]: """Find all activatable distributions in `plugin_env` Example usage:: @@ -923,8 +1008,8 @@ def find_plugins( # scan project names in alphabetic order plugin_projects.sort() - error_info = {} - distributions = {} + error_info: dict[Distribution, Exception] = {} + distributions: dict[Distribution, Exception | None] = {} if full_env is None: env = Environment(self.entries) @@ -1121,13 +1206,29 @@ def add(self, dist: Distribution): dists.append(dist) dists.sort(key=operator.attrgetter('hashcmp'), reverse=True) + @overload def best_match( self, req: Requirement, working_set: WorkingSet, - installer: Callable[[Requirement], Any] | None = None, + installer: _InstallerTypeT[_DistributionT], replace_conflicting: bool = False, - ): + ) -> _DistributionT: ... + @overload + def best_match( + self, + req: Requirement, + working_set: WorkingSet, + installer: _InstallerType | None = None, + replace_conflicting: bool = False, + ) -> Distribution | None: ... + def best_match( + self, + req: Requirement, + working_set: WorkingSet, + installer: _InstallerType | None | _InstallerTypeT[_DistributionT] = None, + replace_conflicting: bool = False, + ) -> Distribution | None: """Find distribution best matching `req` and usable on `working_set` This calls the ``find(req)`` method of the `working_set` to see if a @@ -1154,11 +1255,32 @@ def best_match( # try to download/install return self.obtain(req, installer) + @overload def obtain( self, requirement: Requirement, - installer: Callable[[Requirement], Any] | None = None, - ): + installer: _InstallerTypeT[_DistributionT], + ) -> _DistributionT: ... + @overload + def obtain( + self, + requirement: Requirement, + installer: Callable[[Requirement], None] | None = None, + ) -> None: ... + @overload + def obtain( + self, + requirement: Requirement, + installer: _InstallerType | None = None, + ) -> Distribution | None: ... + def obtain( + self, + requirement: Requirement, + installer: Callable[[Requirement], None] + | _InstallerType + | None + | _InstallerTypeT[_DistributionT] = None, + ) -> Distribution | None: """Obtain a distribution matching `requirement` (e.g. via download) Obtain a distro that matches requirement (e.g. via download). In the @@ -1515,7 +1637,6 @@ class NullProvider: egg_name: str | None = None egg_info: str | None = None loader: _LoaderProtocol | None = None - module_path: str | None # Some subclasses can have a None module_path def __init__(self, module: _ModuleLike): self.loader = getattr(module, '__loader__', None) @@ -1558,7 +1679,7 @@ def get_metadata(self, name: str): exc.reason += ' in {} file at path: {}'.format(name, path) raise - def get_metadata_lines(self, name: str): + def get_metadata_lines(self, name: str) -> Iterator[str]: return yield_lines(self.get_metadata(name)) def resource_isdir(self, resource_name: str): @@ -1570,7 +1691,7 @@ def metadata_isdir(self, name: str) -> bool: def resource_listdir(self, resource_name: str): return self._listdir(self._fn(self.module_path, resource_name)) - def metadata_listdir(self, name: str): + def metadata_listdir(self, name: str) -> list[str]: if self.egg_info: return self._listdir(self._fn(self.egg_info, name)) return [] @@ -1583,6 +1704,7 @@ def run_script(self, script_name: str, namespace: dict[str, Any]): **locals() ), ) + script_text = self.get_metadata(script).replace('\r\n', '\n') script_text = script_text.replace('\r', '\n') script_filename = self._fn(self.egg_info, script) @@ -1613,12 +1735,16 @@ def _isdir(self, path) -> bool: "Can't perform this operation for unregistered loader type" ) - def _listdir(self, path): + def _listdir(self, path) -> list[str]: raise NotImplementedError( "Can't perform this operation for unregistered loader type" ) - def _fn(self, base, resource_name: str): + def _fn(self, base: str | None, resource_name: str): + if base is None: + raise TypeError( + "`base` parameter in `_fn` is `None`. Either override this method or check the parameter first." + ) self._validate_resource_path(resource_name) if resource_name: return os.path.join(base, *resource_name.split('/')) @@ -1778,7 +1904,8 @@ def _register(cls): class EmptyProvider(NullProvider): """Provider that returns nothing for all requests""" - module_path = None + # A special case, we don't want all Providers inheriting from NullProvider to have a potentially None module_path + module_path: str | None = None # type: ignore[assignment] _isdir = _has = lambda self, path: False @@ -1854,7 +1981,7 @@ class ZipProvider(EggProvider): # ZipProvider's loader should always be a zipimporter or equivalent loader: zipimport.zipimporter - def __init__(self, module: _ModuleLike): + def __init__(self, module: _ZipLoaderModule): super().__init__(module) self.zip_pre = self.loader.archive + os.sep @@ -1903,7 +2030,7 @@ def _get_date_and_size(zip_stat): return timestamp, size # FIXME: 'ZipProvider._extract_resource' is too complex (12) - def _extract_resource(self, manager: ResourceManager, zip_path): # noqa: C901 + def _extract_resource(self, manager: ResourceManager, zip_path) -> str: # noqa: C901 if zip_path in self._index(): for name in self._index()[zip_path]: last = self._extract_resource(manager, os.path.join(zip_path, name)) @@ -2040,7 +2167,7 @@ def _get_metadata_path(self, name): def has_metadata(self, name: str) -> bool: return name == 'PKG-INFO' and os.path.isfile(self.path) - def get_metadata(self, name): + def get_metadata(self, name: str): if name != 'PKG-INFO': raise KeyError("No metadata except PKG-INFO is available") @@ -2056,7 +2183,7 @@ def _warn_on_replacement(self, metadata): msg = tmpl.format(**locals()) warnings.warn(msg) - def get_metadata_lines(self, name): + def get_metadata_lines(self, name: str) -> Iterator[str]: return yield_lines(self.get_metadata(name)) @@ -2582,12 +2709,26 @@ def __str__(self): def __repr__(self): return "EntryPoint.parse(%r)" % str(self) + @overload + def load( + self, + require: Literal[True] = True, + env: Environment | None = None, + installer: _InstallerType | None = None, + ) -> _ResolvedEntryPoint: ... + @overload + def load( + self, + require: Literal[False], + *args: Any, + **kwargs: Any, + ) -> _ResolvedEntryPoint: ... def load( self, require: bool = True, *args: Environment | _InstallerType | None, **kwargs: Environment | _InstallerType | None, - ): + ) -> _ResolvedEntryPoint: """ Require packages for this EntryPoint, then resolve it. """ @@ -2604,7 +2745,7 @@ def load( self.require(*args, **kwargs) # type: ignore return self.resolve() - def resolve(self): + def resolve(self) -> _ResolvedEntryPoint: """ Resolve the entry point from its module and attrs. """ @@ -3033,13 +3174,17 @@ def as_requirement(self): return Requirement.parse(spec) - def load_entry_point(self, group: str, name: str): + def load_entry_point(self, group: str, name: str) -> _ResolvedEntryPoint: """Return the `name` entry point of `group` or raise ImportError""" ep = self.get_entry_info(group, name) if ep is None: raise ImportError("Entry point %r not found" % ((group, name),)) return ep.load() + @overload + def get_entry_map(self, group: None = None) -> dict[str, dict[str, EntryPoint]]: ... + @overload + def get_entry_map(self, group: str) -> dict[str, EntryPoint]: ... def get_entry_map(self, group: str | None = None): """Return the entry point map for `group`, or the full entry map""" if not hasattr(self, "_ep_map"): diff --git a/ruff.toml b/ruff.toml index 9f0b42cea9..8828fe61a5 100644 --- a/ruff.toml +++ b/ruff.toml @@ -12,7 +12,7 @@ extend-select = [ # local "FA", # flake8-future-annotations "F404", # late-future-import - "PYI", # flake8-pyi + "PYI", # flake8-pyi "UP", # pyupgrade "YTT", # flake8-2020 ] @@ -40,10 +40,6 @@ ignore = [ "ISC002", ] -[lint.per-file-ignores] -# Auto-generated code -"setuptools/config/_validate_pyproject/*" = ["FA100"] - [format] # Enable preview to get hugged parenthesis unwrapping preview = true