diff --git a/libcst/helpers/module.py b/libcst/helpers/module.py index 3c26122d3..37e6af088 100644 --- a/libcst/helpers/module.py +++ b/libcst/helpers/module.py @@ -5,7 +5,7 @@ # from dataclasses import dataclass from itertools import islice -from pathlib import PurePath +from pathlib import Path, PurePath from typing import List, Optional from libcst import Comment, EmptyLine, ImportFrom, Module @@ -132,11 +132,25 @@ class ModuleNameAndPackage: def calculate_module_and_package( - repo_root: StrPath, filename: StrPath + repo_root: StrPath, filename: StrPath, use_pyproject_toml: bool = False ) -> ModuleNameAndPackage: # Given an absolute repo_root and an absolute filename, calculate the # python module name for the file. - relative_filename = PurePath(filename).relative_to(repo_root) + if use_pyproject_toml: + # But also look for pyproject.toml files, indicating nested packages in the repo. + abs_repo_root = Path(repo_root).resolve() + abs_filename = Path(filename).resolve() + package_root = abs_filename.parent + while package_root != abs_repo_root: + if (package_root / "pyproject.toml").exists(): + break + if package_root == package_root.parent: + break + package_root = package_root.parent + + relative_filename = abs_filename.relative_to(package_root) + else: + relative_filename = PurePath(filename).relative_to(repo_root) relative_filename = relative_filename.with_suffix("") # handle special cases diff --git a/libcst/helpers/tests/test_module.py b/libcst/helpers/tests/test_module.py index 7260f5cc8..815e1fa23 100644 --- a/libcst/helpers/tests/test_module.py +++ b/libcst/helpers/tests/test_module.py @@ -3,7 +3,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # -from typing import Optional +from pathlib import Path, PurePath +from typing import Any, Optional +from unittest.mock import patch import libcst as cst from libcst.helpers.common import ensure_type @@ -251,6 +253,77 @@ def test_calculate_module_and_package( calculate_module_and_package(repo_root, filename), module_and_package ) + @data_provider( + ( + ("foo/foo/__init__.py", ModuleNameAndPackage("foo", "foo")), + ("foo/foo/file.py", ModuleNameAndPackage("foo.file", "foo")), + ( + "foo/foo/sub/subfile.py", + ModuleNameAndPackage("foo.sub.subfile", "foo.sub"), + ), + ("libs/bar/bar/thing.py", ModuleNameAndPackage("bar.thing", "bar")), + ( + "noproj/some/file.py", + ModuleNameAndPackage("noproj.some.file", "noproj.some"), + ), + ) + ) + def test_calculate_module_and_package_using_pyproject_toml( + self, + rel_path: str, + module_and_package: Optional[ModuleNameAndPackage], + ) -> None: + mock_tree: dict[str, Any] = { + "home": { + "user": { + "root": { + "foo": { + "pyproject.toml": "content", + "foo": { + "__init__.py": "content", + "file.py": "content", + "sub": { + "subfile.py": "content", + }, + }, + }, + "libs": { + "bar": { + "pyproject.toml": "content", + "bar": { + "__init__.py": "content", + "thing.py": "content", + }, + } + }, + "noproj": { + "some": { + "file.py": "content", + } + }, + }, + }, + }, + } + repo_root = Path("/home/user/root").resolve() + fake_root: Path = repo_root.parent.parent.parent + + def mock_exists(path: PurePath) -> bool: + parts = path.relative_to(fake_root).parts + subtree = mock_tree + for part in parts: + if (subtree := subtree.get(part)) is None: + return False + return True + + with patch("pathlib.Path.exists", new=mock_exists): + self.assertEqual( + calculate_module_and_package( + repo_root, repo_root / rel_path, use_pyproject_toml=True + ), + module_and_package, + ) + @data_provider( ( # Providing a file outside the root should raise an exception diff --git a/libcst/metadata/base_provider.py b/libcst/metadata/base_provider.py index 1c113f57a..2e03416f0 100644 --- a/libcst/metadata/base_provider.py +++ b/libcst/metadata/base_provider.py @@ -6,12 +6,12 @@ from pathlib import Path from types import MappingProxyType from typing import ( - Callable, Generic, List, Mapping, MutableMapping, Optional, + Protocol, Type, TYPE_CHECKING, TypeVar, @@ -40,6 +40,18 @@ MaybeLazyMetadataT = Union[LazyValue[_ProvidedMetadataT], _ProvidedMetadataT] +class GenCacheMethod(Protocol): + def __call__( + self, + root_path: Path, + paths: List[str], + *, + timeout: Optional[int] = None, + use_pyproject_toml: bool = False, + ) -> Mapping[str, object]: + ... + + # We can't use an ABCMeta here, because of metaclass conflicts class BaseMetadataProvider(MetadataDependent, Generic[_ProvidedMetadataT]): """ @@ -59,7 +71,7 @@ class BaseMetadataProvider(MetadataDependent, Generic[_ProvidedMetadataT]): #: Implement gen_cache to indicate the metadata provider depends on cache from external #: system. This function will be called by :class:`~libcst.metadata.FullRepoManager` #: to compute required cache object per file path. - gen_cache: Optional[Callable[[Path, List[str], int], Mapping[str, object]]] = None + gen_cache: Optional[GenCacheMethod] = None def __init__(self, cache: object = None) -> None: super().__init__() diff --git a/libcst/metadata/file_path_provider.py b/libcst/metadata/file_path_provider.py index 5ed9baa68..6ab01b5fa 100644 --- a/libcst/metadata/file_path_provider.py +++ b/libcst/metadata/file_path_provider.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. from pathlib import Path -from typing import List, Mapping, Optional +from typing import Any, List, Mapping, Optional import libcst as cst from libcst.metadata.base_provider import BatchableMetadataProvider @@ -41,7 +41,7 @@ def visit_Module(self, node: libcst.Module) -> None: @classmethod def gen_cache( - cls, root_path: Path, paths: List[str], timeout: Optional[int] = None + cls, root_path: Path, paths: List[str], **kwargs: Any ) -> Mapping[str, Path]: cache = {path: (root_path / path).resolve() for path in paths} return cache diff --git a/libcst/metadata/full_repo_manager.py b/libcst/metadata/full_repo_manager.py index 83bb6e83a..770ba1f63 100644 --- a/libcst/metadata/full_repo_manager.py +++ b/libcst/metadata/full_repo_manager.py @@ -22,6 +22,7 @@ def __init__( paths: Collection[str], providers: Collection["ProviderT"], timeout: int = 5, + use_pyproject_toml: bool = False, ) -> None: """ Given project root directory with pyre and watchman setup, :class:`~libcst.metadata.FullRepoManager` @@ -38,6 +39,7 @@ def __init__( self.root_path: Path = Path(repo_root_dir) self._cache: Dict["ProviderT", Mapping[str, object]] = {} self._timeout = timeout + self._use_pyproject_toml = use_pyproject_toml self._providers = providers self._paths: List[str] = list(paths) @@ -65,7 +67,10 @@ def resolve_cache(self) -> None: handler = provider.gen_cache if handler: cache[provider] = handler( - self.root_path, self._paths, self._timeout + self.root_path, + self._paths, + timeout=self._timeout, + use_pyproject_toml=self._use_pyproject_toml, ) self._cache = cache diff --git a/libcst/metadata/name_provider.py b/libcst/metadata/name_provider.py index 1868fa665..7de76eb5e 100644 --- a/libcst/metadata/name_provider.py +++ b/libcst/metadata/name_provider.py @@ -5,7 +5,7 @@ import dataclasses from pathlib import Path -from typing import Collection, List, Mapping, Optional, Union +from typing import Any, Collection, List, Mapping, Optional, Union import libcst as cst from libcst._metadata_dependent import LazyValue, MetadataDependent @@ -112,9 +112,19 @@ class FullyQualifiedNameProvider(BatchableMetadataProvider[Collection[QualifiedN @classmethod def gen_cache( - cls, root_path: Path, paths: List[str], timeout: Optional[int] = None + cls, + root_path: Path, + paths: List[str], + *, + use_pyproject_toml: bool = False, + **kwargs: Any, ) -> Mapping[str, ModuleNameAndPackage]: - cache = {path: calculate_module_and_package(root_path, path) for path in paths} + cache = { + path: calculate_module_and_package( + root_path, path, use_pyproject_toml=use_pyproject_toml + ) + for path in paths + } return cache def __init__(self, cache: ModuleNameAndPackage) -> None: diff --git a/libcst/metadata/tests/test_metadata_wrapper.py b/libcst/metadata/tests/test_metadata_wrapper.py index ee61e14fc..9063a99ad 100644 --- a/libcst/metadata/tests/test_metadata_wrapper.py +++ b/libcst/metadata/tests/test_metadata_wrapper.py @@ -48,9 +48,13 @@ def test_hash_by_identity(self) -> None: self.assertNotEqual(hash(mw1), hash(mw3)) self.assertNotEqual(hash(mw2), hash(mw3)) + @staticmethod + def ignore_args(*args: object, **kwargs: object) -> tuple[object, ...]: + return (args, kwargs) + def test_metadata_cache(self) -> None: class DummyMetadataProvider(BatchableMetadataProvider[None]): - gen_cache = tuple + gen_cache = self.ignore_args m = cst.parse_module("pass") mw = MetadataWrapper(m) @@ -60,7 +64,7 @@ class DummyMetadataProvider(BatchableMetadataProvider[None]): mw.resolve(DummyMetadataProvider) class SimpleCacheMetadataProvider(BatchableMetadataProvider[object]): - gen_cache = tuple + gen_cache = self.ignore_args def __init__(self, cache: object) -> None: super().__init__(cache) diff --git a/libcst/metadata/tests/test_name_provider.py b/libcst/metadata/tests/test_name_provider.py index 80215dc66..fbd3631af 100644 --- a/libcst/metadata/tests/test_name_provider.py +++ b/libcst/metadata/tests/test_name_provider.py @@ -54,7 +54,7 @@ def get_fully_qualified_names(file_path: str, module_str: str) -> Set[QualifiedN cst.parse_module(dedent(module_str)), cache={ FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache( - Path(""), [file_path], None + Path(""), [file_path], timeout=None ).get(file_path, "") }, ) diff --git a/libcst/metadata/type_inference_provider.py b/libcst/metadata/type_inference_provider.py index c9c1fc9ac..f00c97b6a 100644 --- a/libcst/metadata/type_inference_provider.py +++ b/libcst/metadata/type_inference_provider.py @@ -6,7 +6,7 @@ import json import subprocess from pathlib import Path -from typing import Dict, List, Mapping, Optional, Sequence, Tuple, TypedDict +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, TypedDict import libcst as cst from libcst._position import CodePosition, CodeRange @@ -50,11 +50,13 @@ class TypeInferenceProvider(BatchableMetadataProvider[str]): METADATA_DEPENDENCIES = (PositionProvider,) - @staticmethod - # pyre-fixme[40]: Static method `gen_cache` cannot override a non-static method - # defined in `cst.metadata.base_provider.BaseMetadataProvider`. + @classmethod def gen_cache( - root_path: Path, paths: List[str], timeout: Optional[int] + cls, + root_path: Path, + paths: List[str], + timeout: Optional[int] = None, + **kwargs: Any, ) -> Mapping[str, object]: params = ",".join(f"path='{root_path / path}'" for path in paths) cmd_args = ["pyre", "--noninteractive", "query", f"types({params})"]