Skip to content

Commit

Permalink
Update parametric_umap.py
Browse files Browse the repository at this point in the history
pep8
  • Loading branch information
AMS-Hippo authored Oct 24, 2024
1 parent 2d96f2b commit bc424cc
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions umap/parametric_umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,12 @@ def save(self, save_location, verbose=True):
print("Pickle of ParametricUMAP model saved to {}".format(model_output))

def add_landmarks(
self, X, sample_pct=0.01, sample_mode="uniform", landmark_loss_weight=0.01, idx=None
self,
X,
sample_pct=0.01,
sample_mode="uniform",
landmark_loss_weight=0.01,
idx=None,
):
"""Add some points from a dataset X as "landmarks."
Expand All @@ -510,15 +515,17 @@ def add_landmarks(
Method for sampling points. Allows "uniform" and "predefined."
landmark_loss_weight : float, optional
Multiplier for landmark loss function.
"""
self.sample_pct = sample_pct
self.sample_mode = sample_mode
self.landmark_loss_weight = landmark_loss_weight

if self.sample_mode == "uniform":
self.prev_epoch_idx = list(
np.random.choice(range(X.shape[0]), int(X.shape[0]*sample_pct), replace=False)
np.random.choice(
range(X.shape[0]), int(X.shape[0]*sample_pct), replace=False
)
)
self.prev_epoch_X = X[self.prev_epoch_idx]
elif self.sample_mode == "predetermined":
Expand Down

0 comments on commit bc424cc

Please sign in to comment.