Skip to content

Commit

Permalink
Update transformers.py
Browse files Browse the repository at this point in the history
Add docstrings to classes
  • Loading branch information
Michael-Ainsworth committed Sep 16, 2020
1 parent 883f509 commit be18fc8
Showing 1 changed file with 116 additions and 12 deletions.
128 changes: 116 additions & 12 deletions proglearn/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,63 @@


class NeuralClassificationTransformer(BaseTransformer):
"""
A class used to transform data from a category to a specialized representation.
Parameters
----------
network : object
A neural network used in the classification transformer.
euclidean_layer_idx : int
An integer to represent the final layer of the transformer.
optimizer : str
An optimizer used when compiling the neural network.
loss : str, default="categorical_crossentropy"
A loss function used when compiling the neural network.
pretrained : bool, default=False
A boolean used to identify if the network is pretrained.
compile_kwargs : dict, default={"metrics": ["acc"]}
A dictionary containing metrics for judging network performance.
fit_kwargs : dict, default={
"epochs": 100,
"callbacks": [keras.callbacks.EarlyStopping(patience=5, monitor="val_acc")],
"verbose": False,
"validation_split": 0.33,
},
A dictionary to hold epochs, callbacks, verbose, and validation split for the network.
Attributes (class)
----------
None
Attributes (object)
----------
network : object
A Keras model cloned from the network parameter.
encoder : object
A Keras model with inputs and outputs based on the network attribute. Output layers
are determined by the euclidean_layer_idx parameter.
_is_fitted : bool
A boolean to identify if the network has already been fitted.
optimizer : str
A string to identify the optimizer used in the network.
loss : str
A string to identify the loss function used in the network.
compile_kwargs : dict
A dictionary containing metrics for judging network performance.
fit_kwargs : dict
A dictionary to hold epochs, callbacks, verbose, and validation split for the network.
Methods
----------
fit(X, y)
Fits the transformer to data X with labels y.
transform(X)
Performs inference using the transformer.
is_fitted()
Indicates whether the transformer is fitted.
"""

def __init__(
self,
network,
Expand All @@ -35,9 +92,6 @@ def __init__(
"validation_split": 0.33,
},
):
"""
Doc strings here.
"""
self.network = keras.models.clone_model(network)
self.encoder = keras.models.Model(
inputs=self.network.inputs,
Expand All @@ -51,8 +105,16 @@ def __init__(

def fit(self, X, y):
"""
Doc strings here.
Fits the transformer to data X with labels y.
Parameters
----------
X : ndarray
Input data matrix.
y : ndarray
Output (i.e. response data matrix).
"""

check_classification_targets(y)
_, y = np.unique(y, return_inverse=True)
self.num_classes = len(np.unique(y))
Expand All @@ -72,7 +134,12 @@ def fit(self, X, y):

def transform(self, X):
"""
Doc strings here.
Performs inference using the transformer.
Parameters
----------
X : ndarray
Input data matrix.
"""

if not self.is_fitted():
Expand All @@ -87,25 +154,53 @@ def transform(self, X):

def is_fitted(self):
"""
Doc strings here.
Indicates whether the transformer is fitted.
Parameters
----------
None
"""

return self._is_fitted


class TreeClassificationTransformer(BaseTransformer):
"""
A class used to transform data from a category to a specialized representation.
Attributes (object)
----------
kwargs : dict
A dictionary to contain parameters of the tree.
_is_fitted_ : bool
A boolean to identify if the model is currently fitted.
Methods
----------
fit(X, y)
Fits the transformer to data X with labels y.
transform(X)
Performs inference using the transformer.
is_fitted()
Indicates whether the transformer is fitted.
"""

def __init__(self, kwargs={}):
"""
Doc strings here.
"""

self.kwargs = kwargs

self._is_fitted = False

def fit(self, X, y):
"""
Doc strings here.
Fits the transformer to data X with labels y.
Parameters
----------
X : ndarray
Input data matrix.
y : ndarray
Output (i.e. response data matrix).
"""

X, y = check_X_y(X, y)
Expand All @@ -119,7 +214,12 @@ def fit(self, X, y):

def transform(self, X):
"""
Doc strings here.
Performs inference using the transformer.
Parameters
----------
X : ndarray
Input data matrix.
"""

if not self.is_fitted():
Expand All @@ -134,7 +234,11 @@ def transform(self, X):

def is_fitted(self):
"""
Doc strings here.
Indicates whether the transformer is fitted.
Parameters
----------
None
"""

return self._is_fitted

0 comments on commit be18fc8

Please sign in to comment.