From 041a7fc26cc5714e31759ec016d40bed25d923f6 Mon Sep 17 00:00:00 2001 From: Divye Gala Date: Tue, 6 Jul 2021 10:04:06 -0700 Subject: [PATCH] HDBSCAN bug on A100 (#4024) While this issue only appeared in A100, it could have appeared on any other GPU. In this kernel https://github.com/rapidsai/cuml/blob/c6f992a5fcbccf5677ca6d639af6b84e93aa8108/cpp/src/hdbscan/detail/kernels/condense.cuh#L85, we launch a thread for every node of a binary tree on the GPU. The problem that occurs then, is: 1. Each node marks itself out of the frontier https://github.com/rapidsai/cuml/blob/c6f992a5fcbccf5677ca6d639af6b84e93aa8108/cpp/src/hdbscan/detail/kernels/condense.cuh#L94 2. For every node that is not a leaf, it marks its left and right child into the frontier https://github.com/rapidsai/cuml/blob/c6f992a5fcbccf5677ca6d639af6b84e93aa8108/cpp/src/hdbscan/detail/kernels/condense.cuh#L117 This is UB because the thread for the non-leaf node could be marking itself out of the frontier, but it could be the child of a node whose thread tries to mark it into the frontier. Edit: Dropped `__threadfence()` solution as it wasn't fully correct. Using a `next_frontier` array instead to keep track of the frontier for the next BFS iteration. Authors: - Divye Gala (https://github.com/divyegala) Approvers: - Dante Gama Dessavre (https://github.com/dantegd) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuml/pull/4024 --- cpp/src/hdbscan/detail/condense.cuh | 13 ++++++++++--- cpp/src/hdbscan/detail/kernels/condense.cuh | 6 +++--- cpp/src/hdbscan/detail/kernels/select.cuh | 10 ++++------ cpp/src/hdbscan/detail/select.cuh | 12 ++++++++++-- python/cuml/test/test_hdbscan.py | 7 ------- 5 files changed, 27 insertions(+), 21 deletions(-) diff --git a/cpp/src/hdbscan/detail/condense.cuh b/cpp/src/hdbscan/detail/condense.cuh index e5aa447c98..b6f0f79fbb 100644 --- a/cpp/src/hdbscan/detail/condense.cuh +++ b/cpp/src/hdbscan/detail/condense.cuh @@ -83,8 +83,10 @@ void build_condensed_hierarchy( "Cannot find single-linkage solution."); rmm::device_uvector frontier(root + 1, stream); + rmm::device_uvector next_frontier(root + 1, stream); thrust::fill(exec_policy, frontier.begin(), frontier.end(), false); + thrust::fill(exec_policy, next_frontier.begin(), next_frontier.end(), false); // Array to propagate the lambda of subtrees actively being collapsed // through multiple bfs iterations. @@ -121,9 +123,14 @@ void build_condensed_hierarchy( // TODO: Investigate whether it would be worth performing a gather/argmatch in order // to schedule only the number of threads needed. (it might not be worth it) condense_hierarchy_kernel<<>>( - frontier.data(), ignore.data(), relabel.data(), children, delta, sizes, - n_leaves, min_cluster_size, out_parent.data(), out_child.data(), - out_lambda.data(), out_size.data()); + frontier.data(), next_frontier.data(), ignore.data(), relabel.data(), + children, delta, sizes, n_leaves, min_cluster_size, out_parent.data(), + out_child.data(), out_lambda.data(), out_size.data()); + + thrust::copy(exec_policy, next_frontier.begin(), next_frontier.end(), + frontier.begin()); + thrust::fill(exec_policy, next_frontier.begin(), next_frontier.end(), + false); n_elements_to_traverse = thrust::reduce(exec_policy, frontier.data(), frontier.data() + root + 1, 0); diff --git a/cpp/src/hdbscan/detail/kernels/condense.cuh b/cpp/src/hdbscan/detail/kernels/condense.cuh index c656bcbd0b..eaa2adf0ac 100644 --- a/cpp/src/hdbscan/detail/kernels/condense.cuh +++ b/cpp/src/hdbscan/detail/kernels/condense.cuh @@ -83,7 +83,7 @@ __device__ inline value_t get_lambda(value_idx node, value_idx num_points, */ template __global__ void condense_hierarchy_kernel( - bool *frontier, value_t *ignore, value_idx *relabel, + bool *frontier, bool *next_frontier, value_t *ignore, value_idx *relabel, const value_idx *children, const value_t *deltas, const value_idx *sizes, int n_leaves, int min_cluster_size, value_idx *out_parent, value_idx *out_child, value_t *out_lambda, value_idx *out_count) { @@ -114,8 +114,8 @@ __global__ void condense_hierarchy_kernel( value_idx right_child = children[((node - n_leaves) * 2) + 1]; // flip frontier for children - frontier[left_child] = true; - frontier[right_child] = true; + next_frontier[left_child] = true; + next_frontier[right_child] = true; // propagate ignore down to children ignore[left_child] = should_ignore ? subtree_lambda : -1; diff --git a/cpp/src/hdbscan/detail/kernels/select.cuh b/cpp/src/hdbscan/detail/kernels/select.cuh index 53eb1eaa40..ebb2930a38 100644 --- a/cpp/src/hdbscan/detail/kernels/select.cuh +++ b/cpp/src/hdbscan/detail/kernels/select.cuh @@ -33,11 +33,9 @@ namespace Select { * @param[in] n_clusters number of clusters */ template -__global__ void propagate_cluster_negation_kernel(const value_idx *indptr, - const value_idx *children, - int *frontier, - int *is_cluster, - int n_clusters) { +__global__ void propagate_cluster_negation_kernel( + const value_idx *indptr, const value_idx *children, int *frontier, + int *next_frontier, int *is_cluster, int n_clusters) { int cluster = blockDim.x * blockIdx.x + threadIdx.x; if (cluster < n_clusters && frontier[cluster]) { @@ -47,7 +45,7 @@ __global__ void propagate_cluster_negation_kernel(const value_idx *indptr, value_idx children_stop = indptr[cluster + 1]; for (int i = children_start; i < children_stop; i++) { value_idx child = children[i]; - frontier[child] = true; + next_frontier[child] = true; is_cluster[child] = false; } } diff --git a/cpp/src/hdbscan/detail/select.cuh b/cpp/src/hdbscan/detail/select.cuh index 538cb38fa4..01076aeacf 100644 --- a/cpp/src/hdbscan/detail/select.cuh +++ b/cpp/src/hdbscan/detail/select.cuh @@ -68,6 +68,9 @@ void perform_bfs(const raft::handle_t &handle, const value_idx *indptr, auto stream = handle.get_stream(); auto thrust_policy = rmm::exec_policy(stream); + rmm::device_uvector next_frontier(n_clusters, stream); + thrust::fill(thrust_policy, next_frontier.begin(), next_frontier.end(), 0); + value_idx n_elements_to_traverse = thrust::reduce(thrust_policy, frontier, frontier + n_clusters, 0); @@ -78,11 +81,16 @@ void perform_bfs(const raft::handle_t &handle, const value_idx *indptr, size_t grid = raft::ceildiv(n_clusters, tpb); while (n_elements_to_traverse > 0) { - bfs_kernel<<>>(indptr, children, frontier, is_cluster, - n_clusters); + bfs_kernel<<>>( + indptr, children, frontier, next_frontier.data(), is_cluster, n_clusters); + + thrust::copy(thrust_policy, next_frontier.begin(), next_frontier.end(), + frontier); + thrust::fill(thrust_policy, next_frontier.begin(), next_frontier.end(), 0); n_elements_to_traverse = thrust::reduce(thrust_policy, frontier, frontier + n_clusters, 0); + CUDA_CHECK(cudaStreamSynchronize(stream)); } } diff --git a/python/cuml/test/test_hdbscan.py b/python/cuml/test/test_hdbscan.py index 4c3a0f9f52..740d888c46 100644 --- a/python/cuml/test/test_hdbscan.py +++ b/python/cuml/test/test_hdbscan.py @@ -32,10 +32,6 @@ import cupy as cp -import rmm - -IS_A100 = 'A100' in rmm._cuda.gpu.deviceGetName(0) - test_datasets = { "digits": datasets.load_digits(), "boston": datasets.load_boston(), @@ -56,7 +52,6 @@ def assert_cluster_counts(sk_agg, cuml_agg, digits=25): np.testing.assert_almost_equal(sk_counts, cu_counts, decimal=-1 * digits) -@pytest.mark.skipif(IS_A100, reason="Skipping test on A100") @pytest.mark.parametrize('nrows', [500]) @pytest.mark.parametrize('ncols', [25]) @pytest.mark.parametrize('nclusters', [2, 5]) @@ -114,7 +109,6 @@ def test_hdbscan_blobs(nrows, ncols, nclusters, np.sort(cuml_agg.cluster_persistence_), rtol=0.01, atol=0.01) -@pytest.mark.skipif(IS_A100, reason="Skipping test on A100") @pytest.mark.parametrize('dataset', test_datasets.values()) @pytest.mark.parametrize('cluster_selection_epsilon', [0.0, 50.0, 150.0]) @pytest.mark.parametrize('min_samples_cluster_size_bounds', [(150, 150, 0), @@ -167,7 +161,6 @@ def test_hdbscan_sklearn_datasets(dataset, np.sort(cuml_agg.cluster_persistence_), rtol=0.1, atol=0.1) -@pytest.mark.skipif(IS_A100, reason="Skipping test on A100") @pytest.mark.parametrize('nrows', [1000]) @pytest.mark.parametrize('dataset', dataset_names) @pytest.mark.parametrize('min_samples', [15])