Skip to content

Commit

Permalink
feat: Allow loading extension from file path
Browse files Browse the repository at this point in the history
  • Loading branch information
pawamoy committed Feb 25, 2023
1 parent e8ad889 commit 131454e
Showing 1 changed file with 34 additions and 8 deletions.
42 changes: 34 additions & 8 deletions src/griffe/agents/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from __future__ import annotations

import enum
import os
import sys
from collections import defaultdict
from importlib.util import module_from_spec, spec_from_file_location
from inspect import isclass
from typing import TYPE_CHECKING, Any, Sequence, Union

Expand All @@ -13,6 +16,7 @@

if TYPE_CHECKING:
import ast
from types import ModuleType

from griffe.agents.inspector import Inspector
from griffe.agents.nodes import ObjectNode
Expand Down Expand Up @@ -223,6 +227,17 @@ def after_inspection(self) -> list[InspectorExtension]:
}


def _load_extension_path(path: str) -> ModuleType:
module_name = os.path.basename(path).rsplit(".", 1)[0]
spec = spec_from_file_location(module_name, path)
if not spec:
raise ExtensionNotLoadedError(f"Could not import module from path '{path}'")
module = module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module) # type: ignore[union-attr]
return module


def load_extension(extension: str | dict[str, Any] | Extension | type[Extension]) -> Extension:
"""Load a configured extension.
Expand All @@ -238,6 +253,8 @@ def load_extension(extension: str | dict[str, Any] | Extension | type[Extension]
Returns:
An extension instance.
"""
ext_object = None

if isinstance(extension, (VisitorExtension, InspectorExtension)):
return extension

Expand All @@ -253,16 +270,25 @@ def load_extension(extension: str | dict[str, Any] | Extension | type[Extension]

if import_path in builtin_extensions:
import_path = f"griffe.agents.extensions.{import_path}"
elif os.path.exists(import_path):
try:
ext_object = _load_extension_path(import_path)
except ImportError as error:
raise ExtensionNotLoadedError(f"Extension module '{import_path}' could not be found") from error

if not ext_object:
try:
ext_object = dynamic_import(import_path)
except ModuleNotFoundError as error:
raise ExtensionNotLoadedError(f"Extension module '{import_path}' could not be found") from error
except ImportError as error:
raise ExtensionNotLoadedError(f"Error while importing extension '{import_path}': {error}") from error

if isclass(ext_object) and issubclass(ext_object, (VisitorExtension, InspectorExtension)):
return ext_object(**options) # type: ignore[misc]

try:
ext_module = dynamic_import(import_path)
except ModuleNotFoundError as error:
raise ExtensionNotLoadedError(f"Extension module '{import_path}' could not be found") from error
except ImportError as error:
raise ExtensionNotLoadedError(f"Error while importing extension module '{import_path}': {error}") from error

try:
return ext_module.Extension(**options)
return ext_object.Extension(**options) # type: ignore[union-attr]
except AttributeError as error:
raise ExtensionNotLoadedError(f"Extension module '{import_path}' has no 'Extension' attribute") from error

Expand Down

0 comments on commit 131454e

Please sign in to comment.