diff --git a/src/capymoa/classifier/_efdt.py b/src/capymoa/classifier/_efdt.py index 3150a0f9..a049e8a6 100644 --- a/src/capymoa/classifier/_efdt.py +++ b/src/capymoa/classifier/_efdt.py @@ -44,10 +44,6 @@ class EFDT(MOAClassifier): 84.39999999999999 """ - MAJORITY_CLASS = 0 - NAIVE_BAYES = 1 - NAIVE_BAYES_ADAPTIVE = 2 - def __init__( self, schema: Schema | None = None, @@ -57,7 +53,7 @@ def __init__( split_criterion: Union[str, SplitCriterion] = "InfoGainSplitCriterion", confidence: float = 1e-3, tie_threshold: float = 0.05, - leaf_prediction: str = MAJORITY_CLASS, + leaf_prediction: str = "MajorityClass", nb_threshold: int = 0, numeric_attribute_observer: str = "GaussianNumericAttributeClassObserver", binary_split: bool = False, @@ -77,8 +73,8 @@ def __init__( :param confidence: Significance level to calculate the Hoeffding bound. The significance level is given by `1 - delta`. Values closer to zero imply longer split decision delays. :param tie_threshold: Threshold below which a split will be forced to break ties. - :param leaf_prediction: Prediction mechanism used at leafs - (0: Majority Class, 1: Naive Bayes, 2: Naive Bayes Adaptive). + :param leaf_prediction: Prediction mechanism used at the leaves + ("MajorityClass" or 0, "NaiveBayes" or 1, "NaiveBayesAdaptive" or 2). :param nb_threshold: Number of instances a leaf should observe before allowing Naive Bayes. :param numeric_attribute_observer: The Splitter or Attribute Observer (AO) used to monitor the class statistics of numeric features and perform splits. @@ -106,6 +102,9 @@ def __init__( "remove_poor_attrs": "-r", "disable_prepruning": "-p", } + if isinstance(leaf_prediction, str): + leaf_prediction_mapping = {"MajorityClass": 0, "NaiveBayes": 1, "NaiveBayesAdaptive": 2} + leaf_prediction = leaf_prediction_mapping.get(leaf_prediction, None) split_criterion = _split_criterion_to_cli_str(split_criterion) config_str = build_cli_str_from_mapping_and_locals(mapping, locals())