Skip to content

Commit

Permalink
Included predict method for PKBC, updated Changelog and Documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
rmj3197 committed Mar 13, 2024
1 parent 5ae98d4 commit 9ed3948
Show file tree
Hide file tree
Showing 74 changed files with 14,518 additions and 291 deletions.
64 changes: 64 additions & 0 deletions QuadratiK/spherical_clustering/_pkbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,3 +415,67 @@ def stats(self):
"""
summary_stats = stats(self.dat, self.labels_)
return summary_stats

def predict(self, X):
"""
Predict the cluster membership for each sample in X.
Parameters
-----------
X : numpy.ndarray, pandas.DataFrame
New data to predict membership
Returns
--------
(Cluster Probabilities, Membership) : tuple
The first element of the tuple is the cluster probabilities of the input samples.
The second element of the tuple is the predicted cluster membership of the new data.
"""
num_data, num_var = X.shape
if self.dat.shape[1] != X.shape[1]:
raise ValueError(
f"X has {num_var} features, but PKBC is expecting {self.dat.shape[1]} features as input. Please provide same number of features as the fitted data."
)
log_w_d = (num_var / 2) * (np.log(2) + np.log(np.pi)) - sp.gammaln(num_var / 2)
v_mat = np.dot(X, self.mu_.T)
alpha_mat_current = np.tile(self.alpha_, (num_data, 1))
rho_mat_current = np.tile(self.rho_, (num_data, 1))
log_prob_mat_denom = np.log(
1 + rho_mat_current**2 - 2 * np.asarray(rho_mat_current) * np.asarray(v_mat)
)
log_prob_mat = (
np.log(1 - (rho_mat_current) ** 2)
- log_w_d
- (num_var / 2) * log_prob_mat_denom
)
prob_sum = np.tile(
np.dot(np.exp(log_prob_mat), self.alpha_).reshape(num_data, 1),
(1, self.num_clust),
)
log_norm_prob_mat_current = (
np.log(alpha_mat_current) + log_prob_mat - np.log(prob_sum)
)
log_weight_mat = log_norm_prob_mat_current - log_prob_mat_denom
alpha_current = np.sum(np.exp(log_norm_prob_mat_current), axis=0) / num_data
mu_num_sum_mat = np.dot(np.exp(log_weight_mat).T, X)
mu_denom = np.linalg.norm(mu_num_sum_mat, axis=1, keepdims=True)
for h in range(self.num_clust):
sum_h_weight_mat = np.sum(np.exp(log_weight_mat[:, h]))
alpha_current_h = alpha_current[h]
mu_denom_h = mu_denom[h]
self.rho_[h] = root_scalar(
root_func,
args=(
num_data,
alpha_current_h,
num_var,
mu_denom_h,
sum_h_weight_mat,
),
bracket=[0, 1],
xtol=0.001,
).root
norm_prob_mat_best = np.exp(log_norm_prob_mat_current)
memb_best = np.argmax(norm_prob_mat_best, axis=1)

return (norm_prob_mat_best, memb_best)
Binary file not shown.
Binary file modified doc/build/doctrees/changelog/v1.0.1.doctree
Binary file not shown.
Binary file modified doc/build/doctrees/environment.pickle
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 9ed3948

Please sign in to comment.