Skip to content

Commit

Permalink
Merge pull request neurodata#280 from latasianguy/document-deciders.p…
Browse files Browse the repository at this point in the history
…y-new

Document deciders.py
  • Loading branch information
levinwil authored Oct 5, 2020
2 parents 9ee1c15 + 959f544 commit d2300d2
Showing 1 changed file with 128 additions and 2 deletions.
130 changes: 128 additions & 2 deletions proglearn/deciders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,47 @@

class SimpleArgmaxAverage(BaseClassificationDecider):
"""
Doc string here.
A class for a decider that uses the average vote for classification.
Uses BaseClassificationDecider as a base class.
Parameters:
-----------
classes : list, default=[]
List of final output classification labels of type obj.
Defaults to an empty list of classes.
Attributes (objects):
-----------
classes : list, default=[]
List of final output classification labels of type obj.
Defaults to an empty list of classes.
_is_fitted : boolean, default=False
Boolean variable to see if the decider is fitted, defaults to False
transformer_id_to_transformers : dict
A dictionary with keys of type obj corresponding to transformer ids
and values of type obj corresponding to a transformer. This dictionary
maps transformers to a particular transformer id.
transformer_id_to_voters : dict
A dictionary with keys of type obj corresponding to transformer ids
and values of type obj corresponding to a voter class. This dictionary
maps voter classes to a particular transformer id.
Methods
-----------
fit(X, y, transformer_id_to_transformers, transformer_id_to_voters, classes=None)
Fits the decider to inputs X and final classification outputs y.
predict_proba(X, transformers_id=None)
Predicts posterior probabilities given input data, X, for each class.
predict(X, transformer, transformer_ids=None)
Predicts the most likely class given input data X.
is_fitted()
Returns if the decider has been fitted.
"""

def __init__(self, classes=[]):
Expand All @@ -32,6 +72,42 @@ def fit(
transformer_id_to_voters,
classes=None,
):
"""
Function for fitting.
Stores attributes (classes, transformer_id_to_transformers,
and transformer_id_to_voters) of a ClassificationDecider.
Parameters:
-----------
X : ndarray
Input data matrix.
y : ndarray
Output (i.e. response) data matrix.
transformer_id_to_transformers : dict
A dictionary with keys of type obj corresponding to transformer ids
and values of type obj corresponding to a transformer. This dictionary
maps transformers to a particular transformer id.
transformer_id_to_voters : dict
A dictionary with keys of type obj corresponding to transformer ids
and values of type obj corresponding to a voter class. This dictionary thus
maps voter classes to a particular transformer id.
classes : list, default=None
List of final output classification labels of type obj.
Raises:
-----------
ValueError :
When the labels have not been provided and the classes are empty.
Returns:
----------
SimpleArgmaxAverage : obj
The ClassificationDecider object of class SimpleArgmaxAverage is returned.
"""
if not isinstance(self.classes, (list, np.ndarray)):
if len(y) == 0:
raise ValueError(
Expand All @@ -48,6 +124,33 @@ def fit(
return self

def predict_proba(self, X, transformer_ids=None):
"""
Predicts posterior probabilities per input example.
Loops through each transformer and bag of transformers.
Performs a transformation of the input data with the transformer.
Gets a voter to map the transformed input data into a posterior distribution.
Gets the mean vote per bag and append it to a vote per transformer id.
Returns the average vote per transformer id.
Parameters:
-----------
X : ndarray
Input data matrix.
transformer_ids : list, default=None
A list with specific transformer ids that will be used for inference. Defaults
to using all transformers if no transformer ids are given.
Raises:
-----------
NotFittedError :
When the model is not fitted.
Returns:
-----------
Returns mean vote across transformer ids as an ndarray.
"""
vote_per_transformer_id = []
for transformer_id in (
transformer_ids
Expand Down Expand Up @@ -76,6 +179,25 @@ def predict_proba(self, X, transformer_ids=None):
return np.mean(vote_per_transformer_id, axis=0)

def predict(self, X, transformer_ids=None):
"""
Predicts the most likely class per input example.
Uses the predict_proba method to get the mean vote per id.
Returns the class with the highest vote.
Parameters:
-----------
X : ndarray
Input data matrix.
transformer_ids : list, default=None
A list with all transformer ids. Defaults to None if no transformer ids
are given.
Returns:
-----------
The class with the highest vote based on the argmax of the votes as an int.
"""
if not self.is_fitted():
msg = (
"This %(name)s instance is not fitted yet. Call 'fit' with "
Expand All @@ -88,6 +210,10 @@ def predict(self, X, transformer_ids=None):

def is_fitted(self):
"""
Doc strings here.
Getter function to check if the decider is fitted.
Returns:
-----------
Boolean class attribute _is_fitted.
"""
return self._is_fitted

0 comments on commit d2300d2

Please sign in to comment.