Skip to content

Commit

Permalink
add : n_best parameter adde to the predict method. (RaviSoji#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
sadrasabouri committed Mar 28, 2022
1 parent 6dca577 commit 2c133c1
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions plda/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self):
def fit_model(self, X, Y, n_principal_components=None):
self.model = Model(X, Y, n_principal_components)

def predict(self, data, space='D', normalize_logps=False):
def predict(self, data, n_best=1, space='D', normalize_logps=False):
""" Classifies data into categories present in the training data.
DESCRIPTION
Expand All @@ -43,6 +43,11 @@ def predict(self, data, space='D', normalize_logps=False):
- The dimensionality of the data depends on the space (see below).
PARAMETERS
n_best (int)
- number of n best predictions
that function returns.
space (str)
- Must be either 'D', 'X', 'U', or 'U_model',
where 'D' is the data space,
Expand All @@ -60,7 +65,7 @@ def predict(self, data, space='D', normalize_logps=False):
the posterior predictive probabilities before returning them.
RETURNS
predictions (numpy.ndarray), shape=data.shape[:-1]
predictions (numpy.ndarray), shape=(n_best)+data.shape[:-1]
logps (numpy.ndarray), shape=(*data.shape[:-1], n_categories)
- Log posterior predictive probabilities for each category,
Expand All @@ -75,9 +80,9 @@ def predict(self, data, space='D', normalize_logps=False):
from_space=space, to_space='U_model')

logpps_k, K = self.calc_logp_pp_categories(data, normalize_logps)
predictions = K[np.argmax(logpps_k, axis=-1)]
predictions = K[np.argsort(logpps_k, axis=-1).T[::-1][:n_best].T]

return predictions, logpps_k
return np.squeeze(predictions), logpps_k

def calc_logp_pp_categories(self, data, normalize_logps):
""" Computes log posterior predictive probabilities for each category.
Expand Down

0 comments on commit 2c133c1

Please sign in to comment.