diff --git a/src/griffe/agents/extensions/base.py b/src/griffe/agents/extensions/base.py index 8af1e756..9cec5053 100644 --- a/src/griffe/agents/extensions/base.py +++ b/src/griffe/agents/extensions/base.py @@ -3,7 +3,10 @@ from __future__ import annotations import enum +import os +import sys from collections import defaultdict +from importlib.util import module_from_spec, spec_from_file_location from inspect import isclass from typing import TYPE_CHECKING, Any, Sequence, Union @@ -13,6 +16,7 @@ if TYPE_CHECKING: import ast + from types import ModuleType from griffe.agents.inspector import Inspector from griffe.agents.nodes import ObjectNode @@ -223,6 +227,17 @@ def after_inspection(self) -> list[InspectorExtension]: } +def _load_extension_path(path: str) -> ModuleType: + module_name = os.path.basename(path).rsplit(".", 1)[0] + spec = spec_from_file_location(module_name, path) + if not spec: + raise ExtensionNotLoadedError(f"Could not import module from path '{path}'") + module = module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore[union-attr] + return module + + def load_extension(extension: str | dict[str, Any] | Extension | type[Extension]) -> Extension: """Load a configured extension. @@ -238,6 +253,8 @@ def load_extension(extension: str | dict[str, Any] | Extension | type[Extension] Returns: An extension instance. """ + ext_object = None + if isinstance(extension, (VisitorExtension, InspectorExtension)): return extension @@ -253,16 +270,25 @@ def load_extension(extension: str | dict[str, Any] | Extension | type[Extension] if import_path in builtin_extensions: import_path = f"griffe.agents.extensions.{import_path}" + elif os.path.exists(import_path): + try: + ext_object = _load_extension_path(import_path) + except ImportError as error: + raise ExtensionNotLoadedError(f"Extension module '{import_path}' could not be found") from error + + if not ext_object: + try: + ext_object = dynamic_import(import_path) + except ModuleNotFoundError as error: + raise ExtensionNotLoadedError(f"Extension module '{import_path}' could not be found") from error + except ImportError as error: + raise ExtensionNotLoadedError(f"Error while importing extension '{import_path}': {error}") from error + + if isclass(ext_object) and issubclass(ext_object, (VisitorExtension, InspectorExtension)): + return ext_object(**options) # type: ignore[misc] try: - ext_module = dynamic_import(import_path) - except ModuleNotFoundError as error: - raise ExtensionNotLoadedError(f"Extension module '{import_path}' could not be found") from error - except ImportError as error: - raise ExtensionNotLoadedError(f"Error while importing extension module '{import_path}': {error}") from error - - try: - return ext_module.Extension(**options) + return ext_object.Extension(**options) # type: ignore[union-attr] except AttributeError as error: raise ExtensionNotLoadedError(f"Extension module '{import_path}' has no 'Extension' attribute") from error