Skip to content

Commit

Permalink
Merge pull request #34 from tachyonicClock/PassiveAggressiveClassifier
Browse files Browse the repository at this point in the history
Add `PassiveAggressiveClassifier`
  • Loading branch information
hmgomes authored Apr 13, 2024
2 parents 49ebec5 + a4c206a commit 43b0471
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
'wiki': ('https://en.wikipedia.org/wiki/%s', ''),
'moa-api': ('https://javadoc.io/doc/nz.ac.waikato.cms.moa/moa/latest/%s', ''),
'doi': ('https://doi.org/%s', ''),
'sklearn': ('https://scikit-learn.org/stable/modules/generated/sklearn.%s.html', 'sklearn.%s'),
}


Expand Down
10 changes: 9 additions & 1 deletion src/capymoa/learner/classifier/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from .classifiers import AdaptiveRandomForest, OnlineBagging, AdaptiveRandomForest
from .efdt import EFDT
from .sklearn import PassiveAggressiveClassifier
from .hoeffding_tree import HoeffdingTree

__all__ = ["AdaptiveRandomForest", "OnlineBagging", "AdaptiveRandomForest", "EFDT", "HoeffdingTree"]
__all__ = [
"AdaptiveRandomForest",
"OnlineBagging",
"AdaptiveRandomForest",
"EFDT",
"HoeffdingTree",
"PassiveAggressiveClassifier",
]
111 changes: 111 additions & 0 deletions src/capymoa/learner/classifier/sklearn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Optional, Dict, Union, Literal
from capymoa.learner.learners import Classifier
from sklearn.linear_model import (
PassiveAggressiveClassifier as skPassiveAggressiveClassifier,
)
from capymoa.stream.instance import Instance, LabeledInstance
from capymoa.stream.stream import Schema
from capymoa.type_alias import LabelIndex, LabelProbabilities
import numpy as np


class PassiveAggressiveClassifier(Classifier):
"""Streaming Passive Aggressive Classifier
This wraps :sklearn:`linear_model.PassiveAggressiveClassifier` for
ease of use in the streaming context. Some options are missing because
they are not relevant in the streaming context.
`Online Passive-Aggressive Algorithms K. Crammer, O. Dekel, J. Keshat, S.
Shalev-Shwartz, Y. Singer - JMLR (2006)
<http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf>`_
>>> from capymoa.datasets import ElectricityTiny
>>> from capymoa.learner.classifier import PassiveAggressiveClassifier
>>> from capymoa.evaluation import prequential_evaluation
>>> stream = ElectricityTiny()
>>> schema = stream.get_schema()
>>> learner = PassiveAggressiveClassifier(schema)
>>> results = prequential_evaluation(stream, learner, max_instances=1000, optimise=False)
>>> results["cumulative"].accuracy()
84.3
"""

sklearner: skPassiveAggressiveClassifier
"""The underlying scikit-learn object. See: :sklearn:`linear_model.PassiveAggressiveClassifier`"""

def __init__(
self,
schema: Schema,
max_step_size: float = 1.0,
fit_intercept: bool = True,
loss: str = "hinge",
n_jobs: Optional[int] = None,
class_weight: Union[Dict[int, float], None, Literal["balanced"]] = None,
average: bool = False,
random_seed=1,
):
"""Construct a passive aggressive classifier.
:param schema: Stream schema
:param max_step_size: Maximum step size (regularization).
:param fit_intercept: Whether the intercept should be estimated or not.
If False, the data is assumed to be already centered.
:param loss: The loss function to be used: hinge: equivalent to PA-I in
the reference paper. squared_hinge: equivalent to PA-II in the reference paper.
:param n_jobs: The number of CPUs to use to do the OVA (One Versus All,
for multi-class problems) computation. None means 1 unless in a
``joblib.parallel_backend`` context. -1 means using all processors.
:param class_weight: Preset for the ``sklearner.class_weight`` fit parameter.
Weights associated with classes. If not given, all classes are
supposed to have weight one.
The “balanced” mode uses the values of y to automatically adjust
weights inversely proportional to class frequencies in the input
data as ``n_samples / (n_classes * np.bincount(y))``.
:param average: When set to True, computes the averaged SGD weights and
stores the result in the ``sklearner.coef_`` attribute. If set to an int greater
than 1, averaging will begin once the total number of samples
seen reaches average. So ``average=10`` will begin averaging after
seeing 10 samples.
:param random_seed: Seed for the random number generator.
"""

super().__init__(schema, random_seed)

self.sklearner = skPassiveAggressiveClassifier(
C=max_step_size,
fit_intercept=fit_intercept,
early_stopping=False,
shuffle=False,
verbose=0,
loss=loss,
n_jobs=n_jobs,
warm_start=False,
class_weight=class_weight,
average=average,
random_state=random_seed,
)
self._classes = schema.get_label_indexes()
self._is_fitted = False

def __str__(self):
return str(self.sklearner)

def train(self, instance: LabeledInstance):
x = instance.x.reshape(1, -1)
y = np.array(instance.y_index).reshape(1)
self.sklearner.partial_fit(x, y, classes=self._classes)
self._is_fitted = True

def predict(self, instance: Instance) -> Optional[LabelIndex]:
if not self._is_fitted:
return None
x = instance.x.reshape(1, -1)
return self.sklearner.predict(x).item()

def predict_proba(self, instance: Instance) -> LabelProbabilities:
proba = np.zeros(len(self._classes))
proba[self.predict(instance)] = 1
return proba
15 changes: 10 additions & 5 deletions tests/test_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,25 @@
import pytest
from functools import partial

from capymoa.learner.classifier.sklearn import PassiveAggressiveClassifier


@pytest.mark.parametrize(
"learner_constructor,accuracy,win_accuracy",
[
(partial(OnlineBagging, ensemble_size=5), 84.6, 89.0),
(partial(AdaptiveRandomForest), 89.6, 91.0),
(partial(HoeffdingTree), 73.85, 73.0),
(partial(EFDT), 82.7, 82.0)
(partial(EFDT), 82.7, 82.0),
(partial(PassiveAggressiveClassifier), 84.7, 81.0),
],
ids=[
"OnlineBagging",
"AdaptiveRandomForest",
"HoeffdingTree",
"EFDT"
]
"EFDT",
"PassiveAggressiveClassifier",
],
)
def test_on_tiny(learner_constructor, accuracy, win_accuracy):
"""Test on tiny is a fast running simple test to check if a learner's
Expand All @@ -32,7 +36,9 @@ def test_on_tiny(learner_constructor, accuracy, win_accuracy):
"""
stream = ElectricityTiny()
evaluator = ClassificationEvaluator(schema=stream.get_schema())
win_evaluator = ClassificationWindowedEvaluator(schema=stream.get_schema(), window_size=100)
win_evaluator = ClassificationWindowedEvaluator(
schema=stream.get_schema(), window_size=100
)
learner = learner_constructor(schema=stream.get_schema())

while stream.has_more_instances():
Expand All @@ -44,4 +50,3 @@ def test_on_tiny(learner_constructor, accuracy, win_accuracy):

assert evaluator.accuracy() == pytest.approx(accuracy, abs=0.1)
assert win_evaluator.accuracy() == pytest.approx(win_accuracy, abs=0.1)

0 comments on commit 43b0471

Please sign in to comment.