Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add branch detection functionality #648

Merged
merged 2 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,40 @@ Based on the paper:
*"Rates of convergence for the cluster tree."*
In Advances in Neural Information Processing Systems, 2010.

----------------
Branch detection
----------------

The hdbscan package supports a branch-detection post-processing step
by `Bot et al. <https://arxiv.org/abs/2311.15887>`_. Cluster shapes,
such as branching structures, can reveal interesting patterns
that are not expressed in density-based cluster hierarchies. The
BranchDetector class mimics the HDBSCAN API and can be used to
detect branching hierarchies in clusters. It provides condensed
branch hierarchies, branch persistences, and branch memberships and
supports joblib's caching functionality. A notebook
`demonstrating the BranchDetector is available <http://nbviewer.jupyter.org/github/scikit-learn-contrib/hdbscan/blob/master/notebooks/How%20to%20detect%20branches.ipynb>`_.

Example usage:

.. code:: python

import hdbscan
from sklearn.datasets import make_blobs

data, _ = make_blobs(1000)

clusterer = hdbscan.HDBSCAN(branch_detection_data=True).fit(data)
branch_detector = hdbscan.BranchDetector().fit(clusterer)
branch_detector.cluster_approximation_graph_.plot(edge_width=0.1)


Based on the paper:
D. M. Bot, J. Peeters, J. Liesenborgs and J. Aerts
*"FLASC: A Flare-Sensitive Clustering Algorithm: Extending HDBSCAN\* for Detecting Branches in Clusters"*
Arxiv 2311.15887, 2023.


----------
Installing
----------
Expand Down Expand Up @@ -300,6 +334,24 @@ To reference the high performance algorithm developed in this library please cit
organization={IEEE}
}

If you used the branch-detection functionality in this codebase in a scientific publication and which to cite it, please use the `Arxiv preprint <https://arxiv.org/abs/2311.15887>`_:

D. M. Bot, J. Peeters, J. Liesenborgs and J. Aerts
*"FLASC: A Flare-Sensitive Clustering Algorithm: Extending HDBSCAN\* for Detecting Branches in Clusters"*
Arxiv 2311.15887, 2023.

.. code:: bibtex

@misc{bot2023flasc,
title={FLASC: A Flare-Sensitive Clustering Algorithm: Extending HDBSCAN* for Detecting Branches in Clusters},
author={D. M. Bot and J. Peeters and J. Liesenborgs and J. Aerts},
year={2023},
eprint={2311.15887},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2311.15887},
}

---------
Licensing
---------
Expand Down
12 changes: 12 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,15 @@ and the prediction module.

.. automodule:: hdbscan.prediction
:members:


Branch detection
----------------

The branches module contains classes for detecting branches within clusters.

.. automodule:: hdbscan.branches
:members:

.. autoclass:: hdbscan.plots.ApproximationGraph
:members:
242 changes: 242 additions & 0 deletions docs/how_to_detect_branches.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
How to detect banches in clusters
=================================

HDBSCAN\* is often used to find subpopulations in exploratory data
analysis workflows. Not only clusters themselves, but also their shape
can represent meaningful subpopulations. For example, a Y-shaped cluster
may represent an evolving process with two distinct end-states.
Detecting these branches can reveal interesting patterns that are not
captured by density-based clustering.

For example, HDBSCAN\* finds 4 clusters in the datasets below, which
does not inform us of the branching structure:

.. image:: images/how_to_detect_branches_3_0.png

Alternatively, HDBSCAN\*’s leaf clusters provide more detail. They
segment the points of different branches into distint clusters. However,
the partitioning and cluster hierarchy does not (necessarily) tell us how
those clusters combine into a larger shape.

.. image:: images/how_to_detect_branches_5_0.png

This is where the branch detection post-processing step comes into play.
The functionality is described in detail by `Bot et
al <https://arxiv.org/abs/2311.15887>`__. It operates on the detected
clusters and extracts a branch-hierarchy analogous to HDBSCAN\*’s
condensed cluster hierarchy. The process is very similar to HDBSCAN\*
clustering, except that it operates on an in-cluster eccentricity rather
than a density measure. Where peaks in a density profile correspond to
clusters, the peaks in an eccentricity profile correspond to branches:

.. image:: images/how_to_detect_branches_7_0.png

Using the branch detection functionality is fairly straightforward.
First, run hdbscan with parameter ``branch_detection_data=True``. This
tells hdbscan to cache the internal data structures needed for the
branch detection process. Then, configure the ``BranchDetector`` class
and fit is with the HDBSCAN object.

The resulting partitioning reflects subgroups for clusters and their
branches:

.. code:: python
from hdbscan import HDBSCAN, BranchDetector

clusterer = HDBSCAN(min_cluster_size=15, branch_detection_data=True).fit(data)
branch_detector = BranchDetector(min_branch_size=15).fit(clusterer)
plot(branch_detector.labels_)

.. image:: images/how_to_detect_branches_9_0.png


Parameter selection
-------------------

The ``BranchDetector``’s main parameters are very similar to HDBSCAN.
Most guidelines for tuning HDBSCAN\* also apply for the branch detector:

- ``min_branch_size`` behaves like HDBSCAN\*’s ``min_cluster_size``. It
configures how many points branches need to contain. Values around 10
to 25 points tend to work well. Lower values are useful when looking
for smaller structures. Higher values can be used to suppress noise
if present.
- ``branch_selection_method`` behaves like HDBSCAN\*’s
``cluster_selection_method``. The leaf and Excess of Mass (EOM)
strategies are used to select branches from the condensed
hierarchies. By default, branches are only reflected in the final
labelling for clusters that have 3 or more branches (at least one
bifurcation).
- ``branch_selection_persistence`` replaces HDBSCAN\*’s
``cluster_selection_epsilon``. This parameter can be used to suppress
branches with a short eccentricity range (y-range in the condensed
hierarchy plot).
- ``allow_single_branch`` behaves like HDBSCAN\*’s
``allow_single_cluster`` and mostly affects the EOM selection
strategy. When enabled, clusters with bifurcations will be given a
single label if the root segment contains most eccentricity mass
(i.e., branches already merge far from the center and most poinst are
central).
- ``max_branch_size`` behaves like HDBSCAN\*’s ``max_cluster_size`` and
mostly affects the EOM selection strategy. Branches with more than
the specified number of points are skipped, selecting their
descendants in the hierarchy instead.

Two parameters are unique to the ``BranchDetector`` class:

- ``branch_detection_method`` determines which points are connected
within a cluster. Both density-based clustering and the branch detection
process need to determine which points are part of the same
density/eccentricity peak. HDBSCAN\* defines density in terms of the distance
between points, providing natural way to define which points are connected at
some density value. Eccentricity does not have such a connection. So, we use
information from the clusters to determine which points should be connected
instead.

- The ``"core"`` method selects all edges that could be part of the
cluster’s minimum spanning tree under HDBSCAN\*’s mutual
reachability distance. This graph contains the detected MST and
all ``min_samples``-nearest neighbours.
- The ``"full"`` method connects all points with a mutual
reachability lower than the maximum distance in the cluster’s MST.
It represents all connectity at the moment the last point joins
the cluster.

These methods differ in their sensitivity, noise robustness, and
computational cost. The ``"core"`` method usually needs slightly
higher ``min_branch_size`` values to suppress noisy branches than the
``"full"`` method. It is a good choice when branches span large
density ranges.

- ``label_sides_as_branches`` determines whether the sides of an
elongated cluster without bifurcations (l-shape) are represented as
distinct subgroups. By default a cluster needs to have one
bifurcation (Y-shape) before the detected branches are represented in
the final labelling.


Useful attributes
-----------------

Like the HDBSCAN class, the BranchDetector class contains several useful
attributes for exploring datasets.

Branch hierarchy
~~~~~~~~~~~~~~~~

Branch hierarchies reflect the tree-shape of clusters. Like the cluster
hierarchy, branch hierarchies can be used to interpret which branches
exist. In addition, they reflect how far apart branches merge into the
cluster.

.. code:: python

idx = np.argmax([len(x) for x in branch_detector.branch_persistences_])
branch_detector.cluster_condensed_trees_[idx].plot(
select_clusters=True, selection_palette=["C3", "C4", "C5"]
)
plt.ylabel("Eccentricity")
plt.title(f"Branches in cluster {idx}")
plt.show()

.. image:: images/how_to_detect_branches_13_0.png

The length of the branches also says something about the compactness /
elongatedness of clusters. For example, the branch hierarchy for the
orange ~-shaped cluster is quite different from the same hierarcy for
the central o-shaped cluster.

.. code:: python

plt.figure(figsize=(6, 3))
plt.subplot(1, 2, 1)
idx = np.argmin([min(*x) for x in branch_detector.branch_persistences_])
branch_detector.cluster_condensed_trees_[idx].plot(colorbar=False)
plt.ylim([0.3, 0])
plt.ylabel("Eccentricity")
plt.title(f"Cluster {idx} (spherical)")

plt.subplot(1, 2, 2)
idx = np.argmax([max(*x) for x in branch_detector.branch_persistences_])
branch_detector.cluster_condensed_trees_[idx].plot(colorbar=False)
plt.ylim([0.3, 0])
plt.ylabel("Eccentricity")
plt.title(f"Cluster {idx} (elongated)")
plt.show()

.. image:: images/how_to_detect_branches_15_0.png

Cluster approximation graphs
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Branches are detected using a graph that approximates the connectivity
within a cluster. These graphs are available in the
``cluster_approximation_graph_`` property and can be used to visualise
data and the branch-detection process. The plotting function is based on
the networkx API and uses networkx functionality to compute a layout if
positions are not provided. Using UMAP to compute positions can be
faster and more expressive. Several helper functions for exporting to
numpy, pandas, and networkx are available.

For example, a figure with points coloured by the final labelling:

.. code:: python

g = branch_detector.cluster_approximation_graph_
g.plot(positions=data, node_size=5, edge_width=0.2, edge_alpha=0.2)
plt.show()

.. image:: images/how_to_detect_branches_17_0.png

Or, a figure with the edges coloured by centrality:

.. code:: python

g.plot(
positions=data,
node_alpha=0,
edge_color="centrality",
edge_cmap="turbo",
edge_width=0.2,
edge_alpha=0.2,
edge_vmax=100,
)
plt.show()

.. image:: images/how_to_detect_branches_19_0.png


Approximate predict
-------------------

A branch-aware ``approximate_predict_branch`` function is available to
predicts branch labels for new points. This function uses a fitted
BranchDetector object to first predict cluster labels and then the
branch labels.

.. code:: python

from hdbscan import approximate_predict_branch

new_points = np.asarray([[0.4, 0.25], [0.23, 0.2], [-0.14, -0.2]])
clusterer.generate_prediction_data()
labels, probs, cluster_labels, cluster_probs, branch_labels, branch_probs = (
approximate_predict_branch(branch_detector, new_points)
)

plt.scatter(
new_points.T[0],
new_points.T[1],
140,
labels % 10,
marker="p",
zorder=5,
cmap="tab10",
vmin=0,
vmax=9,
edgecolor="k",
)
plot(branch_detector.labels_)
plt.show()

.. image:: images/how_to_detect_branches_21_0.png
Binary file added docs/images/how_to_detect_branches_13_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/how_to_detect_branches_15_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/how_to_detect_branches_17_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/how_to_detect_branches_19_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/how_to_detect_branches_21_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/how_to_detect_branches_3_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/how_to_detect_branches_5_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/how_to_detect_branches_7_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/how_to_detect_branches_9_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ User Guide / Tutorial
soft_clustering
how_to_use_epsilon
dbscan_from_hdbscan
how_to_detect_branches
faq

Background on Clustering with HDBSCAN
Expand Down
3 changes: 3 additions & 0 deletions hdbscan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,8 @@
membership_vector,
all_points_membership_vectors,
approximate_predict_scores)
from .branches import (BranchDetector,
detect_branches_in_clusters,
approximate_predict_branch)


6 changes: 3 additions & 3 deletions hdbscan/_hdbscan_linkage.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ cpdef np.ndarray[np.double_t, ndim=2] mst_linkage_core_vector(

cdef np.intp_t current_node
cdef np.intp_t source_node
cdef np.intp_t right_node
cdef np.intp_t left_node
cdef np.intp_t right_node, right_source
cdef np.intp_t left_node, left_source
cdef np.intp_t new_node
cdef np.intp_t i
cdef np.intp_t j
Expand Down Expand Up @@ -124,7 +124,7 @@ cpdef np.ndarray[np.double_t, ndim=2] mst_linkage_core_vector(
continue

right_value = current_distances[j]
right_source = current_sources[j]
right_source = <np.intp_t> current_sources[j]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes here solved the following test errors on my machine:

FAILED hdbscan/tests/test_hdbscan.py::test_hdbscan_prims_kdtree - TypeError: 'float' object cannot be interpreted as an integer
FAILED hdbscan/tests/test_hdbscan.py::test_hdbscan_prims_balltree - TypeError: 'float' object cannot be interpreted as an integer
FAILED hdbscan/tests/test_hdbscan.py::test_hdbscan_high_dimensional - TypeError: 'float' object cannot be interpreted as an integer
FAILED hdbscan/tests/test_rsl.py::test_rsl_prims_balltree - TypeError: 'float' object cannot be interpreted as an integer
FAILED hdbscan/tests/test_rsl.py::test_rsl_prims_kdtree - TypeError: 'float' object cannot be interpreted as an integer

These tests do not appear to be broken on the test runner for the master branch but were crashing my machine.
Could this be related to a specific Cython version?


left_value = dist_metric.dist(&raw_data_ptr[num_features *
current_node],
Expand Down
Loading