diff --git a/proglearn/base.py b/proglearn/base.py index d3bf5e718c..84a1a12629 100644 --- a/proglearn/base.py +++ b/proglearn/base.py @@ -133,6 +133,17 @@ def predict(self, X): Input data matrix. """ pass + + @abc.abstractmethod + def is_fitted(self): + """ + Indicates whether the decider is fitted. + + Parameters + ---------- + None + """ + pass class ClassificationDecider(BaseDecider): @@ -229,4 +240,4 @@ def predict_proba(self, X, task_id): task_id : obj The task on which you are interested in estimating posteriors. """ - pass \ No newline at end of file + pass diff --git a/proglearn/deciders.py b/proglearn/deciders.py index 9a3f6870d0..ccde5aa299 100755 --- a/proglearn/deciders.py +++ b/proglearn/deciders.py @@ -25,6 +25,7 @@ class SimpleAverage(ClassificationDecider): def __init__(self, classes=[]): self.classes = classes + self._is_fitted = False def fit( self, @@ -43,6 +44,8 @@ def fit( self.classes = np.array(self.classes) self.transformer_id_to_transformers = transformer_id_to_transformers self.transformer_id_to_voters = transformer_id_to_voters + + self._is_fitted = True return self def predict_proba(self, X, transformer_ids=None): @@ -67,5 +70,19 @@ def predict_proba(self, X, transformer_ids=None): return np.mean(vote_per_transformer_id, axis=0) def predict(self, X, transformer_ids=None): + if not self.is_fitted(): + msg = ( + "This %(name)s instance is not fitted yet. Call 'fit' with " + "appropriate arguments before using this decider." + ) + raise NotFittedError(msg % {"name": type(self).__name__}) + vote_overall = self.predict_proba(X, transformer_ids=transformer_ids) return self.classes[np.argmax(vote_overall, axis=1)] + + def is_fitted(self): + """ + Doc strings here. + """ + + return self._is_fitted