Skip to content

Commit

Permalink
Merge pull request #134 from mostafa/apply-backend-naming-convention
Browse files Browse the repository at this point in the history
Apply backend naming convention
  • Loading branch information
thomaspatzke authored Jul 14, 2023
2 parents 3790aaa + 9f9c5e2 commit 8eeef06
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 18 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ cov.xml
dist/
docs/_build
coverage.xml
.python-version
96 changes: 89 additions & 7 deletions sigma/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import importlib.metadata
import inspect
import pkgutil
import re
import subprocess
import sys
from typing import Callable, Dict, Any, List, Optional, Set, Union, get_type_hints
from uuid import UUID
import requests
from packaging.version import Version
from packaging.specifiers import Specifier
import warnings

from sigma.conversion.base import Backend
from sigma.pipelines.base import Pipeline
Expand Down Expand Up @@ -55,6 +57,7 @@ def _discover_module_directories(
result = dict()

def is_pipeline(obj):
"""Checks if an object is a pipeline."""
return any(
[
inspect.isclass(obj) and issubclass(obj, Pipeline),
Expand All @@ -64,6 +67,21 @@ def is_pipeline(obj):
]
)

def is_validator(obj):
"""Checks if an object is a validator."""
return (
inspect.isclass(obj)
and issubclass(obj, SigmaRuleValidator)
and obj.__module__ != "sigma.validators.base"
)

def is_backend(obj):
"""Checks if an object is a backend."""
return inspect.isclass(obj) and issubclass(obj, Backend)

def is_duplicate(container, klass, name):
return name in container and container[name] != klass

if include:
for mod in pkgutil.iter_modules(module.__path__, module.__name__ + "."):
# attempt to merge backend directory from module into collected backend directory
Expand Down Expand Up @@ -123,6 +141,7 @@ def is_pipeline(obj):

# OR'd condition ensures backwards compatibility with older plugins
if is_pipeline(possible_obj) or inspect.isfunction(possible_obj):
# Instantiate the pipeline if it is a class.
if inspect.isclass(possible_obj) and issubclass(
possible_obj, Pipeline
):
Expand All @@ -131,18 +150,25 @@ def is_pipeline(obj):
result[obj_name] = possible_obj
elif directory_name == "validators":
for cls_name in submodules:
if (
inspect.isclass(submodules[cls_name])
and issubclass(submodules[cls_name], SigmaRuleValidator)
and submodules[cls_name].__module__ != "sigma.validators.base"
):
if is_validator(submodules[cls_name]):
result[cls_name] = submodules[cls_name]
elif directory_name == "backends":
# Backends reside on the module level
for cls_name in imported_module.__dict__:
klass = getattr(imported_module, cls_name)
if inspect.isclass(klass) and issubclass(klass, Backend):
result.update({cls_name: klass})
identifier = InstalledSigmaPlugins._get_backend_identifier(
klass, cls_name
)
if is_backend(klass):
if is_duplicate(result, klass, identifier):
# If there is a duplicate, use the class name instead.
# This prevents the backend from being overwritten.
warnings.warn(
f"The '{klass.__name__}' wanted to overwrite the class '{result[identifier].__name__}' registered as '{identifier}'. Consider setting the 'identifier' attribute on the '{result[identifier].__name__}'. Ignoring the '{klass.__name__}'.",
)
else:
# Ignore duplicate backends.
result.update({identifier: klass})
else:
raise ValueError(
f"Unknown directory name {directory_name} for module {mod.name}"
Expand Down Expand Up @@ -180,6 +206,62 @@ def get_pipeline_resolver(self) -> ProcessingPipelineResolver:
}
)

@staticmethod
def _get_backend_identifier(obj: Any, default: str) -> Optional[str]:
"""
Get the identifier of a backend object. This is either the identifier attribute of
the object, the __identifier__ attribute of the object, or the __class__ attribute
of the object. The identifier is then converted to snake_case. If the identifier is
empty, the default is returned.
Args:
obj: The Backend object to get the identifier from.
default: The default identifier to return if no identifier could be found.
Returns:
The identifier of the backend object in snake_case or the default identifier.
"""

def removesuffix(base: str, suffix: str) -> str:
"""Removes the suffix from the string if it exists.
This is a backport of the Python 3.9 removesuffix method.
"""
if base.endswith(suffix):
return base[: len(base) - len(suffix)]
return base

try:
# 1. Try to get the obj.identifier attribute.
identifier = getattr(obj, "identifier", None)

# 2. Try to get the obj.__identifier__ attribute.
if not identifier:
identifier = getattr(obj, "__identifier__", None)

# 3. Try to get the obj.__name__ attribute.
if not identifier:
identifier = getattr(obj, "__name__", None)

# 4. Convert the name to snake_case.
if identifier:
identifier = removesuffix(identifier, "Backend")
identifier = removesuffix(identifier, "backend")
identifier = removesuffix(identifier, "_")
words = re.findall(r"[A-Z](?:[A-Z]*(?![a-z])|[a-z]*)", identifier)
if len(words) == 0:
return identifier.lower()
rebuilt_identifier = "_".join(words).lower()
# 5. If we still have the "base" backend, return the module identifier instead.
if rebuilt_identifier == "base":
return obj.__module__.split(".")[-1].lower()
return rebuilt_identifier
else:
# 6. If we still don't have an identifier, return the default.
return default
except Exception:
# 7. If anything goes wrong, return the default.
return default


class SigmaPluginType(EnumLowercaseStringMixin, Enum):
BACKEND = auto()
Expand Down
79 changes: 79 additions & 0 deletions tests/test_backend_identifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pytest

from sigma.plugins import InstalledSigmaPlugins
from sigma.conversion.base import TextQueryBackend
from sigma.backends.test import TextQueryTestBackend, MandatoryPipelineTestBackend


class DummyBackend(TextQueryBackend):
"""Dummy backend for testing purposes."""

identifier = "dummy"


class DummyTestBackend(DummyBackend):
"""Dummy backend for testing purposes."""

identifier = "dummy_test"


class Dummy2TestBackend(DummyBackend):
"""Dummy backend for testing purposes."""

# This won't be used, because the identifier is already set by DummyTestBackend.
__identifier__ = "dummy2_test"


class DummyDunderIdentifierBackend(TextQueryBackend):
"""Dummy backend for testing purposes."""

__identifier__ = "dummy_dunder_identifier"


class AnotherDummyTestBackend(TextQueryBackend):
"""Dummy backend for testing purposes."""

pass


class something_something_backend(TextQueryBackend):
"""Dummy backend for testing purposes."""

pass


class BackendBackend(TextQueryBackend):
"""Dummy backend for testing purposes."""

pass


class BaseBackend(TextQueryBackend):
"""Dummy backend for testing purposes."""

pass


@pytest.mark.parametrize(
"backend_class, expected_backend_identifier",
[
(None, ""),
(TextQueryBackend, "text_query"),
(DummyBackend, "dummy"),
(DummyTestBackend, "dummy_test"),
(Dummy2TestBackend, "dummy"), # Dummy2TestBackend.__identifier__ won't be used.
(DummyDunderIdentifierBackend, "dummy_dunder_identifier"), # __identifier__ is used.
(AnotherDummyTestBackend, "another_dummy_test"), # identifier is generated from __name__.
(something_something_backend, "something_something"),
(BackendBackend, "backend"),
(BaseBackend, "test_backend_identifier"), # test file is the module name.
(TextQueryTestBackend, "text_query_test"),
(MandatoryPipelineTestBackend, "mandatory_pipeline_test"),
],
)
def test_get_backend_identifier(backend_class, expected_backend_identifier):
"""Test that the backend identifier is correctly returned."""
assert (
InstalledSigmaPlugins._get_backend_identifier(backend_class, "")
== expected_backend_identifier
)
16 changes: 5 additions & 11 deletions tests/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def test_autodiscover_backends():
plugins = InstalledSigmaPlugins.autodiscover(include_pipelines=False, include_validators=False)
assert plugins == InstalledSigmaPlugins(
backends={
"TextQueryTestBackend": TextQueryTestBackend,
"MandatoryPipelineTestBackend": MandatoryPipelineTestBackend,
"text_query_test": TextQueryTestBackend,
"mandatory_pipeline_test": MandatoryPipelineTestBackend,
},
pipelines=dict(),
validators=dict(),
Expand Down Expand Up @@ -204,9 +204,7 @@ def test_sigma_plugin_directory_get_by_uuid_str(plugin_directory: SigmaPluginDir
)


def test_sigma_plugin_directory_get_by_uuid_not_found(
plugin_directory: SigmaPluginDirectory,
):
def test_sigma_plugin_directory_get_by_uuid_not_found(plugin_directory: SigmaPluginDirectory):
with pytest.raises(SigmaPluginNotFoundError, match="Plugin with UUID.*not found"):
plugin_directory.get_plugin_by_uuid("6029969b-4e6b-4060-bb0d-464d476065e0")

Expand All @@ -217,9 +215,7 @@ def test_sigma_plugin_directory_get_by_id(plugin_directory: SigmaPluginDirectory
)


def test_sigma_plugin_directory_get_by_id_not_found(
plugin_directory: SigmaPluginDirectory,
):
def test_sigma_plugin_directory_get_by_id_not_found(plugin_directory: SigmaPluginDirectory):
with pytest.raises(SigmaPluginNotFoundError, match="Plugin with identifier.*not found"):
plugin_directory.get_plugin_by_id("not_existing")

Expand All @@ -228,9 +224,7 @@ def test_sigma_plugin_directory_get_plugins(plugin_directory: SigmaPluginDirecto
assert plugin_directory.get_plugins() == list(plugin_directory.plugins.values())


def test_sigma_plugin_directory_get_plugins_filtered(
plugin_directory: SigmaPluginDirectory,
):
def test_sigma_plugin_directory_get_plugins_filtered(plugin_directory: SigmaPluginDirectory):
plugins = plugin_directory.get_plugins(
plugin_types={SigmaPluginType.BACKEND}, plugin_states={SigmaPluginState.TESTING}
)
Expand Down

0 comments on commit 8eeef06

Please sign in to comment.