Skip to content

Commit

Permalink
feat(EFDT): leaf_prediction as str
Browse files Browse the repository at this point in the history
Users can still use leaf_prediction as an integer (0, 1 or 2), but it can also be used as a string:
"MajorityClass": 0, "NaiveBayes": 1, "NaiveBayesAdaptive": 2
  • Loading branch information
hmgomes committed May 1, 2024
1 parent 2e55155 commit 6454179
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/capymoa/classifier/_efdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@ class EFDT(MOAClassifier):
84.39999999999999
"""

MAJORITY_CLASS = 0
NAIVE_BAYES = 1
NAIVE_BAYES_ADAPTIVE = 2

def __init__(
self,
schema: Schema | None = None,
Expand All @@ -57,7 +53,7 @@ def __init__(
split_criterion: Union[str, SplitCriterion] = "InfoGainSplitCriterion",
confidence: float = 1e-3,
tie_threshold: float = 0.05,
leaf_prediction: str = MAJORITY_CLASS,
leaf_prediction: str = "MajorityClass",
nb_threshold: int = 0,
numeric_attribute_observer: str = "GaussianNumericAttributeClassObserver",
binary_split: bool = False,
Expand All @@ -77,8 +73,8 @@ def __init__(
:param confidence: Significance level to calculate the Hoeffding bound. The significance level is given by
`1 - delta`. Values closer to zero imply longer split decision delays.
:param tie_threshold: Threshold below which a split will be forced to break ties.
:param leaf_prediction: Prediction mechanism used at leafs
(0: Majority Class, 1: Naive Bayes, 2: Naive Bayes Adaptive).
:param leaf_prediction: Prediction mechanism used at the leaves
("MajorityClass" or 0, "NaiveBayes" or 1, "NaiveBayesAdaptive" or 2).
:param nb_threshold: Number of instances a leaf should observe before allowing Naive Bayes.
:param numeric_attribute_observer: The Splitter or Attribute Observer (AO) used to monitor the class statistics
of numeric features and perform splits.
Expand Down Expand Up @@ -106,6 +102,9 @@ def __init__(
"remove_poor_attrs": "-r",
"disable_prepruning": "-p",
}
if isinstance(leaf_prediction, str):
leaf_prediction_mapping = {"MajorityClass": 0, "NaiveBayes": 1, "NaiveBayesAdaptive": 2}
leaf_prediction = leaf_prediction_mapping.get(leaf_prediction, None)

split_criterion = _split_criterion_to_cli_str(split_criterion)
config_str = build_cli_str_from_mapping_and_locals(mapping, locals())
Expand Down

0 comments on commit 6454179

Please sign in to comment.