-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1dc6234
commit 80f7007
Showing
3 changed files
with
100 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from __future__ import annotations | ||
from typing import Union | ||
|
||
from capymoa.base import ( | ||
MOAClassifier, | ||
) | ||
from capymoa.stream import Schema | ||
from capymoa._utils import build_cli_str_from_mapping_and_locals | ||
|
||
from moa.classifiers.meta import StreamingGradientBoostedTrees as _MOA_SGBT | ||
|
||
|
||
class SGBT(MOAClassifier): | ||
"""Streaming Gradient Boosted Trees (SGBT) Classifier | ||
Streaming Gradient Boosted Trees (SGBT), which is trained using weighted squared loss elicited | ||
in XGBoost. SGBT exploits trees with a replacement strategy to detect and recover from drifts, | ||
thus enabling the ensemble to adapt without sacrificing the predictive performance. | ||
Reference: | ||
`Gradient boosted trees for evolving data streams. | ||
Nuwan Gunasekara, Bernhard Pfahringer, Heitor Murilo Gomes, Albert Bifet. | ||
Machine Learning, Springer, 2024. | ||
<https://doi.org/10.1007/s10994-024-06517-y>`_ | ||
Example usages: | ||
>>> from capymoa.datasets import ElectricityTiny | ||
>>> from capymoa.classifier import SGBT | ||
>>> from capymoa.evaluation import prequential_evaluation | ||
>>> stream = ElectricityTiny() | ||
>>> schema = stream.get_schema() | ||
>>> learner = SGBT(schema) | ||
>>> results = prequential_evaluation(stream, learner, max_instances=1000) | ||
>>> results["cumulative"].accuracy() | ||
86.3 | ||
>>> stream = ElectricityTiny() | ||
>>> schema = stream.get_schema() | ||
>>> learner = SGBT(schema, base_learner='meta.AdaptiveRandomForestRegressor -s 10', boosting_iterations=10) | ||
>>> results = prequential_evaluation(stream, learner, max_instances=1000) | ||
>>> results["cumulative"].accuracy() | ||
86.8 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
schema: Schema | None = None, | ||
random_seed: int = 0, | ||
base_learner = 'trees.FIMTDD -s VarianceReductionSplitCriterion -g 25 -c 0.05 -e -p', | ||
boosting_iterations: int = 100, | ||
percentage_of_features: int = 75, | ||
learning_rate = 0.0125, | ||
disable_one_hot: bool = False, | ||
multiply_hessian_by: int = 1, | ||
skip_training: int =1, | ||
use_squared_loss: bool = False, | ||
): | ||
"""Streaming Gradient Boosted Trees (SGBT) Classifier | ||
:param schema: The schema of the stream. | ||
:param random_seed: The random seed passed to the MOA learner. | ||
:param base_learner: The base learner to be trained. Default FIMTDD -s VarianceReductionSplitCriterion -g 25 -c 0.05 -e -p. | ||
:param boosting_iterations: The number of boosting iterations. | ||
:param percentage_of_features: The percentage of features to use. | ||
:param learning_rate: The learning rate. | ||
:param disable_one_hot: Whether to disable one-hot encoding for regressors that supports nominal attributes. | ||
:param multiply_hessian_by: The multiply hessian by this parameter to generate weights for multiple iterations. | ||
:param skip_training: Skip training of 1/skip_training instances. skip_training=1 means no skipping is performed (train on all instances). | ||
:param use_squared_loss: Whether to use squared loss for classification. | ||
""" | ||
|
||
mapping = { | ||
"base_learner": "-l", | ||
"boosting_iterations": "-s", | ||
"percentage_of_features": "-m", | ||
"learning_rate": "-L", | ||
"disable_one_hot": "-H", | ||
"multiply_hessian_by": "-M", | ||
"skip_training": "-S", | ||
"use_squared_loss": "-K", | ||
"random_seed": "-r", | ||
} | ||
|
||
assert (type(base_learner) == str | ||
), "Only MOA CLI strings are supported for SGBT base_learner, at the moment." | ||
|
||
config_str = build_cli_str_from_mapping_and_locals(mapping, locals()) | ||
super(SGBT, self).__init__( | ||
moa_learner=_MOA_SGBT, | ||
schema=schema, | ||
CLI=config_str, | ||
random_seed=random_seed, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters