Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Exposing KL divergence in TSNE #4208

Merged
merged 16 commits into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/checks/copyright.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
re.compile(r"[.]flake8[.]cython$"),
re.compile(r"meta[.]yaml$")
]
ExemptFiles = []
ExemptFiles = ['cpp/src/tsne/cannylab/bh.cu']

# this will break starting at year 10000, which is probably OK :)
CheckSimple = re.compile(
Expand Down
10 changes: 8 additions & 2 deletions cpp/include/cuml/manifold/tsne.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ struct TSNEParams {
* @param[in] knn_indices Array containing nearest neighors indices.
* @param[in] knn_dists Array containing nearest neighors distances.
* @param[in] params Parameters for TSNE model
* @param[in] kl_div (optional) KL divergence
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
* @return The Kullback–Leibler divergence
*
* The CUDA implementation is derived from the excellent CannyLabs open source
* implementation here: https://github.com/CannyLab/tsne-cuda/. The CannyLabs
Expand All @@ -132,7 +134,8 @@ void TSNE_fit(const raft::handle_t& handle,
int p,
int64_t* knn_indices,
float* knn_dists,
TSNEParams& params);
TSNEParams& params,
float* kl_div = nullptr);

/**
* @brief Dimensionality reduction via TSNE using either Barnes Hut O(NlogN)
Expand All @@ -149,6 +152,8 @@ void TSNE_fit(const raft::handle_t& handle,
* @param[in] knn_indices Array containing nearest neighors indices.
* @param[in] knn_dists Array containing nearest neighors distances.
* @param[in] params Parameters for TSNE model
* @param[in] kl_div (optional) KL divergence
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
* @return The Kullback–Leibler divergence
*
* The CUDA implementation is derived from the excellent CannyLabs open source
* implementation here: https://github.com/CannyLab/tsne-cuda/. The CannyLabs
Expand All @@ -167,6 +172,7 @@ void TSNE_fit_sparse(const raft::handle_t& handle,
int p,
int* knn_indices,
float* knn_dists,
TSNEParams& params);
TSNEParams& params,
float* kl_div = nullptr);

} // namespace ML
28 changes: 18 additions & 10 deletions cpp/src/tsne/barnes_hut_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -681,27 +681,35 @@ __global__ void attractive_kernel_bh(const value_t* restrict VAL,
const value_t* restrict Y2,
value_t* restrict attract1,
value_t* restrict attract2,
const value_idx NNZ)
value_t* restrict Qs,
const value_idx NNZ,
const value_t dof)
{
const auto index = (blockIdx.x * blockDim.x) + threadIdx.x;
if (index >= NNZ) return;
const auto i = ROW[index];
const auto j = COL[index];

const value_t y1d = Y1[i] - Y1[j];
const value_t y2d = Y2[i] - Y2[j];
value_t squared_euclidean_dist = y1d * y1d + y2d * y2d;
const value_t y1d = Y1[i] - Y1[j];
const value_t y2d = Y2[i] - Y2[j];
value_t dist = y1d * y1d + y2d * y2d;
// As a sum of squares, SED is mathematically >= 0. There might be a source of
// NaNs upstream though, so until we find and fix them, enforce that trait.
if (!(squared_euclidean_dist >= 0)) squared_euclidean_dist = 0.0f;
const value_t PQ = __fdividef(VAL[index], squared_euclidean_dist + 1.0f);
if (!(dist >= 0)) dist = 0.0f;

// TODO: Calculate Kullback-Leibler divergence
// TODO: Convert attractive forces to CSR format
const value_t P = VAL[index];
const value_t Q = compute_q(dist, dof);
const value_t PQ = P * Q;

// Apply forces
atomicAdd(&attract1[i], PQ * (Y1[i] - Y1[j]));
atomicAdd(&attract2[i], PQ * (Y2[i] - Y2[j]));
atomicAdd(&attract1[i], PQ * y1d);
atomicAdd(&attract2[i], PQ * y2d);

if (Qs) { // when computing KL div
Qs[index] = Q;
}

// TODO: Convert attractive forces to CSR format
}

/**
Expand Down
37 changes: 27 additions & 10 deletions cpp/src/tsne/barnes_hut_tsne.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <raft/cudart_utils.h>
#include <cuml/common/logger.hpp>
#include <raft/linalg/eltwise.cuh>
#include "barnes_hut_kernels.cuh"
#include "utils.cuh"

Expand All @@ -36,17 +37,19 @@ namespace TSNE {
*/

template <typename value_idx, typename value_t>
void Barnes_Hut(value_t* VAL,
const value_idx* COL,
const value_idx* ROW,
const value_idx NNZ,
const raft::handle_t& handle,
value_t* Y,
const value_idx n,
const TSNEParams& params)
value_t Barnes_Hut(value_t* VAL,
const value_idx* COL,
const value_idx* ROW,
const value_idx NNZ,
const raft::handle_t& handle,
value_t* Y,
const value_idx n,
const TSNEParams& params)
{
cudaStream_t stream = handle.get_stream();

value_t kl_div = 0;

// Get device properites
//---------------------------------------------------
const int blocks = raft::getMultiProcessorCount();
Expand Down Expand Up @@ -120,6 +123,10 @@ void Barnes_Hut(value_t* VAL,
raft::copy(YY.data() + nnodes + 1, Y + n, n, stream);
}

rmm::device_uvector<value_t> tmp(NNZ, stream);
value_t* Qs = tmp.data();
value_t* KL_divs = tmp.data();

// Set cache levels for faster algorithm execution
//---------------------------------------------------
CUDA_CHECK(
Expand Down Expand Up @@ -259,12 +266,13 @@ void Barnes_Hut(value_t* VAL,
START_TIMER;
BH::Find_Normalization<<<1, 1, 0, stream>>>(Z_norm.data(), n);
CUDA_CHECK(cudaPeekAtLastError());

END_TIMER(Reduction_time);

START_TIMER;
// TODO: Calculate Kullback-Leibler divergence
// For general embedding dimensions
bool last_iter = iter == params.max_iter - 1;

BH::attractive_kernel_bh<<<raft::ceildiv(NNZ, (value_idx)1024), 1024, 0, stream>>>(
VAL,
COL,
Expand All @@ -273,10 +281,17 @@ void Barnes_Hut(value_t* VAL,
YY.data() + nnodes + 1,
attr_forces.data(),
attr_forces.data() + n,
NNZ);
last_iter ? Qs : nullptr,
NNZ,
fmaxf(params.dim - 1, 1));
CUDA_CHECK(cudaPeekAtLastError());
END_TIMER(attractive_time);

if (last_iter) {
kl_div = compute_kl_div(VAL, Qs, KL_divs, NNZ, stream);
CUDA_CHECK(cudaPeekAtLastError());
}

START_TIMER;
BH::IntegrationKernel<<<blocks * FACTOR6, THREADS6, 0, stream>>>(learning_rate,
momentum,
Expand All @@ -302,6 +317,8 @@ void Barnes_Hut(value_t* VAL,
// Copy final YY into true output Y
raft::copy(Y, YY.data(), n, stream);
raft::copy(Y + n, YY.data() + nnodes + 1, n, stream);

return kl_div;
}

} // namespace TSNE
Expand Down
Loading