Skip to content

Commit

Permalink
refactor: Always use search paths to import modules
Browse files Browse the repository at this point in the history
  • Loading branch information
pawamoy committed Jan 28, 2022
1 parent 7290642 commit a9a378f
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 39 deletions.
13 changes: 10 additions & 3 deletions src/griffe/agents/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def inspect(
module_name: str,
*,
filepath: Path | None = None,
import_paths: list[Path] | None = None,
extensions: Extensions | None = None,
parent: Module | None = None,
docstring_parser: Parser | None = None,
Expand All @@ -68,6 +69,7 @@ def inspect(
Parameters:
module_name: The module name (as when importing [from] it).
filepath: The module file path.
import_paths: Paths to import the module from.
extensions: The extensions to use when inspecting the module.
parent: The optional parent of this module.
docstring_parser: The docstring parser to use. By default, no parsing is done.
Expand All @@ -77,6 +79,8 @@ def inspect(
Returns:
The module, with its members populated.
"""
if not import_paths and filepath:
import_paths = [filepath.parent]
return Inspector(
module_name,
filepath,
Expand All @@ -85,7 +89,7 @@ def inspect(
docstring_parser=docstring_parser,
docstring_options=docstring_options,
lines_collection=lines_collection,
).get_module()
).get_module(import_paths)


_compiled_modules = {*sys.builtin_module_names, "_socket", "_struct"}
Expand Down Expand Up @@ -185,18 +189,21 @@ def _get_docstring(self, node: ObjectNode) -> Docstring | None:
parser_options=self.docstring_options,
)

def get_module(self) -> Module:
def get_module(self, import_paths: list[Path] | None = None) -> Module:
"""Build and return the object representing the module attached to this inspector.
This method triggers a complete inspection of the module members.
Parameters:
import_paths: Paths replacing `sys.path` to import the module.
Returns:
A module instance.
"""
import_path = self.module_name
if self.parent is not None:
import_path = f"{self.parent.path}.{import_path}"
value = dynamic_import(import_path)
value = dynamic_import(import_path, import_paths)
top_node = ObjectNode(value, self.module_name)
self.inspect(top_node)
return self.current.module
Expand Down
51 changes: 39 additions & 12 deletions src/griffe/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,44 @@

from __future__ import annotations

import sys
from contextlib import contextmanager
from importlib import import_module
from typing import Any
from pathlib import Path
from typing import Any, Iterator


def dynamic_import(import_path: str) -> Any: # noqa: WPS231
@contextmanager
def sys_path(*paths: str | Path) -> Iterator[None]:
"""Redefine `sys.path` temporarily.
Parameters:
*paths: The paths to use when importing modules.
If no paths are given, keep `sys.path` untouched.
Yields:
Nothing.
"""
if not paths:
yield
return
old_path = sys.path
sys.path = [str(path) for path in paths]
try:
yield
finally:
sys.path = old_path


def dynamic_import(import_path: str, import_paths: list[Path] | None = None) -> Any: # noqa: WPS231
"""Dynamically import the specified object.
It can be a module, class, method, function, attribute,
nested arbitrarily.
Parameters:
import_path: The path of the object to import.
import_paths: The paths to import the object from.
Raises:
ModuleNotFoundError: When the object's module could not be found.
Expand All @@ -24,16 +50,17 @@ def dynamic_import(import_path: str) -> Any: # noqa: WPS231
module_parts: list[str] = import_path.split(".")
object_parts: list[str] = []

while True:
module_path = ".".join(module_parts)
try: # noqa: WPS503 (false-positive)
module = import_module(module_path)
except ModuleNotFoundError:
if len(module_parts) == 1:
raise
object_parts.insert(0, module_parts.pop(-1))
else:
break
with sys_path(*(import_paths or ())):
while True:
module_path = ".".join(module_parts)
try: # noqa: WPS503 (false-positive)
module = import_module(module_path)
except ModuleNotFoundError:
if len(module_parts) == 1:
raise
object_parts.insert(0, module_parts.pop(-1))
else:
break

value = module
for part in object_parts:
Expand Down
1 change: 1 addition & 0 deletions src/griffe/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def _inspect_module(self, module_name: str, filepath: Path | None = None, parent
module = inspect(
module_name,
filepath=filepath,
import_paths=self.finder.search_paths,
extensions=self.extensions,
parent=parent,
docstring_parser=self.docstring_parser,
Expand Down
29 changes: 5 additions & 24 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,6 @@
TMPDIR_PREFIX = "griffe_"


@contextmanager
def sys_path(directory: str | Path) -> Iterator[None]:
"""Insert a directory in front of `sys.path` then remove it.
Parameters:
directory: The directory to insert.
Yields:
Nothing.
"""
old_path = sys.path
sys.path.insert(0, str(directory))
try:
yield
finally:
sys.path = old_path


@contextmanager
def temporary_pyfile(code: str) -> Iterator[tuple[str, Path]]:
"""Create a module.py file containing the given code in a temporary directory.
Expand Down Expand Up @@ -115,12 +97,11 @@ def temporary_inspected_module(code: str) -> Iterator[Module]:
The inspected module.
"""
with temporary_pyfile(dedent(code)) as (name, path):
with sys_path(path.parent):
try:
yield inspect(name, filepath=path)
finally:
del sys.modules["module"] # noqa: WPS420
invalidate_caches()
try:
yield inspect(name, filepath=path, import_paths=[path.parent])
finally:
del sys.modules["module"] # noqa: WPS420
invalidate_caches()


def vtree(*objects: Object, return_leaf: bool = False) -> Object:
Expand Down

0 comments on commit a9a378f

Please sign in to comment.