From df74b25c415279dba39c42d760422b97c4512460 Mon Sep 17 00:00:00 2001 From: Justin Liu Date: Fri, 24 May 2024 10:42:31 +1200 Subject: [PATCH] feat(Anomaly-Detection): add abstract class for anomaly detector algorithms --- notebooks/anomaly_detection.ipynb | 34 +++-------- src/capymoa/anomaly/half_space_trees.py | 6 +- src/capymoa/base.py | 76 ++++++++++++++++++++++++- src/capymoa/evaluation/evaluation.py | 7 ++- src/capymoa/type_alias.py | 5 ++ tests/test_anomaly_detectors.py | 10 ++-- 6 files changed, 99 insertions(+), 39 deletions(-) diff --git a/notebooks/anomaly_detection.ipynb b/notebooks/anomaly_detection.ipynb index 92e13759..d63fca15 100644 --- a/notebooks/anomaly_detection.ipynb +++ b/notebooks/anomaly_detection.ipynb @@ -23,15 +23,16 @@ "## 1. Unsupervised Anomaly Detection for data streams\n", "\n", "* Recent research has been focused on unsupervised anomaly detection for data streams, as it is often difficult to obtain labeled data for training.\n", - "* Instead of using evaluation functions such as **test-then-train loop**, we will show a basic loop from scratch to evaluate the model's performance." + "* Instead of using evaluation functions such as **test-then-train loop**, we will show a basic loop from scratch to evaluate the model's performance.\n", + "* Please notice that lower scores indicate higher anomaly likelihood." ], "id": "c190066b9765a7e7" }, { "metadata": { "ExecuteTime": { - "end_time": "2024-05-20T01:15:08.424694Z", - "start_time": "2024-05-20T01:15:08.189760Z" + "end_time": "2024-05-24T02:31:53.712036Z", + "start_time": "2024-05-24T02:31:51.410148Z" } }, "cell_type": "code", @@ -45,8 +46,8 @@ "evaluator = AUCEvaluator(schema)\n", "while stream.has_more_instances():\n", " instance = stream.next_instance()\n", - " proba = learner.predict_proba(instance)\n", - " evaluator.update(instance.y_index, proba)\n", + " score = learner.score_instance(instance)\n", + " evaluator.update(instance.y_index, score)\n", " learner.train(instance)\n", " \n", "auc = evaluator.auc()\n", @@ -62,28 +63,7 @@ ] } ], - "execution_count": 2 - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2024-05-20T01:12:04.982687Z", - "start_time": "2024-05-20T01:12:04.980906Z" - } - }, - "cell_type": "code", - "source": "", - "id": "28966d267979885f", - "outputs": [], - "execution_count": 4 - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": "", - "id": "5541f346dc4c42bc" + "execution_count": 1 } ], "metadata": { diff --git a/src/capymoa/anomaly/half_space_trees.py b/src/capymoa/anomaly/half_space_trees.py index 531c02cf..b3b827d4 100644 --- a/src/capymoa/anomaly/half_space_trees.py +++ b/src/capymoa/anomaly/half_space_trees.py @@ -1,11 +1,11 @@ from capymoa.base import ( - MOAClassifier, + MOAAnomalyDetector, ) from moa.classifiers.oneclass import HSTrees as _MOA_HSTrees -class HalfSpaceTrees(MOAClassifier): +class HalfSpaceTrees(MOAAnomalyDetector): """ Half-Space Trees This class implements the Half-Space Trees (HS-Trees) algorithm, which is @@ -29,7 +29,7 @@ class HalfSpaceTrees(MOAClassifier): >>> evaluator = AUCEvaluator(schema) >>> while stream.has_more_instances(): ... instance = stream.next_instance() - ... proba = learner.predict_proba(instance) + ... proba = learner.score_instance(instance) ... evaluator.update(instance.y_index, proba) ... learner.train(instance) >>> auc = evaluator.auc() diff --git a/src/capymoa/base.py b/src/capymoa/base.py index 80fd975e..46f8705e 100644 --- a/src/capymoa/base.py +++ b/src/capymoa/base.py @@ -15,7 +15,7 @@ from capymoa.instance import Instance, LabeledInstance, RegressionInstance from capymoa.stream._stream import Schema -from capymoa.type_alias import LabelIndex, LabelProbabilities, TargetValue +from capymoa.type_alias import LabelIndex, LabelProbabilities, TargetValue, AnomalyScore from sklearn.base import ClassifierMixin as _SKClassifierMixin from sklearn.base import RegressorMixin as _SKRegressorMixin @@ -457,3 +457,77 @@ def predict(self, instance): return [0, 0, 0] else: return prediction_PI + + +class AnomalyDetector(ABC): + """ + Abstract base class for anomaly detector. + + Attributes: + - schema: The schema representing the instances. Defaults to None. + - random_seed: The random seed for reproducibility. Defaults to 1. + """ + + def __init__(self, schema: Schema, random_seed=1): + self.random_seed = random_seed + self.schema = schema + if self.schema is None: + raise ValueError("Schema must be initialised") + + def __str__(self): + pass + + @abstractmethod + def train(self, instance: Instance): + pass + + @abstractmethod + def predict(self, instance: Instance) -> Optional[LabelIndex]: + # Returns the predicted label for the instance. + pass + + @abstractmethod + def score_instance(self, instance: Instance) -> AnomalyScore: + # Returns the anomaly score for the instance. A high score is indicative of a normal instance. + pass + + +class MOAAnomalyDetector(AnomalyDetector): + def __init__(self, schema=None, CLI=None, random_seed=1, moa_learner=None): + super().__init__(schema=schema, random_seed=random_seed) + self.CLI = CLI + self.moa_learner = moa_learner + + if random_seed is not None: + self.moa_learner.setRandomSeed(random_seed) + + if self.schema is not None: + self.moa_learner.setModelContext(self.schema.get_moa_header()) + + if self.CLI is not None: + self.moa_learner.getOptions().setViaCLIString(CLI) + + self.moa_learner.prepareForUse() + self.moa_learner.resetLearning() + self.moa_learner.setModelContext(self.schema.get_moa_header()) + + def __str__(self): + full_name = str(self.moa_learner.getClass().getCanonicalName()) + return full_name.rsplit(".", 1)[1] if "." in full_name else full_name + + def CLI_help(self): + return self.moa_learner.getOptions().getHelpString() + + def train(self, instance): + self.moa_learner.trainOnInstance(instance.java_instance) + + def predict(self, instance): + return Utils.maxIndex( + self.moa_learner.getVotesForInstance(instance.java_instance) + ) + + def score_instance(self, instance): + # We assume that the anomaly score is the first element of the prediction array. + # However, if it is not the case for a MOA learner, this method should be overridden. + prediction_array = self.moa_learner.getVotesForInstance(instance.java_instance) + return prediction_array[0] diff --git a/src/capymoa/evaluation/evaluation.py b/src/capymoa/evaluation/evaluation.py index f17bb779..d0c95553 100644 --- a/src/capymoa/evaluation/evaluation.py +++ b/src/capymoa/evaluation/evaluation.py @@ -1,4 +1,5 @@ import typing +import typing_extensions from typing import Optional import pandas as pd @@ -377,13 +378,13 @@ def __str__(self): def get_instances_seen(self): return self.instances_seen - def update(self, y_target_index: int, y_pred: typing.List[float]): + def update(self, y_target_index: int, score: float): """Update the evaluator with the ground-truth and the prediction. :param y_target_index: The ground-truth class index. This is NOT the actual class value, but the index of the class value in the schema. - :param y_pred: The predicted scores. + :param score: The predicted scores. Should be in the range [0, 1]. """ if not isinstance(y_target_index, (np.integer, int)): raise ValueError( @@ -395,7 +396,7 @@ def update(self, y_target_index: int, y_pred: typing.List[float]): self._instance.setClassValue(y_target_index) example = InstanceExample(self._instance) - self.moa_basic_evaluator.addResult(example, y_pred) + self.moa_basic_evaluator.addResult(example, [score, 1-score]) self.instances_seen += 1 diff --git a/src/capymoa/type_alias.py b/src/capymoa/type_alias.py index 37a02b92..24954767 100644 --- a/src/capymoa/type_alias.py +++ b/src/capymoa/type_alias.py @@ -28,3 +28,8 @@ """ Alias for a dependent variable in a regression task. """ + +AnomalyScore = double +""" +Alias for a dependent variable in an anomaly detection task. +""" \ No newline at end of file diff --git a/tests/test_anomaly_detectors.py b/tests/test_anomaly_detectors.py index 61157329..403472eb 100644 --- a/tests/test_anomaly_detectors.py +++ b/tests/test_anomaly_detectors.py @@ -2,7 +2,7 @@ from capymoa.anomaly import ( HalfSpaceTrees, ) -from capymoa.base import Classifier +from capymoa.base import Classifier, AnomalyDetector from capymoa.base import MOAClassifier from capymoa.datasets import ElectricityTiny import pytest @@ -23,7 +23,7 @@ ], ) def test_anomaly_detectors( - learner_constructor: Callable[[Schema], Classifier], + learner_constructor: Callable[[Schema], AnomalyDetector], auc: float, cli_string: Optional[str], ): @@ -41,12 +41,12 @@ def test_anomaly_detectors( stream = ElectricityTiny() evaluator = AUCEvaluator(schema=stream.get_schema()) - learner: Classifier = learner_constructor(schema=stream.get_schema()) + learner: AnomalyDetector = learner_constructor(schema=stream.get_schema()) while stream.has_more_instances(): instance = stream.next_instance() - proba = learner.predict_proba(instance) - evaluator.update(instance.y_index, proba) + score = learner.score_instance(instance) + evaluator.update(instance.y_index, score) learner.train(instance) # Check if the AUC score matches the expected value for both evaluator types