Skip to content

Commit

Permalink
feat: add SGBT
Browse files Browse the repository at this point in the history
  • Loading branch information
nuwangunasekara authored and hmgomes committed May 1, 2024
1 parent 1dc6234 commit 80f7007
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/capymoa/classifier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ._passive_aggressive_classifier import PassiveAggressiveClassifier
from ._sgd_classifier import SGDClassifier
from ._knn import KNN
from ._sgbt import SGBT

__all__ = [
"AdaptiveRandomForest",
Expand All @@ -17,4 +18,5 @@
"KNN",
"PassiveAggressiveClassifier",
"SGDClassifier",
"SGBT"
]
94 changes: 94 additions & 0 deletions src/capymoa/classifier/_sgbt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from __future__ import annotations
from typing import Union

from capymoa.base import (
MOAClassifier,
)
from capymoa.stream import Schema
from capymoa._utils import build_cli_str_from_mapping_and_locals

from moa.classifiers.meta import StreamingGradientBoostedTrees as _MOA_SGBT


class SGBT(MOAClassifier):
"""Streaming Gradient Boosted Trees (SGBT) Classifier
Streaming Gradient Boosted Trees (SGBT), which is trained using weighted squared loss elicited
in XGBoost. SGBT exploits trees with a replacement strategy to detect and recover from drifts,
thus enabling the ensemble to adapt without sacrificing the predictive performance.
Reference:
`Gradient boosted trees for evolving data streams.
Nuwan Gunasekara, Bernhard Pfahringer, Heitor Murilo Gomes, Albert Bifet.
Machine Learning, Springer, 2024.
<https://doi.org/10.1007/s10994-024-06517-y>`_
Example usages:
>>> from capymoa.datasets import ElectricityTiny
>>> from capymoa.classifier import SGBT
>>> from capymoa.evaluation import prequential_evaluation
>>> stream = ElectricityTiny()
>>> schema = stream.get_schema()
>>> learner = SGBT(schema)
>>> results = prequential_evaluation(stream, learner, max_instances=1000)
>>> results["cumulative"].accuracy()
86.3
>>> stream = ElectricityTiny()
>>> schema = stream.get_schema()
>>> learner = SGBT(schema, base_learner='meta.AdaptiveRandomForestRegressor -s 10', boosting_iterations=10)
>>> results = prequential_evaluation(stream, learner, max_instances=1000)
>>> results["cumulative"].accuracy()
86.8
"""

def __init__(
self,
schema: Schema | None = None,
random_seed: int = 0,
base_learner = 'trees.FIMTDD -s VarianceReductionSplitCriterion -g 25 -c 0.05 -e -p',
boosting_iterations: int = 100,
percentage_of_features: int = 75,
learning_rate = 0.0125,
disable_one_hot: bool = False,
multiply_hessian_by: int = 1,
skip_training: int =1,
use_squared_loss: bool = False,
):
"""Streaming Gradient Boosted Trees (SGBT) Classifier
:param schema: The schema of the stream.
:param random_seed: The random seed passed to the MOA learner.
:param base_learner: The base learner to be trained. Default FIMTDD -s VarianceReductionSplitCriterion -g 25 -c 0.05 -e -p.
:param boosting_iterations: The number of boosting iterations.
:param percentage_of_features: The percentage of features to use.
:param learning_rate: The learning rate.
:param disable_one_hot: Whether to disable one-hot encoding for regressors that supports nominal attributes.
:param multiply_hessian_by: The multiply hessian by this parameter to generate weights for multiple iterations.
:param skip_training: Skip training of 1/skip_training instances. skip_training=1 means no skipping is performed (train on all instances).
:param use_squared_loss: Whether to use squared loss for classification.
"""

mapping = {
"base_learner": "-l",
"boosting_iterations": "-s",
"percentage_of_features": "-m",
"learning_rate": "-L",
"disable_one_hot": "-H",
"multiply_hessian_by": "-M",
"skip_training": "-S",
"use_squared_loss": "-K",
"random_seed": "-r",
}

assert (type(base_learner) == str
), "Only MOA CLI strings are supported for SGBT base_learner, at the moment."

config_str = build_cli_str_from_mapping_and_locals(mapping, locals())
super(SGBT, self).__init__(
moa_learner=_MOA_SGBT,
schema=schema,
CLI=config_str,
random_seed=random_seed,
)
5 changes: 4 additions & 1 deletion tests/test_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
AdaptiveRandomForest,
OnlineBagging,
NaiveBayes,
KNN
KNN,
SGBT
)
from capymoa.base import Classifier
from capymoa.base import MOAClassifier
Expand Down Expand Up @@ -38,6 +39,7 @@
(partial(KNN), 81.6, 74.0, None),
(partial(PassiveAggressiveClassifier), 84.7, 81.0, None),
(partial(SGDClassifier), 84.7, 83.0, None),
(partial(SGBT), 88.75, 88.0, None),
],
ids=[
"OnlineBagging",
Expand All @@ -49,6 +51,7 @@
"KNN",
"PassiveAggressiveClassifier",
"SGDClassifier",
"SGBT"
],
)
def test_classifiers(
Expand Down

0 comments on commit 80f7007

Please sign in to comment.