Skip to content

Commit

Permalink
feat(Anomaly-Detection): add abstract class for anomaly detector algo…
Browse files Browse the repository at this point in the history
…rithms
  • Loading branch information
justinuliu authored and hmgomes committed May 29, 2024
1 parent e627f2d commit df74b25
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 39 deletions.
34 changes: 7 additions & 27 deletions notebooks/anomaly_detection.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@
"## 1. Unsupervised Anomaly Detection for data streams\n",
"\n",
"* Recent research has been focused on unsupervised anomaly detection for data streams, as it is often difficult to obtain labeled data for training.\n",
"* Instead of using evaluation functions such as **test-then-train loop**, we will show a basic loop from scratch to evaluate the model's performance."
"* Instead of using evaluation functions such as **test-then-train loop**, we will show a basic loop from scratch to evaluate the model's performance.\n",
"* Please notice that lower scores indicate higher anomaly likelihood."
],
"id": "c190066b9765a7e7"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-20T01:15:08.424694Z",
"start_time": "2024-05-20T01:15:08.189760Z"
"end_time": "2024-05-24T02:31:53.712036Z",
"start_time": "2024-05-24T02:31:51.410148Z"
}
},
"cell_type": "code",
Expand All @@ -45,8 +46,8 @@
"evaluator = AUCEvaluator(schema)\n",
"while stream.has_more_instances():\n",
" instance = stream.next_instance()\n",
" proba = learner.predict_proba(instance)\n",
" evaluator.update(instance.y_index, proba)\n",
" score = learner.score_instance(instance)\n",
" evaluator.update(instance.y_index, score)\n",
" learner.train(instance)\n",
" \n",
"auc = evaluator.auc()\n",
Expand All @@ -62,28 +63,7 @@
]
}
],
"execution_count": 2
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-20T01:12:04.982687Z",
"start_time": "2024-05-20T01:12:04.980906Z"
}
},
"cell_type": "code",
"source": "",
"id": "28966d267979885f",
"outputs": [],
"execution_count": 4
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": "",
"id": "5541f346dc4c42bc"
"execution_count": 1
}
],
"metadata": {
Expand Down
6 changes: 3 additions & 3 deletions src/capymoa/anomaly/half_space_trees.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from capymoa.base import (
MOAClassifier,
MOAAnomalyDetector,
)

from moa.classifiers.oneclass import HSTrees as _MOA_HSTrees


class HalfSpaceTrees(MOAClassifier):
class HalfSpaceTrees(MOAAnomalyDetector):
""" Half-Space Trees
This class implements the Half-Space Trees (HS-Trees) algorithm, which is
Expand All @@ -29,7 +29,7 @@ class HalfSpaceTrees(MOAClassifier):
>>> evaluator = AUCEvaluator(schema)
>>> while stream.has_more_instances():
... instance = stream.next_instance()
... proba = learner.predict_proba(instance)
... proba = learner.score_instance(instance)
... evaluator.update(instance.y_index, proba)
... learner.train(instance)
>>> auc = evaluator.auc()
Expand Down
76 changes: 75 additions & 1 deletion src/capymoa/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from capymoa.instance import Instance, LabeledInstance, RegressionInstance
from capymoa.stream._stream import Schema
from capymoa.type_alias import LabelIndex, LabelProbabilities, TargetValue
from capymoa.type_alias import LabelIndex, LabelProbabilities, TargetValue, AnomalyScore

from sklearn.base import ClassifierMixin as _SKClassifierMixin
from sklearn.base import RegressorMixin as _SKRegressorMixin
Expand Down Expand Up @@ -457,3 +457,77 @@ def predict(self, instance):
return [0, 0, 0]
else:
return prediction_PI


class AnomalyDetector(ABC):
"""
Abstract base class for anomaly detector.
Attributes:
- schema: The schema representing the instances. Defaults to None.
- random_seed: The random seed for reproducibility. Defaults to 1.
"""

def __init__(self, schema: Schema, random_seed=1):
self.random_seed = random_seed
self.schema = schema
if self.schema is None:
raise ValueError("Schema must be initialised")

def __str__(self):
pass

@abstractmethod
def train(self, instance: Instance):
pass

@abstractmethod
def predict(self, instance: Instance) -> Optional[LabelIndex]:
# Returns the predicted label for the instance.
pass

@abstractmethod
def score_instance(self, instance: Instance) -> AnomalyScore:
# Returns the anomaly score for the instance. A high score is indicative of a normal instance.
pass


class MOAAnomalyDetector(AnomalyDetector):
def __init__(self, schema=None, CLI=None, random_seed=1, moa_learner=None):
super().__init__(schema=schema, random_seed=random_seed)
self.CLI = CLI
self.moa_learner = moa_learner

if random_seed is not None:
self.moa_learner.setRandomSeed(random_seed)

if self.schema is not None:
self.moa_learner.setModelContext(self.schema.get_moa_header())

if self.CLI is not None:
self.moa_learner.getOptions().setViaCLIString(CLI)

self.moa_learner.prepareForUse()
self.moa_learner.resetLearning()
self.moa_learner.setModelContext(self.schema.get_moa_header())

def __str__(self):
full_name = str(self.moa_learner.getClass().getCanonicalName())
return full_name.rsplit(".", 1)[1] if "." in full_name else full_name

def CLI_help(self):
return self.moa_learner.getOptions().getHelpString()

def train(self, instance):
self.moa_learner.trainOnInstance(instance.java_instance)

def predict(self, instance):
return Utils.maxIndex(
self.moa_learner.getVotesForInstance(instance.java_instance)
)

def score_instance(self, instance):
# We assume that the anomaly score is the first element of the prediction array.
# However, if it is not the case for a MOA learner, this method should be overridden.
prediction_array = self.moa_learner.getVotesForInstance(instance.java_instance)
return prediction_array[0]
7 changes: 4 additions & 3 deletions src/capymoa/evaluation/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typing
import typing_extensions
from typing import Optional

import pandas as pd
Expand Down Expand Up @@ -377,13 +378,13 @@ def __str__(self):
def get_instances_seen(self):
return self.instances_seen

def update(self, y_target_index: int, y_pred: typing.List[float]):
def update(self, y_target_index: int, score: float):
"""Update the evaluator with the ground-truth and the prediction.
:param y_target_index: The ground-truth class index. This is NOT
the actual class value, but the index of the class value in the
schema.
:param y_pred: The predicted scores.
:param score: The predicted scores. Should be in the range [0, 1].
"""
if not isinstance(y_target_index, (np.integer, int)):
raise ValueError(
Expand All @@ -395,7 +396,7 @@ def update(self, y_target_index: int, y_pred: typing.List[float]):
self._instance.setClassValue(y_target_index)
example = InstanceExample(self._instance)

self.moa_basic_evaluator.addResult(example, y_pred)
self.moa_basic_evaluator.addResult(example, [score, 1-score])

self.instances_seen += 1

Expand Down
5 changes: 5 additions & 0 deletions src/capymoa/type_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,8 @@
"""
Alias for a dependent variable in a regression task.
"""

AnomalyScore = double
"""
Alias for a dependent variable in an anomaly detection task.
"""
10 changes: 5 additions & 5 deletions tests/test_anomaly_detectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from capymoa.anomaly import (
HalfSpaceTrees,
)
from capymoa.base import Classifier
from capymoa.base import Classifier, AnomalyDetector
from capymoa.base import MOAClassifier
from capymoa.datasets import ElectricityTiny
import pytest
Expand All @@ -23,7 +23,7 @@
],
)
def test_anomaly_detectors(
learner_constructor: Callable[[Schema], Classifier],
learner_constructor: Callable[[Schema], AnomalyDetector],
auc: float,
cli_string: Optional[str],
):
Expand All @@ -41,12 +41,12 @@ def test_anomaly_detectors(
stream = ElectricityTiny()
evaluator = AUCEvaluator(schema=stream.get_schema())

learner: Classifier = learner_constructor(schema=stream.get_schema())
learner: AnomalyDetector = learner_constructor(schema=stream.get_schema())

while stream.has_more_instances():
instance = stream.next_instance()
proba = learner.predict_proba(instance)
evaluator.update(instance.y_index, proba)
score = learner.score_instance(instance)
evaluator.update(instance.y_index, score)
learner.train(instance)

# Check if the AUC score matches the expected value for both evaluator types
Expand Down

0 comments on commit df74b25

Please sign in to comment.