Skip to content

Commit

Permalink
Add support for loading pipelines if decorated by or inherited from P…
Browse files Browse the repository at this point in the history
…ipeline class while maintaining backwards compatibility
  • Loading branch information
mostafa committed Jun 29, 2023
1 parent cd9fe30 commit f740c62
Showing 1 changed file with 64 additions and 14 deletions.
78 changes: 64 additions & 14 deletions sigma/plugins.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import builtins
from dataclasses import dataclass, field
from enum import Enum, auto
from importlib import import_module
import importlib
import importlib.metadata
import inspect
Expand All @@ -12,9 +12,9 @@
import requests
from packaging.version import Version
from packaging.specifiers import Specifier
import warnings

from sigma.conversion.base import Backend
from sigma.pipelines.base import Pipeline
from sigma.processing.pipeline import ProcessingPipeline
from sigma.processing.resolver import ProcessingPipelineResolver
from sigma.rule import EnumLowercaseStringMixin
Expand Down Expand Up @@ -53,26 +53,76 @@ def _discover_module_directories(
cls, module, directory_name: str, include: bool
) -> Dict[str, Any]:
result = dict()

def is_pipeline(obj):
return any(
[
issubclass(obj.__class__, Pipeline),
isinstance(obj, Pipeline),
inspect.isfunction(obj)
and get_type_hints(obj).get("return") == ProcessingPipeline,
]
)

if include:
for mod in pkgutil.iter_modules(module.__path__, module.__name__ + "."):
# attempt to merge backend directory from module into collected backend directory
try:
imported_module = importlib.import_module(mod.name)
submodules = imported_module.__dict__[directory_name]
submodules: Dict[str, Any] = {}

# Skip base, common and test pipelines
if imported_module.__name__ in [
"sigma.pipelines.base",
"sigma.pipelines.common",
"sigma.pipelines.test",
]:
continue

# Add exported objects to submodules
# This is to ensure backwards compatibility with older plugins
# that do not use __all__ to export their objects, but instead
# rely on gloal variables that map function/class names to objects
# The global variable name is the "directory_name" in this case,
# which is either "backends", "pipelines" or "validators".
if directory_name in imported_module.__dict__:
submodules.update(imported_module.__dict__[directory_name])

# Look for __all__ at the root (__init__) and
# add all objects that are in __all__ :D
if "__all__" in imported_module.__dict__:
submodules.update(
{
k: v
for k, v in imported_module.__dict__.items()
if k in imported_module.__dict__["__all__"]
and k not in builtins.__dict__
and v not in submodules.values()
}
)
# There is no __all__, so add all objects that are not private, not in builtins,
# and not already in submodules (to avoid duplicates)
else:
submodules.update(
{
k: v
for k, v in imported_module.__dict__.items()
if not k.startswith("_")
and k not in builtins.__dict__
and v not in submodules.values()
}
)

# Pipelines and validators reside in submodules
if directory_name == "pipelines":
for obj_name in submodules:
possible_func_obj = submodules[obj_name]
if inspect.isfunction(possible_func_obj):
if (
not get_type_hints(possible_func_obj).get("return")
== ProcessingPipeline
):
# TODO: This should be a hard error in the future
warnings.warn(
f"Function {mod.name}.{obj_name} does not have a return type hint of ProcessingPipeline."
)
result[obj_name] = possible_func_obj
possible_obj = submodules[obj_name]

# OR'd condition ensures backwards compatibility with older plugins
if is_pipeline(possible_obj) or inspect.isfunction(
possible_obj
):
result[obj_name] = possible_obj
elif directory_name == "validators":
for cls_name in submodules:
if inspect.isclass(submodules[cls_name]) and issubclass(
Expand Down

0 comments on commit f740c62

Please sign in to comment.