diff --git a/README.rst b/README.rst index 49845487..c8f70647 100644 --- a/README.rst +++ b/README.rst @@ -163,6 +163,40 @@ Based on the paper: *"Rates of convergence for the cluster tree."* In Advances in Neural Information Processing Systems, 2010. +---------------- +Branch detection +---------------- + +The hdbscan package supports a branch-detection post-processing step +by `Bot et al. `_. Cluster shapes, +such as branching structures, can reveal interesting patterns +that are not expressed in density-based cluster hierarchies. The +BranchDetector class mimics the HDBSCAN API and can be used to +detect branching hierarchies in clusters. It provides condensed +branch hierarchies, branch persistences, and branch memberships and +supports joblib's caching functionality. A notebook +`demonstrating the BranchDetector is available `_. + +Example usage: + +.. code:: python + + import hdbscan + from sklearn.datasets import make_blobs + + data, _ = make_blobs(1000) + + clusterer = hdbscan.HDBSCAN(branch_detection_data=True).fit(data) + branch_detector = hdbscan.BranchDetector().fit(clusterer) + branch_detector.cluster_approximation_graph_.plot(edge_width=0.1) + + +Based on the paper: + D. M. Bot, J. Peeters, J. Liesenborgs and J. Aerts + *"FLASC: A Flare-Sensitive Clustering Algorithm: Extending HDBSCAN\* for Detecting Branches in Clusters"* + Arxiv 2311.15887, 2023. + + ---------- Installing ---------- @@ -300,6 +334,24 @@ To reference the high performance algorithm developed in this library please cit organization={IEEE} } +If you used the branch-detection functionality in this codebase in a scientific publication and which to cite it, please use the `Arxiv preprint `_: + + D. M. Bot, J. Peeters, J. Liesenborgs and J. Aerts + *"FLASC: A Flare-Sensitive Clustering Algorithm: Extending HDBSCAN\* for Detecting Branches in Clusters"* + Arxiv 2311.15887, 2023. + +.. code:: bibtex + + @misc{bot2023flasc, + title={FLASC: A Flare-Sensitive Clustering Algorithm: Extending HDBSCAN* for Detecting Branches in Clusters}, + author={D. M. Bot and J. Peeters and J. Liesenborgs and J. Aerts}, + year={2023}, + eprint={2311.15887}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2311.15887}, + } + --------- Licensing --------- diff --git a/docs/api.rst b/docs/api.rst index ff9fe976..a11b8285 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -36,3 +36,15 @@ and the prediction module. .. automodule:: hdbscan.prediction :members: + + +Branch detection +---------------- + +The branches module contains classes for detecting branches within clusters. + +.. automodule:: hdbscan.branches + :members: + +.. autoclass:: hdbscan.plots.ApproximationGraph + :members: diff --git a/docs/how_to_detect_branches.rst b/docs/how_to_detect_branches.rst new file mode 100644 index 00000000..80da6a29 --- /dev/null +++ b/docs/how_to_detect_branches.rst @@ -0,0 +1,242 @@ +How to detect banches in clusters +================================= + +HDBSCAN\* is often used to find subpopulations in exploratory data +analysis workflows. Not only clusters themselves, but also their shape +can represent meaningful subpopulations. For example, a Y-shaped cluster +may represent an evolving process with two distinct end-states. +Detecting these branches can reveal interesting patterns that are not +captured by density-based clustering. + +For example, HDBSCAN\* finds 4 clusters in the datasets below, which +does not inform us of the branching structure: + +.. image:: images/how_to_detect_branches_3_0.png + +Alternatively, HDBSCAN\*’s leaf clusters provide more detail. They +segment the points of different branches into distint clusters. However, +the partitioning and cluster hierarchy does not (necessarily) tell us how +those clusters combine into a larger shape. + +.. image:: images/how_to_detect_branches_5_0.png + +This is where the branch detection post-processing step comes into play. +The functionality is described in detail by `Bot et +al `__. It operates on the detected +clusters and extracts a branch-hierarchy analogous to HDBSCAN\*’s +condensed cluster hierarchy. The process is very similar to HDBSCAN\* +clustering, except that it operates on an in-cluster eccentricity rather +than a density measure. Where peaks in a density profile correspond to +clusters, the peaks in an eccentricity profile correspond to branches: + +.. image:: images/how_to_detect_branches_7_0.png + +Using the branch detection functionality is fairly straightforward. +First, run hdbscan with parameter ``branch_detection_data=True``. This +tells hdbscan to cache the internal data structures needed for the +branch detection process. Then, configure the ``BranchDetector`` class +and fit is with the HDBSCAN object. + +The resulting partitioning reflects subgroups for clusters and their +branches: + +.. code:: python + from hdbscan import HDBSCAN, BranchDetector + + clusterer = HDBSCAN(min_cluster_size=15, branch_detection_data=True).fit(data) + branch_detector = BranchDetector(min_branch_size=15).fit(clusterer) + plot(branch_detector.labels_) + +.. image:: images/how_to_detect_branches_9_0.png + + +Parameter selection +------------------- + +The ``BranchDetector``’s main parameters are very similar to HDBSCAN. +Most guidelines for tuning HDBSCAN\* also apply for the branch detector: + +- ``min_branch_size`` behaves like HDBSCAN\*’s ``min_cluster_size``. It + configures how many points branches need to contain. Values around 10 + to 25 points tend to work well. Lower values are useful when looking + for smaller structures. Higher values can be used to suppress noise + if present. +- ``branch_selection_method`` behaves like HDBSCAN\*’s + ``cluster_selection_method``. The leaf and Excess of Mass (EOM) + strategies are used to select branches from the condensed + hierarchies. By default, branches are only reflected in the final + labelling for clusters that have 3 or more branches (at least one + bifurcation). +- ``branch_selection_persistence`` replaces HDBSCAN\*’s + ``cluster_selection_epsilon``. This parameter can be used to suppress + branches with a short eccentricity range (y-range in the condensed + hierarchy plot). +- ``allow_single_branch`` behaves like HDBSCAN\*’s + ``allow_single_cluster`` and mostly affects the EOM selection + strategy. When enabled, clusters with bifurcations will be given a + single label if the root segment contains most eccentricity mass + (i.e., branches already merge far from the center and most poinst are + central). +- ``max_branch_size`` behaves like HDBSCAN\*’s ``max_cluster_size`` and + mostly affects the EOM selection strategy. Branches with more than + the specified number of points are skipped, selecting their + descendants in the hierarchy instead. + +Two parameters are unique to the ``BranchDetector`` class: + +- ``branch_detection_method`` determines which points are connected + within a cluster. Both density-based clustering and the branch detection + process need to determine which points are part of the same + density/eccentricity peak. HDBSCAN\* defines density in terms of the distance + between points, providing natural way to define which points are connected at + some density value. Eccentricity does not have such a connection. So, we use + information from the clusters to determine which points should be connected + instead. + + - The ``"core"`` method selects all edges that could be part of the + cluster’s minimum spanning tree under HDBSCAN\*’s mutual + reachability distance. This graph contains the detected MST and + all ``min_samples``-nearest neighbours. + - The ``"full"`` method connects all points with a mutual + reachability lower than the maximum distance in the cluster’s MST. + It represents all connectity at the moment the last point joins + the cluster. + + These methods differ in their sensitivity, noise robustness, and + computational cost. The ``"core"`` method usually needs slightly + higher ``min_branch_size`` values to suppress noisy branches than the + ``"full"`` method. It is a good choice when branches span large + density ranges. + +- ``label_sides_as_branches`` determines whether the sides of an + elongated cluster without bifurcations (l-shape) are represented as + distinct subgroups. By default a cluster needs to have one + bifurcation (Y-shape) before the detected branches are represented in + the final labelling. + + +Useful attributes +----------------- + +Like the HDBSCAN class, the BranchDetector class contains several useful +attributes for exploring datasets. + +Branch hierarchy +~~~~~~~~~~~~~~~~ + +Branch hierarchies reflect the tree-shape of clusters. Like the cluster +hierarchy, branch hierarchies can be used to interpret which branches +exist. In addition, they reflect how far apart branches merge into the +cluster. + +.. code:: python + + idx = np.argmax([len(x) for x in branch_detector.branch_persistences_]) + branch_detector.cluster_condensed_trees_[idx].plot( + select_clusters=True, selection_palette=["C3", "C4", "C5"] + ) + plt.ylabel("Eccentricity") + plt.title(f"Branches in cluster {idx}") + plt.show() + +.. image:: images/how_to_detect_branches_13_0.png + +The length of the branches also says something about the compactness / +elongatedness of clusters. For example, the branch hierarchy for the +orange ~-shaped cluster is quite different from the same hierarcy for +the central o-shaped cluster. + +.. code:: python + + plt.figure(figsize=(6, 3)) + plt.subplot(1, 2, 1) + idx = np.argmin([min(*x) for x in branch_detector.branch_persistences_]) + branch_detector.cluster_condensed_trees_[idx].plot(colorbar=False) + plt.ylim([0.3, 0]) + plt.ylabel("Eccentricity") + plt.title(f"Cluster {idx} (spherical)") + + plt.subplot(1, 2, 2) + idx = np.argmax([max(*x) for x in branch_detector.branch_persistences_]) + branch_detector.cluster_condensed_trees_[idx].plot(colorbar=False) + plt.ylim([0.3, 0]) + plt.ylabel("Eccentricity") + plt.title(f"Cluster {idx} (elongated)") + plt.show() + +.. image:: images/how_to_detect_branches_15_0.png + +Cluster approximation graphs +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Branches are detected using a graph that approximates the connectivity +within a cluster. These graphs are available in the +``cluster_approximation_graph_`` property and can be used to visualise +data and the branch-detection process. The plotting function is based on +the networkx API and uses networkx functionality to compute a layout if +positions are not provided. Using UMAP to compute positions can be +faster and more expressive. Several helper functions for exporting to +numpy, pandas, and networkx are available. + +For example, a figure with points coloured by the final labelling: + +.. code:: python + + g = branch_detector.cluster_approximation_graph_ + g.plot(positions=data, node_size=5, edge_width=0.2, edge_alpha=0.2) + plt.show() + +.. image:: images/how_to_detect_branches_17_0.png + +Or, a figure with the edges coloured by centrality: + +.. code:: python + + g.plot( + positions=data, + node_alpha=0, + edge_color="centrality", + edge_cmap="turbo", + edge_width=0.2, + edge_alpha=0.2, + edge_vmax=100, + ) + plt.show() + +.. image:: images/how_to_detect_branches_19_0.png + + +Approximate predict +------------------- + +A branch-aware ``approximate_predict_branch`` function is available to +predicts branch labels for new points. This function uses a fitted +BranchDetector object to first predict cluster labels and then the +branch labels. + +.. code:: python + + from hdbscan import approximate_predict_branch + + new_points = np.asarray([[0.4, 0.25], [0.23, 0.2], [-0.14, -0.2]]) + clusterer.generate_prediction_data() + labels, probs, cluster_labels, cluster_probs, branch_labels, branch_probs = ( + approximate_predict_branch(branch_detector, new_points) + ) + + plt.scatter( + new_points.T[0], + new_points.T[1], + 140, + labels % 10, + marker="p", + zorder=5, + cmap="tab10", + vmin=0, + vmax=9, + edgecolor="k", + ) + plot(branch_detector.labels_) + plt.show() + +.. image:: images/how_to_detect_branches_21_0.png diff --git a/docs/images/how_to_detect_branches_13_0.png b/docs/images/how_to_detect_branches_13_0.png new file mode 100644 index 00000000..16d409a4 Binary files /dev/null and b/docs/images/how_to_detect_branches_13_0.png differ diff --git a/docs/images/how_to_detect_branches_15_0.png b/docs/images/how_to_detect_branches_15_0.png new file mode 100644 index 00000000..63cf1f7e Binary files /dev/null and b/docs/images/how_to_detect_branches_15_0.png differ diff --git a/docs/images/how_to_detect_branches_17_0.png b/docs/images/how_to_detect_branches_17_0.png new file mode 100644 index 00000000..05c5c21d Binary files /dev/null and b/docs/images/how_to_detect_branches_17_0.png differ diff --git a/docs/images/how_to_detect_branches_19_0.png b/docs/images/how_to_detect_branches_19_0.png new file mode 100644 index 00000000..206e0828 Binary files /dev/null and b/docs/images/how_to_detect_branches_19_0.png differ diff --git a/docs/images/how_to_detect_branches_21_0.png b/docs/images/how_to_detect_branches_21_0.png new file mode 100644 index 00000000..a1d4e47a Binary files /dev/null and b/docs/images/how_to_detect_branches_21_0.png differ diff --git a/docs/images/how_to_detect_branches_3_0.png b/docs/images/how_to_detect_branches_3_0.png new file mode 100644 index 00000000..6ff165df Binary files /dev/null and b/docs/images/how_to_detect_branches_3_0.png differ diff --git a/docs/images/how_to_detect_branches_5_0.png b/docs/images/how_to_detect_branches_5_0.png new file mode 100644 index 00000000..fdb6c5de Binary files /dev/null and b/docs/images/how_to_detect_branches_5_0.png differ diff --git a/docs/images/how_to_detect_branches_7_0.png b/docs/images/how_to_detect_branches_7_0.png new file mode 100644 index 00000000..a331a189 Binary files /dev/null and b/docs/images/how_to_detect_branches_7_0.png differ diff --git a/docs/images/how_to_detect_branches_9_0.png b/docs/images/how_to_detect_branches_9_0.png new file mode 100644 index 00000000..cd199edf Binary files /dev/null and b/docs/images/how_to_detect_branches_9_0.png differ diff --git a/docs/index.rst b/docs/index.rst index c3434561..5f2c8d23 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -26,6 +26,7 @@ User Guide / Tutorial soft_clustering how_to_use_epsilon dbscan_from_hdbscan + how_to_detect_branches faq Background on Clustering with HDBSCAN diff --git a/hdbscan/__init__.py b/hdbscan/__init__.py index 2f6e10c9..99ea7211 100644 --- a/hdbscan/__init__.py +++ b/hdbscan/__init__.py @@ -5,5 +5,8 @@ membership_vector, all_points_membership_vectors, approximate_predict_scores) +from .branches import (BranchDetector, + detect_branches_in_clusters, + approximate_predict_branch) diff --git a/hdbscan/_hdbscan_linkage.pyx b/hdbscan/_hdbscan_linkage.pyx index 738ed6a2..66152e38 100644 --- a/hdbscan/_hdbscan_linkage.pyx +++ b/hdbscan/_hdbscan_linkage.pyx @@ -76,8 +76,8 @@ cpdef np.ndarray[np.double_t, ndim=2] mst_linkage_core_vector( cdef np.intp_t current_node cdef np.intp_t source_node - cdef np.intp_t right_node - cdef np.intp_t left_node + cdef np.intp_t right_node, right_source + cdef np.intp_t left_node, left_source cdef np.intp_t new_node cdef np.intp_t i cdef np.intp_t j @@ -124,7 +124,7 @@ cpdef np.ndarray[np.double_t, ndim=2] mst_linkage_core_vector( continue right_value = current_distances[j] - right_source = current_sources[j] + right_source = current_sources[j] left_value = dist_metric.dist(&raw_data_ptr[num_features * current_node], diff --git a/hdbscan/_hdbscan_tree.pyx b/hdbscan/_hdbscan_tree.pyx index aeb40518..d8746d53 100644 --- a/hdbscan/_hdbscan_tree.pyx +++ b/hdbscan/_hdbscan_tree.pyx @@ -301,6 +301,65 @@ cdef max_lambdas(np.ndarray tree): return deaths_arr +cdef max_eccentricities(np.ndarray tree, np.intp_t num_points): + """Maximum eccentricity for each segment in the branch hierarchy. + + Differs from max_lambda by using the max value in the sub-tree rather + than per segment. + """ + cdef np.ndarray sorted_parent_data + cdef np.ndarray[np.intp_t, ndim=1] sorted_parents + cdef np.ndarray[np.intp_t, ndim=1] sorted_children + cdef np.ndarray[np.double_t, ndim=1] sorted_lambdas + cdef np.intp_t[:] sorted_parents_view + cdef np.intp_t[:] sorted_child_view + cdef np.double_t[:] sorted_ecc_view + + cdef np.intp_t parent, child + cdef np.intp_t current_parent + cdef np.float64_t eccentricity + cdef np.float64_t max_lambda + + cdef np.ndarray[np.double_t, ndim=1] births_arr + cdef np.double_t[::1] births + + cdef np.intp_t largest_parent = tree['parent'].max() + + sorted_parent_data = np.sort(tree[['parent', 'lambda_val', 'child']], axis=0) + sorted_parents = sorted_parent_data['parent'] + sorted_children = sorted_parent_data['child'] + sorted_lambdas = sorted_parent_data['lambda_val'] + sorted_parents_view = sorted_parents[::-1] + sorted_child_view = sorted_children[::-1] + sorted_ecc_view = sorted_lambdas[::-1] + + births_arr = np.zeros(largest_parent + 1, dtype=np.double) + births = births_arr + + current_parent = -1 + max_eccentricity = 0 + + for parent, child, eccentricity in zip(sorted_parents_view, sorted_child_view, sorted_ecc_view): + # Use maximum density (= eccentricity) within branch rather than + # the maximum density (= eccentricity) of this segment in the condensed tree. + # Need child to be process first -> iterate from high to low parent! + if child >= num_points: + eccentricity = births[child] + if parent == current_parent: + max_eccentricity = max(max_eccentricity, eccentricity) + elif current_parent != -1: + births[current_parent] = max_eccentricity + current_parent = parent + max_eccentricity = eccentricity + else: + # Initialize + current_parent = parent + max_eccentricity = eccentricity + + births[current_parent] = max_eccentricity # value for last parent + return births_arr + + cdef class TreeUnionFind (object): cdef np.ndarray _data_arr @@ -388,8 +447,8 @@ cpdef np.ndarray[np.intp_t, ndim=1] labelling_at_cut( cluster = num_points for row in linkage: if row[2] < cut: - union_find.union_( row[0], cluster) - union_find.union_( row[1], cluster) + union_find.union_(np.intp(row[0]), cluster) + union_find.union_(np.intp(row[1]), cluster) cluster += 1 cluster_size = np.zeros(cluster, dtype=np.intp) @@ -483,10 +542,8 @@ cdef np.ndarray[np.intp_t, ndim=1] do_labelling( return result_arr -cdef get_probabilities(np.ndarray tree, dict cluster_map, np.ndarray labels): - +cdef get_probabilities(np.ndarray tree, dict cluster_map, np.ndarray labels, np.ndarray deaths): cdef np.ndarray[np.double_t, ndim=1] result - cdef np.ndarray[np.double_t, ndim=1] deaths cdef np.ndarray[np.double_t, ndim=1] lambda_array cdef np.ndarray[np.intp_t, ndim=1] child_array cdef np.ndarray[np.intp_t, ndim=1] parent_array @@ -503,7 +560,6 @@ cdef get_probabilities(np.ndarray tree, dict cluster_map, np.ndarray labels): lambda_array = tree['lambda_val'] result = np.zeros(labels.shape[0]) - deaths = max_lambdas(tree) root_cluster = parent_array.min() for n in range(tree.shape[0]): @@ -606,6 +662,7 @@ cpdef np.ndarray get_stability_scores(np.ndarray labels, set clusters, return result + cpdef list recurse_leaf_dfs(np.ndarray cluster_tree, np.intp_t current_node): children = cluster_tree[cluster_tree['parent'] == current_node]['child'] if len(children) == 0: @@ -620,22 +677,24 @@ cpdef list get_cluster_tree_leaves(np.ndarray cluster_tree): root = cluster_tree['parent'].min() return recurse_leaf_dfs(cluster_tree, root) + cpdef np.intp_t traverse_upwards(np.ndarray cluster_tree, np.double_t cluster_selection_epsilon, np.intp_t leaf, np.intp_t allow_single_cluster): root = cluster_tree['parent'].min() - parent = cluster_tree[cluster_tree['child'] == leaf]['parent'] + parent = cluster_tree[cluster_tree['child'] == leaf]['parent'][0] if parent == root: if allow_single_cluster: return parent else: return leaf #return node closest to root - parent_eps = 1/cluster_tree[cluster_tree['child'] == parent]['lambda_val'] + parent_eps = 1/cluster_tree[cluster_tree['child'] == parent]['lambda_val'][0] if parent_eps > cluster_selection_epsilon: return parent else: return traverse_upwards(cluster_tree, cluster_selection_epsilon, parent, allow_single_cluster) + cpdef set epsilon_search(set leaves, np.ndarray cluster_tree, np.double_t cluster_selection_epsilon, np.intp_t allow_single_cluster): selected_clusters = list() @@ -656,6 +715,78 @@ cpdef set epsilon_search(set leaves, np.ndarray cluster_tree, np.double_t cluste return set(selected_clusters) + +cpdef np.ndarray simplify_branch_hierarchy(np.ndarray condensed_tree, + np.double_t persistence_threshold): + """Iteratively remove branches with persistence below threshold. + + Takes the place of epsilon_search which cannot deal with non-zero births. + """ + cdef bint flag + cdef np.intp_t leaf, sibling, parent, leaf_idx, sibling_idx + cdef np.double_t persistence + cdef np.ndarray cluster_tree, child_ids + cdef list leaves, persistences + cdef set processed + cdef np.intp_t num_points = condensed_tree['parent'].min() + + cdef np.ndarray[np.double_t, ndim=1] births_arr = max_eccentricities(condensed_tree, num_points) + cdef np.double_t[::1] births = births_arr + + while True: + processed = set() + cluster_tree = condensed_tree[condensed_tree['child_size'] > 1] + leaves = get_cluster_tree_leaves(cluster_tree) + persistences = [ + births[leaf] - cluster_tree['lambda_val'][cluster_tree['child'] == leaf][0] + for leaf in leaves + ] + flag = True + for leaf, persistence in zip(leaves, persistences): + if leaf in processed: + continue + if persistence < persistence_threshold: + flag = False + # Find parent and sibling + leaf_idx = np.argmax(condensed_tree['child'] == leaf) + parent = condensed_tree['parent'][leaf_idx] + child_ids = np.where((condensed_tree['parent'] == parent) & (condensed_tree['child_size'] > 1))[0] + sibling_idx = child_ids[1] if child_ids[0] == leaf_idx else child_ids[0] + sibling = condensed_tree['child'][sibling_idx] + + # Reset leaf and sibling rows + condensed_tree['child'][leaf_idx] = -999 + condensed_tree['child'][sibling_idx] = -999 + condensed_tree['parent'][leaf_idx] = -999 + condensed_tree['parent'][sibling_idx] = -999 + + # Update parent-values to reflect the merge + condensed_tree['parent'][(condensed_tree['parent'] == leaf)] = parent + condensed_tree['parent'][(condensed_tree['parent'] == sibling)] = parent + + processed.add(leaf) + processed.add(sibling) + + # Remove marked rows + condensed_tree = condensed_tree[condensed_tree['parent'] != -999] + if flag: + break + + return remap_cluster_ids(condensed_tree, num_points) + + +cdef np.ndarray remap_cluster_ids(np.ndarray condensed_tree, np.intp_t num_points): + """Ensures segments are numbered consequetively from 0 to n_clusters-1.""" + max_parent = condensed_tree['parent'].max() + id_map = np.empty(max_parent + 1) + id_map[:num_points] = np.arange(num_points) + remaining_parents = np.unique(condensed_tree['parent']) + id_map[remaining_parents] = num_points + np.arange(remaining_parents.shape[0]) + condensed_tree['parent'] = id_map[condensed_tree['parent']] + condensed_tree['child'] = id_map[condensed_tree['child']] + return condensed_tree + + cpdef tuple get_clusters(np.ndarray tree, dict stability, cluster_selection_method='eom', allow_single_cluster=False, @@ -718,6 +849,7 @@ cpdef tuple get_clusters(np.ndarray tree, dict stability, cdef np.intp_t num_points cdef np.ndarray labels cdef np.double_t max_lambda + cdef np.ndarray[np.double_t, ndim=1] deaths # Assume clusters are ordered by numeric id equivalent to # a topological sort of the tree; This is valid given the @@ -725,16 +857,16 @@ cpdef tuple get_clusters(np.ndarray tree, dict stability, # if you do, change this accordingly! if allow_single_cluster: node_list = sorted(stability.keys(), reverse=True) - node_list = [int(n) for n in node_list] else: node_list = sorted(stability.keys(), reverse=True)[:-1] - node_list = [int(n) for n in node_list] + node_list = [int(n) for n in node_list] # (exclude root) cluster_tree = tree[tree['child_size'] > 1] is_cluster = {cluster: True for cluster in node_list} num_points = np.max(tree[tree['child_size'] == 1]['child']) + 1 max_lambda = np.max(tree['lambda_val']) + deaths = max_lambdas(tree) if max_cluster_size <= 0: max_cluster_size = num_points + 1 # Set to a value that will never be triggered @@ -802,7 +934,131 @@ cpdef tuple get_clusters(np.ndarray tree, dict stability, labels = do_labelling(tree, clusters, cluster_map, allow_single_cluster, cluster_selection_epsilon, match_reference_implementation) - probs = get_probabilities(tree, reverse_cluster_map, labels) + probs = get_probabilities(tree, reverse_cluster_map, labels, deaths) stabilities = get_stability_scores(labels, clusters, stability, max_lambda) return (labels, probs, stabilities) + + +cpdef tuple get_branches(np.ndarray tree, + dict stability, + branch_selection_method='eom', + allow_single_branch=False, + max_branch_size=0): + """Extracts branches from a branch condensed tree. + + Given a tree and stability dict, produce the branch labels + (and probabilities) for a flat "clustering" based on the chosen + branch selection method. + + Parameters + ---------- + tree : numpy recarray + The condensed tree to extract flat clusters from. + Uses lambda_val key to store eccentricity values. + stability : dict + A dictionary mapping cluster_ids to stability values + branch_selection_method : string, optional (default 'eom') + The method of selecting clusters. The default is the + Excess of Mass algorithm specified by 'eom'. The alternate + option is 'leaf'. + allow_single_branch : boolean, optional (default False) + Whether to allow a single branch to be selected by the + Excess of Mass algorithm. + max_branch_size: int, optional (default 0) + The maximum size for clusters located by the EOM clusterer. Can + be overridden by the branch_selection_persistence parameter in + rare cases. + + Returns + ------- + labels : ndarray (n_samples,) + An integer array of branch labels, with -1 denoting noise. + probabilities : ndarray (n_samples,) + The cluster membership strength of each sample. + stabilities : ndarray (n_clusters,) + The cluster coherence strengths of each branch. + """ + cdef list node_list + cdef np.ndarray cluster_tree + cdef np.ndarray child_selection + cdef dict is_cluster + cdef dict cluster_sizes + cdef float subtree_stability + cdef np.intp_t node + cdef np.intp_t sub_node + cdef np.intp_t cluster + cdef np.intp_t num_points + cdef np.ndarray labels + cdef np.double_t max_eccentricity + cdef np.ndarray[np.double_t, ndim=1] births + + # Assume clusters are ordered by numeric id equivalent to + # a topological sort of the tree; This is valid given the + # current implementation above, so don't change that ... or + # if you do, change this accordingly! + if allow_single_branch: + node_list = sorted(stability.keys(), reverse=True) + else: + node_list = sorted(stability.keys(), reverse=True)[:-1] # (exclude root) + node_list = [int(n) for n in node_list] + + cluster_tree = tree[tree['child_size'] > 1] + is_cluster = {cluster: True for cluster in node_list} + num_points = np.min(tree['parent']) + max_eccentricity = np.max(tree['lambda_val']) + births = max_eccentricities(tree, num_points) + + if max_branch_size <= 0: + max_branch_size = num_points + 1 # Set to a value that will never be triggered + cluster_sizes = { + child: child_size + for child, child_size in zip(cluster_tree['child'], cluster_tree['child_size']) + } + if allow_single_branch: + # Compute cluster size for the root node + cluster_sizes[node_list[-1]] = np.sum( + cluster_tree[cluster_tree['parent'] == node_list[-1]]['child_size'] + ) + + if branch_selection_method == 'eom': + for node in node_list: + child_selection = (cluster_tree['parent'] == node) + subtree_stability = np.sum([ + stability[child] for + child in cluster_tree['child'][child_selection]]) + if subtree_stability > stability[node] or cluster_sizes[node] > max_branch_size: + is_cluster[node] = False + stability[node] = subtree_stability + else: + for sub_node in bfs_from_cluster_tree(cluster_tree, node): + if sub_node != node: + is_cluster[sub_node] = False + elif branch_selection_method == 'leaf': + leaves = get_cluster_tree_leaves(cluster_tree) + selected_clusters = set(leaves) + + # Allow single leaf + if len(selected_clusters) == 0 and allow_single_branch: + for c in is_cluster: + is_cluster[c] = False + is_cluster[tree['parent'].min()] = True + else: + for c in is_cluster: + if c in selected_clusters: + is_cluster[c] = True + else: + is_cluster[c] = False + else: + raise ValueError('Invalid Cluster Selection Method: %s\n' + 'Should be one of: "eom", "leaf"\n') + + clusters = set([c for c in is_cluster if is_cluster[c]]) + cluster_map = {c: n for n, c in enumerate(sorted(list(clusters)))} + reverse_cluster_map = {n: c for c, n in cluster_map.items()} + + labels = do_labelling(tree, clusters, cluster_map, allow_single_branch, 0.0, False) + probs = get_probabilities(tree, reverse_cluster_map, labels, births) + stabilities = get_stability_scores(labels, clusters, stability, max_eccentricity) + + return (labels, probs, stabilities) \ No newline at end of file diff --git a/hdbscan/branches.py b/hdbscan/branches.py new file mode 100644 index 00000000..e0643859 --- /dev/null +++ b/hdbscan/branches.py @@ -0,0 +1,1184 @@ +# Support branch detection within clusters. +import numpy as np + +from sklearn.base import BaseEstimator, ClusterMixin +from sklearn.neighbors import KDTree, BallTree +from scipy.sparse import coo_array +from scipy.sparse.csgraph import minimum_spanning_tree +from joblib import Memory +from joblib import Parallel, delayed +from joblib.parallel import cpu_count +from .dist_metrics import DistanceMetric +from ._hdbscan_linkage import label +from .plots import CondensedTree, SingleLinkageTree, ApproximationGraph +from .prediction import approximate_predict +from ._hdbscan_tree import ( + get_branches, + condense_tree, + recurse_leaf_dfs, + compute_stability, + simplify_branch_hierarchy, +) + + +class BranchDetectionData(object): + """Input data for branch detection functionality. + + Recreates and caches internal data structures from the clustering stage. + + Parameters + ---------- + + data : array (n_samples, n_features) + The original data set that was clustered. + + labels : array (n_samples) + The cluster labels for every point in the data set. + + min_samples : int + The min_samples value used in clustering. + + tree_type : string, optional + Which type of space tree to use for core distance computation. + One of: + * ``kdtree`` + * ``balltree`` + + metric : string, optional + The metric used to determine distance for the clustering. + This is the metric that will be used for the space tree to determine + core distances etc. + + **kwargs : + Any further arguments to the metric. + + Attributes + ---------- + + all_finite : bool + Whether the data set contains any infinite or NaN values. + + finite_index : array (n_samples) + The indices of the finite data points in the original data set. + + internal_to_raw : dict + A mapping from the finite data set indices to the original data set. + + tree : KDTree or BallTree + A space partitioning tree that can be queried for nearest neighbors if + the metric is supported by a KDTree or BallTree. + + neighbors : array (n_samples, min_samples) + The nearest neighbor for every non-noise point in the original data set. + + core_distances : array (n_samples) + The core distance for every non-noise point in the original data set. + + dist_metric : callable + Accelerated distance metric function. + """ + + _tree_type_map = {"kdtree": KDTree, "balltree": BallTree} + + def __init__( + self, + data, + all_finite, + finite_index, + labels, + min_samples, + tree_type="kdtree", + metric="euclidean", + **kwargs, + ): + # Select finite data points + self.all_finite = all_finite + self.finite_index = finite_index + clean_data = data.astype(np.float64) + if not all_finite: + labels = labels[finite_index] + clean_data = clean_data[finite_index] + self.internal_to_raw = { + x: y for x, y in zip(range(len(finite_index)), finite_index) + } + else: + self.internal_to_raw = None + + # Construct tree + self.tree = self._tree_type_map[tree_type](clean_data, metric=metric, **kwargs) + self.dist_metric = DistanceMetric.get_metric(metric, **kwargs) + + # Allocate to maintain data point indices + self.core_distances = np.full(clean_data.shape[0], np.nan) + self.neighbors = np.full((clean_data.shape[0], min_samples), -1, dtype=np.int64) + + # Find neighbours for non-noise points + noise_mask = labels != -1 + if noise_mask.any(): + distances, self.neighbors[noise_mask, :] = self.tree.query( + clean_data[noise_mask], k=min_samples + ) + self.core_distances[noise_mask] = distances[:, -1] + + +def detect_branches_in_clusters( + clusterer, + min_branch_size=None, + allow_single_branch=False, + branch_detection_method="full", + branch_selection_method="eom", + branch_selection_persistence=0.0, + max_branch_size=0, + label_sides_as_branches=False, +): + """ + Performs a flare-detection post-processing step to detect branches within + clusters [1]_. + + For each cluster, a graph is constructed connecting the data points based on + their mutual reachability distances. Each edge is given a centrality value + based on how far it lies from the cluster's center. Then, the edges are + clustered as if that centrality was a distance, progressively removing the + 'center' of each cluster and seeing how many branches remain. + + Parameters + ---------- + + clusterer : hdbscan.HDBSCAN + The clusterer object that has been fit to the data with branch detection + data generated. + + min_branch_size : int, optional (default=None) + The minimum number of samples in a group for that group to be + considered a branch; groupings smaller than this size will seen as + points falling out of a branch. Defaults to the clusterer's min_cluster_size. + + allow_single_branch : bool, optional (default=False) + Analogous to ``allow_single_cluster``. + + branch_detection_method : str, optional (default=``full``) + Deteremines which graph is conctructed to detect branches with. Valid + values are, ordered by increasing computation cost and decreasing + sensitivity to noise: + - ``core``: Contains the edges that connect each point to all other + points within a mutual reachability distance lower than or equal to + the point's core distance. This is the cluster's subgraph of the + k-NN graph over the entire data set (with k = ``min_samples``). + - ``full``: Contains all edges between points in each cluster with a + mutual reachability distance lower than or equal to the distance of + the most-distance point in each cluster. These graphs represent the + 0-dimensional simplicial complex of each cluster at the first point in + the filtration where they contain all their points. + + branch_selection_method : str, optional (default='eom') + The method used to select branches from the cluster's condensed tree. + The standard approach for FLASC is to use the ``eom`` approach. + Options are: + * ``eom`` + * ``leaf`` + + branch_selection_persistence: float, optional (default=0.0) + An eccentricity persistence threshold. Branches with a persistence below + this value will be merged. See [3]_ for more information. Note that this + should not be used if we want to predict the cluster labels for new + points in future (e.g. using approximate_predict), as the + :func:`~flasc.prediction.approximate_predict` function is not aware of + this argument. + + max_branch_size : int, optional (default=0) + A limit to the size of clusters returned by the ``eom`` algorithm. + Has no effect when using ``leaf`` clustering (where clusters are + usually small regardless). Note that this should not be used if we + want to predict the cluster labels for new points in future (e.g. using + :func:`~flasc.prediction.approximate_predict`), as that function is + not aware of this argument. + + label_sides_as_branches : bool, optional (default=False), + When this flag is False, branches are only labelled for clusters with at + least three branches (i.e., at least y-shapes). Clusters with only two + branches represent l-shapes. The two branches describe the cluster's + outsides growing towards each other. Enableing this flag separates these + branches from each other in the produced labelling. + + Returns + ------- + labels : np.ndarray, shape (n_samples, ) + Labels that differentiate all subgroups (clusters and branches). Noisy + samples are given the label -1. + + probabilities : np.ndarray, shape (n_samples, ) + Probabilities considering both cluster and branch membership. Noisy + samples are assigned 0. + + branch_labels : np.ndarray, shape (n_samples, ) + Branch labels for each point. Noisy samples are given the label -1. + + branch_probabilities : np.ndarray, shape (n_samples, ) + Branch membership strengths for each point. Noisy samples are + assigned 0. + + branch_persistences : tuple (n_clusters) + A branch persistence (eccentricity range) for each detected branch. + + cluster_approximation_graphs : tuple (n_clusters) + The graphs used to detect branches in each cluster stored as a numpy + array with four columns: source, target, centrality, mutual reachability + distance. Points are labelled by their row-index into the input data. + The edges contained in the graphs depend on the ``branch_detection_method``: + - ``core``: Contains the edges that connect each point to all other + points in a cluster within a mutual reachability distance lower than + or equal to the point's core distance. This is an extension of the + minimum spanning tree introducing only edges with equal distances. The + reachability distance introduces ``num_points`` * ``min_samples`` of + such edges. + - ``full``: Contains all edges between points in each cluster with a + mutual reachability distance lower than or equal to the distance of + the most-distance point in each cluster. These graphs represent the + 0-dimensional simplicial complex of each cluster at the first point in + the filtration where they contain all their points. + + cluster_condensed_trees : tuple (n_clusters) + A condensed branch hierarchy for each cluster produced during the + branch detection step. Data points are numbered with in-cluster ids. + + cluster_linkage_trees : tuple (n_clusters) + A single linkage tree for each cluster produced during the branch + detection step, in the scipy hierarchical clustering format. + (see http://docs.scipy.org/doc/scipy/reference/cluster.hierarchy.html). + Data points are numbered with in-cluster ids. + + cluster_centralities : np.ndarray, shape (n_samples, ) + Centrality values for each point in a cluster. Overemphasizes points' + eccentricity within the cluster as the values are based on minimum + spanning trees that do not contain the equally distanced edges resulting + from the mutual reachability distance. + + cluster_points : list (n_clusters) + The data point row indices for each cluster. + + References + ---------- + .. [1] Bot, D. M., Peeters, J., Liesenborgs J., & Aerts, J. (2023, November). + FLASC: A Flare-Sensitive Clustering Algorithm: Extending HDBSCAN* for + Detecting Branches in Clusters. arXiv:2311.15887 + """ + # Check clusterer state + if clusterer._min_spanning_tree is None: + raise ValueError( + "Clusterer does not have an explicit minimum spannning tree!" + " Try fitting with branch_detection_data=True or" + " gen_min_span_tree=True set." + ) + if clusterer.branch_detection_data_ is None: + raise ValueError( + "Clusterer does not have branch detection data!" + " Try fitting with branch_detection_data=True set," + " or run generate_branch_detection_data on the clusterer" + ) + + # Validate parameters + if min_branch_size is None: + min_branch_size = clusterer.min_cluster_size + branch_selection_persistence = float(branch_selection_persistence) + if not (np.issubdtype(type(min_branch_size), np.integer) and min_branch_size >= 2): + raise ValueError( + f"min_branch_size must be an integer greater or equal " + f"to 2, {min_branch_size} given." + ) + if not ( + np.issubdtype(type(branch_selection_persistence), np.floating) + and branch_selection_persistence >= 0.0 + ): + raise ValueError( + f"branch_selection_persistence must be a float greater or equal to " + f"0.0, {branch_selection_persistence} given." + ) + if branch_selection_method not in ("eom", "leaf"): + raise ValueError( + f"Invalid branch_selection_method: {branch_selection_method}\n" + f'Should be one of: "eom", "leaf"\n' + ) + if branch_detection_method not in ("core", "full"): + raise ValueError( + f"Invalid ``branch_detection_method``: {branch_detection_method}\n" + 'Should be one of: "core", "full"\n' + ) + + # Extract state + memory = clusterer.memory + if isinstance(memory, str): + memory = Memory(memory, verbose=0) + num_clusters = len(clusterer.cluster_persistence_) + labels = clusterer.labels_ + probabilities = clusterer.probabilities_ + if not clusterer.branch_detection_data_.all_finite: + finite_index = clusterer.branch_detection_data_.finite_index + labels = labels[finite_index] + probabilities = probabilities[finite_index] + + # Configure parallelization + run_core = branch_detection_method == "core" + num_jobs = clusterer.core_dist_n_jobs + if num_jobs < 1: + num_jobs = max(cpu_count() + 1 + num_jobs, 1) + thread_pool = ( + SequentialPool() if run_core else Parallel(n_jobs=num_jobs, max_nbytes=None) + ) + + # Detect branches + ( + cluster_points, + cluster_centralities, + cluster_linkage_trees, + cluster_approximation_graphs, + ) = memory.cache(_compute_branch_linkage, ignore=["thread_pool"])( + labels, + probabilities, + clusterer._min_spanning_tree, + clusterer.branch_detection_data_.tree, + clusterer.branch_detection_data_.neighbors, + clusterer.branch_detection_data_.core_distances, + clusterer.branch_detection_data_.dist_metric, + num_clusters, + thread_pool, + run_core=run_core, + ) + ( + branch_labels, + branch_probabilities, + branch_persistences, + cluster_condensed_trees, + ) = memory.cache(_compute_branch_segmentation, ignore=["thread_pool"])( + cluster_linkage_trees, + thread_pool, + min_branch_size=min_branch_size, + allow_single_branch=allow_single_branch, + branch_selection_method=branch_selection_method, + branch_selection_persistence=branch_selection_persistence, + max_branch_size=max_branch_size, + ) + ( + labels, + probabilities, + branch_labels, + branch_probabilities, + cluster_centralities, + ) = memory.cache(_update_labelling)( + labels, + probabilities, + cluster_points, + cluster_centralities, + branch_labels, + branch_probabilities, + branch_persistences, + label_sides_as_branches=label_sides_as_branches, + ) + + # Maintain data indices for non-finite data + if not clusterer.branch_detection_data_.all_finite: + internal_to_raw = clusterer.branch_detection_data_.internal_to_raw + _remap_point_lists(cluster_points, internal_to_raw) + _remap_edge_lists(cluster_approximation_graphs, internal_to_raw) + + num_points = len(clusterer.labels_) + labels = _remap_labels(labels, finite_index, num_points) + probabilities = _remap_probabilities(probabilities, finite_index, num_points) + branch_labels = _remap_labels(branch_labels, finite_index, num_points) + branch_probabilities = _remap_probabilities( + branch_probabilities, finite_index, num_points + ) + cluster_centralities = _remap_probabilities( + cluster_centralities, finite_index, num_points + ) + + return ( + # Combined result + labels, + probabilities, + # Branching result + branch_labels, + branch_probabilities, + branch_persistences, + # Clusters to branches + cluster_approximation_graphs, + cluster_condensed_trees, + cluster_linkage_trees, + cluster_centralities, + cluster_points, + ) + + +def _compute_branch_linkage( + cluster_labels, + cluster_probabilities, + min_spanning_tree, + space_tree, + neighbors, + core_distances, + dist_metric, + num_clusters, + thread_pool, + run_core=False, +): + result = thread_pool( + delayed(_compute_branch_linkage_of_cluster)( + cluster_labels, + cluster_probabilities, + min_spanning_tree, + space_tree, + neighbors, + core_distances, + dist_metric, + run_core, + cluster_id, + ) + for cluster_id in range(num_clusters) + ) + if len(result): + return tuple(zip(*result)) + return (), (), (), () + + +def _compute_branch_linkage_of_cluster( + cluster_labels, + cluster_probabilities, + min_spanning_tree, + space_tree, + neighbors, + core_distances, + dist_metric, + run_core, + cluster_id, +): + """Detect branches within one cluster.""" + # List points within cluster + cluster_mask = cluster_labels == cluster_id + cluster_points = np.where(cluster_mask)[0] + in_cluster_ids = np.full(cluster_labels.shape[0], -1, dtype=np.double) + in_cluster_ids[cluster_points] = np.arange(len(cluster_points), dtype=np.double) + + # Extract MST edges within cluster + parent_mask = cluster_labels[min_spanning_tree[:, 0].astype(np.intp)] == cluster_id + child_mask = cluster_labels[min_spanning_tree[:, 1].astype(np.intp)] == cluster_id + cluster_mst = min_spanning_tree[parent_mask & child_mask] + + # Compute in cluster centrality + points = space_tree.data.base[cluster_points] + centroid = np.average(points, weights=cluster_probabilities[cluster_mask], axis=0) + centralities = dist_metric.pairwise(centroid[None], points)[0, :] + centralities = 1 / centralities + + # Construct cluster approximation graph + if run_core: + edges = _extract_core_cluster_graph( + cluster_mst, core_distances, neighbors[cluster_points], in_cluster_ids + ) + else: + max_dist = cluster_mst.T[2].max() + edges = _extract_full_cluster_graph( + space_tree, core_distances, cluster_points, in_cluster_ids, max_dist + ) + np.maximum( + centralities[edges[:, 0].astype(np.intp)], + centralities[edges[:, 1].astype(np.intp)], + edges[:, 2], + ) + + # Extract centrality MST and compute single linkage + centrality_mst = minimum_spanning_tree( + coo_array( + (edges[:, 2], (edges[:, 0].astype(np.int32), edges[:, 1].astype(np.int32))), + shape=(len(cluster_points), len(cluster_points)), + ) + ).tocoo() + centrality_mst = np.column_stack( + (centrality_mst.row, centrality_mst.col, centrality_mst.data) + ) + centrality_mst = centrality_mst[np.argsort(centrality_mst.T[2]), :] + linkage_tree = label(centrality_mst) + + # Re-label edges with data ids + edges[:, 0] = cluster_points[edges[:, 0].astype(np.intp)] + edges[:, 1] = cluster_points[edges[:, 1].astype(np.intp)] + + # Return values + return cluster_points, centralities, linkage_tree, edges + + +def _extract_core_cluster_graph( + cluster_spanning_tree, + core_distances, + neighbors, + in_cluster_ids, +): + """Create a graph connecting all points within each point's core distance.""" + # Allocate output (won't be filled completely) + cluster_spanning_tree_view = cluster_spanning_tree + num_points = neighbors.shape[0] + num_neighbors = neighbors.shape[1] + count = cluster_spanning_tree_view.shape[0] + edges = np.zeros((count + num_points * num_neighbors, 4), dtype=np.double) + + # Fill (undirected) MST edges with within-cluster-ids + mst_parents = in_cluster_ids[cluster_spanning_tree[:, 0].astype(np.intp)] + mst_children = in_cluster_ids[cluster_spanning_tree[:, 1].astype(np.intp)] + np.minimum(mst_parents, mst_children, edges[:count, 0]) + np.maximum(mst_parents, mst_children, edges[:count, 1]) + + # Fill neighbors with within-cluster-ids + core_parent = np.repeat(np.arange(num_points, dtype=np.double), num_neighbors) + core_children = in_cluster_ids[neighbors.flatten()] + np.minimum(core_parent, core_children, edges[count:, 0]) + np.maximum(core_parent, core_children, edges[count:, 1]) + + # Fill mutual reachabilities + edges[:count, 3] = cluster_spanning_tree[:, 2] + np.maximum( + core_distances[edges[count:, 0].astype(np.intp)], + core_distances[edges[count:, 1].astype(np.intp)], + edges[count:, 3], + ) + + # Extract unique edges that stay within the cluster + edges = np.unique(edges[edges[:, 0] > -1.0, :], axis=0) + return edges + + +def _extract_full_cluster_graph( + space_tree, core_distances, cluster_points, in_cluster_ids, max_dist +): + # Query KDTree/BallTree for neighours within the distance + children_map, distances_map = space_tree.query_radius( + space_tree.data.base[cluster_points], r=max_dist + 1e-8, return_distance=True + ) + + # Count number of returned edges per point + num_children = np.zeros(len(cluster_points), dtype=np.intp) + for i, children in enumerate(children_map): + num_children[i] += len(children) + + # Create full edge list + full_parents = np.repeat( + np.arange(len(cluster_points), dtype=np.double), num_children + ) + full_children = in_cluster_ids[np.concatenate(children_map)] + full_distances = np.concatenate(distances_map) + + # Create output + mask = ( + (full_children != -1.0) + & (full_parents < full_children) + & (full_distances <= max_dist) + ) + edges = np.zeros((mask.sum(), 4), dtype=np.double) + edges[:, 0] = full_parents[mask] + edges[:, 1] = full_children[mask] + np.maximum( + np.maximum( + core_distances[edges[:, 0].astype(np.intp)], + core_distances[edges[:, 1].astype(np.intp)], + ), + full_distances[mask], + edges[:, 3], + ) + return edges + + +def _compute_branch_segmentation( + cluster_linkage_trees, + thread_pool, + min_branch_size=5, + allow_single_branch=False, + branch_selection_method="eom", + branch_selection_persistence=0.0, + max_branch_size=0, +): + """Extracts branches from the linkage hierarchies.""" + results = thread_pool( + delayed(_compute_branch_segmentation_of_cluster)( + cluster_linkage_tree, + min_branch_size=min_branch_size, + allow_single_branch=allow_single_branch, + branch_selection_method=branch_selection_method, + branch_selection_persistence=branch_selection_persistence, + max_branch_size=max_branch_size, + ) + for cluster_linkage_tree in cluster_linkage_trees + ) + if len(results): + return tuple(zip(*results)) + return (), (), (), () + + +def _compute_branch_segmentation_of_cluster( + cluster_linkage_tree, + min_branch_size=5, + allow_single_branch=False, + branch_selection_method="eom", + branch_selection_persistence=0.0, + max_branch_size=0, +): + """Select branches within one cluster.""" + condensed_tree = condense_tree(cluster_linkage_tree, min_branch_size) + if branch_selection_persistence > 0.0: + condensed_tree = simplify_branch_hierarchy( + condensed_tree, branch_selection_persistence + ) + stability = compute_stability(condensed_tree) + (labels, probabilities, persistences) = get_branches( + condensed_tree, + stability, + allow_single_branch=allow_single_branch, + branch_selection_method=branch_selection_method, + max_branch_size=max_branch_size, + ) + # Reset noise labels to k-cluster + labels[labels < 0] = len(persistences) + return (labels, probabilities, persistences, condensed_tree) + + +def _update_labelling( + cluster_labels, + cluster_probabilities, + cluster_points_, + cluster_centralities_, + branch_labels_, + branch_probabilities_, + branch_persistences_, + label_sides_as_branches=False, +): + """Updates the labelling with the detected branches.""" + # Allocate output + num_points = len(cluster_labels) + labels = -1 * np.ones(num_points, dtype=np.intp) + probabilities = cluster_probabilities.copy() + branch_labels = np.zeros(num_points, dtype=np.intp) + branch_probabilities = np.ones(num_points, dtype=np.double) + branch_centralities = np.zeros(num_points, dtype=np.double) + + # Compute the labels and probabilities + running_id = 0 + for _points, _centralities, _labels, _probs, _pers in zip( + cluster_points_, + cluster_centralities_, + branch_labels_, + branch_probabilities_, + branch_persistences_, + ): + num_branches = len(_pers) + branch_centralities[_points] = _centralities + if num_branches <= (1 if label_sides_as_branches else 2): + labels[_points] = running_id + running_id += 1 + else: + labels[_points] = _labels + running_id + branch_labels[_points] = _labels + branch_probabilities[_points] = _probs + probabilities[_points] += _probs + probabilities[_points] /= 2 + running_id += num_branches + 1 + + # Reorder other parts + return ( + labels, + probabilities, + branch_labels, + branch_probabilities, + branch_centralities, + ) + + +def _remap_edge_lists(edge_lists, internal_to_raw): + """ + Takes a list of edge lists and replaces the internal indices to raw indices. + + Parameters + ---------- + edge_lists : list[np.ndarray] + A list of numpy edgelists with the first two columns indicating + datapoints. + internal_to_raw: dict + A mapping from internal integer index to the raw integer index. + """ + for graph in edge_lists: + for edge in graph: + edge[0] = internal_to_raw[edge[0]] + edge[1] = internal_to_raw[edge[1]] + + +def _remap_point_lists(point_lists, internal_to_raw): + """ + Takes a list of points lists and replaces the internal indices to raw indices. + + Parameters + ---------- + point_lists : list[np.ndarray] + A list of numpy arrays with point indices. + internal_to_raw: dict + A mapping from internal integer index to the raw integer index. + """ + for points in point_lists: + for idx in range(len(points)): + points[idx] = internal_to_raw[points[idx]] + + +def _remap_labels(old_labels, finite_index, num_points): + """Creates new label array with infinite points set to -1.""" + new_labels = np.full(num_points, -1) + new_labels[finite_index] = old_labels + return new_labels + + +def _remap_probabilities(old_probs, finite_index, num_points): + """Creates new probability array with infinite points set to 0.""" + new_probs = np.zeros(num_points) + new_probs[finite_index] = old_probs + return new_probs + + +class BranchDetector(BaseEstimator, ClusterMixin): + """Performs a flare-detection post-processing step to detect branches within + clusters [1]_. + + For each cluster, a graph is constructed connecting the data points based on + their mutual reachability distances. Each edge is given a centrality value + based on how far it lies from the cluster's center. Then, the edges are + clustered as if that centrality was a distance, progressively removing the + 'center' of each cluster and seeing how many branches remain. + + Parameters + ---------- + min_branch_size : int, optional (default=None) + The minimum number of samples in a group for that group to be + considered a branch; groupings smaller than this size will seen as + points falling out of a branch. Defaults to the clusterer's min_cluster_size. + + allow_single_branch : bool, optional (default=False) + Analogous to ``allow_single_cluster``. + + branch_detection_method : str, optional (default=``full``) + Deteremines which graph is conctructed to detect branches with. Valid + values are, ordered by increasing computation cost and decreasing + sensitivity to noise: + - ``core``: Contains the edges that connect each point to all other + points within a mutual reachability distance lower than or equal to + the point's core distance. This is the cluster's subgraph of the + k-NN graph over the entire data set (with k = ``min_samples``). + - ``full``: Contains all edges between points in each cluster with a + mutual reachability distance lower than or equal to the distance of + the most-distance point in each cluster. These graphs represent the + 0-dimensional simplicial complex of each cluster at the first point in + the filtration where they contain all their points. + + branch_selection_method : str, optional (default='eom') + The method used to select branches from the cluster's condensed tree. + The standard approach for FLASC is to use the ``eom`` approach. + Options are: + * ``eom`` + * ``leaf`` + + branch_selection_persistence: float, optional (default=0.0) + An eccentricity persistence threshold. Branches with a persistence below + this value will be merged. See [3]_ for more information. Note that this + should not be used if we want to predict the cluster labels for new + points in future (e.g. using approximate_predict), as the + :func:`~flasc.prediction.approximate_predict` function is not aware of + this argument. + + max_branch_size : int, optional (default=0) + A limit to the size of clusters returned by the ``eom`` algorithm. + Has no effect when using ``leaf`` clustering (where clusters are + usually small regardless). Note that this should not be used if we + want to predict the cluster labels for new points in future (e.g. using + :func:`~flasc.prediction.approximate_predict`), as that function is + not aware of this argument. + + label_sides_as_branches : bool, optional (default=False), + When this flag is False, branches are only labelled for clusters with at + least three branches (i.e., at least y-shapes). Clusters with only two + branches represent l-shapes. The two branches describe the cluster's + outsides growing towards each other. Enableing this flag separates these + branches from each other in the produced labelling. + + Attributes + ---------- + labels_ : np.ndarray, shape (n_samples, ) + Labels that differentiate all subgroups (clusters and branches). Noisy + samples are given the label -1. + + probabilities_ : np.ndarray, shape (n_samples, ) + Probabilities considering both cluster and branch membership. Noisy + samples are assigned 0. + + branch_labels_ : np.ndarray, shape (n_samples, ) + Branch labels for each point. Noisy samples are given the label -1. + + branch_probabilities_ : np.ndarray, shape (n_samples, ) + Branch membership strengths for each point. Noisy samples are + assigned 0. + + branch_persistences_ : tuple (n_clusters) + A branch persistence (eccentricity range) for each detected branch. + + cluster_approximation_graphs_ : tuple (n_clusters) + The graphs used to detect branches in each cluster stored as a numpy + array with four columns: source, target, centrality, mutual reachability + distance. Points are labelled by their row-index into the input data. + The edges contained in the graphs depend on the ``branch_detection_method``: + - ``core``: Contains the edges that connect each point to all other + points in a cluster within a mutual reachability distance lower than + or equal to the point's core distance. This is an extension of the + minimum spanning tree introducing only edges with equal distances. The + reachability distance introduces ``num_points`` * ``min_samples`` of + such edges. + - ``full``: Contains all edges between points in each cluster with a + mutual reachability distance lower than or equal to the distance of + the most-distance point in each cluster. These graphs represent the + 0-dimensional simplicial complex of each cluster at the first point in + the filtration where they contain all their points. + + cluster_condensed_trees_ : tuple (n_clusters) + A condensed branch hierarchy for each cluster produced during the + branch detection step. Data points are numbered with in-cluster ids. + + cluster_linkage_trees_ : tuple (n_clusters) + A single linkage tree for each cluster produced during the branch + detection step, in the scipy hierarchical clustering format. + (see http://docs.scipy.org/doc/scipy/reference/cluster.hierarchy.html). + Data points are numbered with in-cluster ids. + + cluster_centralities_ : np.ndarray, shape (n_samples, ) + Centrality values for each point in a cluster. Overemphasizes points' + eccentricity within the cluster as the values are based on minimum + spanning trees that do not contain the equally distanced edges resulting + from the mutual reachability distance. + + cluster_points_ : list (n_clusters) + The data point row indices for each cluster. + + References + ---------- + .. [1] Bot, D. M., Peeters, J., Liesenborgs J., & Aerts, J. (2023, November). + FLASC: A Flare-Sensitive Clustering Algorithm: Extending HDBSCAN* for + Detecting Branches in Clusters. arXiv:2311.15887 + """ + + def __init__( + self, + min_branch_size=None, + allow_single_branch=False, + branch_detection_method="full", + branch_selection_method="eom", + branch_selection_persistence=0.0, + max_branch_size=0, + label_sides_as_branches=False, + ): + self.min_branch_size = min_branch_size + self.allow_single_branch = allow_single_branch + self.branch_detection_method = branch_detection_method + self.branch_selection_method = branch_selection_method + self.branch_selection_persistence = branch_selection_persistence + self.max_branch_size = max_branch_size + self.label_sides_as_branches = label_sides_as_branches + + self._cluster_approximation_graphs = None + self._cluster_condensed_trees = None + self._cluster_linkage_trees = None + self._branch_exemplars = None + + def fit(self, X, y=None): + """ + Perform a flare-detection post-processing step to detect branches within + clusters. + + Parameters + ---------- + X : HDBSCAN + A fitted HDBSCAN object with branch detection data generated. + + Returns + ------- + self : object + Returns self. + """ + self._clusterer = X + kwargs = self.get_params() + ( + self.labels_, + self.probabilities_, + self.branch_labels_, + self.branch_probabilities_, + self.branch_persistences_, + self._cluster_approximation_graphs, + self._cluster_condensed_trees, + self._cluster_linkage_trees, + self.cluster_centralities_, + self.cluster_points_, + ) = detect_branches_in_clusters(X, **kwargs) + + return self + + def fit_predict(self, X, y=None): + """ + Perform a flare-detection post-processing step to detect branches within + clusters [1]_. + + Parameters + ---------- + X : HDBSCAN + A fitted HDBSCAN object with branch detection data generated. + + Returns + ------- + labels : ndarray, shape (n_samples, ) + subgroup labels differentiated by cluster and branch. + """ + self.fit(X, y) + return self.labels_ + + def weighted_centroid(self, label_id, data=None): + """Provides an approximate representative point for a given branch. + Note that this technique assumes a euclidean metric for speed of + computation. For more general metrics use the ``weighted_medoid`` method + which is slower, but can work with the metric the model trained with. + + Parameters + ---------- + label_id: int + The id of the cluster to compute a centroid for. + + data : np.ndarray (n_samples, n_features), optional (default=None) + A dataset to use instead of the raw data that was clustered on. + + Returns + ------- + centroid: array of shape (n_features,) + A representative centroid for cluster ``label_id``. + """ + if self.labels_ is None: + raise AttributeError("Model has not been fit to data") + if self._clusterer._raw_data is None and data is None: + raise AttributeError("Raw data not available") + if label_id == -1: + raise ValueError( + "Cannot calculate weighted centroid for -1 cluster " + "since it is a noise cluster" + ) + if data is None: + data = self._clusterer._raw_data + mask = self.labels_ == label_id + cluster_data = data[mask] + cluster_membership_strengths = self.probabilities_[mask] + + return np.average(cluster_data, weights=cluster_membership_strengths, axis=0) + + def weighted_medoid(self, label_id, data=None): + """Provides an approximate representative point for a given branch. + + Note that this technique can be very slow and memory intensive for large + clusters. For faster results use the ``weighted_centroid`` method which + is faster, but assumes a euclidean metric. + + Parameters + ---------- + label_id: int + The id of the cluster to compute a medoid for. + + data : np.ndarray (n_samples, n_features), optional (default=None) + A dataset to use instead of the raw data that was clustered on. + + Returns + ------- + centroid: array of shape (n_features,) + A representative medoid for cluster ``label_id``. + """ + if self.labels_ is None: + raise AttributeError("Model has not been fit to data") + if self._clusterer._raw_data is None and data is None: + raise AttributeError("Raw data not available") + if label_id == -1: + raise ValueError( + "Cannot calculate weighted centroid for -1 cluster " + "since it is a noise cluster" + ) + if data is None: + data = self._clusterer._raw_data + mask = self.labels_ == label_id + cluster_data = data[mask] + cluster_membership_strengths = self.probabilities_[mask] + + dist_metric = self._clusterer.branch_detection_data_.dist_metric + dist_mat = dist_metric.pairwise(cluster_data) * cluster_membership_strengths + medoid_index = np.argmin(dist_mat.sum(axis=1)) + return cluster_data[medoid_index] + + @property + def cluster_approximation_graph_(self): + """See :class:`~hdbscan.branches.BranchDetector` for documentation.""" + if self._cluster_approximation_graphs is None: + raise AttributeError( + "No cluster approximation graph was generated; try running fit first." + ) + return ApproximationGraph( + self._cluster_approximation_graphs, + self.labels_, + self.probabilities_, + self._clusterer.labels_, + self._clusterer.probabilities_, + self.cluster_centralities_, + self.branch_labels_, + self.branch_probabilities_, + self._clusterer._raw_data, + ) + + @property + def cluster_condensed_trees_(self): + """See :class:`~hdbscan.branches.BranchDetector` for documentation.""" + if self._cluster_condensed_trees is None: + raise AttributeError( + "No cluster condensed trees were generated; try running fit first." + ) + return [ + CondensedTree(tree, self.branch_selection_method, self.allow_single_branch) + for tree in self._cluster_condensed_trees + ] + + @property + def cluster_linkage_trees_(self): + """See :class:`~hdbscan.branches.BranchDetector` for documentation.""" + if self._cluster_linkage_trees is None: + raise AttributeError( + "No cluster linkage trees were generated; try running fit first." + ) + return [SingleLinkageTree(tree) for tree in self._cluster_linkage_trees] + + @property + def branch_exemplars_(self): + """See :class:`~hdbscan.branches.BranchDetector` for documentation.""" + if self._branch_exemplars is not None: + return self._branch_exemplars + if self._clusterer._raw_data is None: + raise AttributeError( + "Branch exemplars not available with precomputed " "distances." + ) + if self._cluster_condensed_trees is None: + raise AttributeError("No branches detected; try running fit first.") + + num_clusters = len(self._cluster_condensed_trees) + branch_cluster_trees = [ + branch_tree[branch_tree["child_size"] > 1] + for branch_tree in self._cluster_condensed_trees + ] + selected_branch_ids = [ + sorted(branch_tree._select_clusters()) + for branch_tree in self.cluster_condensed_trees_ + ] + + self._branch_exemplars = [None] * num_clusters + + for i, points in enumerate(self.cluster_points_): + selected_branches = selected_branch_ids[i] + if len(selected_branches) <= (1 if self.label_sides_as_branches else 2): + continue + + self._branch_exemplars[i] = [] + raw_condensed_tree = self._cluster_condensed_trees[i] + + for branch in selected_branches: + _branch_exemplars = np.array([], dtype=np.intp) + for leaf in recurse_leaf_dfs(branch_cluster_trees[i], np.intp(branch)): + leaf_max_lambda = raw_condensed_tree["lambda_val"][ + raw_condensed_tree["parent"] == leaf + ].max() + candidates = raw_condensed_tree["child"][ + (raw_condensed_tree["parent"] == leaf) + & (raw_condensed_tree["lambda_val"] == leaf_max_lambda) + ] + _branch_exemplars = np.hstack([_branch_exemplars, candidates]) + ids = points[_branch_exemplars] + self._branch_exemplars[i].append(self._clusterer._raw_data[ids, :]) + + return self._branch_exemplars + + +def approximate_predict_branch(branch_detector, points_to_predict): + """Predict the cluster and branch label of new points. + + Extends ``approximate_predict`` to also predict in which branch + new points lie (if the cluster they are part of has branches). + + Parameters + ---------- + branch_detector : BranchDetector + A clustering object that has been fit to vector inpt data. + + points_to_predict : array, or array-like (n_samples, n_features) + The new data points to predict cluster labels for. They should + have the same dimensionality as the original dataset over which + clusterer was fit. + + Returns + ------- + labels : array (n_samples,) + The predicted cluster and branch labels. + + probabilities : array (n_samples,) + The soft cluster scores for each. + + cluster_labels : array (n_samples,) + The predicted cluster labels. + + cluster_probabilities : array (n_samples,) + The soft cluster scores for each. + + branch_labels : array (n_samples,) + The predicted cluster labels. + + branch_probabilities : array (n_samples,) + The soft cluster scores for each. + """ + + cluster_labels, cluster_probabilities, connecting_points = approximate_predict( + branch_detector._clusterer, points_to_predict, return_connecting_points=True + ) + + num_predict = len(points_to_predict) + labels = np.empty(num_predict, dtype=np.intp) + probabilities = np.zeros(num_predict, dtype=np.double) + branch_labels = np.zeros(num_predict, dtype=np.intp) + branch_probabilities = np.ones(num_predict, dtype=np.double) + + min_num_branches = 2 if not branch_detector.label_sides_as_branches else 1 + for i, (label, prob, connecting_point) in enumerate( + zip(cluster_labels, cluster_probabilities, connecting_points) + ): + if label < 0: + labels[i] = -1 + elif len(branch_detector.branch_persistences_[label]) <= min_num_branches: + labels[i] = label + probabilities[i] = prob + else: + labels[i] = branch_detector.labels_[connecting_point] + branch_labels[i] = branch_detector.branch_labels_[connecting_point] + branch_probabilities[i] = branch_detector.branch_probabilities_[ + connecting_point + ] + probabilities[i] = (prob + branch_probabilities[i]) / 2 + return ( + labels, + probabilities, + cluster_labels, + cluster_probabilities, + branch_labels, + branch_probabilities, + ) + + +class SequentialPool: + """API of a Joblib Parallel pool but sequential execution""" + + def __init__(self): + self.n_jobs = 1 + + def __call__(self, jobs): + return [fun(*args, **kwargs) for (fun, args, kwargs) in jobs] diff --git a/hdbscan/flat.py b/hdbscan/flat.py index e5912266..fa84806d 100644 --- a/hdbscan/flat.py +++ b/hdbscan/flat.py @@ -340,7 +340,7 @@ def approximate_predict_flat(clusterer, k=2 * min_samples) for i in range(points_to_predict.shape[0]): - label, prob = _find_cluster_and_probability( + label, prob, neighbors = _find_cluster_and_probability( condensed_tree, prediction_data.cluster_tree, neighbor_indices[i], diff --git a/hdbscan/hdbscan_.py b/hdbscan/hdbscan_.py index c87b69d9..290b47fa 100644 --- a/hdbscan/hdbscan_.py +++ b/hdbscan/hdbscan_.py @@ -36,6 +36,7 @@ from .plots import CondensedTree, SingleLinkageTree, MinimumSpanningTree from .prediction import PredictionData +from .branches import BranchDetectionData KDTREE_VALID_METRICS = ["euclidean", "l2", "minkowski", "p", "manhattan", "cityblock", "l1", "chebyshev", "infinity"] BALLTREE_VALID_METRICS = KDTREE_VALID_METRICS + [ @@ -996,6 +997,11 @@ class HDBSCAN(BaseEstimator, ClusterMixin): to set this to True. (default False) + branch_detection_data : boolean, optional + Whether to generated extra cached data for detecting branch- + hierarchies within clusters. If you wish to use functions from + ``hdbscan.branches`` set this to True. (default False) + match_reference_implementation : bool, optional (default=False) There exist some interpretational differences between this HDBSCAN* implementation and the original authors reference @@ -1053,6 +1059,10 @@ class HDBSCAN(BaseEstimator, ClusterMixin): :func:`~hdbscan.prediction.membership_vector`, and :func:`~hdbscan.prediction.all_points_membership_vectors`). + branch_detection_data_ : BranchDetectionData object + Cached data used for detecting branch-hierarchies within clusters. + Neccessary only if you are using funcotin from ``hdbscan.branches``. + exemplars_ : list A list of exemplar points for clusters. Since HDBSCAN supports arbitrary shapes for clusters we cannot provide a single cluster @@ -1115,6 +1125,7 @@ def __init__( cluster_selection_method="eom", allow_single_cluster=False, prediction_data=False, + branch_detection_data=False, match_reference_implementation=False, **kwargs ): @@ -1135,6 +1146,7 @@ def __init__( self.allow_single_cluster = allow_single_cluster self.match_reference_implementation = match_reference_implementation self.prediction_data = prediction_data + self.branch_detection_data = branch_detection_data self._metric_kwargs = kwargs @@ -1144,6 +1156,8 @@ def __init__( self._raw_data = None self._outlier_scores = None self._prediction_data = None + self._finite_index = None + self._branch_detection_data = None self._relative_validity = None def fit(self, X, y=None): @@ -1171,12 +1185,12 @@ def fit(self, X, y=None): if ~self._all_finite: # Pass only the purely finite indices into hdbscan # We will later assign all non-finite points to the background -1 cluster - finite_index = get_finite_row_indices(X) - clean_data = X[finite_index] + self._finite_index = get_finite_row_indices(X) + clean_data = X[self._finite_index] internal_to_raw = { - x: y for x, y in zip(range(len(finite_index)), finite_index) + x: y for x, y in zip(range(len(self._finite_index)), self._finite_index) } - outliers = list(set(range(X.shape[0])) - set(finite_index)) + outliers = list(set(range(X.shape[0])) - set(self._finite_index)) else: clean_data = X elif issparse(X): @@ -1193,7 +1207,9 @@ def fit(self, X, y=None): # prediction data only applies to the persistent model, so remove # it from the keyword args we pass on the the function kwargs.pop("prediction_data", None) + kwargs.pop("branch_detection_data", None) kwargs.update(self._metric_kwargs) + kwargs['gen_min_span_tree'] |= self.branch_detection_data ( self.labels_, @@ -1213,15 +1229,17 @@ def fit(self, X, y=None): self._single_linkage_tree, internal_to_raw, outliers ) new_labels = np.full(X.shape[0], -1) - new_labels[finite_index] = self.labels_ + new_labels[self._finite_index] = self.labels_ self.labels_ = new_labels new_probabilities = np.zeros(X.shape[0]) - new_probabilities[finite_index] = self.probabilities_ + new_probabilities[self._finite_index] = self.probabilities_ self.probabilities_ = new_probabilities if self.prediction_data: self.generate_prediction_data() + if self.branch_detection_data: + self.generate_branch_detection_data() return self @@ -1275,6 +1293,38 @@ def generate_prediction_data(self): "than mere distances is required!" ) + def generate_branch_detection_data(self): + """ + Create data that caches intermediate results used for detecting + branches within clusters. This data is only useful if you are + intending to use functions from ``hdbscan.branches``. + """ + if self.metric in FAST_METRICS: + min_samples = self.min_samples or self.min_cluster_size + if self.metric in KDTREE_VALID_METRICS: + tree_type = "kdtree" + elif self.metric in BALLTREE_VALID_METRICS: + tree_type = "balltree" + else: + warn("Metric {} not supported for branch detection!".format(self.metric)) + return + + self._branch_detection_data = BranchDetectionData( + self._raw_data, + self._all_finite, + None if self._all_finite else self._finite_index, + self.labels_, + min_samples, + tree_type=tree_type, + metric=self.metric, + **self._metric_kwargs + ) + else: + warn( + "Branch detection for non-vector space inputs is not (yet)" + " implemented." + ) + def weighted_cluster_centroid(self, cluster_id): """Provide an approximate representative point for a given cluster. Note that this technique assumes a euclidean metric for speed of @@ -1385,6 +1435,13 @@ def prediction_data_(self): else: return self._prediction_data + @property + def branch_detection_data_(self): + if self._branch_detection_data is None: + raise AttributeError("No branch detection data was generated") + else: + return self._branch_detection_data + @property def outlier_scores_(self): if self._outlier_scores is not None: diff --git a/hdbscan/plots.py b/hdbscan/plots.py index 617721e5..1328aceb 100644 --- a/hdbscan/plots.py +++ b/hdbscan/plots.py @@ -898,3 +898,413 @@ def to_networkx(self): set_node_attributes(result, data_dict, 'data') return result + + +class ApproximationGraph: + """ + Cluster approximation graph describing the connectivity in clusters + that is used to detect branches. + + Parameters + ---------- + approximation_graphs : list[np.ndarray], shape (n_clusters), + + labels : np.ndarray, shape (n_samples, ) + cluster and branches labelling. + + probabilities : np.ndarray, shape (n_samples, ) + cluster and branches probabilities. + + cluster_labels : np.ndarray, shape (n_samples, ) + HDBSCAN* labelling. + + cluster_probabilities : np.ndarray, shape (n_samples, ) + HDBSCAN* probabilities. + + cluster_centralities : np.ndarray, shape (n_samples, ) + Within cluster centrality values. + + branch_labels : np.ndarray, shape (n_samples, ) + Within cluster branch labels for each point. + + branch_probabilities : np.ndarray, shape (n_samples, ) + Within cluster branch membership strengths for each point. + + Attributes + ---------- + point_mask : np.ndarray[bool], shape (n_samples) + A mask to extract points within clusters from the raw data. + """ + + def __init__( + self, + approximation_graphs, + labels, + probabilities, + cluster_labels, + cluster_probabilities, + cluster_centralities, + branch_labels, + branch_probabilities, + raw_data=None, + ): + self._edges = np.core.records.fromarrays( + np.hstack( + ( + np.concatenate(approximation_graphs), + np.repeat( + np.arange(len(approximation_graphs)), + [g.shape[0] for g in approximation_graphs], + )[None].T, + ) + ).transpose(), + names="parent, child, centrality, mutual_reachability, cluster", + formats="intp, intp, double, double, intp", + ) + self.point_mask = cluster_labels >= 0 + self._raw_data = raw_data[self.point_mask, :] if raw_data is not None else None + self._points = np.core.records.fromarrays( + np.vstack( + ( + np.where(self.point_mask)[0], + labels[self.point_mask], + probabilities[self.point_mask], + cluster_labels[self.point_mask], + cluster_probabilities[self.point_mask], + cluster_centralities[self.point_mask], + branch_labels[self.point_mask], + branch_probabilities[self.point_mask], + ) + ), + names="id, label, probability, cluster_label, cluster_probability, cluster_centrality, branch_label, branch_probability", + formats="intp, intp, double, intp, double, double, intp, double", + ) + self._pos = None + + def plot( + self, + positions=None, + feature_names=None, + node_color="label", + node_vmin=None, + node_vmax=None, + node_cmap="viridis", + node_alpha=1, + # node_desat=None, + node_size=1, + node_marker="o", + edge_color="k", + edge_vmin=None, + edge_vmax=None, + edge_cmap="viridis", + edge_alpha=1, + edge_width=1, + ): + """ + Plots the Approximation graph, requires networkx and matplotlib. + + Parameters + ---------- + positions : np.ndarray, shape (n_samples, 2) (default = None) + A position for each data point in the graph or each data point in the + raw data. When None, the function attempts to compute graphviz' + sfdp layout, which requires pygraphviz to be installed and available. + + node_color : str (default = 'label') + The point attribute to to color the nodes by. Possible values: + - id + - label + - probability + - cluster_label + - cluster_probability + - cluster_centrality + - branch_label + - branch_probability, + - The input data's feature (if available) names if + ``feature_names`` is specified or ``feature_x`` for the x-th feature + if no ``feature_names`` are given, or anything matplotlib scatter + interprets as a color. + + node_vmin : float, (default = None) + The minimum value to use for normalizing node colors. + + node_vmax : float, (default = None) + The maximum value to use for normalizing node colors. + + node_cmap : str, (default = 'tab10') + The cmap to use for coloring nodes. + + node_alpha : float, (default = 1) + The node transparency value. + + node_size : float, (default = 5) + The node marker size value. + + node_marker : str, (default = 'o') + The node marker string. + + edge_color : str (default = 'label') + The point attribute to to color the nodes by. Possible values: + - weight + - mutual reachability + - centrality, + - cluster, + or anything matplotlib linecollection interprets as color. + + edge_vmin : float, (default = None) + The minimum value to use for normalizing edge colors. + + edge_vmax : float, (default = None) + The maximum value to use for normalizing edge colors. + + edge_cmap : str, (default = viridis) + The cmap to use for coloring edges. + + edge_alpha : float, (default = 1) + The edge transparency value. + + edge_width : float, (default = 1) + The edge line width size value. + """ + try: + import matplotlib.pyplot as plt + import matplotlib.collections as mc + except ImportError: + raise ImportError( + "You must install the matplotlib library to plot the Approximation Graph." + ) + + # Extract node color data + if node_color is None: + pass + elif isinstance(node_color, str): + if node_color in self._points.dtype.names: + if "label" in node_color: + node_vmax = 9 + node_vmin = 0 + node_cmap = "tab10" + node_color = self._points[node_color] % 10 + else: + node_color = self._points[node_color] + elif ( + self._raw_data is not None + and feature_names is not None + and node_color in feature_names + ): + idx = feature_names.index(node_color) + node_color = self._raw_data[:, idx] + elif self._raw_data is not None and node_color.startswith("feature_"): + idx = int(node_color[8:]) + node_color = self._raw_data[:, idx] + elif len(node_color) == len(self.point_mask): + node_color = node_color[self.point_mask] + + # Extract edge color data + if isinstance(edge_color, str) and edge_color in self._edges.dtype.names: + edge_color = self._edges[edge_color] + + # Compute or extract layout + self._xs = np.nan * np.ones(len(self.point_mask)) + self._ys = np.nan * np.ones(len(self.point_mask)) + if positions is None: + try: + import networkx as nx + except ImportError: + raise ImportError( + "You must install the networkx to compute a sfdp layout." + ) + if self._pos is None: + g = nx.Graph() + for row in self._edges: + g.add_edge( + row["parent"], + row["child"], + weight=1 / row["mutual_reachability"], + ) + self._pos = nx.nx_agraph.graphviz_layout(g, prog="sfdp") + for k, v in self._pos.items(): + self._xs[k] = v[0] + self._ys[k] = v[1] + else: + if positions.shape[0] == len(self.point_mask): + self._xs = positions[:, 0] + self._ys = positions[:, 1] + elif positions.shape[0] == len(self._points): + for i, d in enumerate(self._points["id"]): + self._xs[d, 0] = positions[i, 0] + self._ys[d, 1] = positions[i, 1] + else: + raise ValueError("Incorrect number of positions specified.") + source = self._edges["parent"] + target = self._edges["child"] + lc = mc.LineCollection( + list( + zip( + zip(self._xs[source], self._ys[source]), + zip(self._xs[target], self._ys[target]), + ) + ), + alpha=edge_alpha, + cmap=edge_cmap, + linewidths=edge_width, + zorder=0, + ) + lc.set_clim(edge_vmin, edge_vmax) + if isinstance(edge_color, str): + lc.set_edgecolor(edge_color) + else: + lc.set_array(edge_color) + if edge_alpha is not None: + lc.set_alpha(edge_alpha) + plt.gca().add_collection(lc) + plt.scatter( + self._xs[~self.point_mask], + self._ys[~self.point_mask], + node_size, + color='silver', + marker=node_marker, + alpha=node_alpha, + linewidth=0, + edgecolor='none', + ) + plt.scatter( + self._xs[self.point_mask], + self._ys[self.point_mask], + node_size, + node_color, + cmap=node_cmap, + marker=node_marker, + alpha=node_alpha, + linewidth=0, + edgecolor='none', + vmin=node_vmin, + vmax=node_vmax, + ) + plt.axis("off") + + def to_numpy(self): + """Converts the approximation graph to numpy arrays. + + Returns + ------- + points : np.recarray, shape (n_points, 8) + A numpy record array with for each point its: + - id (row index), + - label, + - probability, + - cluster label, + - cluster probability, + - cluster centrality, + - branch label, + - branch probability + + edges : np.recarray, shape (n_edges, 5) + A numpy record array with for each edge its: + - parent point, + - child point, + - cluster centrality, + - mutual reachability, + - cluster label + """ + return self._points.copy(), self._edges.copy() + + def to_pandas(self): + """Converts the approximation graph to pandas data frames. + + Returns + ------- + points : pd.DataFrame, shape (n_points, 8) + A DataFrame with for each point its: + - id (row index), + - label, + - probability, + - cluster label, + - cluster probability, + - cluster centrality, + - branch label, + - branch probability + + edges : pd.DataFrame, shape (n_edges, 5) + A DataFrame with for each edge its: + - parent point, + - child point, + - cluster centrality, + - mutual reachability, + - cluster label + """ + try: + from pandas import DataFrame + except ImportError: + raise ImportError( + "You must have pandas installed to export pandas DataFrames" + ) + + points = DataFrame(self._points) + edges = DataFrame(self._edges) + return points, edges + + def to_networkx(self, feature_names=None): + """Convert to a NetworkX Graph object. + + Parameters + ---------- + feature_names : list[n_features] + Names to use for the data features if available. + + Returns + ------- + g : nx.Graph + A NetworkX Graph object containing the non-noise points and edges + within clusters. + + Node attributes: + - label, + - probability, + - cluster label, + - cluster probability, + - cluster centrality, + - branch label, + - branch probability, + + Edge attributes: + - weight (1 / mutual_reachability), + - mutual_reachability, + - centrality, + - cluster label, + - + """ + try: + import networkx as nx + except ImportError: + raise ImportError( + "You must have networkx installed to export networkx graphs" + ) + + g = nx.Graph() + # Add edges + for row in self._edges: + g.add_edge( + row["parent"], + row["child"], + weight=1 / row["mutual_reachability"], + mutual_reachability=row["mutual_reachability"], + centrality=row["centrality"], + cluster=row["cluster"], + ) + + # Add FLASC features + for attr in self._points.dtype.names[1:]: + nx.set_node_attributes(g, dict(self._points[["id", attr]]), attr) + + # Add raw data features + if self._raw_data is not None: + if feature_names is None: + feature_names = [f"feature {i}" for i in range(self._raw_data.shape[1])] + for idx, name in enumerate(feature_names): + nx.set_node_attributes( + g, + dict(zip(self._points["id"], self._raw_data[:, idx])), + name, + ) + + return g \ No newline at end of file diff --git a/hdbscan/prediction.py b/hdbscan/prediction.py index 10cd6c60..b5b1fc52 100644 --- a/hdbscan/prediction.py +++ b/hdbscan/prediction.py @@ -325,10 +325,10 @@ def _find_cluster_and_probability(tree, cluster_tree, neighbor_indices, else: prob = 0.0 - return cluster_label, prob + return cluster_label, prob, nearest_neighbor -def approximate_predict(clusterer, points_to_predict): +def approximate_predict(clusterer, points_to_predict, return_connecting_points=False): """Predict the cluster label of new points. The returned labels will be those of the original clustering found by ``clusterer``, and therefore are not (necessarily) the cluster labels that would @@ -352,6 +352,10 @@ def approximate_predict(clusterer, points_to_predict): The new data points to predict cluster labels for. They should have the same dimensionality as the original dataset over which clusterer was fit. + + return_connecting_points : bool, optional + Whether to return the index of the nearest neighbor in the original + dataset for each of the ``points_to_predict``. Default is False Returns ------- @@ -361,6 +365,11 @@ def approximate_predict(clusterer, points_to_predict): probabilities : array (n_samples,) The soft cluster scores for each of the ``points_to_predict`` + neighbors : array (n_samples,) + The index of the nearest neighbor in the original dataset for each + of the ``points_to_predict``. Only returned if + ``return_connecting_points=True``. + See Also -------- :py:func:`hdbscan.predict.membership_vector` @@ -383,10 +392,15 @@ def approximate_predict(clusterer, points_to_predict): ' will be automatically predicted as noise.') labels = -1 * np.ones(points_to_predict.shape[0], dtype=np.int32) probabilities = np.zeros(points_to_predict.shape[0], dtype=np.float32) + if return_connecting_points: + neighbors = -1 * np.ones(points_to_predict.shape[0], dtype=np.int32) + return labels, probabilities, neighbors return labels, probabilities labels = np.empty(points_to_predict.shape[0], dtype=np.int32) probabilities = np.empty(points_to_predict.shape[0], dtype=np.float64) + if return_connecting_points: + neighbors = np.empty(points_to_predict.shape[0], dtype=np.int32) min_samples = clusterer.min_samples or clusterer.min_cluster_size neighbor_distances, neighbor_indices = \ @@ -394,7 +408,7 @@ def approximate_predict(clusterer, points_to_predict): k=2 * min_samples) for i in range(points_to_predict.shape[0]): - label, prob = _find_cluster_and_probability( + label, prob, neighbor = _find_cluster_and_probability( clusterer.condensed_tree_, clusterer.prediction_data_.cluster_tree, neighbor_indices[i], @@ -406,7 +420,11 @@ def approximate_predict(clusterer, points_to_predict): ) labels[i] = label probabilities[i] = prob + if return_connecting_points: + neighbors[i] = neighbor + if return_connecting_points: + return labels, probabilities, neighbors return labels, probabilities diff --git a/hdbscan/tests/test_branches.py b/hdbscan/tests/test_branches.py new file mode 100644 index 00000000..5a6d9a36 --- /dev/null +++ b/hdbscan/tests/test_branches.py @@ -0,0 +1,474 @@ +import numpy as np +from scipy import stats +from scipy import sparse +from scipy.spatial import distance +from sklearn.utils._testing import assert_raises +from sklearn.utils.estimator_checks import check_estimator +from hdbscan import ( + HDBSCAN, + BranchDetector, + detect_branches_in_clusters, + approximate_predict_branch, +) +from hdbscan.tests.test_hdbscan import ( + if_matplotlib, + if_networkx, + if_pandas, +) + +from sklearn.utils import check_random_state, shuffle as util_shuffle +from sklearn.datasets import make_blobs +from sklearn.preprocessing import StandardScaler + +from tempfile import mkdtemp +from functools import wraps +import numbers +import pytest + +import warnings + + +def if_pygraphviz(func): + """Test decorator that skips test if networkx or pygraphviz is not installed.""" + + @wraps(func) + def run_test(*args, **kwargs): + try: + import networkx + import pygraphviz + except ImportError: + pytest.skip("NetworkX or pygraphviz not available.") + else: + return func(*args, **kwargs) + + return run_test + + +def make_branches(n_samples=100, shuffle=True, noise=None, random_state=None): + if isinstance(n_samples, numbers.Integral): + n_samples_out = n_samples // 3 + n_samples_in = n_samples - n_samples_out + else: + try: + n_samples_out, n_samples_in = n_samples + except ValueError as e: + raise ValueError( + "`n_samples` can be either an int or a two-element tuple." + ) from e + + generator = check_random_state(random_state) + + outer_circ_x = np.cos(np.linspace(np.pi / 2, np.pi, n_samples_out)) + outer_circ_y = np.sin(np.linspace(np.pi / 2, np.pi, n_samples_out)) - 1 + inner_circ_x = np.cos(np.linspace(0, np.pi, n_samples_in)) + inner_circ_y = 1 - np.sin(np.linspace(0, np.pi, n_samples_in)) + + X = np.vstack( + [ + np.append(outer_circ_x, inner_circ_x), + np.append(outer_circ_y, inner_circ_y), + ] + ).T + y = np.hstack( + [ + np.zeros(n_samples_out, dtype=np.intp), + np.ones(n_samples_in, dtype=np.intp), + ] + ) + + if shuffle: + X, y = util_shuffle(X, y, random_state=generator) + + if noise is not None: + X += generator.normal(scale=noise, size=X.shape) + + return X, y + + +def generate_noisy_data(): + blobs, yBlobs = make_blobs( + n_samples=50, + centers=[(-0.75, 2.25), (2.0, -0.5)], + cluster_std=0.2, + random_state=3, + ) + moons, _ = make_branches(n_samples=150, noise=0.06, random_state=3) + yMoons = np.full(moons.shape[0], 2) + np.random.seed(5) + noise = np.random.uniform(-1.0, 3.0, (50, 2)) + yNoise = np.full(50, -1) + return ( + np.vstack((blobs, moons, noise)), + np.concatenate((yBlobs, yMoons, yNoise)), + ) + + +X, y = generate_noisy_data() +X = StandardScaler().fit_transform(X) + +X_missing_data = X.copy() +X_missing_data[0] = [np.nan, 1] +X_missing_data[5] = [np.nan, np.nan] + +# --- Branch Detection Data + + +def test_branch_detection_data(): + """Check that the flag generates internal branch_detection_data.""" + c = HDBSCAN(min_cluster_size=5, branch_detection_data=True).fit(X) + branch_data = c.branch_detection_data_ + assert c.minimum_spanning_tree_ is not None + assert branch_data.all_finite == True + assert branch_data.core_distances.shape[0] == X.shape[0] + assert branch_data.neighbors.shape[0] == X.shape[0] + assert branch_data.neighbors.shape[1] == c.min_samples or c.min_cluster_size + assert branch_data.finite_index is None + + +def test_branch_detection_data_with_missing(): + """Check internal branch_detection_data recognizes missing data.""" + c = HDBSCAN(min_cluster_size=5, branch_detection_data=True).fit(X_missing_data) + branch_data = c.branch_detection_data_ + assert c.minimum_spanning_tree_ is not None + assert branch_data.all_finite == False + assert branch_data.core_distances.shape[0] == X.shape[0] - 2 + assert branch_data.neighbors.shape[0] == X.shape[0] - 2 + assert branch_data.neighbors.shape[1] == c.min_samples or c.min_cluster_size + assert branch_data.finite_index is not None + + +@pytest.mark.skip(reason="Unreachable code-branch cannot be tested.") +def test_branch_detection_data_with_non_tree_metric(): + """Check warning on unsupported metric.""" + with warnings.catch_warnings(record=True) as w: + # There are no fast metrics that are not supported by KDTree or BallTree! + # Cosine and arccoss both crash HDBSCAN. They go down the BallTree path, but + # the implementation does not support them. + c = HDBSCAN( + min_cluster_size=5, branch_detection_data=True, metric="cosine" + ).fit(X) + assert "Metric cosine not supported for branch detection!" in str(w[-1].message) + assert c.minimum_spanning_tree_ is not None + assert_raises(AttributeError, lambda: c.branch_detection_data) + + +def test_branch_detection_data_with_unsupported_input(): + """Check warning on unsupported inputs.""" + # Distance matrix + D = distance.squareform(distance.pdist(X)) + with warnings.catch_warnings(record=True) as w: + c = HDBSCAN( + min_cluster_size=5, metric="precomputed", branch_detection_data=True + ).fit(D) + assert ( + "Branch detection for non-vector space inputs is not (yet) implemented." + in str(w[-1].message) + ) + + # Sparse matrix + D /= np.max(D) + threshold = stats.scoreatpercentile(D.flatten(), 50) + D[D >= threshold] = 0.0 + D = sparse.csr_matrix(D) + D.eliminate_zeros() + with warnings.catch_warnings(record=True) as w: + c = HDBSCAN( + min_cluster_size=5, metric="precomputed", branch_detection_data=True + ).fit(D) + assert ( + "Branch detection for non-vector space inputs is not (yet) implemented." + in str(w[-1].message) + ) + + +def test_generate_branch_detection_data(): + """Generate branch detection data function does not re-generate MST.""" + c = HDBSCAN(min_cluster_size=5).fit(X) + c.generate_branch_detection_data() + assert c.branch_detection_data_ is not None + assert_raises(AttributeError, lambda: c.minimum_spanning_tree_) + + +# --- Detecting Branches + + +def check_detected_groups(c, n_clusters=3, n_branches=6): + """Checks branch_detector output for main invariants.""" + assert len(np.unique(c.labels_)) - int(-1 in c.labels_) == n_branches + noise_mask = c.labels_ == -1 + assert (c.branch_labels_[noise_mask] == 0).all() + assert (c.branch_probabilities_[noise_mask] == 1.0).all() + assert (c.probabilities_[noise_mask] == 0).all() + assert len(c.cluster_points_) == n_clusters + assert len(c.branch_persistences_) == n_clusters + assert sum(len(ps) for ps in c.branch_persistences_) >= (n_branches - n_clusters) + + +def test_branch_detector(): + c = HDBSCAN(min_cluster_size=5, branch_detection_data=True).fit(X) + b = BranchDetector( + branch_detection_method="core", branch_selection_method="eom" + ).fit(c) + check_detected_groups(b, n_branches=7) + + b = BranchDetector( + branch_detection_method="full", branch_selection_method="eom" + ).fit(c) + check_detected_groups(b) + + b = BranchDetector( + branch_detection_method="core", branch_selection_method="leaf" + ).fit(c) + check_detected_groups(b, n_branches=9) + + b = BranchDetector( + branch_detection_method="full", branch_selection_method="leaf" + ).fit(c) + check_detected_groups(b) + + +def test_min_branch_size(): + c = HDBSCAN(min_cluster_size=5, branch_detection_data=True).fit(X) + b = BranchDetector(min_branch_size=7).fit(c) + labels, counts = np.unique(b.labels_, return_counts=True) + assert (counts[labels >= 0] >= 7).all() + check_detected_groups(b) + + +def test_label_sides_as_branches(): + c = HDBSCAN(min_cluster_size=5, branch_detection_data=True).fit(X) + b = BranchDetector(label_sides_as_branches=True).fit(c) + check_detected_groups(b, n_branches=8) + + +def test_max_branch_size(): + """Suppresses one branch.""" + c = HDBSCAN(min_cluster_size=5, branch_detection_data=True).fit(X) + b = BranchDetector(label_sides_as_branches=True, max_branch_size=50).fit(c) + check_detected_groups(b, n_branches=7) + + +def test_allow_single_branch_with_persistence(): + # Generate single-cluster data + np.random.seed(0) + no_structure = np.random.rand(150, 2) + c = HDBSCAN( + min_samples=5, + min_cluster_size=150, + allow_single_cluster=True, + branch_detection_data=True, + ).fit(no_structure) + + # Without persistence, find 6 branches + b = BranchDetector( + min_branch_size=5, + branch_detection_method="core", + branch_selection_method="leaf", + ).fit(c) + unique_labels = np.unique(b.labels_) + assert len(unique_labels) == 6 + # Mac & Windows give 71, Linux gives 72. Probably different random values. + num_noise = np.sum(b.branch_probabilities_ == 0) + assert (num_noise == 71) | (num_noise == 72) + + # Adding presistence removes some branches + b = BranchDetector( + min_branch_size=5, + branch_detection_method="core", + branch_selection_method="leaf", + branch_selection_persistence=0.1, + ).fit(c) + unique_labels = np.unique(b.labels_) + assert len(unique_labels) == 1 + assert np.sum(b.branch_probabilities_ == 0) == 0 + + +def test_badargs(): + c = HDBSCAN(min_cluster_size=5, branch_detection_data=True).fit(X) + c_nofit = HDBSCAN(min_cluster_size=5, branch_detection_data=True) + c_nobranch = HDBSCAN(min_cluster_size=5, gen_min_span_tree=True).fit(X) + c_nomst = HDBSCAN(min_cluster_size=5).fit(X) + c_nomst.generate_branch_detection_data() + + assert_raises(AttributeError, detect_branches_in_clusters, "fail") + assert_raises(AttributeError, detect_branches_in_clusters, None) + assert_raises(AttributeError, detect_branches_in_clusters, "fail") + assert_raises(ValueError, detect_branches_in_clusters, c_nofit) + assert_raises(AttributeError, detect_branches_in_clusters, c_nobranch) + assert_raises(ValueError, detect_branches_in_clusters, c_nomst) + assert_raises(ValueError, detect_branches_in_clusters, c, min_branch_size=-1) + assert_raises(ValueError, detect_branches_in_clusters, c, min_branch_size=0) + assert_raises(ValueError, detect_branches_in_clusters, c, min_branch_size=1) + assert_raises(ValueError, detect_branches_in_clusters, c, min_branch_size=2.0) + assert_raises(ValueError, detect_branches_in_clusters, c, min_branch_size="fail") + assert_raises( + ValueError, detect_branches_in_clusters, c, branch_selection_persistence=-1 + ) + assert_raises( + ValueError, detect_branches_in_clusters, c, branch_selection_persistence=-0.1 + ) + assert_raises( + ValueError, + detect_branches_in_clusters, + c, + branch_selection_method="something_else", + ) + assert_raises( + ValueError, + detect_branches_in_clusters, + c, + branch_detection_method="something_else", + ) + + +# --- Branch Detector Functionality + + +def test_caching(): + cachedir = mkdtemp() + c = HDBSCAN(memory=cachedir, min_samples=5, branch_detection_data=True).fit(X) + b1 = BranchDetector().fit(c) + b2 = BranchDetector(allow_single_branch=True).fit(c) + n_groups1 = len(set(b1.labels_)) - int(-1 in b1.labels_) + n_groups2 = len(set(b2.labels_)) - int(-1 in b2.labels_) + assert n_groups1 == n_groups2 + + +def test_centroid_medoids(): + branch_centers = np.asarray( + [[-0.9, -1.0], [-0.9, 0.1], [-0.8, 1.9], [-0.5, 0.0], [1.7, -0.9]] + ) + + c = HDBSCAN(min_cluster_size=5, branch_detection_data=True).fit(X) + b = BranchDetector().fit(c) + + centroids = np.asarray([b.weighted_centroid(i) for i in range(5)]) + rounded = np.around(np.asarray(centroids), decimals=1) + corder = np.lexsort((rounded[:, 1], rounded[:, 0])) + np.all(np.abs(centroids[corder, :] - branch_centers) < 0.1) + + medoids = np.asarray([b.weighted_medoid(i) for i in range(5)]) + rounded = np.around(np.asarray(medoids), decimals=1) + corder = np.lexsort((rounded[:, 1], rounded[:, 0])) + np.all(np.abs(medoids[corder, :] - branch_centers) < 0.1) + + +def test_exemplars(): + c = HDBSCAN(min_cluster_size=5, branch_detection_data=True).fit(X) + b = BranchDetector().fit(c) + + branch_exemplars = b.branch_exemplars_ + assert branch_exemplars[0] is None + assert branch_exemplars[1] is None + assert len(branch_exemplars[2]) == 3 + assert len(b.branch_exemplars_) == 3 + + +def test_approximate_predict(): + c = HDBSCAN( + min_cluster_size=5, branch_detection_data=True, prediction_data=True + ).fit(X) + b = BranchDetector().fit(c) + + # A point on a branch (not noise) exact labels change per run + l, p, cl, cp, bl, bp = approximate_predict_branch(b, np.array([[-0.8, 0.0]])) + assert cl[0] > -1 + assert len(b.branch_persistences_[cl[0]]) > 2 + + # A point in a cluster + l, p, cl, cp, bl, bp = approximate_predict_branch(b, np.array([[-0.8, 2.0]])) + assert l[0] == cl[0] + assert bl[0] == 0 + assert bp[0] == 1.0 + + # A noise point + l, p, cl, cp, bl, bp = approximate_predict_branch(b, np.array([[1, 3.0]])) + assert l[0] == -1 + assert cl[0] == -1 + assert cp[0] == 0 + assert p[0] == 0.0 + assert cp[0] == 0.0 + assert bp[0] == 1.0 + + +# --- Attribute Output Formats + + +def test_trees_numpy_output_formats(): + c = HDBSCAN(min_cluster_size=5, branch_detection_data=True).fit(X) + b = BranchDetector().fit(c) + points, edges = b.cluster_approximation_graph_.to_numpy() + assert points.shape[0] <= X.shape[0] # Excludes noise points + for t in b.cluster_condensed_trees_: + t.to_numpy() + for t in b.cluster_linkage_trees_: + t.to_numpy() + + +def test_trees_pandas_output_formats(): + c = HDBSCAN(min_cluster_size=5, branch_detection_data=True).fit(X) + b = BranchDetector().fit(c) + if_pandas(b.cluster_approximation_graph_.to_pandas)() + for t in b.cluster_condensed_trees_: + if_pandas(t.to_pandas)() + for t in b.cluster_linkage_trees_: + if_pandas(t.to_pandas)() + + +def test_trees_networkx_output_formats(): + c = HDBSCAN(min_cluster_size=5, branch_detection_data=True).fit(X) + b = BranchDetector().fit(c) + if_networkx(b.cluster_approximation_graph_.to_networkx)() + for t in b.cluster_condensed_trees_: + if_networkx(t.to_networkx)() + for t in b.cluster_linkage_trees_: + if_networkx(t.to_networkx)() + + +# --- Attribute plots + + +def test_condensed_tree_plot(): + c = HDBSCAN(min_cluster_size=5, branch_detection_data=True).fit(X) + b = BranchDetector().fit(c) + for t in b.cluster_condensed_trees_: + if_matplotlib(t.plot)( + select_clusters=True, + label_clusters=True, + selection_palette=("r", "g", "b"), + cmap="Reds", + ) + if_matplotlib(t.plot)(log_size=True, colorbar=False, cmap="none") + + +def test_single_linkage_tree_plot(): + c = HDBSCAN(min_cluster_size=5, branch_detection_data=True).fit(X) + b = BranchDetector().fit(c) + for t in b.cluster_linkage_trees_: + if_matplotlib(t.plot)(cmap="Reds") + if_matplotlib(t.plot)( + vary_line_width=False, + truncate_mode="lastp", + p=10, + cmap="none", + colorbar=False, + ) + + +def test_cluster_approximation_graph_plot(): + c = HDBSCAN(min_cluster_size=5, branch_detection_data=True).fit(X) + b = BranchDetector().fit(c) + g = b.cluster_approximation_graph_ + if_matplotlib(g.plot)(positions=X) + if_pygraphviz(if_matplotlib(g.plot))(node_color="x", feature_names=["x", "y"]) + if_pygraphviz(if_matplotlib(g.plot))(node_color=X[:, 0]) + if_pygraphviz(if_matplotlib(g.plot))(edge_color="centrality", node_alpha=0) + if_pygraphviz(if_matplotlib(g.plot))( + edge_color=g._edges["centrality"], node_alpha=0 + ) + + +@pytest.mark.skip(reason="need to refactor to meet newer standards") +def test_branch_detector_is_sklearn_estimator(): + check_estimator(BranchDetector) + diff --git a/notebooks/How to detect branches.ipynb b/notebooks/How to detect branches.ipynb new file mode 100644 index 00000000..219ae57a --- /dev/null +++ b/notebooks/How to detect branches.ipynb @@ -0,0 +1,538 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# How to detect banches in clusters\n", + "\n", + "HDBSCAN\\* is often used to find subpopulations in exploratory data analysis\n", + "workflows. Not only clusters themselves, but also their shape can represent\n", + "meaningful subpopulations. For example, a Y-shaped cluster may represent an\n", + "evolving process with two distinct end-states. Detecting these branches can\n", + "reveal interesting patterns that are not captured by density-based clustering.\n", + "\n", + "For example, HDBSCAN\\* finds 4 clusters in the datasets below, which does not\n", + "inform us of the branching structure:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "jupyter": { + "source_hidden": true + } + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from hdbscan import HDBSCAN, BranchDetector" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "jupyter": { + "source_hidden": true + } + }, + "outputs": [], + "source": [ + "def plot(labels):\n", + " \"\"\"Plots the data coloured by labels, with noise points in silver.\"\"\"\n", + " noise_mask = labels == -1\n", + " plt.scatter(data[noise_mask, 0], data[noise_mask, 1], 1, color=\"silver\")\n", + " plt.scatter(\n", + " data[~noise_mask, 0],\n", + " data[~noise_mask, 1],\n", + " 1,\n", + " labels[~noise_mask] % 10,\n", + " cmap=\"tab10\",\n", + " vmin=0,\n", + " vmax=9,\n", + " )\n", + " plt.axis(\"off\")\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "jupyter": { + "source_hidden": true + } + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Control points for line segments that merge three clusters\n", + "p0 = (0.13, -0.26)\n", + "p1 = (0.24, -0.12)\n", + "p2 = (0.32, 0.1)\n", + "\n", + "# Noisy points along lines between three clusters\n", + "segments = [\n", + " np.column_stack(\n", + " (np.linspace(p_start[0], p_end[0], 100), np.linspace(p_start[1], p_end[1], 100))\n", + " )\n", + " + np.random.normal(size=(100, 2), scale=0.01)\n", + " for p_start, p_end in [(p0, p1), (p1, p2)]\n", + "]\n", + "\n", + "# Original data with new segments\n", + "data = np.load(\"./clusterable_data.npy\")\n", + "data = np.concatenate((data, *segments))\n", + "\n", + "# HDBSCAN clusters\n", + "clusterer = HDBSCAN(min_cluster_size=15).fit(data)\n", + "plot(clusterer.labels_)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Alternatively, HDBSCAN\\*'s leaf clusters provide more detail. They segment the\n", + "points of different branches into distint clusters. However, the partitioning\n", + "and cluster hierarchy does not (necessarily) tell us how those clusters combine\n", + "into a larger shape." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "jupyter": { + "source_hidden": true + } + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "leaf_clusterer = HDBSCAN(min_cluster_size=15, cluster_selection_method='leaf').fit(data)\n", + "plot(leaf_clusterer.labels_)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is where the branch detection post-processing step comes into play. The\n", + "functionality is described in detail by [Bot et\n", + "al](https://arxiv.org/abs/2311.15887). It operates on the detected clusters and\n", + "extracts a branch-hierarchy analogous to HDBSCAN*'s condensed cluster hierarchy.\n", + "The process is very similar to HDBSCAN* clustering, except that it operates on\n", + "an in-cluster eccentricity rather than a density measure. Where peaks in a\n", + "density profile correspond to clusters, the peaks in an eccentricity profile\n", + "correspond to branches:" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": { + "jupyter": { + "source_hidden": true + } + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.tri as mtri\n", + "\n", + "eccentricities = np.zeros(data.shape[0])\n", + "for label in range(len(clusterer.cluster_persistence_)):\n", + " mask = clusterer.labels_ == label\n", + " centroid = np.average(\n", + " data[mask],\n", + " weights=clusterer.probabilities_[mask],\n", + " axis=0,\n", + " )\n", + " eccentricities[mask] = np.linalg.norm(data[mask] - centroid, axis=1)\n", + "\n", + "fig = plt.figure()\n", + "tri = mtri.Triangulation(data[:, 0], data[:, 1])\n", + "ax = fig.add_subplot(1, 1, 1, projection=\"3d\", computed_zorder=False)\n", + "ax.view_init(elev=45, azim=-100)\n", + "ax.scatter(\n", + " data.T[0],\n", + " data.T[1],\n", + " np.repeat(eccentricities.min(), data.shape[0]),\n", + " s=2,\n", + " edgecolor=\"none\",\n", + " linewidth=0,\n", + ")\n", + "ax.tricontour(tri, eccentricities, levels=np.linspace(0, eccentricities.max(), 15))\n", + "ax.set_xticklabels([])\n", + "ax.set_yticklabels([])\n", + "ax.set_zticklabels([])\n", + "zlim = ax.get_zlim()\n", + "ax.set_box_aspect(aspect=(3, 3, 1))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Using the branch detection functionality is fairly straightforward. First, run\n", + "hdbscan with parameter `branch_detection_data=True`. This tells hdbscan to cache\n", + "the internal data structures needed for the branch detection process. Then,\n", + "configure the ``BranchDetector`` class and fit is with the HDBSCAN object.\n", + "\n", + "The resulting partitioning reflects subgroups for clusters and their\n", + "branches:" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "clusterer = HDBSCAN(min_cluster_size=15, branch_detection_data=True).fit(data)\n", + "branch_detector = BranchDetector(min_branch_size=15).fit(clusterer)\n", + "plot(branch_detector.labels_)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Parameter selection" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `BranchDetector`'s main parameters are very similar to HDBSCAN\\*. Most\n", + "guidelines for tuning HDBSCAN\\* also apply for the branch detector:\n", + "\n", + "- `min_branch_size` behaves like HDBSCAN\\*'s `min_cluster_size`. It configures\n", + " how many points branches need to contain. Values around 10 to 25 points tend\n", + " to work well. Lower values are useful when looking for smaller structures.\n", + " Higher values can be used to suppress noise if present.\n", + "- `branch_selection_method` behaves like HDBSCAN\\*'s `cluster_selection_method`.\n", + " The leaf and Excess of Mass (EOM) strategies are used to select branches from\n", + " the condensed hierarchies. By default, branches are only reflected in the\n", + " final labelling for clusters that have 3 or more branches (at least one\n", + " bifurcation).\n", + "- `branch_selection_persistence` replaces HDBSCAN\\*'s `cluster_selection_epsilon`.\n", + " This parameter can be used to suppress branches with a short eccentricity\n", + " range (y-range in the condensed hierarchy plot).\n", + "- `allow_single_branch` behaves like HDBSCAN\\*'s `allow_single_cluster` and\n", + " mostly affects the EOM selection strategy. When enabled, clusters with\n", + " bifurcations will be given a single label if the root segment contains most\n", + " eccentricity mass (i.e., branches already merge far from the center and most\n", + " poinst are central).\n", + "- `max_branch_size` behaves like HDBSCAN\\*'s `max_cluster_size` and mostly\n", + " affects the EOM selection strategy. Branches with more than the specified\n", + " number of points are skipped, selecting their descendants in the hierarchy\n", + " instead.\n", + "\n", + "Two parameters are unique to the `BranchDetector` class:\n", + "\n", + "- `branch_detection_method` determines which points are connected within a\n", + " cluster. Both density-based clustering and the branch detection process need\n", + " to determine which points are part of the same density/eccentricity peak.\n", + " HDBSCAN\\* defines density in terms of the distance between points, providing\n", + " natural way to define which points are connected at some density value.\n", + " Eccentricity does not have such a connection. So, we use information from the\n", + " clusters to determine which points should be connected instead.\n", + " - The `\"core\"` method selects all edges that could be part of the cluster's\n", + " minimum spanning tree under HDBSCAN\\*'s mutual reachability distance. This\n", + " graph contains the detected MST and all `min_samples`-nearest neighbours. \n", + " - The `\"full\"` method connects all points with a mutual reachability lower\n", + " than the maximum distance in the cluster's MST. It represents all connectity\n", + " at the moment the last point joins the cluster. These methods differ in\n", + " their sensitivity, noise robustness, and computational cost. The `\"core\"`\n", + " method usually needs slightly higher `min_branch_size` values to suppress\n", + " noisy branches than the `\"full\"` method. It is a good choice when branches\n", + " span large density ranges.\n", + "- `label_sides_as_branches` determines whether the sides of an elongated cluster\n", + " without bifurcations (l-shape) are represented as distinct subgroups. By\n", + " default a cluster needs to have one bifurcation (Y-shape) before the detected\n", + " branches are represented in the final labelling." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Useful attributes\n", + "\n", + "Like the HDBSCAN class, the BranchDetector class contains several useful\n", + "attributes for exploring datasets.\n", + "\n", + "### Branch hierarchy\n", + "\n", + "Branch hierarchies reflect the tree-shape of clusters. Like the cluster\n", + "hierarchy, branch hierarchies can be used to interpret which branches exist. In\n", + "addition, they reflect how far apart branches merge into the cluster. " + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "idx = np.argmax([len(x) for x in branch_detector.branch_persistences_])\n", + "branch_detector.cluster_condensed_trees_[idx].plot(\n", + " select_clusters=True, selection_palette=[\"C3\", \"C4\", \"C5\"]\n", + ")\n", + "plt.ylabel(\"Eccentricity\")\n", + "plt.title(f\"Branches in cluster {idx}\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The length of the branches also says something about the compactness /\n", + "elongatedness of clusters. For example, the branch hierarchy for the orange\n", + "~-shaped cluster is quite different from the same hierarcy for the central\n", + "o-shaped cluster." + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.figure(figsize=(6, 3))\n", + "plt.subplot(1, 2, 1)\n", + "idx = np.argmin([min(*x) for x in branch_detector.branch_persistences_])\n", + "branch_detector.cluster_condensed_trees_[idx].plot(colorbar=False)\n", + "plt.ylim([0.3, 0])\n", + "plt.ylabel(\"Eccentricity\")\n", + "plt.title(f\"Cluster {idx} (spherical)\")\n", + "\n", + "plt.subplot(1, 2, 2)\n", + "idx = np.argmax([max(*x) for x in branch_detector.branch_persistences_])\n", + "branch_detector.cluster_condensed_trees_[idx].plot(colorbar=False)\n", + "plt.ylim([0.3, 0])\n", + "plt.ylabel(\"Eccentricity\")\n", + "plt.title(f\"Cluster {idx} (elongated)\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Cluster approximation graphs\n", + "\n", + "Branches are detected using a graph that approximates the connectivity within a\n", + "cluster. These graphs are available in the `cluster_approximation_graph_`\n", + "property and can be used to visualise data and the branch-detection process. The\n", + "plotting function is based on the networkx API and uses networkx functionality\n", + "to compute a layout if positions are not provided. Using UMAP to compute\n", + "positions can be faster and more expressive. Several helper functions for\n", + "exporting to numpy, pandas, and networkx are available.\n", + "\n", + "For example, a figure with points coloured by the final labelling:" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "g = branch_detector.cluster_approximation_graph_\n", + "g.plot(positions=data, node_size=5, edge_width=0.2, edge_alpha=0.2)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Or, a figure with the edges coloured by centrality:" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "g.plot(\n", + " positions=data,\n", + " node_alpha=0,\n", + " edge_color=\"centrality\",\n", + " edge_cmap=\"turbo\",\n", + " edge_width=0.2,\n", + " edge_alpha=0.2,\n", + " edge_vmax=100,\n", + ")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Approximate predict\n", + "\n", + "A branch-aware ``approximate_predict_branch`` function is available to predicts\n", + "branch labels for new points. This function uses a fitted BranchDetector object\n", + "to first predict cluster labels and then the branch labels." + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from hdbscan import approximate_predict_branch\n", + "\n", + "new_points = np.asarray([[0.4, 0.25], [0.23, 0.2], [-0.14, -0.2]])\n", + "clusterer.generate_prediction_data()\n", + "labels, probs, cluster_labels, cluster_probs, branch_labels, branch_probs = (\n", + " approximate_predict_branch(branch_detector, new_points)\n", + ")\n", + "\n", + "plt.scatter(\n", + " new_points.T[0],\n", + " new_points.T[1],\n", + " 140,\n", + " labels % 10,\n", + " marker=\"p\",\n", + " zorder=5,\n", + " cmap=\"tab10\",\n", + " vmin=0,\n", + " vmax=9,\n", + " edgecolor=\"k\",\n", + ")\n", + "plot(branch_detector.labels_)\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}