Skip to content

Commit

Permalink
FullyQualifiedNameProvider: Optionally consider pyproject.toml files …
Browse files Browse the repository at this point in the history
…when determining a file's module name and package (#1148)
  • Loading branch information
camillol committed Jun 12, 2024
1 parent 47ff8cb commit 9f6e276
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 20 deletions.
20 changes: 17 additions & 3 deletions libcst/helpers/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#
from dataclasses import dataclass
from itertools import islice
from pathlib import PurePath
from pathlib import Path, PurePath
from typing import List, Optional

from libcst import Comment, EmptyLine, ImportFrom, Module
Expand Down Expand Up @@ -132,11 +132,25 @@ class ModuleNameAndPackage:


def calculate_module_and_package(
repo_root: StrPath, filename: StrPath
repo_root: StrPath, filename: StrPath, use_pyproject_toml: bool = False
) -> ModuleNameAndPackage:
# Given an absolute repo_root and an absolute filename, calculate the
# python module name for the file.
relative_filename = PurePath(filename).relative_to(repo_root)
if use_pyproject_toml:
# But also look for pyproject.toml files, indicating nested packages in the repo.
abs_repo_root = Path(repo_root).resolve()
abs_filename = Path(filename).resolve()
package_root = abs_filename.parent
while package_root != abs_repo_root:
if (package_root / "pyproject.toml").exists():
break
if package_root == package_root.parent:
break
package_root = package_root.parent

relative_filename = abs_filename.relative_to(package_root)
else:
relative_filename = PurePath(filename).relative_to(repo_root)
relative_filename = relative_filename.with_suffix("")

# handle special cases
Expand Down
75 changes: 74 additions & 1 deletion libcst/helpers/tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
from typing import Optional
from pathlib import Path, PurePath
from typing import Any, Optional
from unittest.mock import patch

import libcst as cst
from libcst.helpers.common import ensure_type
Expand Down Expand Up @@ -251,6 +253,77 @@ def test_calculate_module_and_package(
calculate_module_and_package(repo_root, filename), module_and_package
)

@data_provider(
(
("foo/foo/__init__.py", ModuleNameAndPackage("foo", "foo")),
("foo/foo/file.py", ModuleNameAndPackage("foo.file", "foo")),
(
"foo/foo/sub/subfile.py",
ModuleNameAndPackage("foo.sub.subfile", "foo.sub"),
),
("libs/bar/bar/thing.py", ModuleNameAndPackage("bar.thing", "bar")),
(
"noproj/some/file.py",
ModuleNameAndPackage("noproj.some.file", "noproj.some"),
),
)
)
def test_calculate_module_and_package_using_pyproject_toml(
self,
rel_path: str,
module_and_package: Optional[ModuleNameAndPackage],
) -> None:
mock_tree: dict[str, Any] = {
"home": {
"user": {
"root": {
"foo": {
"pyproject.toml": "content",
"foo": {
"__init__.py": "content",
"file.py": "content",
"sub": {
"subfile.py": "content",
},
},
},
"libs": {
"bar": {
"pyproject.toml": "content",
"bar": {
"__init__.py": "content",
"thing.py": "content",
},
}
},
"noproj": {
"some": {
"file.py": "content",
}
},
},
},
},
}
repo_root = Path("/home/user/root").resolve()
fake_root: Path = repo_root.parent.parent.parent

def mock_exists(path: PurePath) -> bool:
parts = path.relative_to(fake_root).parts
subtree = mock_tree
for part in parts:
if (subtree := subtree.get(part)) is None:
return False
return True

with patch("pathlib.Path.exists", new=mock_exists):
self.assertEqual(
calculate_module_and_package(
repo_root, repo_root / rel_path, use_pyproject_toml=True
),
module_and_package,
)

@data_provider(
(
# Providing a file outside the root should raise an exception
Expand Down
16 changes: 14 additions & 2 deletions libcst/metadata/base_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from pathlib import Path
from types import MappingProxyType
from typing import (
Callable,
Generic,
List,
Mapping,
MutableMapping,
Optional,
Protocol,
Type,
TYPE_CHECKING,
TypeVar,
Expand Down Expand Up @@ -40,6 +40,18 @@
MaybeLazyMetadataT = Union[LazyValue[_ProvidedMetadataT], _ProvidedMetadataT]


class GenCacheMethod(Protocol):
def __call__(
self,
root_path: Path,
paths: List[str],
*,
timeout: Optional[int] = None,
use_pyproject_toml: bool = False,
) -> Mapping[str, object]:
...


# We can't use an ABCMeta here, because of metaclass conflicts
class BaseMetadataProvider(MetadataDependent, Generic[_ProvidedMetadataT]):
"""
Expand All @@ -59,7 +71,7 @@ class BaseMetadataProvider(MetadataDependent, Generic[_ProvidedMetadataT]):
#: Implement gen_cache to indicate the metadata provider depends on cache from external
#: system. This function will be called by :class:`~libcst.metadata.FullRepoManager`
#: to compute required cache object per file path.
gen_cache: Optional[Callable[[Path, List[str], int], Mapping[str, object]]] = None
gen_cache: Optional[GenCacheMethod] = None

def __init__(self, cache: object = None) -> None:
super().__init__()
Expand Down
4 changes: 2 additions & 2 deletions libcst/metadata/file_path_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.

from pathlib import Path
from typing import List, Mapping, Optional
from typing import Any, List, Mapping, Optional

import libcst as cst
from libcst.metadata.base_provider import BatchableMetadataProvider
Expand Down Expand Up @@ -41,7 +41,7 @@ def visit_Module(self, node: libcst.Module) -> None:

@classmethod
def gen_cache(
cls, root_path: Path, paths: List[str], timeout: Optional[int] = None
cls, root_path: Path, paths: List[str], **kwargs: Any
) -> Mapping[str, Path]:
cache = {path: (root_path / path).resolve() for path in paths}
return cache
Expand Down
7 changes: 6 additions & 1 deletion libcst/metadata/full_repo_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
paths: Collection[str],
providers: Collection["ProviderT"],
timeout: int = 5,
use_pyproject_toml: bool = False,
) -> None:
"""
Given project root directory with pyre and watchman setup, :class:`~libcst.metadata.FullRepoManager`
Expand All @@ -38,6 +39,7 @@ def __init__(
self.root_path: Path = Path(repo_root_dir)
self._cache: Dict["ProviderT", Mapping[str, object]] = {}
self._timeout = timeout
self._use_pyproject_toml = use_pyproject_toml
self._providers = providers
self._paths: List[str] = list(paths)

Expand Down Expand Up @@ -65,7 +67,10 @@ def resolve_cache(self) -> None:
handler = provider.gen_cache
if handler:
cache[provider] = handler(
self.root_path, self._paths, self._timeout
self.root_path,
self._paths,
timeout=self._timeout,
use_pyproject_toml=self._use_pyproject_toml,
)
self._cache = cache

Expand Down
16 changes: 13 additions & 3 deletions libcst/metadata/name_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import dataclasses
from pathlib import Path
from typing import Collection, List, Mapping, Optional, Union
from typing import Any, Collection, List, Mapping, Optional, Union

import libcst as cst
from libcst._metadata_dependent import LazyValue, MetadataDependent
Expand Down Expand Up @@ -112,9 +112,19 @@ class FullyQualifiedNameProvider(BatchableMetadataProvider[Collection[QualifiedN

@classmethod
def gen_cache(
cls, root_path: Path, paths: List[str], timeout: Optional[int] = None
cls,
root_path: Path,
paths: List[str],
*,
use_pyproject_toml: bool = False,
**kwargs: Any,
) -> Mapping[str, ModuleNameAndPackage]:
cache = {path: calculate_module_and_package(root_path, path) for path in paths}
cache = {
path: calculate_module_and_package(
root_path, path, use_pyproject_toml=use_pyproject_toml
)
for path in paths
}
return cache

def __init__(self, cache: ModuleNameAndPackage) -> None:
Expand Down
8 changes: 6 additions & 2 deletions libcst/metadata/tests/test_metadata_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,13 @@ def test_hash_by_identity(self) -> None:
self.assertNotEqual(hash(mw1), hash(mw3))
self.assertNotEqual(hash(mw2), hash(mw3))

@staticmethod
def ignore_args(*args: object, **kwargs: object) -> tuple[object, ...]:
return (args, kwargs)

def test_metadata_cache(self) -> None:
class DummyMetadataProvider(BatchableMetadataProvider[None]):
gen_cache = tuple
gen_cache = self.ignore_args

m = cst.parse_module("pass")
mw = MetadataWrapper(m)
Expand All @@ -60,7 +64,7 @@ class DummyMetadataProvider(BatchableMetadataProvider[None]):
mw.resolve(DummyMetadataProvider)

class SimpleCacheMetadataProvider(BatchableMetadataProvider[object]):
gen_cache = tuple
gen_cache = self.ignore_args

def __init__(self, cache: object) -> None:
super().__init__(cache)
Expand Down
2 changes: 1 addition & 1 deletion libcst/metadata/tests/test_name_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_fully_qualified_names(file_path: str, module_str: str) -> Set[QualifiedN
cst.parse_module(dedent(module_str)),
cache={
FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache(
Path(""), [file_path], None
Path(""), [file_path], timeout=None
).get(file_path, "")
},
)
Expand Down
12 changes: 7 additions & 5 deletions libcst/metadata/type_inference_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
import subprocess
from pathlib import Path
from typing import Dict, List, Mapping, Optional, Sequence, Tuple, TypedDict
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, TypedDict

import libcst as cst
from libcst._position import CodePosition, CodeRange
Expand Down Expand Up @@ -50,11 +50,13 @@ class TypeInferenceProvider(BatchableMetadataProvider[str]):

METADATA_DEPENDENCIES = (PositionProvider,)

@staticmethod
# pyre-fixme[40]: Static method `gen_cache` cannot override a non-static method
# defined in `cst.metadata.base_provider.BaseMetadataProvider`.
@classmethod
def gen_cache(
root_path: Path, paths: List[str], timeout: Optional[int]
cls,
root_path: Path,
paths: List[str],
timeout: Optional[int] = None,
**kwargs: Any,
) -> Mapping[str, object]:
params = ",".join(f"path='{root_path / path}'" for path in paths)
cmd_args = ["pyre", "--noninteractive", "query", f"types({params})"]
Expand Down

0 comments on commit 9f6e276

Please sign in to comment.