diff --git a/CHANGELOG.md b/CHANGELOG.md index 9538154871..2432b5a090 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,11 +13,13 @@ - PR #444: Added supervised training to UMAP - PR #460: Random Forest & Decision Trees (Single-GPU, Classification - PR #491: added doxygen build target for ml-prims +- PR #505: Adding R-Squared Score to python interface - PR #507: Added coordinate descent for lasso and elastic-net - PR #511: a minmax ml-prim - PR #520: Add local build script to mimic gpuCI - PR #503: Added column-wise matrix sort primitive - PR #525: Add docs build script to cuML +- PR #528: Remove current KMeans and replace it with a new single GPU implementation built using ML primitives ## Improvements @@ -41,6 +43,7 @@ - PR #524: Use cmake find blas and find lapack to pass configure options to faiss - PR #527: Added notes on UMAP differences from reference implementation - PR #540: Use latest release version in update-version CI script +- PR #552: Re-enable assert in kmeans tests with xfail as needed ## Bug Fixes @@ -72,6 +75,9 @@ - PR #519: README.md Updates and adding BUILD.md back - PR #526: Fix the issue of wrong results when fit and transform of PCA are called separately - PR #531: Fixing missing arguments in updateDevice() for RF +- PR #543: Exposing dbscan batch size through cython API and fixing broken batching +- PR #551: Made use of ZLIB_LIBRARIES consistent between ml_test and ml_mg_test +- PR #557: Modified CI script to run cuML tests before building mlprims and removed lapack flag # cuML 0.6.0 (22 Mar 2019) diff --git a/ci/gpu/build.sh b/ci/gpu/build.sh index a45df810af..bf40030833 100644 --- a/ci/gpu/build.sh +++ b/ci/gpu/build.sh @@ -82,16 +82,6 @@ logger "Build cuML..." cd $WORKSPACE/python python setup.py build_ext --inplace -logger "Build ml-prims tests..." -mkdir -p $WORKSPACE/ml-prims/build -cd $WORKSPACE/ml-prims/build -cmake $GPU_ARCH .. - -logger "Clean up make..." -make clean -logger "Make ml-prims test..." -make -j${PARALLEL_LEVEL} - ################################################################################ # TEST - Run GoogleTest and py.tests for libcuml and cuML @@ -104,12 +94,25 @@ logger "GoogleTest for libcuml..." cd $WORKSPACE/cuML/build GTEST_OUTPUT="xml:${WORKSPACE}/test-results/libcuml_cpp/" ./ml_test - logger "Python py.test for cuML..." cd $WORKSPACE/python py.test --cache-clear --junitxml=${WORKSPACE}/junit-cuml.xml -v +################################################################################ +# TEST - Build and run ml-prim tests +################################################################################ + +logger "Build ml-prims tests..." +mkdir -p $WORKSPACE/ml-prims/build +cd $WORKSPACE/ml-prims/build +cmake $GPU_ARCH .. + +logger "Clean up make..." +make clean +logger "Make ml-prims test..." +make -j${PARALLEL_LEVEL} + logger "Run ml-prims test..." cd $WORKSPACE/ml-prims/build GTEST_OUTPUT="xml:${WORKSPACE}/test-results/ml-prims/" ./test/mlcommon_test diff --git a/cuML/CMakeLists.txt b/cuML/CMakeLists.txt index 9e21a83971..a75bf92068 100644 --- a/cuML/CMakeLists.txt +++ b/cuML/CMakeLists.txt @@ -172,7 +172,7 @@ endif(NOT CMAKE_CXX11_ABI) include (ExternalProject) ExternalProject_Add(faiss SOURCE_DIR ${FAISS_DIR} - CONFIGURE_COMMAND LIBS=-pthread CPPFLAGS=-w LDFLAGS=-L${CMAKE_INSTALL_PREFIX}/lib ${FAISS_DIR}/configure --prefix=${CMAKE_CURRENT_BINARY_DIR}/faiss --with-blas=${BLAS_LIBRARIES} --with-lapack=${LAPACK_LIBRARIES} --with-cuda=${CUDA_TOOLKIT_ROOT_DIR} --quiet + CONFIGURE_COMMAND LIBS=-pthread CPPFLAGS=-w LDFLAGS=-L${CMAKE_INSTALL_PREFIX}/lib ${FAISS_DIR}/configure --prefix=${CMAKE_CURRENT_BINARY_DIR}/faiss --with-blas=${BLAS_LIBRARIES} --with-cuda=${CUDA_TOOLKIT_ROOT_DIR} --quiet PREFIX ${CMAKE_CURRENT_BINARY_DIR}/faiss/ BUILD_COMMAND $(MAKE) INSTALL_COMMAND $(MAKE) -s install > /dev/null || $(MAKE) && cd gpu && $(MAKE) -s install > /dev/null || $(MAKE) @@ -209,7 +209,7 @@ file(GLOB_RECURSE cuml_test_cuda_sources "test/*.cu") file(GLOB_RECURSE cuml_mg_test_cuda_sources "test_mg/*.cu") ################################################################################################### -# - build libcuml shared library ------------------------------------------------------------------ +# - build libcuml++ shared library ------------------------------------------------------------------ set(CUML_CPP_TARGET "cuml++") add_library(${CUML_CPP_TARGET} SHARED src/pca/pca.cu @@ -223,6 +223,7 @@ add_library(${CUML_CPP_TARGET} SHARED src/common/cuml_api.cpp src/umap/umap.cu src/solver/solver.cu + src/metrics/metrics.cu src/decisiontree/decisiontree.cu src/randomforest/randomforest.cu) @@ -235,8 +236,7 @@ set(CUML_LINK_LIBRARIES ${CUDA_nvgraph_LIBRARY} ${ZLIB_LIBRARIES} gpufaisslib - faisslib - ${BLAS_LIBRARIES}) + faisslib) if(OPENMP_FOUND) set(CUML_LINK_LIBRARIES ${CUML_LINK_LIBRARIES} OpenMP::OpenMP_CXX pthread) @@ -260,10 +260,9 @@ target_link_libraries(ml_test ${CUDA_cusparse_LIBRARY} ${CUDA_nvgraph_LIBRARY} faisslib - ${BLAS_LIBRARIES} ${CUML_CPP_TARGET} pthread - z) + ${ZLIB_LIBRARIES}) ################################################################################################### # - build test executable ------------------------------------------------------------------------- @@ -281,7 +280,6 @@ target_link_libraries(ml_mg_test ${CUDA_nvgraph_LIBRARY} gpufaisslib faisslib - ${BLAS_LIBRARIES} ${CUML_CPP_TARGET} pthread ${ZLIB_LIBRARIES}) diff --git a/cuML/examples/dbscan/dbscan_example.cpp b/cuML/examples/dbscan/dbscan_example.cpp index e6cb432326..4f7e4ec1ef 100644 --- a/cuML/examples/dbscan/dbscan_example.cpp +++ b/cuML/examples/dbscan/dbscan_example.cpp @@ -78,12 +78,13 @@ void printUsage() << "-num_samples -num_features " << "[-min_pts ] " << "[-eps ] " + << "[-max_bytes_per_batch ] " << std::endl; return; } void loadDefaultDataset(std::vector& inputData, size_t& nRows, - size_t& nCols, int& minPts, float& eps) + size_t& nCols, int& minPts, float& eps, size_t& max_bytes_per_batch) { constexpr size_t NUM_ROWS = 25; constexpr size_t NUM_COLS = 3; @@ -122,6 +123,7 @@ void loadDefaultDataset(std::vector& inputData, size_t& nRows, nCols = NUM_COLS; minPts = MIN_PTS; eps = EPS; + max_bytes_per_batch = 0; // allow algorithm to set this inputData.insert(inputData.begin(), data, data+nRows*nCols); } @@ -167,6 +169,7 @@ int main(int argc, char * argv[]) get_argval(argv, argv+argc,"-input", std::string("")); int minPts = get_argval(argv, argv+argc, "-min_pts", 3); float eps = get_argval(argv, argv+argc, "-eps", 1.0f); + size_t max_bytes_per_batch = get_argval(argv, argv+argc, "-max_bytes_per_batch", (size_t)2e7); { cudaError_t cudaStatus = cudaSuccess; @@ -215,7 +218,7 @@ int main(int argc, char * argv[]) // Samples file not specified, run with defaults std::cout << "Samples file not specified. (-input option)" << std::endl; std::cout << "Running with default dataset:" << std::endl; - loadDefaultDataset(h_inputData, nRows, nCols, minPts, eps); + loadDefaultDataset(h_inputData, nRows, nCols, minPts, eps, max_bytes_per_batch); } else if(nRows == 0 || nCols == 0) { @@ -274,10 +277,10 @@ int main(int argc, char * argv[]) << "Number of samples - " << nRows << std::endl << "Number of features - " << nCols << std::endl << "min_pts - " << minPts << std::endl - << "eps - " << eps - << std::endl; + << "eps - " << eps << std::endl + << "max_bytes_per_batch - " << max_bytes_per_batch << std::endl; - ML::dbscanFit(cumlHandle, d_inputData, nRows, nCols, eps, minPts, d_labels); + ML::dbscanFit(cumlHandle, d_inputData, nRows, nCols, eps, minPts, d_labels, max_bytes_per_batch); CUDA_RT_CALL( cudaMemcpyAsync(h_labels.data(), d_labels, nRows*sizeof(int), cudaMemcpyDeviceToHost, stream) ); CUDA_RT_CALL( cudaStreamSynchronize(stream) ); diff --git a/cuML/examples/kmeans/kmeans_example.cpp b/cuML/examples/kmeans/kmeans_example.cpp index 84e8dcaf88..c804164ab1 100644 --- a/cuML/examples/kmeans/kmeans_example.cpp +++ b/cuML/examples/kmeans/kmeans_example.cpp @@ -23,7 +23,17 @@ #include -#include +#ifdef HAVE_CUB +#include +#endif //HAVE_CUB + +#ifdef HAVE_RMM +#include +#include +#endif //HAVE_RMM + +#include +#include #ifndef CUDA_RT_CALL #define CUDA_RT_CALL( call ) \ @@ -54,6 +64,38 @@ bool get_arg(char ** begin, char ** end, const std::string& arg) { return false; } +class cachingDeviceAllocator : public ML::deviceAllocator { +public: + cachingDeviceAllocator() +#ifdef HAVE_CUB + , _allocator(8, 3, cub::CachingDeviceAllocator::INVALID_BIN, cub::CachingDeviceAllocator::INVALID_SIZE) +#endif //HAVE_CUB + {} + + virtual void* allocate( std::size_t n, cudaStream_t stream ) { + void* ptr = 0; +#ifdef HAVE_CUB + _allocator.DeviceAllocate( &ptr, n, stream ); +#else //!HAVE_CUB + CUDA_RT_CALL( cudaMalloc( &ptr, n ) ); +#endif //HAVE_CUB + return ptr; + } + + virtual void deallocate( void* p, std::size_t, cudaStream_t ) { +#ifdef HAVE_CUB + _allocator.DeviceFree(p); +#else //!HAVE_CUB + CUDA_RT_CALL( cudaFree(p) ); +#endif //HAVE_CUB + } + +#ifdef HAVE_CUB +private: + cub::CachingDeviceAllocator _allocator; +#endif //HAVE_CUB +}; + int main(int argc, char * argv[]) { const int dev_id = get_argval(argv, argv+argc,"-dev_id", 0); @@ -78,6 +120,16 @@ int main(int argc, char * argv[]) std::cerr<<"ERROR: Could not initialize CUDA on device: "< h_srcdata; @@ -100,73 +152,81 @@ int main(int argc, char * argv[]) bool results_correct = true; if ( 0 == h_srcdata.size() || (num_rows*num_cols) == h_srcdata.size() ) { - int k_max = k; - double threshold = 1.0E-4; //Scikit-Learn default //Input parameters copied from kmeans_test.cu if (0 == h_srcdata.size()) { k = 2; - k_max = k; max_iterations = 300; threshold = 0.05; } - int dopredict = 0; - int verbose = 0; - int seed = 1; - int n_gpu = 1; - int init_from_data = 0; + int seed = 0; + int metric = 1; + ML::kmeans::InitMethod method = ML::kmeans::InitMethod::Random; //Inputs copied from kmeans_test.cu - size_t mTrain = 4; - size_t n = 2; - char ord = 'c'; // here c means col order, NOT C (vs F) order + size_t n_samples = 4; + size_t n_features = 2; if (0 == h_srcdata.size()) { - h_srcdata = {1.0,1.0,3.0,4.0, 1.0,2.0,2.0,3.0}; + h_srcdata = {1.0,1.0, + 3.0,4.0, + 1.0,2.0, + 2.0,3.0}; } else { - ord = 'r'; // r means row order - mTrain = num_rows; - n = num_cols; + n_samples = num_rows; + n_features = num_cols; } std::cout<<"Run KMeans with k="< std::numeric_limits::epsilon() ) { std::cerr<<"ERROR: h_centroids_ref["< > cluster_stats(k,std::make_pair(static_cast(0),0.0)); double global_inertia = 0.0; size_t max_points = 0; - for ( size_t i = 0; i < mTrain; ++i ) { + for ( size_t i = 0; i < n_samples; ++i ) { int label = h_pred_labels[i]; cluster_stats[label].first += 1; max_points = std::max(cluster_stats[label].first,max_points); double sd = 0.0; - for ( int j = 0; j < n; ++j ) { - const double cluster_centroid_comp = h_pred_centroids[label*n+j]; - const double point_comp = h_srcdata[i*n+j]; + for ( int j = 0; j < n_features; ++j ) { + const double cluster_centroid_comp = h_pred_centroids[label*n_features+j]; + const double point_comp = h_srcdata[i*n_features+j]; sd += (cluster_centroid_comp-point_comp)*(cluster_centroid_comp-point_comp); } cluster_stats[label].second += sd; @@ -217,18 +277,24 @@ int main(int argc, char * argv[]) std::cout<<"Global inertia = "< +#include + +namespace ML { + +template +class Tensor { +public: + enum { NumDim = Dim }; + typedef DataT* DataPtrT; + + __host__ + ~Tensor(){ + if (_state == AllocState::Owner){ + if(memory_type(_data) == cudaMemoryTypeDevice){ + _dAllocator->deallocate(_data, this->getSizeInBytes(), _stream); + }else if(memory_type(_data) == cudaMemoryTypeHost){ + _hAllocator->deallocate(_data, this->getSizeInBytes(), _stream); + } + } + } + + __host__ + Tensor(DataPtrT data, + const std::vector &sizes) + : _data(data), _state(AllocState::NotOwner){ + + static_assert(Dim > 0, + "must have > 0 dimensions"); + + ASSERT(sizes.size() == Dim, + "invalid argument: # of entries in the input argument 'sizes' must match the tensor dimension" ); + + for (int i = 0; i < Dim; ++i) { + _size[i] = sizes[i]; + } + + _stride[Dim - 1] = (IndexT) 1; + for (int j = Dim - 2; j >= 0; --j) { + _stride[j] = _stride[j + 1] * _size[j + 1]; + } + } + + // allocate the data using the allocator and release when the object goes out of scope + // allocating tensor is the owner of the data + __host__ + Tensor(const std::vector &sizes, + std::shared_ptr allocator, + cudaStream_t stream): + _stream(stream), + _dAllocator(allocator), + _state(AllocState::Owner){ + + static_assert(Dim > 0, + "must have > 0 dimensions"); + + ASSERT(sizes.size() == Dim, + "dimension mismatch" ); + + for (int i = 0; i < Dim; ++i) { + _size[i] = sizes[i]; + } + + _stride[Dim - 1] = (IndexT) 1; + for (int j = Dim - 2; j >= 0; --j) { + _stride[j] = _stride[j + 1] * _size[j + 1]; + } + + _data = static_cast(_dAllocator->allocate(this->getSizeInBytes(), + _stream)); + + CUDA_CHECK( cudaStreamSynchronize( _stream ) ); + + ASSERT(this->data() || (this->getSizeInBytes() == 0), + "device allocation failed"); + } + + /// returns the total number of elements contained within our data + __host__ size_t + numElements() const { + size_t num = (size_t) getSize(0); + + for (int i = 1; i < Dim; ++i) { + num *= (size_t) getSize(i); + } + + return num; + } + + /// returns the size of a given dimension, `[0, Dim - 1]` + __host__ inline IndexT getSize(int i) const { + return _size[i]; + } + + /// returns the stride array + __host__ inline const IndexT* strides() const { + return _stride; + } + + /// returns the stride array. + __host__ inline const IndexT getStride(int i) const { + return _stride[i]; + } + + /// returns the total size in bytes of our data + __host__ size_t getSizeInBytes() const { + return numElements() * sizeof(DataT); + } + + + /// returns a raw pointer to the start of our data + __host__ inline DataPtrT data() { + return _data; + } + + /// returns a raw pointer to the start of our data. + __host__ inline DataPtrT begin() { + return _data; + } + + /// returns a raw pointer to the end of our data + __host__ inline DataPtrT end() { + return data() + numElements(); + } + + /// returns a raw pointer to the start of our data (const) + __host__ inline + DataPtrT data() const { + return _data; + } + + /// returns a raw pointer to the end of our data (const) + __host__ inline DataPtrT end() const { + return data() + numElements(); + } + + /// returns the size array. + __host__ inline const IndexT* sizes() const { + return _size; + } + + template + __host__ Tensor + view(const std::vector &sizes, + const std::vector &start_pos){ + ASSERT(sizes.size() == NewDim, + "invalid view requested"); + ASSERT(start_pos.size() == Dim, + "dimensionality of the position if incorrect"); + + // calc offset at start_pos + uint32_t offset = 0; + for(uint32_t dim = 0; dim < Dim; ++dim){ + offset += start_pos[dim] * getStride(dim); + } + DataPtrT newData = this->data() + offset; + + + // The total size of the new view must be the <= total size of the old view + size_t curSize = numElements(); + size_t newSize = 1; + + for (auto s : sizes) { + newSize *= s; + } + + ASSERT(newSize <= curSize, + "invalid view requested"); + + return Tensor(newData, sizes); + } + + +private: + enum AllocState { + /// This tensor itself owns the memory, which must be freed via + /// cudaFree + Owner, + + /// This tensor itself is not an owner of the memory; there is + /// nothing to free + NotOwner + }; + + +protected: + + std::shared_ptr _dAllocator; + std::shared_ptr _hAllocator; + + /// Raw pointer to where the tensor data begins + DataPtrT _data; + + /// Array of strides (in sizeof(T) terms) per each dimension + IndexT _stride[Dim]; + + /// Size per each dimension + IndexT _size[Dim]; + + AllocState _state; + + cudaStream_t _stream; +}; + +}; // end namespace ML + diff --git a/cuML/src/dbscan/adjgraph/algo.h b/cuML/src/dbscan/adjgraph/algo.h index ad4a99e672..45a8aaf260 100644 --- a/cuML/src/dbscan/adjgraph/algo.h +++ b/cuML/src/dbscan/adjgraph/algo.h @@ -42,7 +42,7 @@ __global__ void adj_graph_kernel(Pack data, int batchSize) { Type scan_id = data.ex_scan[row]; for(int i=0; i data, int batchSize, auto execution_policy = thrust::cuda::par(alloc).on(stream); exclusive_scan(execution_policy, dev_vd, dev_vd + batchSize, dev_ex_scan); adj_graph_kernel<<>>(data, batchSize); + CUDA_CHECK(cudaPeekAtLastError()); } } // End Algo diff --git a/cuML/src/dbscan/dbscan.cu b/cuML/src/dbscan/dbscan.cu index 8fe3fc8348..e81169613b 100644 --- a/cuML/src/dbscan/dbscan.cu +++ b/cuML/src/dbscan/dbscan.cu @@ -24,29 +24,28 @@ namespace ML { using namespace Dbscan; - void dbscanFit(const cumlHandle& handle, float *input, int n_rows, int n_cols, float eps, int min_pts, - int *labels) { - dbscanFitImpl(handle.getImpl(), input, n_rows, n_cols, eps, min_pts, labels, handle.getStream()); + int *labels, size_t max_bytes_per_batch = 0, bool verbose) { + dbscanFitImpl(handle.getImpl(), input, n_rows, n_cols, eps, min_pts, labels, max_bytes_per_batch, handle.getStream(), verbose); } void dbscanFit(const cumlHandle& handle, double *input, int n_rows, int n_cols, double eps, int min_pts, - int *labels) { - dbscanFitImpl(handle.getImpl(), input, n_rows, n_cols, eps, min_pts, labels, handle.getStream()); + int *labels, size_t max_bytes_per_batch = 0, bool verbose) { + dbscanFitImpl(handle.getImpl(), input, n_rows, n_cols, eps, min_pts, labels, max_bytes_per_batch, handle.getStream(), verbose); } }; // end namespace ML extern "C" cumlError_t cumlSpDbscanFit(cumlHandle_t handle, float *input, int n_rows, int n_cols, float eps, int min_pts, - int *labels) { + int *labels, size_t max_bytes_per_batch, bool verbose) { cumlError_t status; ML::cumlHandle *handle_ptr; std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(handle); if (status == CUML_SUCCESS) { try { - dbscanFit(*handle_ptr, input, n_rows, n_cols, eps, min_pts, labels); + dbscanFit(*handle_ptr, input, n_rows, n_cols, eps, min_pts, labels, max_bytes_per_batch, verbose); } //TODO: Implement this //catch (const MLCommon::Exception& e) @@ -64,14 +63,14 @@ extern "C" cumlError_t cumlSpDbscanFit(cumlHandle_t handle, float *input, int n_ } extern "C" cumlError_t cumlDpDbscanFit(cumlHandle_t handle, double *input, int n_rows, int n_cols, double eps, int min_pts, - int *labels) { + int *labels, size_t max_bytes_per_batch, bool verbose) { cumlError_t status; ML::cumlHandle *handle_ptr; std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(handle); if (status == CUML_SUCCESS) { try { - dbscanFit(*handle_ptr, input, n_rows, n_cols, eps, min_pts, labels); + dbscanFit(*handle_ptr, input, n_rows, n_cols, eps, min_pts, labels, max_bytes_per_batch, verbose); } //TODO: Implement this //catch (const MLCommon::Exception& e) diff --git a/cuML/src/dbscan/dbscan.h b/cuML/src/dbscan/dbscan.h index cf6d407f34..903009f7b3 100644 --- a/cuML/src/dbscan/dbscan.h +++ b/cuML/src/dbscan/dbscan.h @@ -23,32 +23,52 @@ namespace ML { using namespace Dbscan; +// Default max mem set to a reasonable value for a 16gb card. +static const size_t DEFAULT_MAX_MEM_BYTES = 13e9; -int computeBatchCount(int n_rows) { +template +int computeBatchCount(int n_rows, size_t max_bytes_per_batch) { int n_batches = 1; // There seems to be a weird overflow bug with cutlass gemm kernels // hence, artifically limiting to a smaller batchsize! ///TODO: in future, when we bump up the underlying cutlass version, this should go away // paving way to cudaMemGetInfo based workspace allocation - static const size_t MaxElems = (size_t)1024 * 1024 * 1024 * 2; // 2e9 while(true) { - size_t batchSize = ceildiv(n_rows, n_batches); - if(batchSize * n_rows < MaxElems) + if(batchSize * n_rows * sizeof(T) < max_bytes_per_batch || batchSize == 1) break; ++n_batches; } - return n_batches; } template -void dbscanFitImpl(const ML::cumlHandle_impl& handle, T *input, int n_rows, int n_cols, T eps, int min_pts, int *labels, cudaStream_t stream) { +void dbscanFitImpl(const ML::cumlHandle_impl& handle, T *input, + int n_rows, int n_cols, + T eps, int min_pts, + int *labels, + size_t max_bytes_per_batch, + cudaStream_t stream, + bool verbose) { int algoVd = 1; int algoAdj = 1; int algoCcl = 2; - int n_batches = computeBatchCount(n_rows); + + if(max_bytes_per_batch <= 0) + // @todo: Query device for remaining memory + max_bytes_per_batch = DEFAULT_MAX_MEM_BYTES; + + int n_batches = computeBatchCount(n_rows, max_bytes_per_batch); + + if(verbose) { + size_t batchSize = ceildiv(n_rows, n_batches); + if(n_batches > 1) { + std::cout << "Running batched training on " << n_batches << " batches w/ "; + std::cout << batchSize * n_rows * sizeof(T) << " bytes." << std::endl; + } + } + size_t workspaceSize = Dbscan::run(handle, input, n_rows, n_cols, eps, min_pts, labels, algoVd, algoAdj, algoCcl, NULL, n_batches, stream); diff --git a/cuML/src/dbscan/dbscan.hpp b/cuML/src/dbscan/dbscan.hpp index b81a7b63b3..805bacfcdb 100644 --- a/cuML/src/dbscan/dbscan.hpp +++ b/cuML/src/dbscan/dbscan.hpp @@ -19,11 +19,39 @@ namespace ML{ +/** + * @brief Fits a DBSCAN model on an input feature matrix and outputs the labels. + * @param handle: cuml handle to use across the algorithm + * @param input: row-major input feature matrix + * @param n_rows: number of samples in the input feature matrix + * @param n_cols: number of features in the input feature matrix + * @param eps: the epsilon value to use for epsilon-neighborhood determination + * @param min_pts: minimum number of points to determine a cluster + * @param labels: (size n_rows) output labels array + * @param max_mem_bytes: the maximum number of bytes to be used for each batch of + * the pairwise distance calculation. This enables the trade off between + * memory usage and algorithm execution time. + * @param verbose: print useful information as algorithm executes + */ void dbscanFit(const cumlHandle& handle, float *input, int n_rows, int n_cols, float eps, int min_pts, - int *labels); + int *labels, size_t max_bytes_per_batch, bool verbose = false); +/** + * @brief Fits a DBSCAN model on an input feature matrix and outputs the labels. + * @param handle: cuml handle to use across the algorithm + * @param input: row-major input feature matrix + * @param n_rows: number of samples in the input feature matrix + * @param n_cols: number of features in the input feature matrix + * @param eps: the epsilon value to use for epsilon-neighborhood determination + * @param min_pts: minimum number of points to determine a cluster + * @param labels: (size n_rows) output labels array + * @param max_mem_bytes: the maximum number of bytes to be used for each batch of + * the pairwise distance calculation. This enables the trade off between + * memory usage and algorithm execution time. + * @param verbose: print useful information as algorithm executes + */ void dbscanFit(const cumlHandle& handle, double *input, int n_rows, int n_cols, double eps, int min_pts, - int *labels); + int *labels, size_t max_bytes_per_batch, bool verbose = false); } diff --git a/cuML/src/dbscan/dbscan_api.h b/cuML/src/dbscan/dbscan_api.h index 7aea19709c..4c58c5ae65 100644 --- a/cuML/src/dbscan/dbscan_api.h +++ b/cuML/src/dbscan/dbscan_api.h @@ -22,11 +22,13 @@ extern "C" { #endif //Single precision version of DBSCAN fit -cumlError_t cumlSpDbscanFit(cumlHandle_t handle, float* input, int n_rows, int n_cols, float eps, int min_pts, int *labels); +cumlError_t cumlSpDbscanFit(cumlHandle_t handle, float* input, + int n_rows, int n_cols, float eps, int min_pts, int *labels, size_t max_bytes_per_batch); //Double precision version of DBSCAN fit -cumlError_t cumlDpDbscanFit(cumlHandle_t handle, double *input, int n_rows, int n_cols, double eps, int min_pts, int *labels); +cumlError_t cumlDpDbscanFit(cumlHandle_t handle, double *input, + int n_rows, int n_cols, double eps, int min_pts, int *labels, size_t max_bytes_per_batch); #ifdef __cplusplus } -#endif \ No newline at end of file +#endif diff --git a/cuML/src/dbscan/runner.h b/cuML/src/dbscan/runner.h index d61dd9e68c..24316073cb 100644 --- a/cuML/src/dbscan/runner.h +++ b/cuML/src/dbscan/runner.h @@ -48,11 +48,12 @@ template */ size_t run(const ML::cumlHandle_impl& handle, Type_f* x, Type N, Type D, Type_f eps, Type minPts, Type* labels, int algoVd, int algoAdj, int algoCcl, void* workspace, int nBatches, cudaStream_t stream) { + const size_t align = 256; int batchSize = ceildiv(N, nBatches); size_t adjSize = alignTo(sizeof(bool) * N * batchSize, align); - size_t corePtsSize = alignTo(sizeof(bool) * N, align); - size_t visitedSize = alignTo(sizeof(bool) * N, align); + size_t corePtsSize = alignTo(sizeof(bool) * batchSize, align); + size_t visitedSize = alignTo(sizeof(bool) * batchSize, align); size_t xaSize = alignTo(sizeof(bool) * N, align); size_t mSize = alignTo(sizeof(bool), align); size_t vdSize = alignTo(sizeof(Type) * (batchSize + 1), align); @@ -89,21 +90,24 @@ size_t run(const ML::cumlHandle_impl& handle, Type_f* x, Type N, Type D, Type_f MLCommon::device_buffer adj_graph(handle.getDeviceAllocator(), stream); int startVertexId = i * batchSize; int nPoints = min(N-startVertexId, batchSize); + if(nPoints <= 0) continue; VertexDeg::run(handle, adj, vd, x, eps, N, D, algoVd, startVertexId, nPoints, stream); - MLCommon::updateHost(&curradjlen, vd + nPoints, 1, stream); - CUDA_CHECK(cudaStreamSynchronize(stream)); + + MLCommon::updateHost(&curradjlen, vd + nPoints, 1, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); // Running AdjGraph - // TODO -: To come up with a mechanism as to reduce and reuse adjgraph mallocs if (curradjlen > adjlen || adj_graph.data() == NULL) { adjlen = curradjlen; adj_graph.resize(adjlen, stream); } + AdjGraph::run(handle, adj, vd, adj_graph.data(), ex_scan, N, minPts, core_pts, algoAdj, nPoints, stream); + // Running Labelling Label::run(handle, adj, vd, adj_graph.data(), ex_scan, N, minPts, core_pts, visited, labels, xa, fa, m, map_id, algoCcl, startVertexId, diff --git a/cuML/src/dbscan/vertexdeg/algo.h b/cuML/src/dbscan/vertexdeg/algo.h index 53763a721b..33f7cb3a74 100644 --- a/cuML/src/dbscan/vertexdeg/algo.h +++ b/cuML/src/dbscan/vertexdeg/algo.h @@ -55,7 +55,8 @@ void launcher(const ML::cumlHandle_impl& handle, Pack data, int startVe (value_t val, // current value in gemm matrix int global_c_idx) { // index of output in global memory int acc = val <= eps2; - int vd_offset = global_c_idx / n; // bucket offset for the vertex degrees + int vd_offset = global_c_idx - (n * (global_c_idx / n)); // bucket offset for the vertex degrees + atomicAdd(vd+vd_offset, acc); atomicAdd(vd+n, acc); return bool(acc); diff --git a/cuML/src/kmeans/kmeans-inl.cuh b/cuML/src/kmeans/kmeans-inl.cuh new file mode 100644 index 0000000000..6ba59c6400 --- /dev/null +++ b/cuML/src/kmeans/kmeans-inl.cuh @@ -0,0 +1,1362 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace ML { + +//@todo: Use GLOG once https://github.com/rapidsai/cuml/issues/100 is addressed. +#define LOG(verbose, fmt, ...) \ + do{ \ + if(verbose){ \ + std::string msg; \ + char verboseMsg[2048]; \ + std::sprintf(verboseMsg, fmt, ##__VA_ARGS__); \ + msg += verboseMsg; \ + std::cerr << msg; \ + } \ + } while (0) + +namespace kmeans{ +namespace detail{ +typedef cutlass::Shape<8, 128, 128> OutputTile_8_128_128; + +template +struct SamplingOp{ + DataT *rnd; + int *flag; + DataT cluster_cost; + int oversampling_factor; + + CUB_RUNTIME_FUNCTION __forceinline__ + SamplingOp(DataT c, int l, DataT *rand, int* ptr) + : cluster_cost(c), + oversampling_factor(l), + rnd(rand), + flag(ptr) { + } + + __host__ __device__ __forceinline__ + bool operator()(const cub::KeyValuePair &a) const { + DataT prob_threshold = (DataT) rnd[a.key]; + + DataT prob_x = (( oversampling_factor * a.value) / cluster_cost); + + return !flag[a.key] && (prob_x > prob_threshold); + } +}; + +template +struct KeyValueIndexOp{ + __host__ __device__ __forceinline__ + IndexT operator()(const cub::KeyValuePair &a) const { + return a.key; + } +}; + + +// @todo: Decide this based on avail memory and perf analysis.. +template +CountT +getDataBatchSize(CountT n_samples, CountT n_features){ + return std::min((CountT)(1<<16), n_samples); +} + +// Computes the intensity histogram from a sequence of labels +template +void +countLabels(const cumlHandle_impl& handle, + SampleIteratorT labels, + CounterT *count, + int n_samples, + int n_clusters, + MLCommon::device_buffer &workspace, + cudaStream_t stream) { + int num_levels = n_clusters + 1; + int lower_level = 0; + int upper_level = n_clusters; + + size_t temp_storage_bytes = 0; + CUDA_CHECK( cub::DeviceHistogram::HistogramEven(nullptr, + temp_storage_bytes, + labels, + count, + num_levels, + lower_level, + upper_level, + n_samples, + stream) ); + + workspace.resize(temp_storage_bytes, stream); + + CUDA_CHECK( cub::DeviceHistogram::HistogramEven(workspace.data(), + temp_storage_bytes, + labels, + count, + num_levels, + lower_level, + upper_level, + n_samples, + stream) ); +} + + + +template +Tensor +sampleCentroids(const cumlHandle_impl &handle, + Tensor &X, + Tensor &minClusterDistance, + Tensor &isSampleCentroid, + typename kmeans::detail::SamplingOp &select_op, + MLCommon::device_buffer &workspace, + cudaStream_t stream){ + int n_local_samples = X.getSize(0); + int n_features = X.getSize(1); + + Tensor nSelected({1}, + handle.getDeviceAllocator(), stream); + + cub::ArgIndexInputIterator ip_itr(minClusterDistance.data()); + Tensor< cub::KeyValuePair, 1 > sampledMinClusterDistance({n_local_samples}, + handle.getDeviceAllocator(), stream); + size_t temp_storage_bytes = 0; + CUDA_CHECK( cub::DeviceSelect::If(nullptr, + temp_storage_bytes, + ip_itr, + sampledMinClusterDistance.data(), + nSelected.data(), + n_local_samples, + select_op, + stream) ); + + workspace.resize(temp_storage_bytes, stream); + + CUDA_CHECK( cub::DeviceSelect::If(workspace.data(), + temp_storage_bytes, + ip_itr, + sampledMinClusterDistance.data(), + nSelected.data(), + n_local_samples, + select_op, + stream) ); + + + int nPtsSampledInRank = 0; + MLCommon::copy(&nPtsSampledInRank, + nSelected.data(), + nSelected.numElements(), + stream); + CUDA_CHECK( cudaStreamSynchronize(stream) ); + + int* rawPtr_isSampleCentroid = isSampleCentroid.data(); + ML::thrustAllocatorAdapter alloc( handle.getDeviceAllocator(), stream ); + auto execution_policy = thrust::cuda::par(alloc).on(stream); + thrust::for_each_n(execution_policy, + sampledMinClusterDistance.begin(), + nPtsSampledInRank, + [=] __device__ (cub::KeyValuePair val) { + rawPtr_isSampleCentroid[val.key] = 1; + }); + + Tensor inRankCp({nPtsSampledInRank, n_features}, + handle.getDeviceAllocator(), stream); + + + Matrix::gather(X.data(), + X.getSize(1), + X.getSize(0), + sampledMinClusterDistance.data(), + nPtsSampledInRank, + inRankCp.data(), + [=] __device__ (cub::KeyValuePair val) { // MapTransformOp + return val.key; + }, + stream); + + return inRankCp; +} + + +template +DataT +computeClusterCost(const cumlHandle_impl &handle, + Tensor &minClusterDistance, + MLCommon::device_buffer &workspace, + ReductionOpT reduction_op, + cudaStream_t stream){ + Tensor clusterCostD({1}, + handle.getDeviceAllocator(), stream); + + size_t temp_storage_bytes = 0; + CUDA_CHECK( cub::DeviceReduce::Reduce(nullptr, + temp_storage_bytes, + minClusterDistance.data(), + clusterCostD.data(), + minClusterDistance.numElements(), + reduction_op, + DataT(), + stream) ); + + workspace.resize(temp_storage_bytes, stream); + + CUDA_CHECK( cub::DeviceReduce::Reduce(workspace.data(), + temp_storage_bytes, + minClusterDistance.data(), + clusterCostD.data(), + minClusterDistance.numElements(), + reduction_op, + DataT(), + stream) ); + + DataT clusterCostInRank; + MLCommon::copy(&clusterCostInRank, + clusterCostD.data(), + clusterCostD.numElements(), + stream); + + CUDA_CHECK( cudaStreamSynchronize(stream) ); + return clusterCostInRank; +} + + +// calculate pairwise distance between 'dataset[n x d]' and 'centroids[k x d]', +// result will be stored in 'pairwiseDistance[n x k]' +template +void +pairwiseDistanceImpl(const cumlHandle_impl& handle, + Tensor &X, + Tensor ¢roids, + Tensor &pairwiseDistance, + MLCommon::device_buffer &workspace, + cudaStream_t stream){ + + auto n_samples = X.getSize(0); + auto n_features = X.getSize(1); + auto n_clusters = centroids.getSize(0); + + size_t worksize = Distance::getWorkspaceSize + (X.data(), + centroids.data(), + n_samples, + n_clusters, + n_features); + + workspace.resize(worksize, stream); + + + Distance::distance + + (X.data(), + centroids.data(), + pairwiseDistance.data(), + n_samples, + n_clusters, + n_features, + workspace.data(), + worksize, + stream); + +} + +// calculate pairwise distance between 'dataset[n x d]' and 'centroids[k x d]', +// result will be stored in 'pairwiseDistance[n x k]' +template +void +pairwiseDistance(const cumlHandle_impl& handle, + Tensor &X, + Tensor ¢roids, + Tensor &pairwiseDistance, + MLCommon::device_buffer &workspace, + MLCommon::Distance::DistanceType metric, + cudaStream_t stream){ + + auto n_samples = X.getSize(0); + auto n_features = X.getSize(1); + auto n_clusters = centroids.getSize(0); + + ASSERT(X.getSize(1) == centroids.getSize(1), + "# features in dataset and centroids are different (must be same)"); + + + if(metric == Distance::DistanceType::EucExpandedL2){ + pairwiseDistanceImpl(handle, X, centroids, pairwiseDistance, workspace, stream); + }else if(metric == Distance::DistanceType::EucExpandedL2Sqrt){ + pairwiseDistanceImpl(handle, X, centroids, pairwiseDistance, workspace, stream); + }else if(metric == Distance::DistanceType::EucExpandedCosine){ + pairwiseDistanceImpl(handle, X, centroids, pairwiseDistance, workspace, stream); + }else if(metric == Distance::DistanceType::EucUnexpandedL1){ + pairwiseDistanceImpl(handle, X, centroids, pairwiseDistance, workspace, stream); + }else{ + THROW("unknown distance metric"); + } +} + +// Calculates a pair for every sample in input 'X' where key is an index to an sample in 'centroids' (index of the nearest centroid) and 'value' is the distance between the sample and the 'centroid[key]' +template +void +minClusterAndDistance(const cumlHandle_impl &handle, + Tensor &X, + Tensor ¢roids, + Tensor &pairwiseDistance, + Tensor, 1, IndexT> &minClusterAndDistance, + MLCommon::device_buffer &workspace, + MLCommon::Distance::DistanceType metric, + cudaStream_t stream){ + auto n_samples = X.getSize(0); + auto n_features = X.getSize(1); + auto n_clusters = centroids.getSize(0); + auto dataBatchSize = kmeans::detail::getDataBatchSize(n_samples, n_features); + + // tile over the input dataset + for (auto dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { + + // # of samples for the current batch + auto ns = std::min(dataBatchSize, n_samples - dIdx); + + // datasetView [ns x n_features] - view representing the current batch of input dataset + auto datasetView = X.template view<2>({ns, n_features}, + {dIdx, 0}); + + // distanceView [ns x n_clusters] + auto distanceView = pairwiseDistance.template view<2>({ns, n_clusters}, + {0, 0}); + + // minClusterAndDistanceView [ns x n_clusters] + auto minClusterAndDistanceView = minClusterAndDistance.template view<1>({ns}, + {dIdx}); + + + // calculate pairwise distance between cluster centroids and current batch of input dataset + kmeans::detail::pairwiseDistance(handle, + datasetView, + centroids, + distanceView, + workspace, + metric, + stream); + + // argmin reduction returning pair + // calculates the closest centroid and the distance to the closent centroid + cub::KeyValuePair initial_value(0, std::numeric_limits::max()); + LinAlg::coalescedReduction(minClusterAndDistanceView.data(), + distanceView.data(), + distanceView.getSize(1), + distanceView.getSize(0), + initial_value, + stream, + false, + [=] __device__ (const DataT val, const IndexT i) { + cub::KeyValuePair pair; + pair.key = i; + pair.value = val; + return pair; + }, + [=] __device__ (cub::KeyValuePair a, + cub::KeyValuePair b) { + return (b.value < a.value) ? b : a; + }, + [=] __device__ (cub::KeyValuePair pair) { + return pair; + }); + + } +} + +template +void +minClusterDistance(const cumlHandle_impl &handle, + Tensor &X, + Tensor ¢roids, + Tensor &pairwiseDistance, + Tensor &minClusterDistance, + MLCommon::device_buffer &workspace, + MLCommon::Distance::DistanceType metric, + cudaStream_t stream){ + auto n_samples = X.getSize(0); + auto n_features = X.getSize(1); + auto nc = centroids.getSize(0); + + auto dataBatchSize = kmeans::detail::getDataBatchSize(n_samples, n_features); + + // tile over the input data and calculate distance matrix [n_samples x n_clusters] + for (int dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { + + // # of samples for the current batch + int ns = std::min(dataBatchSize, X.getSize(0) - dIdx); + + // datasetView [ns x n_features] - view representing the current batch of input dataset + auto datasetView = X.template view<2>({ns, n_features}, + {dIdx, 0}); + + // minClusterDistanceView [ns x n_clusters] + auto minClusterDistanceView = minClusterDistance.template view<1>({ns}, + {dIdx}); + + // calculate pairwise distance between cluster centroids and current batch of input dataset + kmeans::detail::pairwiseDistance(handle, + datasetView, + centroids, + pairwiseDistance, + workspace, + metric, + stream); + + LinAlg::coalescedReduction(minClusterDistanceView.data(), + pairwiseDistance.data(), + nc, // leading dimension of pairwiseDistance + ns, // second dimension of pairwiseDistance + std::numeric_limits::max(), + stream, + false, + [=] __device__ (DataT val, int i) { // MainLambda + return val; + }, + [=] __device__ (DataT a, DataT b) { // ReduceLambda + return (b < a) ? b : a; + }, + [=] __device__ (DataT val) { // FinalLambda + return val; + }); + + } +} + +// shuffle and randomly select 'n_samples_to_gather' from input 'in' and stores in 'out' +// does not modify the input +template +void +shuffleAndGather(const cumlHandle_impl &handle, + Tensor &in, + Tensor &out, + size_t n_samples_to_gather, + int seed, + cudaStream_t stream, + MLCommon::device_buffer *workspace = nullptr){ + auto n_samples = in.getSize(0); + auto n_features = in.getSize(1); + + Tensor indices({n_samples}, + handle.getDeviceAllocator(), + stream); + + if(workspace){ + // shuffle indices on device using ml-prims + Random::permute(indices.data(), + nullptr, + nullptr, + in.getSize(1), + in.getSize(0), + true, + stream); + }else{ + // shuffle indices on host and copy to device... + host_buffer ht_indices(handle.getHostAllocator(), stream, n_samples); + + std::iota(ht_indices.begin(), ht_indices.end(), 0); + + std::mt19937 gen(seed); + std::shuffle(ht_indices.begin(), ht_indices.end(), gen); + + MLCommon::copy(indices.data(), ht_indices.data(), + indices.numElements(), stream); + } + + Matrix::gather(in.data(), + in.getSize(1), + in.getSize(0), + indices.data(), + n_samples_to_gather, + out.data(), + stream); +} + + +template +void +countSamplesInCluster(const cumlHandle_impl& handle, + Tensor &X, + Tensor ¢roids, + MLCommon::device_buffer &workspace, + MLCommon::Distance::DistanceType metric, + Tensor &sampleCountInCluster, + cudaStream_t stream){ + auto n_samples = X.getSize(0); + auto n_features = X.getSize(1); + auto n_clusters = centroids.getSize(0); + + int dataBatchSize = kmeans::detail::getDataBatchSize(n_samples, n_features); + + // stores (key, value) pair corresponding to each sample where + // - key is the index of nearest cluster + // - value is the distance to the nearest cluster + Tensor, 1, IndexT> minClusterAndDistance({n_samples}, + handle.getDeviceAllocator(), + stream); + + // temporary buffer to store distance matrix, destructor releases the resource + Tensor pairwiseDistance({dataBatchSize, n_clusters}, + handle.getDeviceAllocator(), + stream); + + // computes minClusterAndDistance[0:n_samples) where minClusterAndDistance[i] is a pair where + // 'key' is index to an sample in 'centroids' (index of the nearest centroid) and + // 'value' is the distance between the sample 'X[i]' and the 'centroid[key]' + kmeans::detail::minClusterAndDistance(handle, + X, + centroids, + pairwiseDistance, + minClusterAndDistance, + workspace, + metric, + stream); + + // Using TransformInputIteratorT to dereference an array of cub::KeyValuePair and converting them to just return the Key to be used in reduce_rows_by_key prims + kmeans::detail::KeyValueIndexOp conversion_op; + cub::TransformInputIterator, + cub::KeyValuePair*> itr(minClusterAndDistance.data(), + conversion_op); + + + // count # of samples in each cluster + kmeans::detail::countLabels(handle, + itr, + sampleCountInCluster.data(), + n_samples, n_clusters, + workspace, + stream); + +} + +template +void +kmeansPlusPlus(const cumlHandle_impl& handle, + int verbose, + int n_clusters, + int seed, + Tensor &C, + Tensor &weights, + MLCommon::Distance::DistanceType metric, + MLCommon::device_buffer &workspace, + MLCommon::device_buffer ¢roidsRawData, + cudaStream_t stream){ + auto n_pot_centroids = C.getSize(0); // # of potential centroids + auto n_features = C.getSize(1); + + // temporary buffer for probabilities + Tensor prob({n_pot_centroids}, + handle.getDeviceAllocator(), + stream); + + ML::thrustAllocatorAdapter alloc( handle.getDeviceAllocator(), stream ); + auto execution_policy = thrust::cuda::par(alloc).on(stream); + thrust::transform(execution_policy, + weights.begin(), + weights.end(), + prob.begin(), + [] __device__(int weight){ + return static_cast(weight); + }); + + host_buffer h_prob(handle.getHostAllocator(), stream); + h_prob.resize(n_pot_centroids, stream); + + std::mt19937 gen(seed); + + // reset buffer to store the chosen centroid + centroidsRawData.resize(n_clusters * n_features, stream); + + Tensor minClusterDistance({n_pot_centroids}, + handle.getDeviceAllocator(), + stream); + + + int dataBatchSize = kmeans::detail::getDataBatchSize(n_pot_centroids, n_features); + + device_buffer pairwiseDistanceRaw(handle.getDeviceAllocator(), stream); + pairwiseDistanceRaw.resize(dataBatchSize * n_clusters, + stream); + + int n_pts_sampled = 0; + for (int iter = 0; iter < n_clusters; iter++) { + LOG(verbose, + "KMeans++ - Iteraton %d/%d\n", iter, n_clusters); + + MLCommon::copy(h_prob.data(), + prob.data(), + prob.numElements(), + stream); + CUDA_CHECK(cudaStreamSynchronize(stream) ); + + std::discrete_distribution<> d(h_prob.begin(), h_prob.end()); + // d(gen) returns random # between [0...n_pot_centroids], mod is unncessary but just placing it to avoid untested behaviors + int cIdx = d(gen) % n_pot_centroids; + + LOG(verbose, + "Chosing centroid-%d randomly from %d potential centroids\n", + cIdx, n_pot_centroids); + + auto curCentroid = C.template view<2>({1, n_features}, + {cIdx, 0}); + + MLCommon::copy(centroidsRawData.data() + n_pts_sampled * n_features, + curCentroid.data(), + curCentroid.numElements(), + stream); + n_pts_sampled++; + + auto centroids = std::move(Tensor(centroidsRawData.data(), {n_pts_sampled, n_features})); + + Tensor pairwiseDistance((DataT *)pairwiseDistanceRaw.data(), + {dataBatchSize, centroids.getSize(0)}); + + kmeans::detail::minClusterDistance(handle, + C, + centroids, + pairwiseDistance, + minClusterDistance, + workspace, + metric, + stream); + + + DataT clusteringCost + = kmeans::detail::computeClusterCost(handle, + minClusterDistance, + workspace, + [] __device__(const DataT &a, + const DataT &b){ + return a + b; + }, + stream); + + + cub::ArgIndexInputIterator itr_w(weights.data()); + + ML::thrustAllocatorAdapter alloc( handle.getDeviceAllocator(), stream ); + auto execution_policy = thrust::cuda::par(alloc).on(stream); + thrust::transform(execution_policy, + minClusterDistance.begin(), + minClusterDistance.end(), + itr_w, + prob.begin(), + [=] __device__ (const DataT &minDist, + const cub::KeyValuePair &weight) { + if(weight.key == cIdx){ + // sample was chosen in the previous iteration, so reset the weights to avoid future selection... + return static_cast(0); + }else{ + return weight.value * minDist/clusteringCost; + } + }); + } +} + +} // end namespace detail +} // end namespace kmeans + +template +KMeans::~KMeans(){ + cudaStream_t stream = _handle.getStream(); + _workspace.release(stream); + _centroidsRawData.release(stream); + _labelsRawData.release(stream); +} + +// constructor +template +KMeans::KMeans(const ML::cumlHandle_impl &handle, + int k, + MLCommon::Distance::DistanceType metric, + kmeans::InitMethod init, + int max_iterations, + double tolerance, + int seed, + int verbose) + : _handle(handle), + n_clusters(k), + tol(tolerance), + max_iter(max_iterations), + _init (init), + _metric (metric), + _verbose(verbose), + _workspace(handle.getDeviceAllocator(), handle.getStream()), + _centroidsRawData(handle.getDeviceAllocator(), handle.getStream()), + _labelsRawData(handle.getDeviceAllocator(), handle.getStream()){ + options.oversampling_factor = 2.0 * n_clusters; + if (seed < 0) { + std::random_device rd; + options.seed = rd(); + }else{ + options.seed = seed; + } + + inertia = std::numeric_limits::infinity();; + n_iter = 0; + _n_features = 0; +} + +template +void +KMeans::setCentroids(const DataT *X, + int n_samples, + int n_features){ + cudaStream_t stream = _handle.getStream(); + + n_clusters = n_samples; + options.oversampling_factor = 2.0 * n_clusters; + _n_features = n_features; + + _centroidsRawData.resize(n_clusters * n_features, stream); + + MLCommon::copy(_centroidsRawData.begin(), X, + n_samples * n_features, stream); + + auto centroids = std::move(Tensor(_centroidsRawData.data(), {n_clusters, n_features})); +} + + +template +DataT* +KMeans::centroids(){ + return _centroidsRawData.data(); +} + + +// Selects 'n_clusters' samples randomly from X +template +void +KMeans::initRandom(Tensor &X){ + cudaStream_t stream = _handle.getStream(); + auto n_features = X.getSize(1); + + // allocate centroids buffer + _centroidsRawData.resize(n_clusters * n_features, stream); + auto centroids = std::move(Tensor(_centroidsRawData.data(), {n_clusters, n_features})); + + kmeans::detail::shuffleAndGather(_handle, + X, + centroids, + n_clusters, + options.seed, + stream); +} + + +/* + * @brief Selects 'n_clusters' samples from X using scalable kmeans++ algorithm + * Scalable kmeans++ pseudocode + * 1: C = sample a point uniformly at random from X + * 2: psi = phi_X (C) + * 3: for O( log(psi) ) times do + * 4: C' = sample each point x in X independently with probability p_x = l * ( d^2(x, C) / phi_X (C) ) + * 5: C = C U C' + * 6: end for + * 7: For x in C, set w_x to be the number of points in X closer to x than any other point in C + * 8: Recluster the weighted points in C into k clusters + */ + +template +void +KMeans::initKMeansPlusPlus(Tensor &X){ + cudaStream_t stream = _handle.getStream(); + auto n_samples = X.getSize(0); + auto n_features = X.getSize(1); + + Random::Rng rng(options.seed, + options.gtype); + + // <<<< Step-1 >>> : C <- sample a point uniformly at random from X + std::mt19937 gen(options.seed); + std::uniform_int_distribution<> dis(0, + n_samples - 1); + + int cIdx = dis(gen); + auto initialCentroid = X.template view<2>({1, n_features}, + {cIdx, 0}); + + // flag the sample that is chosen as initial centroid + host_buffer h_isSampleCentroid(_handle.getHostAllocator(), stream, n_samples); + std::fill(h_isSampleCentroid.begin(), h_isSampleCentroid.end(), 0); + h_isSampleCentroid[cIdx] = 1; + + // flag the sample that is chosen as initial centroid + Tensor isSampleCentroid({n_samples}, + _handle.getDeviceAllocator(), + stream); + + MLCommon::copy(isSampleCentroid.data(), h_isSampleCentroid.data(), + isSampleCentroid.numElements(), stream); + + device_buffer centroidsRawData(_handle.getDeviceAllocator(), stream); + + // reset buffer to store the chosen centroid + centroidsRawData.reserve(n_clusters * n_features, stream); + centroidsRawData.resize(initialCentroid.numElements(), stream); + MLCommon::copy(centroidsRawData.begin(), initialCentroid.data(), + initialCentroid.numElements(), stream); + + auto potentialCentroids = std::move(Tensor(centroidsRawData.data(), {initialCentroid.getSize(0), initialCentroid.getSize(1)})); + // <<< End of Step-1 >>> + + + int dataBatchSize = kmeans::detail::getDataBatchSize(n_samples, n_features); + + device_buffer pairwiseDistanceRaw(_handle.getDeviceAllocator(), stream); + pairwiseDistanceRaw.resize(dataBatchSize * n_clusters, + stream); + + + Tensor minClusterDistance({n_samples}, + _handle.getDeviceAllocator(), stream); + Tensor uniformRands({n_samples}, + _handle.getDeviceAllocator(), stream); + Tensor clusterCostD({1}, + _handle.getDeviceAllocator(), stream); + + // <<< Step-2 >>>: psi <- phi_X (C) + Tensor pairwiseDistance((DataT *)pairwiseDistanceRaw.data(), + {dataBatchSize, potentialCentroids.getSize(0)}); + + kmeans::detail::minClusterDistance(_handle, + X, + potentialCentroids, + pairwiseDistance, + minClusterDistance, + _workspace, + _metric, + stream); + + + DataT phi = kmeans::detail::computeClusterCost(_handle, + minClusterDistance, + _workspace, + [] __device__(const DataT &a, const DataT &b){ + return a+b; + }, + stream); + + + // Scalable kmeans++ paper claims 8 rounds is sufficient + int niter = std::min(8, (int)ceil(log(phi))); + + // <<< End of Step-2 >>> + + // <<<< Step-3 >>> : for O( log(psi) ) times do + for(int iter = 0; + iter < niter; + ++iter){ + LOG(_verbose, + "KMeans|| - Iteration %d: # potential centroids sampled - %d\n", iter, potentialCentroids.getSize(0)); + + pairwiseDistanceRaw.resize(dataBatchSize * potentialCentroids.getSize(0), + stream); + Tensor pairwiseDistance((DataT *)pairwiseDistanceRaw.data(), + {dataBatchSize, potentialCentroids.getSize(0)}); + + kmeans::detail::minClusterDistance(_handle, + X, + potentialCentroids, + pairwiseDistance, + minClusterDistance, + _workspace, + _metric, + stream); + + + DataT clusterCost = kmeans::detail::computeClusterCost(_handle, + minClusterDistance, + _workspace, + [] __device__(const DataT &a, const DataT &b){ + return a+b; + }, + stream); + + // <<<< Step-4 >>> : Sample each point x in X independently and identify new potentialCentroids + rng.uniform(uniformRands.data(), uniformRands.getSize(0), (DataT)0, (DataT)1, stream); + kmeans::detail::SamplingOp select_op(clusterCost, + options.oversampling_factor, + uniformRands.data(), + isSampleCentroid.data()); + + auto Cp = kmeans::detail::sampleCentroids(_handle, + X, + minClusterDistance, + isSampleCentroid, + select_op, + _workspace, + stream); + /// <<<< End of Step-4 >>>> + + + /// <<<< Step-5 >>> : C = C U C' + // append the data in Cp to the buffer holding the potentialCentroids + centroidsRawData.resize(centroidsRawData.size() + Cp.numElements(), stream); + MLCommon::copy(centroidsRawData.end() - Cp.numElements(), + Cp.data(), + Cp.numElements(), + stream); + + int tot_centroids = potentialCentroids.getSize(0) + Cp.getSize(0); + potentialCentroids = std::move(Tensor(centroidsRawData.data(), {tot_centroids, n_features})); + /// <<<< End of Step-5 >>> + } /// <<<< Step-6 >>> + + if(potentialCentroids.getSize(0) > n_clusters){ + // <<< Step-7 >>>: For x in C, set w_x to be the number of pts closest to X + // temporary buffer to store the sample count per cluster, destructor releases the resource + Tensor weights({potentialCentroids.getSize(0)}, + _handle.getDeviceAllocator(), + stream); + + + kmeans::detail::countSamplesInCluster(_handle, X, potentialCentroids, _workspace, _metric, weights, stream); + + // <<< end of Step-7 >>> + + //Step-8: Recluster the weighted points in C into k clusters + _centroidsRawData.resize(n_clusters * n_features, stream); + kmeans::detail::kmeansPlusPlus(_handle, _verbose, n_clusters, options.seed, potentialCentroids, weights, _metric, _workspace, _centroidsRawData, stream); + }else{ + ASSERT(potentialCentroids.getSize(0) < n_clusters, + "failed to converge, restart with different seed"); + + _centroidsRawData.resize(n_clusters * n_features, stream); + MLCommon::copy(_centroidsRawData.data(), + potentialCentroids.data(), + potentialCentroids.numElements(), + stream); + } +} + + +template +__host__ +void +KMeans::fit(Tensor& X){ + cudaStream_t stream = _handle.getStream(); + auto n_samples = X.getSize(0); + auto n_features = X.getSize(1); + + auto dataBatchSize = kmeans::detail::getDataBatchSize(n_samples, n_features); + + // stores (key, value) pair corresponding to each sample where + // - key is the index of nearest cluster + // - value is the distance to the nearest cluster + Tensor, 1, IndexT> minClusterAndDistance({n_samples}, + _handle.getDeviceAllocator(), + stream); + + // temporary buffer to store distance matrix, destructor releases the resource + Tensor pairwiseDistance({dataBatchSize, n_clusters}, + _handle.getDeviceAllocator(), + stream); + + // temporary buffer to store intermediate centroids, destructor releases the resource + Tensor newCentroids({n_clusters, n_features}, + _handle.getDeviceAllocator(), + stream); + + // temporary buffer to store the sample count per cluster, destructor releases the resource + Tensor sampleCountInCluster({n_clusters}, + _handle.getDeviceAllocator(), + stream); + + + DataT priorClusteringCost = 0; + for(n_iter = 0; n_iter < max_iter; ++n_iter){ + + auto centroids = std::move(Tensor(_centroidsRawData.data(), {n_clusters, n_features})); + + // computes minClusterAndDistance[0:n_samples) where minClusterAndDistance[i] is a pair where + // 'key' is index to an sample in 'centroids' (index of the nearest centroid) and + // 'value' is the distance between the sample 'X[i]' and the 'centroid[key]' + kmeans::detail::minClusterAndDistance(_handle, + X, + centroids, + pairwiseDistance, + minClusterAndDistance, + _workspace, + _metric, + stream); + + + // Using TransformInputIteratorT to dereference an array of cub::KeyValuePair and converting them to just return the Key to be used in reduce_rows_by_key prims + kmeans::detail::KeyValueIndexOp conversion_op; + cub::TransformInputIterator, + cub::KeyValuePair*> itr(minClusterAndDistance.data(), + conversion_op); + + _workspace.resize(n_samples, stream); + + // Calculates sum of all the samples assigned to cluster-i and store the result in newCentroids[i] + LinAlg::reduce_rows_by_key(X.data(), + X.getSize(1), + itr, + _workspace.data(), + X.getSize(0), + X.getSize(1), + n_clusters, + newCentroids.data(), + stream); + + + // count # of samples in each cluster + kmeans::detail::countLabels(_handle, + itr, + sampleCountInCluster.data(), + n_samples, n_clusters, + _workspace, + stream); + + + + // Computes newCentroids[i] = newCentroids[i]/sampleCountInCluster[i] where + // newCentroids[n_samples x n_features] - 2D array, newCentroids[i] has sum of all the samples assigned to cluster-i + // sampleCountInCluster[n_clusters] - 1D array, sampleCountInCluster[i] contains # of samples in cluster-i. + // Note - when sampleCountInCluster[i] is 0, newCentroid[i] is reset to 0 + + // transforms int values in sampleCountInCluster to its inverse and more importantly to DataT because matrixVectorOp supports only when matrix and vector are of same type + _workspace.resize(sampleCountInCluster.numElements() * sizeof(DataT), stream); + auto sampleCountInClusterInverse = std::move(Tensor((DataT *)_workspace.data(), {n_clusters})); + + ML::thrustAllocatorAdapter alloc( _handle.getDeviceAllocator(), stream ); + auto execution_policy = thrust::cuda::par(alloc).on(stream); + thrust::transform(execution_policy, + sampleCountInCluster.begin(), + sampleCountInCluster.end(), + sampleCountInClusterInverse.begin(), + [=] __device__ (int count){ + if(count == 0) + return static_cast(0); + else + return static_cast(1.0)/static_cast(count); + }); + + LinAlg::matrixVectorOp(newCentroids.data(), + newCentroids.data(), + sampleCountInClusterInverse.data(), + newCentroids.getSize(1), + newCentroids.getSize(0), + true, + false, + [=] __device__ (DataT mat, DataT vec){ + return mat * vec; + }, + stream); + + // copy the centroids[i] to newCentroids[i] when sampleCountInCluster[i] is 0 + cub::ArgIndexInputIterator itr_sc(sampleCountInCluster.data()); + Matrix::gather_if(centroids.data(), + centroids.getSize(1), + centroids.getSize(0), + itr_sc, + itr_sc, + sampleCountInCluster.numElements(), + newCentroids.data(), + [=] __device__ (cub::KeyValuePair map){ // predicate + // copy when the # of samples in the cluster is 0 + if(map.value == 0) + return true; + else + return false; + }, + [=] __device__ (cub::KeyValuePair map) { // map + return map.key; + }, + stream); + + + // compute the squared norm between the newCentroids and the original centroids, destructor releases the resource + Tensor sqrdNorm({1}, _handle.getDeviceAllocator(), stream); + LinAlg::mapThenSumReduce(sqrdNorm.data(), + newCentroids.numElements(), + [=] __device__(const DataT a, const DataT b){ + DataT diff = a - b; + return diff * diff; + }, + stream, + centroids.data(), + newCentroids.data()); + + + DataT sqrdNormError = 0; + MLCommon::copy(&sqrdNormError, + sqrdNorm.data(), + sqrdNorm.numElements(), + stream); + + MLCommon::copy(_centroidsRawData.data(), + newCentroids.data(), + newCentroids.numElements(), + stream); + + bool done = false; + if(options.inertia_check){ + // calculate cluster cost phi_x(C) + const cub::KeyValuePair clusteringCost + = kmeans::detail::computeClusterCost(_handle, + minClusterAndDistance, + _workspace, + [] __device__(const cub::KeyValuePair &a, + const cub::KeyValuePair &b){ + cub::KeyValuePair res; + res.key = 0; + res.value = a.value + b.value; + return res; + }, + stream); + + DataT curClusteringCost = clusteringCost.value; + + ASSERT(curClusteringCost != (DataT)0.0, + "Too few points and centriods being found is getting 0 cost from centers\n"); + + + if(n_iter > 0){ + DataT delta = curClusteringCost/priorClusteringCost; + if(delta > 1 - tol) done = true; + } + priorClusteringCost = curClusteringCost; + } + + CUDA_CHECK( cudaStreamSynchronize(stream) ); + if (sqrdNormError < tol) done = true; + + if(done){ + LOG(_verbose, "Threshold triggered after %d iterations. Terminating early.\n", n_iter); + break; + } + } +} + + + +template +__host__ +void +KMeans::fit(const DataT *X, + int n_samples, + int n_features){ + if(n_clusters >= n_samples){ + n_clusters = n_samples; + options.oversampling_factor = 2.0 * n_clusters; + } + _n_features = n_features; + + ASSERT(options.oversampling_factor > 0, + "oversampling factor must be > 0 (requested %d)", + (int) options.oversampling_factor); + + + ASSERT(memory_type(X) == cudaMemoryTypeDevice, + "input data must be device accessible"); + + Tensor data((DataT *)X, + {n_samples, n_features}); + + + if(_init == kmeans::InitMethod::Random){ + //initializing with random samples from input dataset + LOG(_verbose, "KMeans.fit: initialize cluster centers by randomly choosing from the input data.\n"); + initRandom(data); + }else if(_init == kmeans::InitMethod::KMeansPlusPlus){ + // default method to initialize is kmeans++ + LOG(_verbose, "KMeans.fit: initialize cluster centers using k-means++ algorithm.\n"); + initKMeansPlusPlus(data); + }else if(_init == kmeans::InitMethod::Array){ + LOG(_verbose, "KMeans.fit: initialize cluster centers from the ndarray array input passed to init arguement.\n"); + ASSERT(_centroidsRawData.data(), + "call setCentroids() method to set the initial centroids or use the other initialized methods"); + }else{ + THROW("unknown initialization method to select initial centers"); + } + + + fit(data); +} + + +template +__host__ +void +KMeans::predict(Tensor &X){ + cudaStream_t stream = _handle.getStream(); + auto n_samples = X.getSize(0); + auto n_features = X.getSize(1); + auto dataBatchSize = kmeans::detail::getDataBatchSize(n_samples, n_features); + + // reset inertia + inertia = std::numeric_limits::infinity(); + + auto centroids = std::move(Tensor(_centroidsRawData.data(), {n_clusters, n_features})); + + Tensor, 1> minClusterAndDistance({n_samples}, + _handle.getDeviceAllocator(), + stream); + Tensor pairwiseDistance({dataBatchSize, n_clusters}, + _handle.getDeviceAllocator(), + stream); + + // computes minClusterAndDistance[0:n_samples) where minClusterAndDistance[i] is a pair where + // 'key' is index to an sample in 'centroids' (index of the nearest centroid) and + // 'value' is the distance between the sample 'X[i]' and the 'centroid[key]' + kmeans::detail::minClusterAndDistance(_handle, + X, + centroids, + pairwiseDistance, + minClusterAndDistance, + _workspace, + _metric, + stream); + + + // calculate cluster cost phi_x(C) + const cub::KeyValuePair clusteringCost + = kmeans::detail::computeClusterCost(_handle, + minClusterAndDistance, + _workspace, + [] __device__(const cub::KeyValuePair &a, + const cub::KeyValuePair &b){ + cub::KeyValuePair res; + res.key = 0; + res.value = a.value + b.value; + return res; + }, + stream); + + // Cluster cost phi_x(C) from all ranks + inertia = clusteringCost.value; + + _labelsRawData.resize(n_samples, stream); + + auto labels = std::move(Tensor(_labelsRawData.data(), {n_samples})); + ML::thrustAllocatorAdapter alloc( _handle.getDeviceAllocator(), stream ); + auto execution_policy = thrust::cuda::par(alloc).on(stream); + thrust::transform(execution_policy, + minClusterAndDistance.begin(), + minClusterAndDistance.end(), + labels.begin(), + [=] __device__ (cub::KeyValuePair pair){ + return pair.key; + }); +} + + +template +__host__ +void +KMeans::predict(const DataT *X, + int n_samples, + int n_features, + IndexT *labelsRawPtr){ + cudaStream_t stream = _handle.getStream(); + + ASSERT(_n_features == n_features, + "model is trained for %d-dimensional data (provided data is %d-dimensional)", _n_features, n_features); + + ASSERT(n_clusters > 0, + "no clusters exist"); + + Tensor data((DataT *)X, {n_samples, n_features}); + + predict(data); + + auto labels = std::move(Tensor(_labelsRawData.data(), {n_samples})); + + MLCommon::copy(labelsRawPtr, + labels.data(), + labels.numElements(), + stream); +} + +template +__host__ +void +KMeans::transform(const DataT *X, + int n_samples, + int n_features, + DataT *X_new){ + cudaStream_t stream = _handle.getStream(); + + ASSERT(_n_features == n_features, + "model is trained for %d-dimensional data (provided data is %d-dimensional)", _n_features, n_features); + + ASSERT(n_clusters > 0, + "no clusters exist"); + + ASSERT(memory_type(X) == cudaMemoryTypeDevice, + "input data must be device accessible"); + + ASSERT(memory_type(X) == cudaMemoryTypeDevice, + "output data storage must be device accessible"); + + + auto centroids = std::move(Tensor(_centroidsRawData.data(), {n_clusters, n_features})); + + Tensor dataset((DataT *)X, {n_samples, n_features}); + Tensor pairwiseDistance((DataT *)X_new, {n_samples, n_clusters}); + + auto dataBatchSize = kmeans::detail::getDataBatchSize(n_samples, n_features); + + + // tile over the input data and calculate distance matrix [n_samples x n_clusters] + for (int dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { + // # of samples for the current batch + int ns = std::min(dataBatchSize, n_samples - dIdx); + + // datasetView [ns x n_features] - view representing the current batch of input dataset + auto datasetView = dataset.template view<2>({ns, n_features}, + {dIdx, 0}); + + // pairwiseDistanceView [ns x n_clusters] + auto pairwiseDistanceView = pairwiseDistance.template view<2>({ns, n_clusters}, + {dIdx, 0}); + + + // calculate pairwise distance between cluster centroids and current batch of input dataset + kmeans::detail::pairwiseDistance(_handle, + datasetView, + centroids, + pairwiseDistanceView, + _workspace, + _metric, + stream); + } +} +}; // end namespace ML diff --git a/cuML/src/kmeans/kmeans.cu b/cuML/src/kmeans/kmeans.cu index dd05259272..6e6963ae3b 100644 --- a/cuML/src/kmeans/kmeans.cu +++ b/cuML/src/kmeans/kmeans.cu @@ -1,1437 +1,201 @@ -/*! - * Copyright 2017-2018 H2O.ai, Inc. - * License Apache License Version 2.0 (see LICENSE for details) - */ -#include -#include -#include -#include -#include -#include -#include "cuda.h" -#include -#include -#include "kmeans_c.h" -#include "kmeans_impl.h" -#include "kmeans_general.h" -#include "kmeans.h" -#include -#include -#include -#include -#include -#include "utils.h" -#include - -cudaStream_t cuda_stream[MAX_NGPUS]; - -/** - * METHODS FOR DATA COPYING AND GENERATION - */ - -template -void random_data(int verbose, thrust::device_vector &array, int m, int n) { - thrust::host_vector host_array(m * n); - for (int i = 0; i < m * n; i++) { - host_array[i] = (T) rand() / (T) RAND_MAX; - } - array = host_array; -} - -/** - * Copies data from srcdata to array - * @tparam T - * @param verbose Logging level - * @param ord Column on row order of data - * @param array Destination array - * @param srcdata Source data - * @param q Shard number (from 0 to n_gpu) - * @param n - * @param npergpu - * @param d - */ -template -void copy_data(int verbose, const char ord, thrust::device_vector &array, - const T *srcdata, int q, int n, size_t npergpu, int d) { - if (ord == 'c') { - thrust::host_vector host_array(npergpu * d); - log_debug(verbose, "Copy data COL ORDER -> ROW ORDER"); - - for (size_t i = 0; i < npergpu * d; i++) { - size_t indexi = i % d; // col - size_t indexj = i / d + q * npergpu; // row (shifted by which gpu) - host_array[i] = srcdata[indexi * n + indexj]; - } - array = host_array; - } else { - log_debug(verbose, "Copy data ROW ORDER not changed"); - thrust::host_vector host_array(srcdata + q * npergpu * d, - srcdata + q * npergpu * d + npergpu * d); - array = host_array; - } -} - -/** - * Like copy_data but shuffles the data according to mapping from v - * @tparam T - * @param verbose - * @param v - * @param ord - * @param array - * @param srcdata - * @param q - * @param n - * @param npergpu - * @param d - */ -template -void copy_data_shuffled(int verbose, std::vector v, const char ord, - thrust::device_vector &array, const T *srcdata, int q, int n, - int npergpu, int d) { - thrust::host_vector host_array(npergpu * d); - if (ord == 'c') { - log_debug(verbose, "Copy data shuffle COL ORDER -> ROW ORDER"); - - for (int i = 0; i < npergpu; i++) { - for (size_t j = 0; j < d; j++) { - host_array[i * d + j] = srcdata[v[q * npergpu + i] + j * n]; // shift by which gpu - } - } - } else { - log_debug(verbose, "Copy data shuffle ROW ORDER not changed"); - - for (int i = 0; i < npergpu; i++) { - for (size_t j = 0; j < d; j++) { - host_array[i * d + j] = srcdata[v[q * npergpu + i] * d + j]; // shift by which gpu - } - } - } - array = host_array; -} - -template -void copy_centroids_shuffled(int verbose, std::vector v, const char ord, - thrust::device_vector &array, const T *srcdata, int n, int k, - int d) { - copy_data_shuffled(verbose, v, ord, array, srcdata, 0, n, k, d); -} - -/** - * Copies centroids from initial training set randomly. - * @tparam T - * @param verbose - * @param seed - * @param ord - * @param array - * @param srcdata - * @param q - * @param n - * @param npergpu - * @param d - * @param k - */ -template -void random_centroids(int verbose, int seed, const char ord, - thrust::device_vector &array, const T *srcdata, int q, int n, - int npergpu, int d, int k) { - thrust::host_vector host_array(k * d); - if (seed < 0) { - std::random_device rd; //Will be used to obtain a seed for the random number engine - seed = rd(); - } - std::mt19937 gen(seed); - std::uniform_int_distribution<> dis(0, n - 1); // random i in range from 0..n-1 (i.e. only 1 gpu gets centroids) - - if (ord == 'c') { - log_debug(verbose, "Random centroids COL ORDER -> ROW ORDER"); - for (int i = 0; i < k; i++) { // clusters - size_t reali = dis(gen); // + q*npergpu; // row sampled (called indexj above) - for (size_t j = 0; j < d; j++) { // cols - host_array[i * d + j] = srcdata[reali + j * n]; - } - } - } else { - log_debug(verbose, "Random centroids ROW ORDER not changed"); - for (int i = 0; i < k; i++) { // rows - size_t reali = dis(gen); // + q*npergpu ; // row sampled - for (size_t j = 0; j < d; j++) { // cols - host_array[i * d + j] = srcdata[reali * d + j]; - } - } - } - array = host_array; -} - -/** - * KMEANS METHODS FIT, PREDICT, TRANSFORM - */ - -#define __HBAR__ \ - "----------------------------------------------------------------------------\n" - -namespace h2o4gpukmeans { - -template -int kmeans_find_clusters(int verbose, const char ord, int seed, - thrust::device_vector **data, thrust::device_vector **labels, - thrust::device_vector **d_centroids, - thrust::device_vector **data_dots, size_t rows, size_t cols, - int init_from_data, int k, int k_max, T threshold, const T *srcdata, - int n_gpu, std::vector dList, T &residual, std::vector v, - int max_iterations); - -template -int kmeans_fit(int verbose, int seed, int gpu_idtry, int n_gputry, size_t rows, - size_t cols, const char ord, int k, int k_max, int max_iterations, - int init_from_data, T threshold, const T *srcdata, T **pred_centroids, - int **pred_labels); - -template -int pick_point_idx_weighted(int seed, std::vector *data, - thrust::host_vector weights) { - T weighted_sum = 0; - - for (int i = 0; i < weights.size(); i++) { - if (data) { - weighted_sum += (data->data()[i] * weights.data()[i]); - } else { - weighted_sum += weights.data()[i]; - } - } - - T best_prob = 0.0; - int best_prob_idx = 0; - - std::mt19937 mt(seed); - std::uniform_real_distribution<> dist(0.0, 1.0); - - int i = 0; - for (i = 0; i <= weights.size(); i++) { - if (weights.size() == i) { - break; - } - - T prob_threshold = (T) dist(mt); - - T data_val = weights.data()[i]; - if (data) { - data_val *= data->data()[i]; - } - - T prob_x = (data_val / weighted_sum); - - if (prob_x > prob_threshold) { - break; - } - - if (prob_x >= best_prob) { - best_prob = prob_x; - best_prob_idx = i; - } - } - - return weights.size() == i ? best_prob_idx : i; -} - -/** - * Copies cols records, starting at position idx*cols from data to centroids. Removes them afterwards from data. - * Removes record from weights at position idx. - * @tparam T - * @param idx - * @param cols - * @param data - * @param weights - * @param centroids - */ -template -void add_centroid(int idx, int cols, thrust::host_vector &data, - thrust::host_vector &weights, std::vector ¢roids) { - for (int i = 0; i < cols; i++) { - centroids.push_back(data[idx * cols + i]); - } - weights[idx] = 0; -} - -struct square_root: public thrust::unary_function { - __host__ __device__ - float operator()(float x) const { - return sqrtf(x); - } -}; - -template -void filterByDot(int d, int k, int *numChosen, thrust::device_vector &dists, - thrust::device_vector ¢roids, - thrust::device_vector ¢roid_dots) { - - float alpha = 1.0f; - float beta = 0.0f; - - CUDACHECK(cudaSetDevice(0)); - kmeans::detail::make_self_dots(k, d, centroids, centroid_dots); - - thrust::transform(centroid_dots.begin(), centroid_dots.begin() + k, - centroid_dots.begin(), square_root()); - - cublasStatus_t stat = - safe_cublas( - cublasSgemm(kmeans::detail::cublas_handle[0], CUBLAS_OP_T, CUBLAS_OP_N, k, k, d, &alpha, thrust::raw_pointer_cast(centroids.data()), d, thrust::raw_pointer_cast(centroids.data()), d, &beta, thrust::raw_pointer_cast(dists.data()), k)) - ; //Has to be k or k - - //Check cosine angle between two vectors, must be < .9 - kmeans::detail::checkCosine(d, k, numChosen, dists, centroids, - centroid_dots); -} - -template -struct min_calc_functor { - T* all_costs_ptr; - T* min_costs_ptr; - T max = std::numeric_limits::max(); - int potential_k_rows; - int rows_per_run; - - min_calc_functor(T* _all_costs_ptr, T* _min_costs_ptr, - int _potential_k_rows, int _rows_per_run) { - all_costs_ptr = _all_costs_ptr; - min_costs_ptr = _min_costs_ptr; - potential_k_rows = _potential_k_rows; - rows_per_run = _rows_per_run; - } - - __host__ __device__ - void operator()(int idx) const { - T best = max; - for (int j = 0; j < potential_k_rows; j++) { - best = min(best, std::abs(all_costs_ptr[j * rows_per_run + idx])); - } - min_costs_ptr[idx] = min(min_costs_ptr[idx], best); - } -}; - -/** - * K-Means|| initialization method implementation as described in "Scalable K-Means++". - * - * This is a probabilistic method, which tries to choose points as much spread out as possible as centroids. +/* + * Copyright (c) 2019, NVIDIA CORPORATION. * - * In case it finds more than k centroids a K-Means++ algorithm is ran on potential centroids to pick k best suited ones. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf + * http://www.apache.org/licenses/LICENSE-2.0 * - * @tparam T - * @param verbose - * @param seed - * @param ord - * @param data - * @param data_dots - * @param centroids - * @param rows - * @param cols - * @param k - * @param num_gpu - * @param threshold - */ -template -thrust::host_vector kmeans_parallel(int verbose, int seed, const char ord, - thrust::device_vector **data, thrust::device_vector **data_dots, - size_t rows, int cols, int k, int num_gpu, T threshold) { - if (seed < 0) { - std::random_device rd; - int seed = rd(); - } - - size_t rows_per_gpu = rows / num_gpu; - - std::mt19937 gen(seed); - std::uniform_int_distribution<> dis(0, rows - 1); - - // Find the position (GPU idx and idx on that GPU) of the initial centroid - int first_center = dis(gen); - int first_center_idx = first_center % rows_per_gpu; - int first_center_gpu = first_center / rows_per_gpu; - - log_verbose(verbose, "KMeans|| - Initial centroid %d on GPU %d.", - first_center_idx, first_center_gpu); - - // Copies the initial centroid to potential centroids vector. That vector will store all potential centroids found - // in the previous iteration. - thrust::host_vector h_potential_centroids(cols); - std::vector> h_potential_centroids_per_gpu(num_gpu); - - CUDACHECK(cudaSetDevice(first_center_gpu)); - - thrust::copy((*data[first_center_gpu]).begin() + first_center_idx * cols, - (*data[first_center_gpu]).begin() + (first_center_idx + 1) * cols, - h_potential_centroids.begin()); - - thrust::host_vector h_all_potential_centroids = h_potential_centroids; - - // Initial the cost-to-potential-centroids and cost-to-closest-potential-centroid matrices. Initial cost is +infinity - std::vector> d_min_costs(num_gpu); - for (int q = 0; q < num_gpu; q++) { - CUDACHECK(cudaSetDevice(q)); - d_min_costs[q].resize(rows_per_gpu); - thrust::fill(d_min_costs[q].begin(), d_min_costs[q].end(), - std::numeric_limits::max()); - } - - double t0 = timer(); - - int curr_k = h_potential_centroids.size() / cols; - int max_k = k; - - while (curr_k < max_k) { - T total_min_cost = 0.0; - - int new_potential_centroids = 0; -#pragma omp parallel for - for (int i = 0; i < num_gpu; i++) { - CUDACHECK(cudaSetDevice(i)); - - thrust::device_vector d_potential_centroids = - h_potential_centroids; - - int potential_k_rows = d_potential_centroids.size() / cols; - - // Compute all the costs to each potential centroid from previous iteration - thrust::device_vector centroid_dots(potential_k_rows); - - kmeans::detail::batch_calculate_distances(verbose, 0, rows_per_gpu, - cols, potential_k_rows, *data[i], d_potential_centroids, - *data_dots[i], centroid_dots, - [&](int rows_per_run, size_t offset, thrust::device_vector &pairwise_distances) { - // Find the closest potential center cost for each row - auto min_cost_counter = thrust::make_counting_iterator(0); - auto all_costs_ptr = thrust::raw_pointer_cast(pairwise_distances.data()); - auto min_costs_ptr = thrust::raw_pointer_cast(d_min_costs[i].data() + offset); - thrust::for_each(min_cost_counter, - min_cost_counter + rows_per_run, - // Functor instead of a lambda b/c nvcc is complaining about - // nesting a __device__ lambda inside a regular lambda - min_calc_functor(all_costs_ptr, min_costs_ptr, potential_k_rows, rows_per_run)); - }); - } - - for (int i = 0; i < num_gpu; i++) { - CUDACHECK(cudaSetDevice(i)); - total_min_cost += thrust::reduce(d_min_costs[i].begin(), - d_min_costs[i].end()); - } - - log_verbose(verbose, "KMeans|| - Total min cost from centers %g.", - total_min_cost); - - if (total_min_cost == (T) 0.0) { - thrust::host_vector final_centroids(0); - if (verbose) { - fprintf(stderr, - "Too few points and centriods being found is getting 0 cost from centers\n"); - fflush(stderr); - } - - return final_centroids; - } - - std::set copy_from_gpus; -#pragma omp parallel for - for (int i = 0; i < num_gpu; i++) { - CUDACHECK(cudaSetDevice(i)); - - // Count how many potential centroids there are using probabilities - // The further the row is from the closest cluster center the higher the probability - auto pot_cent_filter_counter = thrust::make_counting_iterator(0); - auto min_costs_ptr = thrust::raw_pointer_cast( - d_min_costs[i].data()); -int pot_cent_num = thrust::count_if( - pot_cent_filter_counter, - pot_cent_filter_counter + rows_per_gpu, [=]__device__(int idx){ - thrust::default_random_engine rng(seed); - thrust::uniform_real_distribution<> dist(0.0, 1.0); - int device; - cudaGetDevice(&device); - rng.discard(idx + device * rows_per_gpu); - T prob_threshold = (T) dist(rng); - - T prob_x = (( 2.0 * k * min_costs_ptr[idx]) / total_min_cost); - - return prob_x > prob_threshold; - } - ); - - log_debug(verbose, "KMeans|| - Potential centroids on GPU %d = %d.", - i, pot_cent_num); - - if (pot_cent_num > 0) { - copy_from_gpus.insert(i); - - // Copy all potential cluster centers - thrust::device_vector d_new_potential_centroids( - pot_cent_num * cols); - - auto range = thrust::make_counting_iterator(0); - thrust::copy_if( - (*data[i]).begin(), (*data[i]).end(), range, - d_new_potential_centroids.begin(), [=] __device__(int idx) { - int row = idx / cols; - thrust::default_random_engine rng(seed); - thrust::uniform_real_distribution<> dist(0.0, 1.0); - int device; - cudaGetDevice(&device); - rng.discard(row + device * rows_per_gpu); - T prob_threshold = (T) dist(rng); - - T prob_x = (( 2.0 * k * min_costs_ptr[row]) / total_min_cost); - - return prob_x > prob_threshold; - }); - - h_potential_centroids_per_gpu[i].clear(); - h_potential_centroids_per_gpu[i].resize( - d_new_potential_centroids.size()); - - new_potential_centroids += d_new_potential_centroids.size(); - - thrust::copy(d_new_potential_centroids.begin(), - d_new_potential_centroids.end(), - h_potential_centroids_per_gpu[i].begin()); - - } - - } - - log_verbose(verbose, "KMeans|| - New potential centroids %d.", - new_potential_centroids); - - // Gather potential cluster centers from all GPUs - if (new_potential_centroids > 0) { - h_potential_centroids.clear(); - h_potential_centroids.resize(new_potential_centroids); - - int old_pot_centroids_size = h_all_potential_centroids.size(); - h_all_potential_centroids.resize( - old_pot_centroids_size + new_potential_centroids); - - int offset = 0; - for (int i = 0; i < num_gpu; i++) { - if (copy_from_gpus.find(i) != copy_from_gpus.end()) { - thrust::copy(h_potential_centroids_per_gpu[i].begin(), - h_potential_centroids_per_gpu[i].end(), - h_potential_centroids.begin() + offset); - offset += h_potential_centroids_per_gpu[i].size(); - } - } - - CUDACHECK(cudaSetDevice(0)); - thrust::device_vector new_centroids = h_potential_centroids; - - thrust::device_vector new_centroids_dist( - (new_potential_centroids / cols) - * (new_potential_centroids / cols)); - thrust::device_vector new_centroids_dot( - new_potential_centroids / cols); - - int numChosen = new_potential_centroids / cols; - filterByDot(cols, numChosen, &numChosen, new_centroids_dist, - new_centroids, new_centroids_dot); - - thrust::host_vector h_new_centroids = new_centroids; - h_all_potential_centroids.resize( - old_pot_centroids_size + (numChosen * cols)); - thrust::copy(h_new_centroids.begin(), - h_new_centroids.begin() + (numChosen * cols), - h_all_potential_centroids.begin() + old_pot_centroids_size); - curr_k = curr_k + numChosen; - - } else { - thrust::host_vector final_centroids(0); - if (verbose) { - fprintf(stderr, - "Too few points , not able to find centroid candidate \n"); - fflush(stderr); - } - - return final_centroids; - } - } - - thrust::host_vector final_centroids(0); - int potential_centroids_num = h_all_potential_centroids.size() / cols; - - final_centroids.resize(k * cols); - thrust::copy(h_all_potential_centroids.begin(), - h_all_potential_centroids.begin() + (max_k * cols), - final_centroids.begin()); - - return final_centroids; -} - -volatile std::atomic_int flaggpu(0); - -inline void my_function_gpu(int sig) { // can be called asynchronously - fprintf(stderr, "Caught signal %d. Terminating shortly.\n", sig); - flaggpu = 1; -} - -std::vector kmeans_init(int verbose, int *final_n_gpu, int n_gputry, - int gpu_idtry, int rows) { - if (rows > std::numeric_limits::max()) { - fprintf(stderr, "rows > %d not implemented\n", - std::numeric_limits::max()); - fflush(stderr); - exit(0); - } - - std::signal(SIGINT, my_function_gpu); - std::signal(SIGTERM, my_function_gpu); - - // no more gpus than visible gpus - int n_gpuvis; - cudaGetDeviceCount(&n_gpuvis); - int n_gpu = std::min(n_gpuvis, n_gputry); - - // no more than rows - n_gpu = std::min(n_gpu, rows); - - if (verbose) { - std::cout << n_gpu << " gpus." << std::endl; - } - - int gpu_id = gpu_idtry % n_gpuvis; - - // setup GPU list to use - std::vector dList(n_gpu); - for (int idx = 0; idx < n_gpu; idx++) { - int device_idx = (gpu_id + idx) % n_gpuvis; - dList[idx] = device_idx; - } - - *final_n_gpu = n_gpu; - return dList; -} - -template -H2O4GPUKMeans::H2O4GPUKMeans(const T *A, int k, int n, int d) { - _A = A; - _k = k; - _n = n; - _d = d; -} - -template -int kmeans_find_clusters(int verbose, const char ord, int seed, - thrust::device_vector **data, thrust::device_vector **labels, - thrust::device_vector **d_centroids, - thrust::device_vector **data_dots, size_t rows, size_t cols, - int init_from_data, int k, int k_max, T threshold, const T *srcdata, - int n_gpu, std::vector dList, T &residual, std::vector v, - int max_iterations) { - int bytecount = cols * k * sizeof(T); - if (0 == init_from_data) { - - log_debug(verbose, "KMeans - Using random initialization."); - - int masterq = 0; - CUDACHECK(cudaSetDevice(dList[masterq])); - // DM: simply copies first k rows data into GPU_0 - copy_centroids_shuffled(verbose, v, ord, *d_centroids[masterq], - &srcdata[0], rows, k, cols); - - // DM: can remove all of this - // Copy centroids to all devices - std::vector streams; - streams.resize(n_gpu); -#pragma omp parallel for - for (int q = 0; q < n_gpu; q++) { - if (q == masterq) - continue; - - CUDACHECK(cudaSetDevice(dList[q])); - if (verbose > 0) { - std::cout << "Copying centroid data to device: " << dList[q] - << std::endl; - } - - streams[q] = reinterpret_cast(malloc( - sizeof(cudaStream_t))); - cudaStreamCreate(streams[q]); - cudaMemcpyPeerAsync(thrust::raw_pointer_cast(&(*d_centroids[q])[0]), - dList[q], - thrust::raw_pointer_cast(&(*d_centroids[masterq])[0]), - dList[masterq], bytecount, *(streams[q])); - } -//#pragma omp parallel for - for (int q = 0; q < n_gpu; q++) { - if (q == masterq) - continue; - cudaSetDevice(dList[q]); - cudaStreamDestroy(*(streams[q])); -#if(DEBUGKMEANS) - thrust::host_vector h_centroidq=*d_centroids[q]; - for(int ii=0;ii final_centroids = kmeans_parallel(verbose, seed, - ord, data, data_dots, rows, cols, k, n_gpu, threshold); - if (final_centroids.size() == 0) { - if (verbose) { - fprintf(stderr, - "kmeans || failed to find %d number of cluster points \n", - k); - fflush(stderr); - } - - residual = 0.0; - return 0; - } - -#pragma omp parallel for - for (int q = 0; q < n_gpu; q++) { - CUDACHECK(cudaSetDevice(dList[q])); - cudaMemcpy(thrust::raw_pointer_cast(&(*d_centroids[q])[0]), - thrust::raw_pointer_cast(&final_centroids[0]), bytecount, - cudaMemcpyHostToDevice); - } - - } - -#pragma omp parallel for - for (int q = 0; q < n_gpu; q++) { - CUDACHECK(cudaSetDevice(dList[q])); - labels[q] = new thrust::device_vector(rows / n_gpu); - } - - double t0 = timer(); - - int iter = kmeans::kmeans(verbose, &flaggpu, rows, cols, k, k_max, data, - labels, d_centroids, data_dots, dList, n_gpu, max_iterations, - threshold, true); - - if (iter < 0) { - log_error(verbose, "KMeans algorithm failed."); - return iter; - } - - // Calculate the residual - size_t rows_per_gpu = rows / n_gpu; - std::vector> d_min_costs(n_gpu); - for (int q = 0; q < n_gpu; q++) { - CUDACHECK(cudaSetDevice(q)); - d_min_costs[q].resize(rows_per_gpu); - thrust::fill(d_min_costs[q].begin(), d_min_costs[q].end(), - std::numeric_limits::max()); - } -#pragma omp parallel for - for (int i = 0; i < n_gpu; i++) { - CUDACHECK(cudaSetDevice(i)); - - int potential_k_rows = k; - // Compute all the costs to each potential centroid from previous iteration - thrust::device_vector centroid_dots(potential_k_rows); - - kmeans::detail::batch_calculate_distances(verbose, 0, rows_per_gpu, - cols, k, *data[i], *d_centroids[i], *data_dots[i], - centroid_dots, - [&](int rows_per_run, size_t offset, thrust::device_vector &pairwise_distances) { - // Find the closest potential center cost for each row - auto min_cost_counter = thrust::make_counting_iterator(0); - auto all_costs_ptr = thrust::raw_pointer_cast(pairwise_distances.data()); - auto min_costs_ptr = thrust::raw_pointer_cast(d_min_costs[i].data() + offset); - thrust::for_each(min_cost_counter, - min_cost_counter + rows_per_run, - // Functor instead of a lambda b/c nvcc is complaining about - // nesting a __device__ lambda inside a regular lambda - min_calc_functor(all_costs_ptr, min_costs_ptr, potential_k_rows, rows_per_run)); - }); - } - - residual = 0.0; - for (int i = 0; i < n_gpu; i++) { - CUDACHECK(cudaSetDevice(i)); - residual += thrust::reduce(d_min_costs[i].begin(), - d_min_costs[i].end()); - } - - double timefit = static_cast(timer() - t0); - - if (verbose) { - std::cout << " Time fit: " << timefit << " s" << std::endl; - fprintf(stderr, "Time fir: %g \n", timefit); - fflush(stderr); - } - - return iter; -} - -template -int kmeans_fit(int verbose, int seed, int gpu_idtry, int n_gputry, size_t rows, - size_t cols, const char ord, int k, int k_max, int max_iterations, - int init_from_data, T threshold, const T *srcdata, T **pred_centroids, - int **pred_labels) { - // init random seed if use the C function rand() - if (seed >= 0) { - srand(seed); - } else { - srand(unsigned(time(NULL))); - } - - // no more clusters than rows - if (k_max > rows) { - k_max = static_cast(rows); - fprintf(stderr, - "Number of clusters adjusted to be equal to number of rows.\n"); - fflush(stderr); - } - - int n_gpu; - // only creates a list of GPUs to use. can be removed for single GPU - std::vector dList = kmeans_init(verbose, &n_gpu, n_gputry, gpu_idtry, - rows); - - double t0t = timer(); - thrust::device_vector *data[n_gpu]; - thrust::device_vector *labels[n_gpu]; - thrust::device_vector *d_centroids[n_gpu]; - thrust::device_vector *data_dots[n_gpu]; -#pragma omp parallel for - for (int q = 0; q < n_gpu; q++) { - CUDACHECK(cudaSetDevice(dList[q])); - data[q] = new thrust::device_vector(rows / n_gpu * cols); - d_centroids[q] = new thrust::device_vector(k_max * cols); - data_dots[q] = new thrust::device_vector(rows / n_gpu); - - kmeans::detail::labels_init(); - } - - log_debug(verbose, "Number of points: %d", rows); - log_debug(verbose, "Number of dimensions: %d", cols); - log_debug(verbose, "Number of clusters: %d", k); - log_debug(verbose, "Max number of clusters: %d", k_max); - log_debug(verbose, "Max. number of iterations: %d", max_iterations); - log_debug(verbose, "Stopping threshold: %d", threshold); - - std::vector v(rows); - std::iota(std::begin(v), std::end(v), 0); // Fill with 0, 1, ..., rows. - - if (seed >= 0) { - std::shuffle(v.begin(), v.end(), std::default_random_engine(seed)); - } else { - std::random_shuffle(v.begin(), v.end()); - } - - // Copy the data to devices -#pragma omp parallel for - for (int q = 0; q < n_gpu; q++) { - CUDACHECK(cudaSetDevice(dList[q])); - if (verbose) { - std::cout << "Copying data to device: " << dList[q] << std::endl; - } - - copy_data(verbose, ord, *data[q], &srcdata[0], q, rows, rows / n_gpu, - cols); - - // Pre-compute the data matrix norms - kmeans::detail::make_self_dots(rows / n_gpu, cols, *data[q], - *data_dots[q]); - } - - // Host memory - thrust::host_vector results(k_max + 1, (T) (1e20)); - - // Loop to find *best* k - // Perform k-means in binary search - int left = k; //must be at least 2 - int right = k_max; //int(floor(len(data)/2)) #assumption of clusters of size 2 at least - int mid = int(floor((right + left) / 2)); - int oldmid = mid; - int tests; - int iter = 0; - T objective[2]; // 0= left of mid, 1= right of mid - T minres = 0; - T residual = 0.0; - - if (left == 1) - left = 2; // at least do 2 clusters - // eval left edge - - iter = kmeans_find_clusters(verbose, ord, seed, data, labels, d_centroids, - data_dots, rows, cols, init_from_data, left, k_max, threshold, - srcdata, n_gpu, dList, residual, v, max_iterations); - results[left] = residual; - - if (left != right) { - //eval right edge - residual = 0.0; - iter = kmeans_find_clusters(verbose, ord, seed, data, labels, - d_centroids, data_dots, rows, cols, init_from_data, right, - k_max, threshold, srcdata, n_gpu, dList, residual, v, - max_iterations); - int tmp_left = left; - int tmp_right = right; - T tmp_residual = 0.0; - - while ((residual == 0.0) && (right > 0)) { - right = (tmp_left + tmp_right) / 2; - // This k is already explored and need not be explored again - if (right == tmp_left) { - residual = tmp_residual; - right = tmp_left; - break; - } - iter = kmeans_find_clusters(verbose, ord, seed, data, labels, - d_centroids, data_dots, rows, cols, init_from_data, right, - k_max, threshold, srcdata, n_gpu, dList, residual, v, - max_iterations); - results[right] = residual; - - if (residual == 0.0) { - tmp_right = right; - } else { - tmp_left = right; - tmp_residual = residual; - - if (abs(tmp_left - tmp_right) == 1) { - break; - } - } - // Escape from an infinite loop if we come across - if (tmp_left == tmp_right) { - residual = tmp_residual; - right = tmp_left; - break; - } - - residual = 0.0; - } - results[right] = residual; - minres = residual * 0.9; - mid = int(floor((right + left) / 2)); - oldmid = mid; - } - - // binary search - while (left < right - 1) { - tests = 0; - while (results[mid] > results[left] && tests < 3) { - - iter = kmeans_find_clusters(verbose, ord, seed, data, labels, - d_centroids, data_dots, rows, cols, init_from_data, mid, - k_max, threshold, srcdata, n_gpu, dList, residual, v, - max_iterations); - results[mid] = residual; - if (results[mid] > results[left] && (mid + 1) < right) { - mid += 1; - results[mid] = 1e20; - } else if (results[mid] > results[left] && (mid - 1) > left) { - mid -= 1; - results[mid] = 1e20; - } - tests += 1; - } - objective[0] = abs(results[left] - results[mid]) - / (results[left] - minres); - objective[0] /= mid - left; - objective[1] = abs(results[mid] - results[right]) - / (results[mid] - minres); - objective[1] /= right - mid; - if (objective[0] > 1.2 * objective[1]) { //abs(resid_reduction[left]-resid_reduction[mid])/(mid-left)) { - // our point is in the left-of-mid side - right = mid; - } else { - left = mid; - } - oldmid = mid; - mid = int(floor((right + left) / 2)); - } - - int k_final = 0; - k_final = right; - if (results[left] < results[oldmid]) - k_final = left; - - // if k_star isn't what we just ran, re-run to get correct centroids and dist data on return-> this saves memory - if (k_final != oldmid) { - iter = kmeans_find_clusters(verbose, ord, seed, data, labels, - d_centroids, data_dots, rows, cols, init_from_data, k_final, - k_max, threshold, srcdata, n_gpu, dList, residual, v, - max_iterations); - } - - double timetransfer = static_cast(timer() - t0t); - - double t1 = timer(); - - // copy result of centroids (sitting entirely on each device) back to host - // TODO FIXME: When do delete ctr and h_labels memory??? - thrust::host_vector *ctr = new thrust::host_vector(*d_centroids[0]); - *pred_centroids = ctr->data(); - - // copy assigned labels - thrust::host_vector *h_labels = new thrust::host_vector(rows); -//#pragma omp parallel for - for (int q = 0; q < n_gpu; q++) { - int offset = labels[q]->size() * q; - h_labels->insert(h_labels->begin() + offset, labels[q]->begin(), - labels[q]->end()); - } - - *pred_labels = h_labels->data(); - - // debug - if (verbose >= H2O4GPU_LOG_VERBOSE) { - for (unsigned int ii = 0; ii < k; ii++) { - fprintf(stderr, "ii=%d of k=%d ", ii, k); - for (unsigned int jj = 0; jj < cols; jj++) { - fprintf(stderr, "%g ", (*pred_centroids)[cols * ii + jj]); - } - fprintf(stderr, "\n"); - fflush(stderr); - } - - printf("Number of iteration: %d\n", iter); - } - -#pragma omp parallel for - for (int q = 0; q < n_gpu; q++) { - CUDACHECK(cudaSetDevice(dList[q])); - delete (data[q]); - delete (labels[q]); - delete (d_centroids[q]); - delete (data_dots[q]); - kmeans::detail::labels_close(); - } - - double timecleanup = static_cast(timer() - t1); - - if (verbose) { - fprintf(stderr, "Timetransfer: %g Timecleanup: %g\n", timetransfer, - timecleanup); - fflush(stderr); - } - - return k_final; -} - -template -int kmeans_predict(int verbose, int gpu_idtry, int n_gputry, size_t rows, - size_t cols, const char ord, int k, const T *srcdata, - const T *centroids, int **pred_labels) { - // Print centroids - if (verbose >= H2O4GPU_LOG_VERBOSE) { - std::cout << std::endl; - for (int i = 0; i < cols * k; i++) { - std::cout << centroids[i] << " "; - if (i % cols == 1) { - std::cout << std::endl; - } - } - } - - int n_gpu; - std::vector dList = kmeans_init(verbose, &n_gpu, n_gputry, gpu_idtry, - rows); - - thrust::device_vector *d_data[n_gpu]; - thrust::device_vector *d_centroids[n_gpu]; - thrust::device_vector *data_dots[n_gpu]; - thrust::device_vector *centroid_dots[n_gpu]; - thrust::host_vector *h_labels = new thrust::host_vector(0); - std::vector> d_labels(n_gpu); - -#pragma omp parallel for - for (int q = 0; q < n_gpu; q++) { - CUDACHECK(cudaSetDevice(dList[q])); - kmeans::detail::labels_init(); - - data_dots[q] = new thrust::device_vector(rows / n_gpu); - centroid_dots[q] = new thrust::device_vector(k); - - d_centroids[q] = new thrust::device_vector(k * cols); - d_data[q] = new thrust::device_vector(rows / n_gpu * cols); - - copy_data(verbose, 'r', *d_centroids[q], ¢roids[0], 0, k, k, cols); - - copy_data(verbose, ord, *d_data[q], &srcdata[0], q, rows, rows / n_gpu, - cols); - - kmeans::detail::make_self_dots(rows / n_gpu, cols, *d_data[q], - *data_dots[q]); - - d_labels[q].resize(rows / n_gpu); - - kmeans::detail::batch_calculate_distances(verbose, q, rows / n_gpu, - cols, k, *d_data[q], *d_centroids[q], *data_dots[q], - *centroid_dots[q], - [&](int n, size_t offset, thrust::device_vector &pairwise_distances) { - kmeans::detail::relabel(n, k, pairwise_distances, d_labels[q], offset); - }); - - } - - for (int q = 0; q < n_gpu; q++) { - h_labels->insert(h_labels->end(), d_labels[q].begin(), - d_labels[q].end()); - } - - *pred_labels = h_labels->data(); - -#pragma omp parallel for - for (int q = 0; q < n_gpu; q++) { - safe_cuda(cudaSetDevice(dList[q])); - kmeans::detail::labels_close(); - delete (data_dots[q]); - delete (centroid_dots[q]); - delete (d_centroids[q]); - delete (d_data[q]); - } - - return 0; -} - -template -int kmeans_transform(int verbose, int gpu_idtry, int n_gputry, size_t rows, - size_t cols, const char ord, int k, const T *srcdata, - const T *centroids, T **preds) { - // Print centroids - if (verbose >= H2O4GPU_LOG_VERBOSE) { - std::cout << std::endl; - for (int i = 0; i < cols * k; i++) { - std::cout << centroids[i] << " "; - if (i % cols == 1) { - std::cout << std::endl; - } - } - } - - int n_gpu; - std::vector dList = kmeans_init(verbose, &n_gpu, n_gputry, gpu_idtry, - rows); - - thrust::device_vector *d_data[n_gpu]; - thrust::device_vector *d_centroids[n_gpu]; - thrust::device_vector *d_pairwise_distances[n_gpu]; - thrust::device_vector *data_dots[n_gpu]; - thrust::device_vector *centroid_dots[n_gpu]; -#pragma omp parallel for - for (int q = 0; q < n_gpu; q++) { - CUDACHECK(cudaSetDevice(dList[q])); - kmeans::detail::labels_init(); - - data_dots[q] = new thrust::device_vector(rows / n_gpu); - centroid_dots[q] = new thrust::device_vector(k); - d_pairwise_distances[q] = new thrust::device_vector( - rows / n_gpu * k); - - d_centroids[q] = new thrust::device_vector(k * cols); - d_data[q] = new thrust::device_vector(rows / n_gpu * cols); - - copy_data(verbose, 'r', *d_centroids[q], ¢roids[0], 0, k, k, cols); - - copy_data(verbose, ord, *d_data[q], &srcdata[0], q, rows, rows / n_gpu, - cols); - - kmeans::detail::make_self_dots(rows / n_gpu, cols, *d_data[q], - *data_dots[q]); - - // TODO batch this - kmeans::detail::calculate_distances(verbose, q, rows / n_gpu, cols, k, - *d_data[q], 0, *d_centroids[q], *data_dots[q], - *centroid_dots[q], *d_pairwise_distances[q]); - } - -#pragma omp parallel for - for (int q = 0; q < n_gpu; q++) { - CUDACHECK(cudaSetDevice(dList[q])); - thrust::transform((*d_pairwise_distances[q]).begin(), - (*d_pairwise_distances[q]).end(), - (*d_pairwise_distances[q]).begin(), square_root()); - } - - // Move the resulting labels into host memory from all devices - thrust::host_vector *h_pairwise_distances = new thrust::host_vector( - 0); -#pragma omp parallel for - for (int q = 0; q < n_gpu; q++) { - h_pairwise_distances->insert(h_pairwise_distances->end(), - d_pairwise_distances[q]->begin(), - d_pairwise_distances[q]->end()); - } - *preds = h_pairwise_distances->data(); - - // Print centroids - if (verbose >= H2O4GPU_LOG_VERBOSE) { - std::cout << std::endl; - for (int i = 0; i < rows * cols; i++) { - std::cout << h_pairwise_distances->data()[i] << " "; - if (i % cols == 1) { - std::cout << std::endl; - } - } - } - -#pragma omp parallel for - for (int q = 0; q < n_gpu; q++) { - safe_cuda(cudaSetDevice(dList[q])); - kmeans::detail::labels_close(); - delete (d_pairwise_distances[q]); - delete (data_dots[q]); - delete (centroid_dots[q]); - delete (d_centroids[q]); - delete (d_data[q]); - } - - return 0; -} - -template -int makePtr_dense(int dopredict, int verbose, int seed, int gpu_idtry, - int n_gputry, size_t rows, size_t cols, const char ord, int k, - int k_max, int max_iterations, int init_from_data, T threshold, - const T *srcdata, const T *centroids, T **pred_centroids, - int **pred_labels) { - if (dopredict == 0) { - return kmeans_fit(verbose, seed, gpu_idtry, n_gputry, rows, cols, ord, - k, k_max, max_iterations, init_from_data, threshold, srcdata, - pred_centroids, pred_labels); - } else { - return kmeans_predict(verbose, gpu_idtry, n_gputry, rows, cols, ord, k, - srcdata, centroids, pred_labels); - } -} - -template int -makePtr_dense(int dopredict, int verbose, int seed, int gpu_id, - int n_gpu, size_t rows, size_t cols, const char ord, int k, int k_max, - int max_iterations, int init_from_data, float threshold, - const float *srcdata, const float *centroids, float **pred_centroids, - int **pred_labels); - -template int -makePtr_dense(int dopredict, int verbose, int seed, int gpu_id, - int n_gpu, size_t rows, size_t cols, const char ord, int k, int k_max, - int max_iterations, int init_from_data, double threshold, - const double *srcdata, const double *centroids, double **pred_centroids, - int **pred_labels); - -template int kmeans_fit(int verbose, int seed, int gpu_idtry, - int n_gputry, size_t rows, size_t cols, const char ord, int k, - int k_max, int max_iterations, int init_from_data, float threshold, - const float *srcdata, float **pred_centroids, int **pred_labels); - -template int kmeans_fit(int verbose, int seed, int gpu_idtry, - int n_gputry, size_t rows, size_t cols, const char ord, int k, - int k_max, int max_iterations, int init_from_data, double threshold, - const double *srcdata, double **pred_centroids, int **pred_labels); - -template int kmeans_find_clusters(int verbose, const char ord, int seed, - thrust::device_vector **data, - thrust::device_vector **labels, - thrust::device_vector **d_centroids, - thrust::device_vector **data_dots, size_t rows, size_t cols, - int init_from_data, int k, int k_max, float threshold, - const float *srcdata, int n_gpu, std::vector dList, - float &residual, std::vector v, int max_iterations); - -template int kmeans_find_clusters(int verbose, const char ord, int seed, - thrust::device_vector **data, - thrust::device_vector **labels, - thrust::device_vector **d_centroids, - thrust::device_vector **data_dots, size_t rows, size_t cols, - int init_from_data, int k, int k_max, double threshold, - const double *srcdata, int n_gpu, std::vector dList, - double &residual, std::vector v, int max_iterations); - -template int kmeans_predict(int verbose, int gpu_idtry, int n_gputry, - size_t rows, size_t cols, const char ord, int k, const float *srcdata, - const float *centroids, int **pred_labels); - -template int kmeans_predict(int verbose, int gpu_idtry, int n_gputry, - size_t rows, size_t cols, const char ord, int k, const double *srcdata, - const double *centroids, int **pred_labels); - -template int kmeans_transform(int verbose, int gpu_id, int n_gpu, - size_t m, size_t n, const char ord, int k, const float *src_data, - const float *centroids, float **preds); - -template int kmeans_transform(int verbose, int gpu_id, int n_gpu, - size_t m, size_t n, const char ord, int k, const double *src_data, - const double *centroids, double **preds); - -// Explicit template instantiation. -#if !defined(H2O4GPU_DOUBLE) || H2O4GPU_DOUBLE == 1 - -template -class H2O4GPUKMeans ; - -#endif - -#if !defined(H2O4GPU_SINGLE) || H2O4GPU_SINGLE == 1 - -template -class H2O4GPUKMeans ; - -#endif - -int get_n_gpus(int n_gputry) { - int nDevices; - cudaGetDeviceCount(&nDevices); - - if (n_gputry < 0) { // get all available GPUs - return nDevices; - } else if (n_gputry > nDevices) { - return nDevices; - } else { - return n_gputry; - } -} - -} // namespace h2o4gpukmeans - -namespace ML { -/* - * Interface for other languages + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -// Fit and Predict -void make_ptr_kmeans(int dopredict, int verbose, int seed, int gpu_id, - int n_gpu, size_t mTrain, size_t n, const char ord, int k, int k_max, - int max_iterations, int init_from_data, float threshold, - const float *srcdata, const float *centroids, float *pred_centroids, - int *pred_labels) { - float *h_srcdata = (float*) malloc(mTrain * n * sizeof(float)); - cudaMemcpy((void*)h_srcdata, (void*)srcdata, mTrain*n * sizeof(float), cudaMemcpyDeviceToHost); - - float *h_centroids = nullptr; - if (dopredict) { - h_centroids = (float*) malloc(k * n * sizeof(float)); - cudaMemcpy((void*) h_centroids, (void*) centroids, - k * n * sizeof(float), cudaMemcpyDeviceToHost); - } - - int *h_pred_labels = nullptr; - float *h_pred_centroids = nullptr; - int actual_n_gpu = h2o4gpukmeans::get_n_gpus(n_gpu); - int actual_k = h2o4gpukmeans::makePtr_dense(dopredict, verbose, seed, - gpu_id, actual_n_gpu, mTrain, n, ord, k, k_max, max_iterations, - init_from_data, threshold, h_srcdata, h_centroids, - &h_pred_centroids, &h_pred_labels); - - cudaSetDevice(gpu_id); - - if (dopredict == 0) { - cudaMemcpy(pred_centroids, h_pred_centroids, k * n * sizeof(float), - cudaMemcpyHostToDevice); - } - - cudaMemcpy(pred_labels, h_pred_labels, mTrain * sizeof(int), - cudaMemcpyHostToDevice); - - free(h_srcdata); - if (dopredict) { - free(h_centroids); - } -} - -void make_ptr_kmeans(int dopredict, int verbose, int seed, int gpu_id, - int n_gpu, size_t mTrain, size_t n, const char ord, int k, int k_max, - int max_iterations, int init_from_data, double threshold, - const double *srcdata, const double *centroids, double *pred_centroids, - int *pred_labels) { - - double *h_srcdata = (double*) malloc(mTrain * n * sizeof(double)); - cudaMemcpy((void*)h_srcdata, (void*)srcdata, mTrain*n * sizeof(double), cudaMemcpyDeviceToHost); - - double *h_centroids = nullptr; - if (dopredict) { - h_centroids = (double*) malloc(k * n * sizeof(double)); - cudaMemcpy((void*) h_centroids, (void*) centroids, - k * n * sizeof(double), cudaMemcpyDeviceToHost); - } - - int *h_pred_labels = nullptr; - double *h_pred_centroids = nullptr; - int actual_n_gpu = h2o4gpukmeans::get_n_gpus(n_gpu); - int actual_k = h2o4gpukmeans::makePtr_dense(dopredict, verbose, - seed, gpu_id, actual_n_gpu, mTrain, n, ord, k, k_max, - max_iterations, init_from_data, threshold, h_srcdata, h_centroids, - &h_pred_centroids, &h_pred_labels); - - cudaSetDevice(gpu_id); - // int dev = -1; - // cudaGetDevice(&dev); - // printf("device: %d\n", dev); - - if (dopredict == 0) { - cudaMemcpy(pred_centroids, h_pred_centroids, k * n * sizeof(double), - cudaMemcpyHostToDevice); - } - - cudaMemcpy(pred_labels, h_pred_labels, mTrain * sizeof(int), - cudaMemcpyHostToDevice); - - free(h_srcdata); - if (dopredict) { - free(h_centroids); - } - -} - -// Transform -void kmeans_transform(int verbose, int gpu_id, int n_gpu, size_t m, size_t n, - const char ord, int k, const float *src_data, const float *centroids, - float *preds) { - float *h_srcdata = (float*) malloc(m * n * sizeof(float)); - cudaMemcpy((void*)h_srcdata, (void*)src_data, m*n * sizeof(float), cudaMemcpyDeviceToHost); - - float *h_centroids = (float*) malloc(k * n * sizeof(float)); - cudaMemcpy((void*) h_centroids, (void*) centroids, k * n * sizeof(float), - cudaMemcpyDeviceToHost); - - // const float *h_srcdata = src_data; - // const float *h_centroids = centroids; - - float *h_preds = nullptr; - int actual_n_gpu = h2o4gpukmeans::get_n_gpus(n_gpu); - h2o4gpukmeans::kmeans_transform(verbose, gpu_id, actual_n_gpu, m, n, - ord, k, h_srcdata, h_centroids, &h_preds); - - cudaSetDevice(gpu_id); - - cudaMemcpy(preds, h_preds, m * k * sizeof(float), cudaMemcpyHostToDevice); - - free(h_srcdata); - free(h_centroids); -} - -void kmeans_transform(int verbose, int gpu_id, int n_gpu, size_t m, size_t n, - const char ord, int k, const double *src_data, const double *centroids, - double *preds) { - - double *h_srcdata = (double*) malloc(m * n * sizeof(double)); - cudaMemcpy((void*)h_srcdata, (void*)src_data, m*n * sizeof(double), cudaMemcpyDeviceToHost); - - double *h_centroids = (double*) malloc(k * n * sizeof(double)); - cudaMemcpy((void*) h_centroids, (void*) centroids, k * n * sizeof(double), - cudaMemcpyDeviceToHost); - - // const double *h_srcdata = src_data; - // const double *h_centroids = centroids; - - double *h_preds = nullptr; - int actual_n_gpu = h2o4gpukmeans::get_n_gpus(n_gpu); - h2o4gpukmeans::kmeans_transform(verbose, gpu_id, actual_n_gpu, m, n, - ord, k, h_srcdata, h_centroids, &h_preds); - - cudaSetDevice(gpu_id); - - cudaMemcpy(preds, h_preds, m * k * sizeof(double), cudaMemcpyHostToDevice); - - free(h_srcdata); - free(h_centroids); -} - -} // end namespace ML +#include "kmeans.cuh" + +namespace ML{ +namespace kmeans{ + + +void fit_predict(const ML::cumlHandle& handle, + int n_clusters, + int metric, + kmeans::InitMethod init, + int max_iter, + double tol, + int seed, + const float *X, + int n_samples, + int n_features, + float *centroids, + int *labels, + int verbose){ + const ML::cumlHandle_impl &h = handle.getImpl(); + ML::detail::streamSyncer _(h); + cudaStream_t stream = h.getStream(); + + ML::KMeans kmeans_obj(h, n_clusters, static_cast(metric), init, max_iter, tol, seed, verbose); + + if(kmeans::InitMethod::Array == init){ + ASSERT(centroids != nullptr, + "centroids array is null (require a valid array of centroids for the requested initialization method)"); + kmeans_obj.setCentroids(centroids, n_clusters, n_features); + } + + kmeans_obj.fit(X, n_samples, n_features); + if(labels){ + kmeans_obj.predict(X, n_samples, n_features, labels); + } + + MLCommon::copy(centroids, kmeans_obj.centroids(), n_clusters * n_features, stream); + +} + +void fit_predict(const ML::cumlHandle& handle, + int n_clusters, + int metric, + kmeans::InitMethod init, + int max_iter, + double tol, + int seed, + const double *X, + int n_samples, + int n_features, + double *centroids, + int *labels, + int verbose){ + const ML::cumlHandle_impl &h = handle.getImpl(); + ML::detail::streamSyncer _(h); + cudaStream_t stream = h.getStream(); + + ML::KMeans kmeans_obj(h, n_clusters, static_cast(metric), init, max_iter, tol, seed, verbose); + + if(kmeans::InitMethod::Array == init){ + ASSERT(centroids != nullptr, + "centroids array is null (require a valid array of centroids for the requested initialization method)"); + kmeans_obj.setCentroids(centroids, n_clusters, n_features); + } + + kmeans_obj.fit(X, n_samples, n_features); + if(labels){ + kmeans_obj.predict(X, n_samples, n_features, labels); + } + + MLCommon::copy(centroids, kmeans_obj.centroids(), n_clusters * n_features, stream); +} + +void fit(const ML::cumlHandle& handle, + int n_clusters, + int metric, + kmeans::InitMethod init, + int max_iter, + double tol, + int seed, + const float *X, + int n_samples, + int n_features, + float *centroids, + int verbose){ + fit_predict(handle, n_clusters, metric, init, max_iter, tol, seed, X, n_samples, n_features, centroids, nullptr, verbose); +} + +void fit(const ML::cumlHandle& handle, + int n_clusters, + int metric, + kmeans::InitMethod init, + int max_iter, + double tol, + int seed, + const double *X, + int n_samples, + int n_features, + double *centroids, + int verbose){ + fit_predict(handle, n_clusters, metric, init, max_iter, tol, seed, X, n_samples, n_features, centroids, nullptr, verbose); +} + + + +void predict(const ML::cumlHandle& handle, + float *centroids, + int n_clusters, + const float *X, + int n_samples, + int n_features, + int metric, + int *labels, + int verbose){ + const ML::cumlHandle_impl &h = handle.getImpl(); + ML::detail::streamSyncer _(h); + cudaStream_t stream = h.getStream(); + + ML::KMeans kmeans_obj(h, n_clusters, static_cast(metric)); + + kmeans_obj.setCentroids(centroids, n_clusters, n_features); + + kmeans_obj.predict(X, n_samples, n_features, labels); +} + + +void predict(const ML::cumlHandle& handle, + double *centroids, + int n_clusters, + const double *X, + int n_samples, + int n_features, + int metric, + int *labels, + int verbose){ + const ML::cumlHandle_impl &h = handle.getImpl(); + ML::detail::streamSyncer _(h); + cudaStream_t stream = h.getStream(); + + ML::KMeans kmeans_obj(h, n_clusters, static_cast(metric)); + kmeans_obj.setCentroids(centroids, n_clusters, n_features); + + kmeans_obj.predict(X, n_samples, n_features, labels); +} + + +void transform(const ML::cumlHandle& handle, + const float *centroids, + int n_clusters, + const float *X, + int n_samples, + int n_features, + int metric, + float *X_new, + int verbose){ + const ML::cumlHandle_impl &h = handle.getImpl(); + ML::detail::streamSyncer _(h); + cudaStream_t stream = h.getStream(); + + ML::KMeans kmeans_obj(h, n_clusters, static_cast(metric)); + kmeans_obj.setCentroids(centroids, n_clusters, n_features); + kmeans_obj.transform(X, n_samples, n_features, X_new); +} + +void transform(const ML::cumlHandle& handle, + const double *centroids, + int n_clusters, + const double *X, + int n_samples, + int n_features, + int metric, + double *X_new, + int verbose){ + const ML::cumlHandle_impl &h = handle.getImpl(); + ML::detail::streamSyncer _(h); + cudaStream_t stream = h.getStream(); + + ML::KMeans kmeans_obj(h, n_clusters, static_cast(metric)); + kmeans_obj.setCentroids(centroids, n_clusters, n_features); + kmeans_obj.transform(X, n_samples, n_features, X_new); +} + + +}; // end namespace kmeans +}; // end namespace ML diff --git a/cuML/src/kmeans/kmeans.cuh b/cuML/src/kmeans/kmeans.cuh new file mode 100644 index 0000000000..c6c042d4f6 --- /dev/null +++ b/cuML/src/kmeans/kmeans.cuh @@ -0,0 +1,217 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "kmeans.hpp" + +namespace ML { +using namespace MLCommon; + +/** + * @tparam DataT + * @tparam IndexT + */ +template +class KMeans{ +public: + + struct Options{ + int seed; + int oversampling_factor; + Random::GeneratorType gtype; + bool inertia_check; + public: + Options(){ + seed = -1; + oversampling_factor = 0; + gtype = Random::GeneratorType::GenPhilox; + inertia_check = false; + } + }; + + + /** + * @param[in] cumlHandle The handle to the cuML library context that manages the CUDA resources. + * @param[in] n_clusters The number of clusters to form as well as the number of centroids to generate (default:8). + * @param[in] metric Metric to use for distance computation. Any metric from MLCommon::Distance::DistanceType can be used + * @param[in] init Method for initialization, defaults to k-means++: + * - InitMethod::KMeansPlusPlus (k-means++): Use scalable k-means++ algorithm to select the initial cluster centers. + * - InitMethod::Random (random): Choose 'n_clusters' observations (rows) at random from the input data for the initial centroids. + * - InitMethod::Array (ndarray): Use the user provided input for initial centroids, it should be set using setCentroids method. + * @param[in] max_iter Maximum number of iterations for the k-means algorithm. + * @param[in] tol Relative tolerance with regards to inertia to declare convergence. + * @param[in] seed Seed to the random number generator. + */ + KMeans(const ML::cumlHandle_impl &cumlHandle, + int n_clusters = 8, + MLCommon::Distance::DistanceType metric = MLCommon::Distance::DistanceType::EucExpandedL2Sqrt, + kmeans::InitMethod init = kmeans::InitMethod::KMeansPlusPlus, + int max_iter = 300, + double tol = 1e-4, + int seed = -1, + int verbose = 0); + + + /** + * @brief Compute k-means clustering. + * + * @param[in] X Training instances to cluster. It must be noted that the data must be in row-major format and stored in device accessible location. + * @param[in] n_samples Number of samples in the input X. + * @param[in] n_features Number of features or the dimensions of each sample. + */ + void + fit(const DataT *X, + int n_samples, + int n_features); + + + /** + * @brief Predict the closest cluster each sample in X belongs to. + * + * @param[in] X New data to predict. + * @param[in] n_samples Number of samples in the input X. + * @param[in] n_features Number of features or the dimensions of each sample. + * @param[out] labels Index of the cluster each sample belongs to. + */ + void + predict(const DataT *X, + int n_samples, + int n_features, + IndexT *labels); + + /** + * @brief Transform X to a cluster-distance space. In the new space, each dimension is the distance to the cluster centers. + * + * @param[in] X New data to transform. + * @param[in] n_samples Number of samples in the input X. + * @param[in] n_features Number of features or the dimensions of each sample. + * @param[out] X_new X transformed in the new space (output size is [n_samples x n_clusters]. + */ + void + transform(const DataT *X, + int n_samples, + int n_features, + DataT *X_new); + + + /** + * @brief Set centroids to be user provided value X + * + * @param[in] X New data to centroids. + * @param[in] n_samples Number of samples in the input X. + * @param[in] n_features Number of features or the dimensions of each sample. + */ + void + setCentroids(const DataT *X, + int n_samples, + int n_features); + + // returns the raw pointer to the generated centroids + DataT* + centroids(); + + + // release local resources + ~KMeans(); + + + // internal functions (ideally must be private methods, but its not possible today because of the lambda limitations of CUDA compiler) + void + fit(Tensor &); + + void + predict(Tensor &); + + void + initRandom(Tensor &); + + void + initKMeansPlusPlus(Tensor &); + + + +private: + + // Maximum number of iterations of the k-means algorithm for a single run. + int max_iter; + + // Relative tolerance with regards to inertia to declare convergence. + double tol; + + // Sum of squared distances of samples to their closest cluster center. + double inertia; + + // Number of iterations run. + int n_iter; + + // number of clusters to form as well as the number of centroids to generate. + int n_clusters; + + // number of features or the dimensions of the input. + int _n_features; + + // verbosity mode. + int _verbose; + + // Device-accessible allocation of expandable storage used as temorary buffers + MLCommon::device_buffer _workspace; + + // underlying expandable storage that holds centroids data + MLCommon::device_buffer _centroidsRawData; + + // underlying expandable storage that holds labels + MLCommon::device_buffer _labelsRawData; + + const ML::cumlHandle_impl &_handle; + + // method for initialization, defaults to 'k-means++': + kmeans::InitMethod _init; + + MLCommon::Distance::DistanceType _metric; + + // optional + Options options; +}; +} + +#include "kmeans-inl.cuh" + diff --git a/cuML/src/kmeans/kmeans.h b/cuML/src/kmeans/kmeans.h deleted file mode 100644 index e653791285..0000000000 --- a/cuML/src/kmeans/kmeans.h +++ /dev/null @@ -1,91 +0,0 @@ -/*! - * Copyright 2017-2018 H2O.ai, Inc. - * License Apache License Version 2.0 (see LICENSE for details) - */ -#pragma once -#include -#include -#include -#include "kmeans_labels.h" -#include "kmeans_centroids.h" - -template -struct count_functor { - T* pairwise_distances_ptr; - int* counts_ptr; - int k; - int rows_per_run; - - count_functor(T* _pairwise_distances_ptr, int* _counts_ptr, int _k, - int _rows_per_run) { - pairwise_distances_ptr = _pairwise_distances_ptr; - counts_ptr = _counts_ptr; - k = _k; - rows_per_run = _rows_per_run; - } - - __device__ - void operator()(int idx) const { - int closest_centroid_idx = 0; - T best_distance = pairwise_distances_ptr[idx]; - // FIXME potentially slow due to striding - for (int i = 1; i < k; i++) { - T distance = pairwise_distances_ptr[idx + i * rows_per_run]; - - if (distance < best_distance) { - best_distance = distance; - closest_centroid_idx = i; - } - } - atomicAdd(&counts_ptr[closest_centroid_idx], 1); - } -}; - -/** - * Calculates closest centroid for each record and counts how many points are assigned to each centroid. - * @tparam T - * @param verbose - * @param num_gpu - * @param rows_per_gpu - * @param cols - * @param data - * @param data_dots - * @param centroids - * @param weights - * @param pairwise_distances - * @param labels - */ -template -void count_pts_per_centroid(int verbose, int num_gpu, int rows_per_gpu, - int cols, thrust::device_vector **data, - thrust::device_vector **data_dots, thrust::host_vector centroids, - thrust::host_vector &weights) { - int k = centroids.size() / cols; -#pragma omp parallel for - for (int i = 0; i < num_gpu; i++) { - thrust::host_vector weights_tmp(weights.size()); - - CUDACHECK(cudaSetDevice(i)); - thrust::device_vector centroid_dots(k); - thrust::device_vector d_centroids = centroids; - thrust::device_vector counts(k); - - kmeans::detail::batch_calculate_distances(verbose, 0, rows_per_gpu, - cols, k, *data[i], d_centroids, *data_dots[i], centroid_dots, - [&](int rows_per_run, size_t offset, thrust::device_vector &pairwise_distances) { - auto counting = thrust::make_counting_iterator(0); - auto counts_ptr = thrust::raw_pointer_cast(counts.data()); - auto pairwise_distances_ptr = thrust::raw_pointer_cast(pairwise_distances.data()); - thrust::for_each(counting, - counting + rows_per_run, - count_functor(pairwise_distances_ptr, counts_ptr, k, rows_per_run) - ); - }); - - kmeans::detail::memcpy(weights_tmp, counts); - kmeans::detail::streamsync(i); - for (int p = 0; p < k; p++) { - weights[p] += weights_tmp[p]; - } - } -} diff --git a/cuML/src/kmeans/kmeans.hpp b/cuML/src/kmeans/kmeans.hpp new file mode 100644 index 0000000000..f91e62396e --- /dev/null +++ b/cuML/src/kmeans/kmeans.hpp @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace ML { + +namespace kmeans{ + +enum InitMethod{KMeansPlusPlus, Random, Array}; + +/** + * @brief Compute k-means clustering and optionally predicts cluster index for each sample in the input. + * + * @param[in] cumlHandle The handle to the cuML library context that manages the CUDA resources. + * @param[in] n_clusters The number of clusters to form as well as the number of centroids to generate (default:8). + * @param[in] metric Metric to use for distance computation. Any metric from MLCommon::Distance::DistanceType can be used + * @param[in] init Method for initialization, defaults to k-means++: + * - InitMethod::KMeansPlusPlus (k-means++): Use scalable k-means++ algorithm to select the initial cluster centers. + * - InitMethod::Random (random): Choose 'n_clusters' observations (rows) at random from the input data for the initial centroids. + * - InitMethod::Array (ndarray): Use 'centroids' as initial cluster centers. + * @param[in] max_iter Maximum number of iterations for the k-means algorithm. + * @param[in] tol Relative tolerance with regards to inertia to declare convergence. + * @param[in] seed Seed to the random number generator. + * @param[in] X Training instances to cluster. It must be noted that the data must be in row-major format and stored in device accessible location. + * @param[in] n_samples Number of samples in the input X. + * @param[in] n_features Number of features or the dimensions of each sample. + * @param[in|out] centroids [in] When init is InitMethod::Array, use centroids as the initial cluster centers + * [out] Otherwise, generated centroids from the kmeans algorithm is stored at the address pointed by 'centroids'. + * @param[out] labels [optional] Index of the cluster each sample in X belongs to. + */ +void fit_predict(const ML::cumlHandle& handle, + int n_clusters, + int metric, + InitMethod init, + int max_iter, + double tol, + int seed, + const float *X, + int n_samples, + int n_features, + float *centroids, + int *labels = 0, + int verbose = 0); + +void fit_predict(const ML::cumlHandle& handle, + int n_clusters, + int metric, + InitMethod init, + int max_iter, + double tol, + int seed, + const double *X, + int n_samples, + int n_features, + double *centroids, + int *labels = 0, + int verbose = 0); + + +/** + * @brief Compute k-means clustering. + * + * @param[in] cumlHandle The handle to the cuML library context that manages the CUDA resources. + * @param[in] n_clusters The number of clusters to form as well as the number of centroids to generate (default:8). + * @param[in] metric Metric to use for distance computation. Any metric from MLCommon::Distance::DistanceType can be used + * @param[in] init Method for initialization, defaults to k-means++: + * - InitMethod::KMeansPlusPlus (k-means++): Use scalable k-means++ algorithm to select the initial cluster centers. + * - InitMethod::Random (random): Choose 'n_clusters' observations (rows) at random from the input data for the initial centroids. + * - InitMethod::Array (ndarray): Use 'centroids' as initial cluster centers. + * @param[in] max_iter Maximum number of iterations for the k-means algorithm. + * @param[in] tol Relative tolerance with regards to inertia to declare convergence. + * @param[in] seed Seed to the random number generator. + * @param[in] X Training instances to cluster. It must be noted that the data must be in row-major format and stored in device accessible location. + * @param[in] n_samples Number of samples in the input X. + * @param[in] n_features Number of features or the dimensions of each sample. + * @param[in|out] centroids [in] When init is InitMethod::Array, use centroids as the initial cluster centers + * [out] Otherwise, generated centroids from the kmeans algorithm is stored at the address pointed by 'centroids'. + */ +void fit(const ML::cumlHandle& handle, + int n_clusters, + int metric, + InitMethod init, + int max_iter, + double tol, + int seed, + const float *X, + int n_samples, + int n_features, + float *centroids, + int verbose = 0); + +void fit(const ML::cumlHandle& handle, + int n_clusters, + int metric, + InitMethod init, + int max_iter, + double tol, + int seed, + const double *X, + int n_samples, + int n_features, + double *centroids, + int verbose = 0); + + + + +/** + * @brief Predict the closest cluster each sample in X belongs to. + * + * @param[in] cumlHandle The handle to the cuML library context that manages the CUDA resources. + * @param[in] centroids Cluster centroids. It must be noted that the data must be in row-major format and stored in device accessible location. + * @param[in] n_clusters The number of clusters. + * @param[in] X Training instances to cluster. It must be noted that the data must be in row-major format and stored in device accessible location. + * @param[in] n_samples Number of samples in the input X. + * @param[in] n_features Number of features or the dimensions of each sample in 'X' (value should be same as the dimension for each cluster centers in 'centroids'). + * @param[in] metric Metric to use for distance computation. Any metric from MLCommon::Distance::DistanceType can be used + * @param[out] labels Index of the cluster each sample in X belongs to. + */ +void predict(const ML::cumlHandle& handle, + float *centroids, + int n_clusters, + const float *X, + int n_samples, + int n_features, + int metric, + int *labels, + int verbose = 0); + + +void predict(const ML::cumlHandle& handle, + double *centroids, + int n_clusters, + const double *X, + int n_samples, + int n_features, + int metric, + int *labels, + int verbose = 0); + + + + +/** + * @brief Transform X to a cluster-distance space. + * + * @param[in] cumlHandle The handle to the cuML library context that manages the CUDA resources. + * @param[in] centroids Cluster centroids. It must be noted that the data must be in row-major format and stored in device accessible location. + * @param[in] n_clusters The number of clusters. + * @param[in] X Training instances to cluster. It must be noted that the data must be in row-major format and stored in device accessible location. + * @param[in] n_samples Number of samples in the input X. + * @param[in] n_features Number of features or the dimensions of each sample in 'X' (it should be same as the dimension for each cluster centers in 'centroids'). + * @param[in] metric Metric to use for distance computation. Any metric from MLCommon::Distance::DistanceType can be used + * @param[out] X_new X transformed in the new space.. + */ +void transform(const ML::cumlHandle& handle, + const float *centroids, + int n_clusters, + const float *X, + int n_samples, + int n_features, + int metric, + float *X_new, + int verbose = 0); + + +void transform(const ML::cumlHandle& handle, + const double *centroids, + int n_clusters, + const double *X, + int n_samples, + int n_features, + int metric, + double *X_new, + int verbose = 0); + + +};// end namespace kmeans +};// end namespace ML diff --git a/cuML/src/kmeans/kmeans_c.h b/cuML/src/kmeans/kmeans_c.h deleted file mode 100644 index 13d10d1599..0000000000 --- a/cuML/src/kmeans/kmeans_c.h +++ /dev/null @@ -1,76 +0,0 @@ -/*! - * Copyright 2017-2018 H2O.ai, Inc. - * License Apache License Version 2.0 (see LICENSE for details) - */ - -#pragma once -#ifdef __JETBRAINS_IDE__ -#define __host__ -#define __device__ -#endif - -#include -#include -#include "timer.h" - -namespace h2o4gpukmeans { - -template -class H2O4GPUKMeans { -private: - // Data - const M *_A; - int _k; - int _n; - int _d; -public: - H2O4GPUKMeans(const M *A, int k, int n, int d); -}; - -template -class H2O4GPUKMeansCPU { -private: - // Data - const M *_A; - int _k; - int _n; - int _d; -public: - H2O4GPUKMeansCPU(const M *A, int k, int n, int d); -}; - -template -int makePtr_dense(int dopredict, int verbose, int seed, int gpu_id, int n_gpu, - size_t rows, size_t cols, const char ord, int k, int max_iterations, - int init_from_data, T threshold, const T *srcdata, const T *centroids, - T **pred_centroids, int **pred_labels); - -template -int kmeans_transform(int verbose, int gpu_id, int n_gpu, size_t m, size_t n, - const char ord, int k, const T *srcdata, const T *centroids, T **preds); - -} // namespace h2o4gpukmeans - -namespace ML { - -void make_ptr_kmeans(int dopredict, int verbose, int seed, int gpu_id, - int n_gpu, size_t mTrain, size_t n, const char ord, int k, int k_max, - int max_iterations, int init_from_data, float threshold, - const float *srcdata, const float *centroids, float *pred_centroids, - int *pred_labels); - -void make_ptr_kmeans(int dopredict, int verbose, int seed, int gpu_id, - int n_gpu, size_t mTrain, size_t n, const char ord, int k, int k_max, - int max_iterations, int init_from_data, double threshold, - const double *srcdata, const double *centroids, double *pred_centroids, - int *pred_labels); - -void kmeans_transform(int verbose, int gpu_id, int n_gpu, size_t m, size_t n, - const char ord, int k, const float *srcdata, const float *centroids, - float *preds); - -void kmeans_transform(int verbose, int gpu_id, int n_gpu, size_t m, size_t n, - const char ord, int k, const double *srcdata, const double *centroids, - double *preds); - -} diff --git a/cuML/src/kmeans/kmeans_centroids.h b/cuML/src/kmeans/kmeans_centroids.h deleted file mode 100644 index 4fecdbdcaf..0000000000 --- a/cuML/src/kmeans/kmeans_centroids.h +++ /dev/null @@ -1,187 +0,0 @@ -/*! - * Modifications Copyright 2017-2018 H2O.ai, Inc. - */ -// original code from https://github.com/NVIDIA/kmeans (Apache V2.0 License) -#pragma once -#include -#include -#include "kmeans_labels.h" - -inline __device__ double my_atomic_add(double *address, double val) { - unsigned long long int *address_as_ull = (unsigned long long int *) address; - unsigned long long int old = *address_as_ull, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, - __double_as_longlong(val + __longlong_as_double(assumed))); - } while (assumed != old); - return __longlong_as_double(old); -} - -inline __device__ float my_atomic_add(float *address, float val) { - return (atomicAdd(address, val)); -} - -namespace kmeans { -namespace detail { - -template -__device__ __forceinline__ -void update_centroid(int label, int dimension, int d, T accumulator, - T *centroids, int count, int *counts) { - int index = label * d + dimension; - T *target = centroids + index; - my_atomic_add(target, accumulator); - if (dimension == 0) { - atomicAdd(counts + label, count); - } -} - -template -__global__ void calculate_centroids(int n, int d, int k, T *data, - int *ordered_labels, int *ordered_indices, T *centroids, int *counts) { - int in_flight = blockDim.y * gridDim.y; - int labels_per_row = (n - 1) / in_flight + 1; - for (int dimension = threadIdx.x; dimension < d; dimension += blockDim.x) { - T accumulator = 0; - int count = 0; - int global_id = threadIdx.y + blockIdx.y * blockDim.y; - int start = global_id * labels_per_row; - int end = (global_id + 1) * labels_per_row; - end = (end > n) ? n : end; - int prior_label; - if (start < n) { - prior_label = ordered_labels[start]; - - for (int label_number = start; label_number < end; label_number++) { - int label = ordered_labels[label_number]; - if (label != prior_label) { - update_centroid(prior_label, dimension, d, accumulator, - centroids, count, counts); - accumulator = 0; - count = 0; - } - - T value = data[dimension + ordered_indices[label_number] * d]; - accumulator += value; - prior_label = label; - count++; - } - update_centroid(prior_label, dimension, d, accumulator, centroids, - count, counts); - } - } -} - -template -__global__ void revert_zeroed_centroids(int d, int k, T *tmp_centroids, - T *centroids, int *counts) { - int global_id_x = threadIdx.x + blockIdx.x * blockDim.x; - int global_id_y = threadIdx.y + blockIdx.y * blockDim.y; - if ((global_id_x < d) && (global_id_y < k)) { - if (counts[global_id_y] < 1) { - centroids[global_id_x + d * global_id_y] = tmp_centroids[global_id_x - + d * global_id_y]; - } - } -} - -template -__global__ void scale_centroids(int d, int k, int *counts, T *centroids) { - int global_id_x = threadIdx.x + blockIdx.x * blockDim.x; - int global_id_y = threadIdx.y + blockIdx.y * blockDim.y; - if ((global_id_x < d) && (global_id_y < k)) { - int count = counts[global_id_y]; - //To avoid introducing divide by zero errors - //If a centroid has no weight, we'll do no normalization - //This will keep its coordinates defined. - if (count < 1) { - count = 1; - } - T scale = 1.0 / T(count); - centroids[global_id_x + d * global_id_y] *= scale; - } -} - -// Scale - should be true when running on a single GPU -template -void find_centroids(int q, int n, int d, int k, int k_max, - thrust::device_vector &data, thrust::device_vector &labels, - thrust::device_vector ¢roids, thrust::device_vector &range, - thrust::device_vector &indices, thrust::device_vector &counts, - bool scale) { - int dev_num; - cudaGetDevice(&dev_num); - - // If no scaling then this step will be handled on the host - // when aggregating centroids from all GPUs - // Cache original centroids in case some centroids are not present in labels - // and would get zeroed - thrust::device_vector tmp_centroids; - if (scale) { - tmp_centroids = thrust::device_vector(k_max * d); - memcpy(tmp_centroids, centroids); - } - - memcpy(indices, range); - // TODO the rest of the algorithm doesn't necessarily require labels/data to be sorted - // but *might* make if faster due to less atomic updates - thrust::sort_by_key(labels.begin(), labels.end(), indices.begin()); - // TODO cub is faster but sort_by_key_int isn't sorting, possibly a bug -// mycub::sort_by_key_int(labels, indices); - -#if(CHECK) - gpuErrchk(cudaGetLastError()); -#endif - - // Need to zero this - the algo uses this array to accumulate values for each centroid - // which are then averaged to get a new centroid - memzero(centroids); - memzero(counts); - - //Calculate centroids - int n_threads_x = 64; // TODO FIXME - int n_threads_y = 16; // TODO FIXME - //XXX Number of blocks here is hard coded at 30 - //This should be taken care of more thoughtfully. - calculate_centroids<<>>(n, d, k, thrust::raw_pointer_cast(data.data()), - thrust::raw_pointer_cast(labels.data()), - thrust::raw_pointer_cast(indices.data()), - thrust::raw_pointer_cast(centroids.data()), - thrust::raw_pointer_cast(counts.data())); - -#if(CHECK) - gpuErrchk(cudaGetLastError()); -#endif - - // Scaling should take place on the GPU if n_gpus=1 so we don't - // move centroids and counts between host and device all the time for nothing - if (scale) { - // Revert only if `scale`, otherwise this will be taken care of in the host - // Revert reverts centroids for which count is equal 0 - revert_zeroed_centroids<<>>(d, k, - thrust::raw_pointer_cast(tmp_centroids.data()), - thrust::raw_pointer_cast(centroids.data()), - thrust::raw_pointer_cast(counts.data())); - -#if(CHECK) - gpuErrchk(cudaGetLastError()); -#endif - - //Averages the centroids - scale_centroids<<>>(d, k, - thrust::raw_pointer_cast(counts.data()), - thrust::raw_pointer_cast(centroids.data())); -#if(CHECK) - gpuErrchk(cudaGetLastError()); -#endif - } -} - -} -} diff --git a/cuML/src/kmeans/kmeans_general.h b/cuML/src/kmeans/kmeans_general.h deleted file mode 100644 index 4c5a558338..0000000000 --- a/cuML/src/kmeans/kmeans_general.h +++ /dev/null @@ -1,26 +0,0 @@ -/*! - * Copyright 2017-2018 H2O.ai, Inc. - * License Apache License Version 2.0 (see LICENSE for details) - */ -#pragma once -#include "logger.h" -#define MAX_NGPUS 16 - -#define VERBOSE 0 -#define CHECK 1 -#define DEBUGKMEANS 0 - -// TODO(pseudotensor): Avoid throw for python exception handling. Need to avoid all exit's and return exit code all the way back. -#define gpuErrchk(ans) { gpu_assert((ans), __FILE__, __LINE__); } -#define safe_cuda(ans) throw_on_cuda_error((ans), __FILE__, __LINE__); -#define safe_cublas(ans) throw_on_cublas_error((ans), __FILE__, __LINE__); - -#define CUDACHECK(cmd) do { \ - cudaError_t e = cmd; \ - if( e != cudaSuccess ) { \ - printf("Cuda failure %s:%d '%s'\n", \ - __FILE__,__LINE__,cudaGetErrorString(e));\ - fflush( stdout ); \ - exit(EXIT_FAILURE); \ - } \ - } while(0) diff --git a/cuML/src/kmeans/kmeans_impl.h b/cuML/src/kmeans/kmeans_impl.h deleted file mode 100644 index 9b3b943188..0000000000 --- a/cuML/src/kmeans/kmeans_impl.h +++ /dev/null @@ -1,256 +0,0 @@ -/*! - * Modifications Copyright 2017-2018 H2O.ai, Inc. - */ -// original code from https://github.com/NVIDIA/kmeans (Apache V2.0 License) -#pragma once -#include -#include -#include -#include -#include -#include -#include -#include "kmeans_centroids.h" -#include "kmeans_labels.h" -#include "kmeans_general.h" - -namespace kmeans { - -//! kmeans clusters data into k groups -/*! - - \param n Number of data points - \param d Number of dimensions - \param k Number of clusters - \param data Data points, in row-major order. This vector must have - size n * d, and since it's in row-major order, data point x occupies - positions [x * d, (x + 1) * d) in the vector. The vector is passed - by reference since it is shared with the caller and not copied. - \param labels Cluster labels. This vector has size n. - The vector is passed by reference since it is shared with the caller - and not copied. - \param centroids Centroid locations, in row-major order. This - vector must have size k * d, and since it's in row-major order, - centroid x occupies positions [x * d, (x + 1) * d) in the - vector. The vector is passed by reference since it is shared - with the caller and not copied. - \param threshold This controls early termination of the kmeans - iterations. If the ratio of points being reassigned to a different - centroid is less than the threshold, than the iterations are - terminated. Defaults to 1e-3. - \param max_iterations Maximum number of iterations to run - \return The number of iterations actually performed. - */ - -template -int kmeans(int verbose, volatile std::atomic_int *flag, int n, int d, int k, - int k_max, thrust::device_vector **data, - thrust::device_vector **labels, - thrust::device_vector **centroids, - thrust::device_vector **data_dots, std::vector dList, int n_gpu, - int max_iterations, double threshold = 1e-3, bool do_per_iter_check = - true) { - - thrust::device_vector *centroid_dots[n_gpu]; - thrust::device_vector *labels_copy[n_gpu]; - thrust::device_vector *range[n_gpu]; - thrust::device_vector *indices[n_gpu]; - thrust::device_vector *counts[n_gpu]; - thrust::device_vector d_old_centroids; - - thrust::host_vector h_counts(k); - thrust::host_vector h_counts_tmp(k); - thrust::host_vector h_centroids(k * d); - h_centroids = *centroids[0]; // all should be equal - thrust::host_vector h_centroids_tmp(k_max * d); - - T *d_distance_sum[n_gpu]; - - bool unable_alloc = false; -#pragma omp parallel for - for (int q = 0; q < n_gpu; q++) { - log_debug(verbose, "Before kmeans() Allocation: gpu: %d", q); - - safe_cuda(cudaSetDevice(dList[q])); - safe_cuda(cudaMalloc(&d_distance_sum[q], sizeof(T))); - - try { - centroid_dots[q] = new thrust::device_vector(k); - labels_copy[q] = new thrust::device_vector(n / n_gpu); - range[q] = new thrust::device_vector(n / n_gpu); - counts[q] = new thrust::device_vector(k); - indices[q] = new thrust::device_vector(n / n_gpu); - } catch (thrust::system_error &e) { - log_error(verbose, - "Unable to allocate memory for gpu: %d | n/n_gpu: %d | k: %d | d: %d | error: %s", - q, n / n_gpu, k, d, e.what()); - unable_alloc = true; - // throw std::runtime_error(ss.str()); - } catch (std::bad_alloc &e) { - log_error(verbose, - "Unable to allocate memory for gpu: %d | n/n_gpu: %d | k: %d | d: %d | error: %s", - q, n / n_gpu, k, d, e.what()); - unable_alloc = true; - //throw std::runtime_error(ss.str()); - } - - if (!unable_alloc) { - //Create and save "range" for initializing labels - thrust::copy(thrust::counting_iterator(0), - thrust::counting_iterator(n / n_gpu), - (*range[q]).begin()); - } - } - - if (unable_alloc) - return (-1); - - log_debug(verbose, "Before kmeans() Iterations"); - - int i = 0; - bool done = false; - for (; i < max_iterations; i++) { - log_verbose(verbose, "KMeans - Iteration %d/%d", i, max_iterations); - - if (*flag) - continue; - - safe_cuda(cudaSetDevice(dList[0])); - d_old_centroids = *centroids[dList[0]]; - -#pragma omp parallel for - for (int q = 0; q < n_gpu; q++) { - safe_cuda(cudaSetDevice(dList[q])); - - detail::batch_calculate_distances(verbose, q, n / n_gpu, d, k, - *data[q], *centroids[q], *data_dots[q], *centroid_dots[q], - [&](int n, size_t offset, thrust::device_vector &pairwise_distances) { - detail::relabel(n, k, pairwise_distances, *labels[q], offset); - }); - - log_verbose(verbose, "KMeans - Relabeled."); - - detail::memcpy(*labels_copy[q], *labels[q]); - detail::find_centroids(q, n / n_gpu, d, k, k_max, *data[q], - *labels_copy[q], *centroids[q], *range[q], *indices[q], - *counts[q], n_gpu <= 1); - } - - // Scale the centroids on host - if (n_gpu > 1) { - //Average the centroids from each device - for (int p = 0; p < k; p++) - h_counts[p] = 0.0; - for (int q = 0; q < n_gpu; q++) { - safe_cuda(cudaSetDevice(dList[q])); - detail::memcpy(h_counts_tmp, *counts[q]); - detail::streamsync(dList[q]); - for (int p = 0; p < k; p++) - h_counts[p] += h_counts_tmp[p]; - } - - // Zero the centroids only if any of the GPUs actually updated them - for (int p = 0; p < k; p++) { - for (int r = 0; r < d; r++) { - if (h_counts[p] != 0) { - h_centroids[p * d + r] = 0.0; - } - } - } - - for (int q = 0; q < n_gpu; q++) { - safe_cuda(cudaSetDevice(dList[q])); - detail::memcpy(h_centroids_tmp, *centroids[q]); - detail::streamsync(dList[q]); - for (int p = 0; p < k; p++) { - for (int r = 0; r < d; r++) { - if (h_counts[p] != 0) { - h_centroids[p * d + r] += - h_centroids_tmp[p * d + r]; - } - } - } - } - - for (int p = 0; p < k; p++) { - for (int r = 0; r < d; r++) { - // If 0 counts that means we leave the original centroids - if (h_counts[p] == 0) { - h_counts[p] = 1; - } - h_centroids[p * d + r] /= h_counts[p]; - } - } - - //Copy the averaged centroids to each device -#pragma omp parallel for - for (int q = 0; q < n_gpu; q++) { - safe_cuda(cudaSetDevice(dList[q])); - detail::memcpy(*centroids[q], h_centroids); - } - } - - // whether to perform per iteration check - if (do_per_iter_check) { - safe_cuda(cudaSetDevice(dList[0])); - - T - squared_norm = thrust::inner_product( - d_old_centroids.begin(), d_old_centroids.end(), - (*centroids[0]).begin(), - (T) 0.0, - thrust::plus(), - [=]__device__(T left, T right) { - T diff = left - right; - return diff * diff; - } - ); - - if (squared_norm < threshold) { - if (verbose) { - std::cout << "Threshold triggered. Terminating early." - << std::endl; - } - done = true; - } - } - - if (*flag) { - fprintf(stderr, "Signal caught. Terminated early.\n"); - fflush(stderr); - *flag = 0; // set flag - done = true; - } - - if (done || i == max_iterations - 1) { - // Final relabeling - uses final centroids -#pragma omp parallel for - for (int q = 0; q < n_gpu; q++) { - safe_cuda(cudaSetDevice(dList[q])); - detail::batch_calculate_distances(verbose, q, n / n_gpu, d, k, - *data[q], *centroids[q], *data_dots[q], - *centroid_dots[q], - [&](int n, size_t offset, thrust::device_vector &pairwise_distances) { - detail::relabel(n, k, pairwise_distances, *labels[q], offset); - }); - } - break; - } - } - -//#pragma omp parallel for - for (int q = 0; q < n_gpu; q++) { - safe_cuda(cudaSetDevice(dList[q])); - delete (centroid_dots[q]); - delete (labels_copy[q]); - delete (range[q]); - delete (counts[q]); - delete (indices[q]); - } - - log_debug(verbose, "Iterations: %d", i); - - return i; -} - -} diff --git a/cuML/src/kmeans/kmeans_labels.h b/cuML/src/kmeans/kmeans_labels.h deleted file mode 100644 index 0270586874..0000000000 --- a/cuML/src/kmeans/kmeans_labels.h +++ /dev/null @@ -1,907 +0,0 @@ -/*! - * Modifications Copyright 2017-2018 H2O.ai, Inc. - */ -// original code from https://github.com/NVIDIA/kmeans (Apache V2.0 License) -#pragma once -#include -#include "cub/cub.cuh" -#include -#include -#include -#include -#include -#include "kmeans_general.h" -#include -#include - -inline void gpu_assert(cudaError_t code, const char *file, int line, - bool abort = true) { - if (code != cudaSuccess) { - fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, - line); - std::stringstream ss; - ss << file << "(" << line << ")"; - std::string file_and_line; - ss >> file_and_line; - thrust::system_error(code, thrust::cuda_category(), file_and_line); - } -} - -inline cudaError_t throw_on_cuda_error(cudaError_t code, const char *file, - int line) { - if (code != cudaSuccess) { - std::stringstream ss; - ss << file << "(" << line << ")"; - std::string file_and_line; - ss >> file_and_line; - thrust::system_error(code, thrust::cuda_category(), file_and_line); - } - - return code; -} - -#ifdef CUBLAS_API_H_ -// cuBLAS API errors -static const char *cudaGetErrorEnum(cublasStatus_t error) { - switch (error) { - case CUBLAS_STATUS_SUCCESS: - return "CUBLAS_STATUS_SUCCESS"; - - case CUBLAS_STATUS_NOT_INITIALIZED: - return "CUBLAS_STATUS_NOT_INITIALIZED"; - - case CUBLAS_STATUS_ALLOC_FAILED: - return "CUBLAS_STATUS_ALLOC_FAILED"; - - case CUBLAS_STATUS_INVALID_VALUE: - return "CUBLAS_STATUS_INVALID_VALUE"; - - case CUBLAS_STATUS_ARCH_MISMATCH: - return "CUBLAS_STATUS_ARCH_MISMATCH"; - - case CUBLAS_STATUS_MAPPING_ERROR: - return "CUBLAS_STATUS_MAPPING_ERROR"; - - case CUBLAS_STATUS_EXECUTION_FAILED: - return "CUBLAS_STATUS_EXECUTION_FAILED"; - - case CUBLAS_STATUS_INTERNAL_ERROR: - return "CUBLAS_STATUS_INTERNAL_ERROR"; - } - - return ""; -} -#endif - -inline cublasStatus_t throw_on_cublas_error(cublasStatus_t code, - const char *file, int line) { - - if (code != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, "cublas error: %s %s %d\n", cudaGetErrorEnum(code), - file, line); - std::stringstream ss; - ss << file << "(" << line << ")"; - std::string file_and_line; - ss >> file_and_line; - thrust::system_error(code, thrust::cuda_category(), file_and_line); - } - - return code; -} - -extern cudaStream_t cuda_stream[MAX_NGPUS]; - -template -extern __global__ void debugMark() { -} -; - -namespace kmeans { -namespace detail { - -void labels_init(); -void labels_close(); - -extern cublasHandle_t cublas_handle[MAX_NGPUS]; -template -void memcpy(thrust::host_vector > &H, - thrust::device_vector > &D) { - int dev_num; - safe_cuda(cudaGetDevice(&dev_num)); - safe_cuda( - cudaMemcpyAsync(thrust::raw_pointer_cast(H.data()), - thrust::raw_pointer_cast(D.data()), sizeof(T) * D.size(), - cudaMemcpyDeviceToHost, cuda_stream[dev_num])); -} - -template -void memcpy(thrust::device_vector > &D, - thrust::host_vector > &H) { - int dev_num; - safe_cuda(cudaGetDevice(&dev_num)); - safe_cuda( - cudaMemcpyAsync(thrust::raw_pointer_cast(D.data()), - thrust::raw_pointer_cast(H.data()), sizeof(T) * H.size(), - cudaMemcpyHostToDevice, cuda_stream[dev_num])); -} -template -void memcpy(thrust::device_vector > &Do, - thrust::device_vector > &Di) { - int dev_num; - safe_cuda(cudaGetDevice(&dev_num)); - safe_cuda( - cudaMemcpyAsync(thrust::raw_pointer_cast(Do.data()), - thrust::raw_pointer_cast(Di.data()), sizeof(T) * Di.size(), - cudaMemcpyDeviceToDevice, cuda_stream[dev_num])); -} -template -void memzero(thrust::device_vector >& D) { - int dev_num; - safe_cuda(cudaGetDevice(&dev_num)); - safe_cuda( - cudaMemsetAsync(thrust::raw_pointer_cast(D.data()), 0, - sizeof(T) * D.size(), cuda_stream[dev_num])); -} -void streamsync(int dev_num); - -//n: number of points -//d: dimensionality of points -//data: points, laid out in row-major order (n rows, d cols) -//dots: result vector (n rows) -// NOTE: -//Memory accesses in this function are uncoalesced!! -//This is because data is in row major order -//However, in k-means, it's called outside the optimization loop -//on the large data array, and inside the optimization loop it's -//called only on a small array, so it doesn't really matter. -//If this becomes a performance limiter, transpose the data somewhere -template -__global__ void self_dots(int n, int d, T* data, T* dots) { - T accumulator = 0; - int global_id = blockDim.x * blockIdx.x + threadIdx.x; - - if (global_id < n) { - for (int i = 0; i < d; i++) { - T value = data[i + global_id * d]; - accumulator += value * value; - } - dots[global_id] = accumulator; - } -} - -template -void make_self_dots(int n, int d, thrust::device_vector& data, - thrust::device_vector& dots) { - int dev_num; -#define MAX_BLOCK_THREADS0 256 - const int GRID_SIZE = (n - 1) / MAX_BLOCK_THREADS0 + 1; - safe_cuda(cudaGetDevice(&dev_num)); - self_dots<<>>(n, d, - thrust::raw_pointer_cast(data.data()), - thrust::raw_pointer_cast(dots.data())); -#if(CHECK) - gpuErrchk(cudaGetLastError()); -#endif - -} - -#define MAX_BLOCK_THREADS 32 -template -__global__ void all_dots(int n, int k, T* data_dots, T* centroid_dots, - T* dots) { - __shared__ T local_data_dots[MAX_BLOCK_THREADS]; - __shared__ T local_centroid_dots[MAX_BLOCK_THREADS]; - // if(threadIdx.x==0 && threadIdx.y==0 && blockIdx.x==0) printf("inside %d %d %d\n",threadIdx.x,blockIdx.x,blockDim.x); - - int data_index = threadIdx.x + blockIdx.x * blockDim.x; - if ((data_index < n) && (threadIdx.y == 0)) { - local_data_dots[threadIdx.x] = data_dots[data_index]; - } - - int centroid_index = threadIdx.x + blockIdx.y * blockDim.y; - if ((centroid_index < k) && (threadIdx.y == 1)) { - local_centroid_dots[threadIdx.x] = centroid_dots[centroid_index]; - } - - __syncthreads(); - - centroid_index = threadIdx.y + blockIdx.y * blockDim.y; - // printf("data_index=%d centroid_index=%d\n",data_index,centroid_index); - if ((data_index < n) && (centroid_index < k)) { - dots[data_index + centroid_index * n] = local_data_dots[threadIdx.x] - + local_centroid_dots[threadIdx.y]; - } -} - -template -void make_all_dots(int n, int k, size_t offset, - thrust::device_vector& data_dots, - thrust::device_vector& centroid_dots, - thrust::device_vector& dots) { - int dev_num; - safe_cuda(cudaGetDevice(&dev_num)); - const int BLOCK_THREADSX = MAX_BLOCK_THREADS; // BLOCK_THREADSX*BLOCK_THREADSY<=1024 on modern arch's (sm_61) - const int BLOCK_THREADSY = MAX_BLOCK_THREADS; - const int GRID_SIZEX = (n - 1) / BLOCK_THREADSX + 1; // on old arch's this has to be less than 2^16=65536 - const int GRID_SIZEY = (k - 1) / BLOCK_THREADSY + 1; // this has to be less than 2^16=65536 - // printf("pre all_dots: %d %d %d %d\n",GRID_SIZEX,GRID_SIZEY,BLOCK_THREADSX,BLOCK_THREADSY); fflush(stdout); - all_dots<<>>(n, - k, thrust::raw_pointer_cast(data_dots.data() + offset), - thrust::raw_pointer_cast(centroid_dots.data()), - thrust::raw_pointer_cast(dots.data())); -#if(CHECK) - gpuErrchk(cudaGetLastError()); -#endif -} -; - -#define WARP_SIZE 32 -#define BLOCK_SIZE 1024 -#define BSIZE_DIV_WSIZE (BLOCK_SIZE/WARP_SIZE) -#define IDX(i,j,lda) ((i)+(j)*(lda)) -template -__global__ void rejectByCosines(int k, int *accept, T *dists, - T *centroid_dots) { - - // Global indices - int gidx, gidy; - - // Lengths from cosine_dots - float lenA, lenB; - // Threshold - float thresh = 0.9; - - // Observation vector is determined by global y-index - gidy = threadIdx.y + blockIdx.y * blockDim.y; - while (gidy < k) { - // Get lengths from global memory, stored in centroid_dots - lenA = centroid_dots[gidy]; - - gidx = threadIdx.x + blockIdx.x * blockDim.x; - while (gidx < gidy) { - lenB = centroid_dots[gidx]; - if (lenA > 1e-8 && lenB > 1e-8) - dists[IDX(gidx, gidy, k)] /= lenA * lenB; - if (dists[IDX(gidx, gidy, k)] > thresh - && ((lenA < 2.0 * lenB) && (lenB < 2.0 * lenA))) - accept[gidy] = 0; - gidx += blockDim.x * gridDim.x; - } - // Move to another centroid - gidy += blockDim.y * gridDim.y; - } - -} - -template -void checkCosine(int d, int k, int *numChosen, thrust::device_vector &dists, - thrust::device_vector ¢roids, - thrust::device_vector ¢roid_dots) { - - dim3 blockDim, gridDim; - - int h_accept[k]; - - thrust::device_vector accept(k); - thrust::fill(accept.begin(), accept.begin() + k, 1); - //printf("after fill accept\n"); - // Divide dists by centroid lengths to get cosine matrix - blockDim.x = WARP_SIZE; - blockDim.y = BLOCK_SIZE / WARP_SIZE; - blockDim.z = 1; - gridDim.x = min((k + WARP_SIZE - 1) / WARP_SIZE, 65535); - gridDim.y = min((k + BSIZE_DIV_WSIZE - 1) / BSIZE_DIV_WSIZE, 65535); - gridDim.z = 1; - - rejectByCosines<<>>(k, - thrust::raw_pointer_cast(accept.data()), - thrust::raw_pointer_cast(dists.data()), - thrust::raw_pointer_cast(centroid_dots.data())); - -#if(CHECK) - gpuErrchk(cudaGetLastError()); -#endif - - *numChosen = thrust::reduce(accept.begin(), accept.begin() + k); - - CUDACHECK( - cudaMemcpy(h_accept, thrust::raw_pointer_cast(accept.data()), - k * sizeof(int), cudaMemcpyDeviceToHost)); - - int skipcopy = 1; - for (int z = 0; z < *numChosen; ++z) { - if (h_accept[z] == 0) - skipcopy = 0; - } - - if (!skipcopy && (*numChosen > 1 && *numChosen < k)) { - int i, j; - int* candidate_map = new int[d * (*numChosen)]; - j = 0; - for (i = 0; i < k; ++i) { - if (h_accept[i]) { - for (int m = 0; m < d; ++m) - candidate_map[j * d + m] = i * d + m; - j += 1; - } - } - - thrust::device_vector d_candidate_map(d * (*numChosen)); - CUDACHECK(cudaMemcpy(thrust::raw_pointer_cast(d_candidate_map.data()), candidate_map, d*(*numChosen)*sizeof(int), cudaMemcpyHostToDevice)) - ; - - thrust::device_vector cent_copy(dists); - - thrust::copy_n(centroids.begin(), d * k, cent_copy.begin()); - -#if(CHECK) - gpuErrchk(cudaGetLastError()); -#endif - - // Gather accepted centroid candidates into centroid memory - thrust::gather(d_candidate_map.begin(), - d_candidate_map.begin() + d * (*numChosen), cent_copy.begin(), - centroids.begin()); - } -} - -template -void calculate_distances(int verbose, int q, size_t n, int d, int k, - thrust::device_vector& data, size_t data_offset, - thrust::device_vector& centroids, - thrust::device_vector& data_dots, - thrust::device_vector& centroid_dots, - thrust::device_vector& pairwise_distances); - -template -void batch_calculate_distances(int verbose, int q, size_t n, int d, int k, - thrust::device_vector &data, thrust::device_vector ¢roids, - thrust::device_vector &data_dots, - thrust::device_vector ¢roid_dots, F functor) { - int fudges_size = 4; - double fudges[] = { 1.0, 0.75, 0.5, 0.25 }; - for (const double fudge : fudges) { - try { - // Get info about available memory - // This part of the algo can be very memory consuming - // We might need to batch it - size_t free_byte; - size_t total_byte; - CUDACHECK(cudaMemGetInfo(&free_byte, &total_byte)); - free_byte = free_byte * fudge; - - size_t required_byte = n * k * sizeof(T); - - size_t runs = std::ceil(required_byte / (double) free_byte); - - log_verbose(verbose, - "Batch calculate distance - Rows %ld | K %ld | Data size %d", - n, k, sizeof(T)); - - log_verbose(verbose, - "Batch calculate distance - Free memory %zu | Required memory %zu | Runs %d", - free_byte, required_byte, runs); - - size_t offset = 0; - size_t rows_per_run = n / runs; - thrust::device_vector pairwise_distances(rows_per_run * k); - - for (int run = 0; run < runs; run++) { - if (run + 1 == runs && n % rows_per_run != 0) { - rows_per_run = n % rows_per_run; - } - - thrust::fill_n(pairwise_distances.begin(), - pairwise_distances.size(), (T) 0.0); - - log_verbose(verbose, "Batch calculate distance - Allocated"); - - kmeans::detail::calculate_distances(verbose, 0, rows_per_run, d, - k, data, offset, centroids, data_dots, centroid_dots, - pairwise_distances); - - log_verbose(verbose, - "Batch calculate distance - Distances calculated"); - - functor(rows_per_run, offset, pairwise_distances); - - log_verbose(verbose, "Batch calculate distance - Functor ran"); - - offset += rows_per_run; - } - } catch (const std::bad_alloc& e) { - cudaGetLastError(); - if (fudges[fudges_size - 1] != fudge) { - log_warn(verbose, - "Batch calculate distance - Failed to allocate memory for pairwise distances - retrying."); - continue; - } else { - log_error(verbose, - "Batch calculate distance - Failed to allocate memory for pairwise distances - exiting."); - throw e; - } - } - - return; - } -} - -template -__global__ void make_new_labels(int n, int k, T* pairwise_distances, - int* labels) { - T min_distance = FLT_MAX; //std::numeric_limits::max(); // might be ok TODO FIXME - T min_idx = -1; - int global_id = threadIdx.x + blockIdx.x * blockDim.x; - if (global_id < n) { - for (int c = 0; c < k; c++) { - T distance = pairwise_distances[c * n + global_id]; - if (distance < min_distance) { - min_distance = distance; - min_idx = c; - } - } - labels[global_id] = min_idx; - } -} - -template -void relabel(int n, int k, thrust::device_vector& pairwise_distances, - thrust::device_vector& labels, size_t offset) { - int dev_num; - safe_cuda(cudaGetDevice(&dev_num)); -#define MAX_BLOCK_THREADS2 256 - const int GRID_SIZE = (n - 1) / MAX_BLOCK_THREADS2 + 1; - make_new_labels<<>>( - n, k, thrust::raw_pointer_cast(pairwise_distances.data()), - thrust::raw_pointer_cast(labels.data() + offset)); -#if(CHECK) - gpuErrchk(cudaGetLastError()); -#endif -} - -} -} -namespace mycub { - -extern void *d_key_alt_buf[MAX_NGPUS]; -extern unsigned int key_alt_buf_bytes[MAX_NGPUS]; -extern void *d_value_alt_buf[MAX_NGPUS]; -extern unsigned int value_alt_buf_bytes[MAX_NGPUS]; -extern void *d_temp_storage[MAX_NGPUS]; -extern size_t temp_storage_bytes[MAX_NGPUS]; -extern void *d_temp_storage2[MAX_NGPUS]; -extern size_t temp_storage_bytes2[MAX_NGPUS]; -extern bool cub_initted; - -void sort_by_key_int(thrust::device_vector& keys, - thrust::device_vector& values); - -template -void sort_by_key(thrust::device_vector& keys, - thrust::device_vector& values) { - int dev_num; - safe_cuda(cudaGetDevice(&dev_num)); - cudaStream_t this_stream = cuda_stream[dev_num]; - int SIZE = keys.size(); - if (key_alt_buf_bytes[dev_num] < sizeof(T) * SIZE) { - if (d_key_alt_buf[dev_num]) - safe_cuda(cudaFree(d_key_alt_buf[dev_num])); - safe_cuda(cudaMalloc(&d_key_alt_buf[dev_num], sizeof(T) * SIZE)); - key_alt_buf_bytes[dev_num] = sizeof(T) * SIZE; - std::cout << "Malloc key_alt_buf" << std::endl; - } - if (value_alt_buf_bytes[dev_num] < sizeof(U) * SIZE) { - if (d_value_alt_buf[dev_num]) - safe_cuda(cudaFree(d_value_alt_buf[dev_num])); - safe_cuda(cudaMalloc(&d_value_alt_buf[dev_num], sizeof(U) * SIZE)); - value_alt_buf_bytes[dev_num] = sizeof(U) * SIZE; - std::cout << "Malloc value_alt_buf" << std::endl; - } - cub::DoubleBuffer d_keys(thrust::raw_pointer_cast(keys.data()), - (T*) d_key_alt_buf[dev_num]); - cub::DoubleBuffer d_values(thrust::raw_pointer_cast(values.data()), - (U*) d_value_alt_buf[dev_num]); - cudaError_t err; - - // Determine temporary device storage requirements for sorting operation - //if (temp_storage_bytes[dev_num] == 0) { - void *d_temp; - size_t temp_bytes; - err = cub::DeviceRadixSort::SortPairs(d_temp_storage[dev_num], temp_bytes, - d_keys, d_values, SIZE, 0, sizeof(T) * 8, this_stream); - // Allocate temporary storage for sorting operation - safe_cuda(cudaMalloc(&d_temp, temp_bytes)); - d_temp_storage[dev_num] = d_temp; - temp_storage_bytes[dev_num] = temp_bytes; - std::cout << "Malloc temp_storage. " << temp_storage_bytes[dev_num] - << " bytes" << std::endl; - std::cout << "d_temp_storage[" << dev_num << "] = " - << d_temp_storage[dev_num] << std::endl; - if (err) { - std::cout << "Error " << err << " in SortPairs 1" << std::endl; - std::cout << cudaGetErrorString(err) << std::endl; - } - //} - // Run sorting operation - err = cub::DeviceRadixSort::SortPairs(d_temp, temp_bytes, d_keys, d_values, - SIZE, 0, sizeof(T) * 8, this_stream); - if (err) - std::cout << "Error in SortPairs 2" << std::endl; - //cub::DeviceRadixSort::SortPairs(d_temp_storage[dev_num], temp_storage_bytes[dev_num], d_keys, - // d_values, SIZE, 0, sizeof(T)*8, this_stream); - -} -template -void sum_reduce(thrust::device_vector& values, T* sum) { - int dev_num; - safe_cuda(cudaGetDevice(&dev_num)); - if (!d_temp_storage2[dev_num]) { - cub::DeviceReduce::Sum(d_temp_storage2[dev_num], - temp_storage_bytes2[dev_num], - thrust::raw_pointer_cast(values.data()), sum, values.size(), - cuda_stream[dev_num]); - // Allocate temporary storage for sorting operation - safe_cuda( - cudaMalloc(&d_temp_storage2[dev_num], - temp_storage_bytes2[dev_num])); - } - cub::DeviceReduce::Sum(d_temp_storage2[dev_num], - temp_storage_bytes2[dev_num], - thrust::raw_pointer_cast(values.data()), sum, values.size(), - cuda_stream[dev_num]); -} -void cub_init(); -void cub_close(); - -void cub_init(int dev); -void cub_close(int dev); -} - -namespace kmeans { -namespace detail { - -template -struct absolute_value { - __host__ __device__ - void operator()(T &x) const { - x = (x > 0 ? x : -x); - } -}; - -cublasHandle_t cublas_handle[MAX_NGPUS]; - -void labels_init() { - cublasStatus_t stat; - cudaError_t err; - int dev_num; - safe_cuda(cudaGetDevice(&dev_num)); - stat = cublasCreate(&detail::cublas_handle[dev_num]); - if (stat != CUBLAS_STATUS_SUCCESS) { - std::cout << "CUBLAS initialization failed" << std::endl; - exit(1); - } - err = safe_cuda(cudaStreamCreate(&cuda_stream[dev_num])) - ; - if (err != cudaSuccess) { - std::cout << "Stream creation failed" << std::endl; - - } - cublasSetStream(cublas_handle[dev_num], cuda_stream[dev_num]); - mycub::cub_init(dev_num); -} - -void labels_close() { - int dev_num; - safe_cuda(cudaGetDevice(&dev_num)); - safe_cublas(cublasDestroy(cublas_handle[dev_num])); - safe_cuda(cudaStreamDestroy(cuda_stream[dev_num])); - mycub::cub_close(dev_num); -} - -void streamsync(int dev_num) { - cudaStreamSynchronize(cuda_stream[dev_num]); -} - -/** - * Matrix multiplication: alpha * A^T * B + beta * C - * Optimized for tall and skinny matrices - * - * @tparam float_t - * @param A - * @param B - * @param C - * @param alpha - * @param beta - * @param n - * @param d - * @param k - * @param max_block_rows - * @return - */ -template -__global__ void matmul(const float_t *A, const float_t *B, float_t *C, - const float_t alpha, const float_t beta, int n, int d, int k, - int max_block_rows) { - - constexpr int smem_alignment = sizeof(float_t) > 8 ? sizeof(float_t) : 8; - extern __shared__ __align__(smem_alignment) unsigned char my_smem[]; - float_t *shared = reinterpret_cast(my_smem); - - float_t *s_A = shared; - float_t *s_B = shared + max_block_rows * d; - - for (int i = threadIdx.x; i < d * k; i += blockDim.x) { - s_B[i] = B[i]; - } - - size_t block_start_row_index = blockIdx.x * max_block_rows; - size_t block_rows = max_block_rows; - - if (blockIdx.x == gridDim.x - 1 && n % max_block_rows != 0) { - block_rows = n % max_block_rows; - } - - for (size_t i = threadIdx.x; i < d * block_rows; i += blockDim.x) { - s_A[i] = alpha * A[d * block_start_row_index + i]; - } - - __syncthreads(); - - float_t elem_c = 0; - - int col_c = threadIdx.x % k; - size_t abs_row_c = block_start_row_index + threadIdx.x / k; - int row_c = threadIdx.x / k; - - // Thread/Block combination either too far for data array - // Or is calculating for index that should be calculated in a different blocks - in some edge cases - // "col_c * n + abs_row_c" can yield same result in different thread/block combinations - if (abs_row_c >= n || threadIdx.x >= block_rows * k) { - return; - } - - for (size_t i = 0; i < d; i++) { - elem_c += s_B[d * col_c + i] * s_A[d * row_c + i]; - } - - C[col_c * n + abs_row_c] = beta * C[col_c * n + abs_row_c] + elem_c; - -} - -template<> -void calculate_distances(int verbose, int q, size_t n, int d, int k, - thrust::device_vector &data, size_t data_offset, - thrust::device_vector ¢roids, - thrust::device_vector &data_dots, - thrust::device_vector ¢roid_dots, - thrust::device_vector &pairwise_distances) { - detail::make_self_dots(k, d, centroids, centroid_dots); - detail::make_all_dots(n, k, data_offset, data_dots, centroid_dots, - pairwise_distances); - - //||x-y||^2 = ||x||^2 + ||y||^2 - 2 x . y - //pairwise_distances has ||x||^2 + ||y||^2, so beta = 1 - //The dgemm calculates x.y for all x and y, so alpha = -2.0 - double alpha = -2.0; - double beta = 1.0; - //If the data were in standard column major order, we'd do a - //centroids * data ^ T - //But the data is in row major order, so we have to permute - //the arguments a little - int dev_num; - safe_cuda(cudaGetDevice(&dev_num)); - - bool do_cublas = true; - if (k <= 16 && d <= 64) { - const int BLOCK_SIZE_MUL = 128; - int block_rows = std::min((size_t) BLOCK_SIZE_MUL / k, n); - int grid_size = std::ceil(static_cast(n) / block_rows); - - int shared_size_B = d * k * sizeof(double); - size_t shared_size_A = block_rows * d * sizeof(double); - if (shared_size_B + shared_size_A < (1 << 15)) { - - matmul<<>>( - thrust::raw_pointer_cast(data.data() + data_offset * d), - thrust::raw_pointer_cast(centroids.data()), - thrust::raw_pointer_cast(pairwise_distances.data()), alpha, - beta, n, d, k, block_rows); - do_cublas = false; - } - } - - if (do_cublas) { - cublasStatus_t stat = - safe_cublas( - cublasDgemm(detail::cublas_handle[dev_num], CUBLAS_OP_T, CUBLAS_OP_N, n, k, d, &alpha, thrust::raw_pointer_cast(data.data() + data_offset * d), d, //Has to be n or d - thrust::raw_pointer_cast(centroids.data()), d,//Has to be k or d - &beta, thrust::raw_pointer_cast(pairwise_distances.data()), n)) - ; //Has to be n or k - - if (stat != CUBLAS_STATUS_SUCCESS) { - std::cout << "Invalid Dgemm" << std::endl; - exit(1); - } - } - - thrust::for_each(pairwise_distances.begin(), pairwise_distances.end(), - absolute_value()); // in-place transformation to ensure all distances are positive indefinite - -#if(CHECK) - gpuErrchk(cudaGetLastError()); -#endif -} - -template<> -void calculate_distances(int verbose, int q, size_t n, int d, int k, - thrust::device_vector &data, size_t data_offset, - thrust::device_vector ¢roids, - thrust::device_vector &data_dots, - thrust::device_vector ¢roid_dots, - thrust::device_vector &pairwise_distances) { - detail::make_self_dots(k, d, centroids, centroid_dots); - detail::make_all_dots(n, k, data_offset, data_dots, centroid_dots, - pairwise_distances); - - //||x-y||^2 = ||x||^2 + ||y||^2 - 2 x . y - //pairwise_distances has ||x||^2 + ||y||^2, so beta = 1 - //The dgemm calculates x.y for all x and y, so alpha = -2.0 - float alpha = -2.0; - float beta = 1.0; - //If the data were in standard column major order, we'd do a - //centroids * data ^ T - //But the data is in row major order, so we have to permute - //the arguments a little - int dev_num; - safe_cuda(cudaGetDevice(&dev_num)); - - if (k <= 16 && d <= 64) { - const int BLOCK_SIZE_MUL = 128; - int block_rows = std::min((size_t) BLOCK_SIZE_MUL / k, n); - int grid_size = std::ceil(static_cast(n) / block_rows); - - int shared_size_B = d * k * sizeof(float); - int shared_size_A = block_rows * d * sizeof(float); - - matmul<<>>( - thrust::raw_pointer_cast(data.data() + data_offset * d), - thrust::raw_pointer_cast(centroids.data()), - thrust::raw_pointer_cast(pairwise_distances.data()), alpha, - beta, n, d, k, block_rows); - } else { - cublasStatus_t stat = - safe_cublas( - cublasSgemm(detail::cublas_handle[dev_num], CUBLAS_OP_T, CUBLAS_OP_N, n, k, d, &alpha, thrust::raw_pointer_cast(data.data() + data_offset * d), d, //Has to be n or d - thrust::raw_pointer_cast(centroids.data()), d,//Has to be k or d - &beta, thrust::raw_pointer_cast(pairwise_distances.data()), n)) - ; //Has to be n or k - - if (stat != CUBLAS_STATUS_SUCCESS) { - std::cout << "Invalid Sgemm" << std::endl; - exit(1); - } - } - - thrust::for_each(pairwise_distances.begin(), pairwise_distances.end(), - absolute_value()); // in-place transformation to ensure all distances are positive indefinite - -#if(CHECK) - gpuErrchk(cudaGetLastError()); -#endif -} - -} -} - -namespace mycub { - -void *d_key_alt_buf[MAX_NGPUS]; -unsigned int key_alt_buf_bytes[MAX_NGPUS]; -void *d_value_alt_buf[MAX_NGPUS]; -unsigned int value_alt_buf_bytes[MAX_NGPUS]; -void *d_temp_storage[MAX_NGPUS]; -size_t temp_storage_bytes[MAX_NGPUS]; -void *d_temp_storage2[MAX_NGPUS]; -size_t temp_storage_bytes2[MAX_NGPUS]; -bool cub_initted; -void cub_init() { - // std::cout <<"CUB init" << std::endl; - for (int q = 0; q < MAX_NGPUS; q++) { - d_key_alt_buf[q] = NULL; - key_alt_buf_bytes[q] = 0; - d_value_alt_buf[q] = NULL; - value_alt_buf_bytes[q] = 0; - d_temp_storage[q] = NULL; - temp_storage_bytes[q] = 0; - d_temp_storage2[q] = NULL; - temp_storage_bytes2[q] = 0; - } - cub_initted = true; -} - -void cub_init(int dev) { - d_key_alt_buf[dev] = NULL; - key_alt_buf_bytes[dev] = 0; - d_value_alt_buf[dev] = NULL; - value_alt_buf_bytes[dev] = 0; - d_temp_storage[dev] = NULL; - temp_storage_bytes[dev] = 0; - d_temp_storage2[dev] = NULL; - temp_storage_bytes2[dev] = 0; -} - -void cub_close() { - for (int q = 0; q < MAX_NGPUS; q++) { - if (d_key_alt_buf[q]) - safe_cuda(cudaFree(d_key_alt_buf[q])); - if (d_value_alt_buf[q]) - safe_cuda(cudaFree(d_value_alt_buf[q])); - if (d_temp_storage[q]) - safe_cuda(cudaFree(d_temp_storage[q])); - if (d_temp_storage2[q]) - safe_cuda(cudaFree(d_temp_storage2[q])); - d_temp_storage[q] = NULL; - d_temp_storage2[q] = NULL; - } - cub_initted = false; -} - -void cub_close(int dev) { - if (d_key_alt_buf[dev]) - safe_cuda(cudaFree(d_key_alt_buf[dev])); - if (d_value_alt_buf[dev]) - safe_cuda(cudaFree(d_value_alt_buf[dev])); - if (d_temp_storage[dev]) - safe_cuda(cudaFree(d_temp_storage[dev])); - if (d_temp_storage2[dev]) - safe_cuda(cudaFree(d_temp_storage2[dev])); - d_temp_storage[dev] = NULL; - d_temp_storage2[dev] = NULL; -} - -void sort_by_key_int(thrust::device_vector &keys, - thrust::device_vector &values) { - int dev_num; - safe_cuda(cudaGetDevice(&dev_num)); - cudaStream_t this_stream = cuda_stream[dev_num]; - int SIZE = keys.size(); - //int *d_key_alt_buf, *d_value_alt_buf; - if (key_alt_buf_bytes[dev_num] < sizeof(int) * SIZE) { - if (d_key_alt_buf[dev_num]) - safe_cuda(cudaFree(d_key_alt_buf[dev_num])); - safe_cuda(cudaMalloc(&d_key_alt_buf[dev_num], sizeof(int) * SIZE)); - key_alt_buf_bytes[dev_num] = sizeof(int) * SIZE; - } - if (value_alt_buf_bytes[dev_num] < sizeof(int) * SIZE) { - if (d_value_alt_buf[dev_num]) - safe_cuda(cudaFree(d_value_alt_buf[dev_num])); - safe_cuda(cudaMalloc(&d_value_alt_buf[dev_num], sizeof(int) * SIZE)); - value_alt_buf_bytes[dev_num] = sizeof(int) * SIZE; - } - cub::DoubleBuffer d_keys(thrust::raw_pointer_cast(keys.data()), - (int *) d_key_alt_buf[dev_num]); - cub::DoubleBuffer d_values(thrust::raw_pointer_cast(values.data()), - (int *) d_value_alt_buf[dev_num]); - - // Determine temporary device storage requirements for sorting operation - if (!d_temp_storage[dev_num]) { - cub::DeviceRadixSort::SortPairs(d_temp_storage[dev_num], - temp_storage_bytes[dev_num], d_keys, d_values, SIZE, 0, - sizeof(int) * 8, this_stream); - // Allocate temporary storage for sorting operation - safe_cuda( - cudaMalloc(&d_temp_storage[dev_num], - temp_storage_bytes[dev_num])); - } - // Run sorting operation - cub::DeviceRadixSort::SortPairs(d_temp_storage[dev_num], - temp_storage_bytes[dev_num], d_keys, d_values, SIZE, 0, - sizeof(int) * 8, this_stream); - // Sorted keys and values are referenced by d_keys.Current() and d_values.Current() - - keys.data() = thrust::device_pointer_cast(d_keys.Current()); - values.data() = thrust::device_pointer_cast(d_values.Current()); -} - -} diff --git a/cuML/src/kmeans/logger.h b/cuML/src/kmeans/logger.h deleted file mode 100644 index 75b6067aa6..0000000000 --- a/cuML/src/kmeans/logger.h +++ /dev/null @@ -1,51 +0,0 @@ -/*! - * Copyright 2017-2018 H2O.ai, Inc. - * License Apache License Version 2.0 (see LICENSE for details) - */ -#pragma once - -#include -#include -#include -#include -#include - -#define H2O4GPU_LOG_NOTHING 0 // Fatals are errors terminating the program immediately -#define H2O4GPU_LOG_FATAL 100 // Fatals are errors terminating the program immediately -#define H2O4GPU_LOG_ERROR 200 // Errors are when the program may not exit -#define H2O4GPU_LOG_INFO 300 // Info -#define H2O4GPU_LOG_WARN 400 // Warns about unwanted, but not dangerous, state/behaviour -#define H2O4GPU_LOG_DEBUG 500 // Most basic debug information -#define H2O4GPU_LOG_VERBOSE 600 // Everything possible - -#define log_fatal(desired_level, ...) log(desired_level, H2O4GPU_LOG_FATAL, __FILE__, __LINE__, __VA_ARGS__) -#define log_error(desired_level, ...) log(desired_level, H2O4GPU_LOG_ERROR, __FILE__, __LINE__, __VA_ARGS__) -#define log_info(desired_level, ...) log(desired_level, H2O4GPU_LOG_INFO, __FILE__, __LINE__, __VA_ARGS__) -#define log_warn(desired_level, ...) log(desired_level, H2O4GPU_LOG_WARN, __FILE__, __LINE__, __VA_ARGS__) -#define log_debug(desired_level, ...) log(desired_level, H2O4GPU_LOG_DEBUG, __FILE__, __LINE__, __VA_ARGS__) -#define log_verbose(desired_level, ...) log(desired_level, H2O4GPU_LOG_VERBOSE, __FILE__, __LINE__, __VA_ARGS__) - -static const char *levels[] = { "NOTHING", "FATAL", "ERROR", "INFO", "WARN", - "DEBUG", "VERBOSE" }; - -bool should_log(const int desired_lvl, const int verbosity) { - return verbosity > H2O4GPU_LOG_NOTHING && verbosity <= desired_lvl; -} - -void log(int desired_level, int level, const char *file, int line, - const char *fmt, ...) { - if (should_log(desired_level, level)) { - time_t now = time(NULL); - struct tm *local_time = localtime(&now); - - va_list args; - char buf[16]; - buf[strftime(buf, sizeof(buf), "%H:%M:%S", local_time)] = '\0'; - fprintf(stderr, "%s %-5s %s:%d: ", buf, levels[level / 100], file, - line); - va_start(args, fmt); - vfprintf(stderr, fmt, args); - va_end(args); - fprintf(stderr, "\n"); - } -} diff --git a/cuML/src/kmeans/timer.h b/cuML/src/kmeans/timer.h deleted file mode 100644 index 14daabba64..0000000000 --- a/cuML/src/kmeans/timer.h +++ /dev/null @@ -1,18 +0,0 @@ -/*! - * Modifications Copyright 2017-2018 H2O.ai, Inc. - */ -#ifndef TIMER_H_ -#define TIMER_H_ - -#include -#include - -template -T timer() { - struct timeval tv; - gettimeofday(&tv, NULL); - return static_cast(tv.tv_sec) - + static_cast(tv.tv_usec) * static_cast(1e-6); -} - -#endif // UTIL_H_ diff --git a/cuML/src/kmeans/utils.h b/cuML/src/kmeans/utils.h deleted file mode 100644 index a3b7b25b3d..0000000000 --- a/cuML/src/kmeans/utils.h +++ /dev/null @@ -1,59 +0,0 @@ -/*! - * Copyright 2017-2018 H2O.ai, Inc. - * License Apache License Version 2.0 (see LICENSE for details) - */ -#pragma once -#include -// #include "cblas.h" -#include "logger.h" - -template -void self_dot(std::vector array_in, int n, int dim, std::vector& dots) { - for (int pt = 0; pt < n; pt++) { - T sum = 0.0; - for (int i = 0; i < dim; i++) { - sum += array_in[pt * dim + i] * array_in[pt * dim + i]; - } - dots[pt] = sum; - } -} - -// void compute_distances(std::vector data_in, -// std::vector centroids_in, -// std::vector &pairwise_distances, -// int n, int dim, int k) { -// std::vector data_dots(n); -// std::vector centroid_dots(k); -// self_dot(data_in, n, dim, data_dots); -// self_dot(centroids_in, k, dim, centroid_dots); -// for (int nn=0; nn data_in, -// std::vector centroids_in, -// std::vector &pairwise_distances, -// int n, int dim, int k) { -// std::vector data_dots(n); -// std::vector centroid_dots(k); -// self_dot(data_in, n, dim, data_dots); -// self_dot(centroids_in, k, dim, centroid_dots); -// for (int nn=0; nn namespace ML { @@ -24,20 +24,40 @@ namespace ML { /** * Calculates the "Coefficient of Determination" (R-Squared) score - * normalizing the sum of squared errors by the total sum of squares. + * normalizing the sum of squared errors by the total sum of squares + * with single precision. * * This score indicates the proportionate amount of variation in an * expected response variable is explained by the independent variables * in a linear regression model. The larger the R-squared value, the * more variability is explained by the linear regression model. * + * @param handle: cumlHandle * @param y: Array of ground-truth response variables * @param y_hat: Array of predicted response variables * @param n: Number of elements in y and y_hat * @return: The R-squared value. */ - T r2_scoreFloat(float *y, float *y_hat, int n) { - return MLCommon::Metrics::r2_score(y, y_hat, n); - } + float r2_score_py(const cumlHandle& handle, float *y, float *y_hat, int n); + + + /** + * Calculates the "Coefficient of Determination" (R-Squared) score + * normalizing the sum of squared errors by the total sum of squares + * with double precision. + * + * This score indicates the proportionate amount of variation in an + * expected response variable is explained by the independent variables + * in a linear regression model. The larger the R-squared value, the + * more variability is explained by the linear regression model. + * + * @param handle: cumlHandle + * @param y: Array of ground-truth response variables + * @param y_hat: Array of predicted response variables + * @param n: Number of elements in y and y_hat + * @return: The R-squared value. + */ + double r2_score_py(const cumlHandle& handle, double *y, double *y_hat, int n); + } } diff --git a/cuML/src/ml_cuda_utils.h b/cuML/src/ml_cuda_utils.h index 5eaa3d4874..a4ae636500 100644 --- a/cuML/src/ml_cuda_utils.h +++ b/cuML/src/ml_cuda_utils.h @@ -15,12 +15,32 @@ */ #include - +#include "utils.h" namespace ML { - int get_device(const void *ptr) { - cudaPointerAttributes att; - cudaPointerGetAttributes(&att, ptr); - return att.device; +int get_device(const void *ptr) { + cudaPointerAttributes att; + cudaPointerGetAttributes(&att, ptr); + return att.device; +} + +cudaMemoryType +memory_type(const void *p){ + cudaPointerAttributes att; + cudaError_t err = cudaPointerGetAttributes(&att, p); + ASSERT(err == cudaSuccess || + err == cudaErrorInvalidValue, "%s", cudaGetErrorString(err)); + + if (err == cudaErrorInvalidValue) { + // Make sure the current thread error status has been reset + err = cudaGetLastError(); + ASSERT(err == cudaErrorInvalidValue, "%s", cudaGetErrorString(err)); } +#if CUDA_VERSION >= 10000 + return att.type; +#else + return att.memoryType; +#endif +} + } diff --git a/cuML/test/dbscan_test.cu b/cuML/test/dbscan_test.cu index ab22fedfc5..fc5f0d375f 100644 --- a/cuML/test/dbscan_test.cu +++ b/cuML/test/dbscan_test.cu @@ -42,6 +42,7 @@ template return os; } + template class DbscanTest: public ::testing::TestWithParam > { protected: @@ -69,10 +70,14 @@ protected: int min_pts = 2; cumlHandle handle; handle.setStream(stream); - dbscanFit(handle, data, params.n_row, params.n_col, eps, min_pts, labels); + + // forces a batch size of 2 + size_t max_bytes_per_batch = 4 * sizeof(T) * 8; + + dbscanFit(handle, data, params.n_row, params.n_col, eps, min_pts, labels, max_bytes_per_batch, true); CUDA_CHECK( cudaStreamSynchronize(stream) ); - CUDA_CHECK( cudaStreamDestroy(stream) ); + CUDA_CHECK( cudaStreamDestroy(stream) ); } void SetUp() override { @@ -99,6 +104,7 @@ const std::vector > inputsd2 = { { 0.05, 6 * 2, 6, 2, 1234ULL }}; + typedef DbscanTest DbscanTestF; TEST_P(DbscanTestF, Result) { ASSERT_TRUE( diff --git a/cuML/test/kmeans_test.cu b/cuML/test/kmeans_test.cu index 16ae03bbb8..4ca3416b80 100644 --- a/cuML/test/kmeans_test.cu +++ b/cuML/test/kmeans_test.cu @@ -27,105 +27,120 @@ using namespace MLCommon; template struct KmeansInputs { int n_clusters; - T tol; - int n_row; - int n_col; + T tol; + int n_row; + int n_col; }; template class KmeansTest: public ::testing::TestWithParam > { protected: - void basicTest() { - params = ::testing::TestWithParam>::GetParam(); + void basicTest() { + params = ::testing::TestWithParam>::GetParam(); int m = params.n_row; int n = params.n_col; int k = params.n_clusters; - // make space for outputs : pred_centroids, pred_labels - // and reference output : labels_ref - allocate(d_srcdata, n * m); - allocate(labels_fit, m); - allocate(labels_ref_fit, m); - allocate(pred_centroids, k * n); - allocate(centroids_ref, k * n); + // make space for outputs : d_centroids, d_labels + // and reference output : d_labels_ref + allocate(d_srcdata, n * m); + allocate(d_labels, m); + allocate(d_labels_ref, m); + allocate(d_centroids, k * n); + allocate(d_centroids_ref, k * n); // make testdata on host - std::vector h_srcdata = {1.0,1.0,3.0,4.0, 1.0,2.0,2.0,3.0}; + std::vector h_srcdata = {1.0,1.0, + 3.0,4.0, + 1.0,2.0, + 2.0,3.0}; h_srcdata.resize(n * m); updateDevice(d_srcdata, h_srcdata.data(), m*n, stream); // make and assign reference output - std::vector h_labels_ref_fit = {1, 1, 0, 0}; - h_labels_ref_fit.resize(m); - updateDevice(labels_ref_fit, h_labels_ref_fit.data(), m, stream); + std::vector h_labels_ref = {0, 1, 0, 1}; + h_labels_ref.resize(m); + updateDevice(d_labels_ref, h_labels_ref.data(), m, stream); - std::vector h_centroids_ref = {3.5,2.5, 1.0,1.5}; + std::vector h_centroids_ref = {1.0,1.5, + 2.5,3.5}; h_centroids_ref.resize(k * n); - updateDevice(centroids_ref, h_centroids_ref.data(), k * n, stream); + updateDevice(d_centroids_ref, h_centroids_ref.data(), k * n, stream); + + cumlHandle handle; + handle.setStream(stream); // The actual kmeans api calls - // fit - make_ptr_kmeans(0, verbose, seed, gpu_id, n_gpu, m, n, - ord, k, k, max_iterations, - init_from_data, params.tol, d_srcdata, nullptr, pred_centroids, labels_fit); + // fit + kmeans::fit_predict(handle, + k, + metric, + init, + max_iterations, + params.tol, + seed, + d_srcdata, + m, + n, + d_centroids, + d_labels); + CUDA_CHECK( cudaStreamSynchronize(stream) ); } - void SetUp() override { - CUDA_CHECK(cudaStreamCreate(&stream)); - basicTest(); - } + void SetUp() override { + CUDA_CHECK(cudaStreamCreate(&stream)); + basicTest(); + } - void TearDown() override { + void TearDown() override { CUDA_CHECK(cudaFree(d_srcdata)); - CUDA_CHECK(cudaFree(labels_fit)); - CUDA_CHECK(cudaFree(pred_centroids)); - CUDA_CHECK(cudaFree(labels_ref_fit)); - CUDA_CHECK(cudaFree(centroids_ref)); - CUDA_CHECK(cudaStreamDestroy(stream)); - } + CUDA_CHECK(cudaFree(d_labels)); + CUDA_CHECK(cudaFree(d_centroids)); + CUDA_CHECK(cudaFree(d_labels_ref)); + CUDA_CHECK(cudaFree(d_centroids_ref)); + CUDA_CHECK(cudaStreamDestroy(stream)); + } protected: - KmeansInputs params; - T *d_srcdata; - int *labels_fit, *labels_ref_fit; - T *pred_centroids, *centroids_ref; + KmeansInputs params; + T *d_srcdata; + int *d_labels, *d_labels_ref; + T *d_centroids, *d_centroids_ref; int verbose = 0; - int seed = 1; - int gpu_id = 0; - int n_gpu = -1; - char ord = 'c'; // here c means col order, NOT C (vs F) order + int seed = 0; int max_iterations = 300; - int init_from_data = 0; + kmeans::InitMethod init = kmeans::InitMethod::Random; + int metric = 1; cudaStream_t stream; }; const std::vector > inputsf2 = { - { 2, 0.05f, 4, 2 }}; + { 2, 0.05f, 4, 2 }}; const std::vector > inputsd2 = { - { 2, 0.05, 4, 2 }}; + { 2, 0.05, 4, 2 }}; // FIXME: These tests are disabled due to being too sensitive to RNG: // https://github.com/rapidsai/cuml/issues/71 typedef KmeansTest KmeansTestF; -TEST_P(KmeansTestF, DISABLED_Fit) { - ASSERT_TRUE( - devArrMatch(labels_ref_fit, labels_fit, params.n_row, - CompareApproxAbs(params.tol))); - ASSERT_TRUE( - devArrMatch(centroids_ref, pred_centroids, params.n_clusters * params.n_col, - CompareApproxAbs(params.tol))); +TEST_P(KmeansTestF, Result) { + ASSERT_TRUE( + devArrMatch(d_labels_ref, d_labels, params.n_row, + CompareApproxAbs(params.tol))); + ASSERT_TRUE( + devArrMatch(d_centroids_ref, d_centroids, params.n_clusters * params.n_col, + CompareApproxAbs(params.tol))); } typedef KmeansTest KmeansTestD; -TEST_P(KmeansTestD, DISABLED_Fit) { - ASSERT_TRUE( - devArrMatch(labels_ref_fit, labels_fit, params.n_row, - CompareApproxAbs(params.tol))); - ASSERT_TRUE( - devArrMatch(centroids_ref, pred_centroids, params.n_clusters * params.n_col, - CompareApproxAbs(params.tol))); +TEST_P(KmeansTestD, Result) { + ASSERT_TRUE( + devArrMatch(d_labels_ref, d_labels, params.n_row, + CompareApproxAbs(params.tol))); + ASSERT_TRUE( + devArrMatch(d_centroids_ref, d_centroids, params.n_clusters * params.n_col, + CompareApproxAbs(params.tol))); } INSTANTIATE_TEST_CASE_P(KmeansTests, KmeansTestF, ::testing::ValuesIn(inputsf2)); diff --git a/ml-prims/src/linalg/reduce_rows_by_key.h b/ml-prims/src/linalg/reduce_rows_by_key.h index b27e043773..0251e8f40e 100644 --- a/ml-prims/src/linalg/reduce_rows_by_key.h +++ b/ml-prims/src/linalg/reduce_rows_by_key.h @@ -342,7 +342,7 @@ void sum_rows_by_key_large_nkeys_rowmajor( const DataIteratorT d_A, grid.z = std::min(grid.z, MAX_BLOCKS); //std::cout << "block = " << block.x << ", " << block.y << std::endl; //std::cout << "grid = " << grid.x << ", " << grid.y << ", " << grid.z << std::endl; - cudaMemset(d_sums, 0, sizeof(DataType)*nkeys*ncols); + cudaMemsetAsync(d_sums, 0, sizeof(DataType)*nkeys*ncols, st); sum_rows_by_key_large_nkeys_kernel_rowmajor<<>>(d_A, lda, d_keys, nrows, ncols, key_offset, nkeys, d_sums); } diff --git a/ml-prims/src/metrics/metrics.h b/ml-prims/src/score/scores.h similarity index 93% rename from ml-prims/src/metrics/metrics.h rename to ml-prims/src/score/scores.h index 28947a80fa..ca2227ad38 100644 --- a/ml-prims/src/metrics/metrics.h +++ b/ml-prims/src/score/scores.h @@ -27,7 +27,7 @@ #include namespace MLCommon { - namespace Metrics { + namespace Score { template /** @@ -56,14 +56,14 @@ namespace MLCommon { MLCommon::allocate(sse_arr, n); MLCommon::LinAlg::eltwiseSub(sse_arr, y, y_hat, n, stream); - MLCommon::LinAlg::powerScalar(sse_arr, sse_arr, 2.0f, n, stream); + MLCommon::LinAlg::powerScalar(sse_arr, sse_arr, math_t(2.0), n, stream); CUDA_CHECK(cudaPeekAtLastError()); math_t *ssto_arr; MLCommon::allocate(ssto_arr, n); MLCommon::LinAlg::subtractDevScalar(ssto_arr, y, y_bar, n, stream); - MLCommon::LinAlg::powerScalar(ssto_arr, ssto_arr, 2.0f, n, stream); + MLCommon::LinAlg::powerScalar(ssto_arr, ssto_arr, math_t(2.0), n, stream); CUDA_CHECK(cudaPeekAtLastError()); thrust::device_ptr d_sse = thrust::device_pointer_cast(sse_arr); diff --git a/ml-prims/test/CMakeLists.txt b/ml-prims/test/CMakeLists.txt index 586f3a69d1..9bf642c15b 100644 --- a/ml-prims/test/CMakeLists.txt +++ b/ml-prims/test/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2018, NVIDIA CORPORATION. +# Copyright (c) 2018-2019, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -43,18 +43,22 @@ add_executable(mlcommon_test gather.cu gemm.cu grid_sync.cu + hinge.cu kselection.cu + linearReg.cu + log.cu + logisticReg.cu map_then_reduce.cu math.cu matrix.cu matrix_vector_op.cu mean.cu mean_center.cu - metrics.cu minmax.cu mvg.cu multiply.cu norm.cu + penalty.cu permute.cu power.cu reduce.cu @@ -62,6 +66,8 @@ add_executable(mlcommon_test rng.cu rng_int.cu rsvd.cu + score.cu + sigmoid.cu sqrt.cu stddev.cu strided_reduction.cu @@ -70,12 +76,6 @@ add_executable(mlcommon_test svd.cu transpose.cu unary_op.cu - hinge.cu - linearReg.cu - log.cu - logisticReg.cu - penalty.cu - sigmoid.cu weighted_mean.cu ) diff --git a/ml-prims/test/metrics.cu b/ml-prims/test/score.cu similarity index 82% rename from ml-prims/test/metrics.cu rename to ml-prims/test/score.cu index c85135d5f6..eb5dd87b1c 100644 --- a/ml-prims/test/metrics.cu +++ b/ml-prims/test/score.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "metrics/metrics.h" +#include "score/scores.h" #include #include "random/rng.h" #include "test_utils.h" @@ -24,9 +24,9 @@ #include namespace MLCommon { -namespace Metrics { +namespace Score { -class MetricsTest : public ::testing::Test { +class ScoreTest : public ::testing::Test { protected: void SetUp() override {} @@ -34,8 +34,8 @@ protected: }; -typedef MetricsTest MetricsTestHighScore; -TEST(MetricsTestHighScore, Result) { +typedef ScoreTest ScoreTestHighScore; +TEST(ScoreTestHighScore, Result) { float y[5] = {0.1, 0.2, 0.3, 0.4, 0.5}; float y_hat[5] = {0.12, 0.22, 0.32, 0.42, 0.52}; @@ -51,13 +51,13 @@ TEST(MetricsTestHighScore, Result) { MLCommon::updateDevice(d_y_hat, y_hat, 5, stream); MLCommon::updateDevice(d_y, y, 5, stream); - float result = MLCommon::Metrics::r2_score(d_y, d_y_hat, 5, stream); + float result = MLCommon::Score::r2_score(d_y, d_y_hat, 5, stream); ASSERT_TRUE(result == 0.98f); CUDA_CHECK(cudaStreamDestroy(stream)); } -typedef MetricsTest MetricsTestLowScore; -TEST(MetricsTestLowScore, Result) { +typedef ScoreTest ScoreTestLowScore; +TEST(ScoreTestLowScore, Result) { float y[5] = {0.1, 0.2, 0.3, 0.4, 0.5}; float y_hat[5] = {0.012, 0.022, 0.032, 0.042, 0.052}; @@ -73,7 +73,7 @@ TEST(MetricsTestLowScore, Result) { MLCommon::updateDevice(d_y_hat, y_hat, 5, stream); MLCommon::updateDevice(d_y, y, 5, stream); - float result = MLCommon::Metrics::r2_score(d_y, d_y_hat, 5, stream); + float result = MLCommon::Score::r2_score(d_y, d_y_hat, 5, stream); std::cout << "Result: " << result - -3.4012f << std::endl; ASSERT_TRUE(result - -3.4012f < 0.00001); diff --git a/python/cuml/__init__.py b/python/cuml/__init__.py index 65395ff614..690e0f7f7f 100644 --- a/python/cuml/__init__.py +++ b/python/cuml/__init__.py @@ -31,6 +31,8 @@ from cuml.linear_model.lasso import Lasso from cuml.linear_model.elastic_net import ElasticNet +from cuml.metrics.regression import r2_score + from cuml.neighbors.nearest_neighbors import NearestNeighbors from cuml.utils.pointer_utils import device_of_gpu_matrix diff --git a/python/cuml/cluster/dbscan.pyx b/python/cuml/cluster/dbscan.pyx index a9db1646fa..2e624451f6 100644 --- a/python/cuml/cluster/dbscan.pyx +++ b/python/cuml/cluster/dbscan.pyx @@ -43,7 +43,9 @@ cdef extern from "dbscan/dbscan.hpp" namespace "ML": int n_cols, float eps, int min_pts, - int *labels) + int *labels, + size_t max_bytes_per_batch, + bool verbose) cdef void dbscanFit(cumlHandle& handle, double *input, @@ -51,7 +53,9 @@ cdef extern from "dbscan/dbscan.hpp" namespace "ML": int n_cols, double eps, int min_pts, - int *labels) + int *labels, + size_t max_bytes_per_batch, + bool verbose) class DBSCAN(Base): @@ -103,6 +107,14 @@ class DBSCAN(Base): an important core point (including the point itself). verbose : bool Whether to print debug spews + max_bytes_per_batch : (optional) int64 + Calculate batch size using no more than this number of bytes for the pairwise + distance computation. This enables the trade-off between runtime and memory usage for + making the N^2 pairwise distance computations more tractable for large numbers of samples. + If you are experiencing out of memory errors when running DBSCAN, you can set this + value based on the memory size of your device. Note: this option does not set + the maximum total memory used in the DBSCAN computation and so this value will not + be able to be set to the total memory available on the device. Attributes ----------- @@ -126,12 +138,18 @@ class DBSCAN(Base): For additional docs, see `scikitlearn's DBSCAN `_. """ - def __init__(self, eps=0.5, handle=None, min_samples=5, verbose=False): + def __init__(self, eps=0.5, handle=None, min_samples=5, verbose=False, max_bytes_per_batch = None): super(DBSCAN, self).__init__(handle, verbose) self.eps = eps self.min_samples = min_samples self.labels_ = None self.labels_array = None + self.max_bytes_per_batch = max_bytes_per_batch + self.verbose = verbose + + # C++ API expects this to be numeric. + if self.max_bytes_per_batch is None: + self.max_bytes_per_batch = 0; def _get_ctype_ptr(self, obj): # The manner to access the pointers in the gdf's might change, so @@ -182,6 +200,8 @@ class DBSCAN(Base): self.labels_ = cudf.Series(np.zeros(self.n_rows, dtype=np.int32)) self.labels_array = self.labels_._column._data.to_gpu_array() cdef uintptr_t labels_ptr = self._get_ctype_ptr(self.labels_array) + + if self.gdf_datatype.type == np.float32: dbscanFit(handle_[0], input_ptr, @@ -189,7 +209,9 @@ class DBSCAN(Base): self.n_cols, self.eps, self.min_samples, - labels_ptr) + labels_ptr, + self.max_bytes_per_batch, + self.verbose) else: dbscanFit(handle_[0], input_ptr, @@ -197,7 +219,9 @@ class DBSCAN(Base): self.n_cols, self.eps, self.min_samples, - labels_ptr) + labels_ptr, + self.max_bytes_per_batch, + self.verbose) # make sure that the `dbscanFit` is complete before the following delete # call happens self.handle.sync() diff --git a/python/cuml/cluster/kmeans.pyx b/python/cuml/cluster/kmeans.pyx index 512c8319fe..97895ce513 100644 --- a/python/cuml/cluster/kmeans.pyx +++ b/python/cuml/cluster/kmeans.pyx @@ -31,62 +31,107 @@ from libcpp cimport bool from libc.stdint cimport uintptr_t from libc.stdlib cimport calloc, malloc, free -cdef extern from "kmeans/kmeans_c.h" namespace "ML": - - cdef void make_ptr_kmeans( - int dopredict, - int verbose, - int seed, - int gpu_id, - int n_gpu, - size_t mTrain, - size_t n, - const char ord, - int k, - int k_max, - int max_iterations, - int init_from_data, - float threshold, - const float *srcdata, - const float *centroids, - float *pred_centroids, - int *pred_labels - ) - - cdef void make_ptr_kmeans( - int dopredict, - int verbose, - int seed, - int gpu_id, - int n_gpu, - size_t mTrain, - size_t n, - const char ord, - int k, - int k_max, - int max_iterations, - int init_from_data, - double threshold, - const double *srcdata, - const double *centroids, - double *pred_centroids, - int *pred_labels - ) - - - cdef void kmeans_transform(int verbose, - int gpu_id, int n_gpu, - size_t m, size_t n, const char ord, int k, - const float *src_data, const float *centroids, - float *preds) - - cdef void kmeans_transform(int verbose, - int gpu_id, int n_gpu, - size_t m, size_t n, const char ord, int k, - const double *src_data, const double *centroids, - double *preds) - -class KMeans: +from cuml.common.base import Base +from cuml.common.handle cimport cumlHandle + +cdef extern from "kmeans/kmeans.hpp" namespace "ML::kmeans": + + enum InitMethod: + KMeansPlusPlus, Random, Array + + cdef void fit_predict(cumlHandle& handle, + int n_clusters, + int metric, + InitMethod init, + int max_iter, + double tol, + int seed, + const float *X, + int n_samples, + int n_features, + float *centroids, + int *labels, + int verbose) + + cdef void fit_predict(cumlHandle& handle, + int n_clusters, + int metric, + InitMethod init, + int max_iter, + double tol, + int seed, + const double *X, + int n_samples, + int n_features, + double *centroids, + int *labels, + int verbose); + + cdef void fit(cumlHandle& handle, + int n_clusters, + int metric, + InitMethod init, + int max_iter, + double tol, + int seed, + const float *X, + int n_samples, + int n_features, + float *centroids, + int verbose) + + cdef void fit(cumlHandle& handle, + int n_clusters, + int metric, + InitMethod init, + int max_iter, + double tol, + int seed, + const double *X, int n_samples, int n_features, + double *centroids, + int verbose) + + cdef void predict(cumlHandle& handle, + float *centroids, + int n_clusters, + const float *X, + int n_samples, + int n_features, + int metric, + int *labels, + int verbose) + + cdef void predict(cumlHandle& handle, + double *centroids, + int n_clusters, + const double *X, + int n_samples, + int n_features, + int metric, + int *labels, + int verbose) + + cdef void transform(cumlHandle& handle, + const float *centroids, + int n_clusters, + const float *X, + int n_samples, + int n_features, + int metric, + float *X_new, + int verbose) + + cdef void transform(cumlHandle& handle, + const double *centroids, + int n_clusters, + const double *X, + int n_samples, + int n_features, + int metric, + double *X_new, + int verbose) + +class KMeans(Base): """ KMeans is a basic but powerful clustering method which is optimized via Expectation Maximization. @@ -96,7 +141,7 @@ class KMeans: cuML's KMeans expects a cuDF DataFrame, and supports the fast KMeans++ intialization method. This method is more stable than randomnly selecting K points. - + Examples -------- @@ -150,19 +195,21 @@ class KMeans: labels: - 0 1 - 1 1 - 2 0 - 3 0 + 0 0 + 1 0 + 2 1 + 3 1 cluster_centers: 0 1 - 0 3.5 2.5 - 1 1.0 1.5 - + 0 1.0 1.5 + 1 3.5 2.5 + Parameters ---------- + handle : cuml.Handle + If it is None, a new one is created just for this class. n_clusters : int (default = 8) The number of centroids or clusters you want. max_iter : int (default = 300) @@ -175,16 +222,16 @@ class KMeans: If you want results to be the same when you restart Python, select a state. precompute_distances : boolean (default = 'auto') Not supported yet. - init : 'kmeans++' - Uses fast and stable kmeans++ intialization. More options will be supported later. + init : {'scalable-kmeans++', 'k-means||' , 'random' or an ndarray} (default = 'scalable-k-means++') + 'scalable-k-means++' or 'k-means||': Uses fast and stable scalable kmeans++ intialization. + 'random': Choose 'n_cluster' observations (rows) at random from data for the initial centroids. + If an ndarray is passed, it should be of shape (n_clusters, n_features) and gives the initial centers. n_init : int (default = 1) Number of times intialization is run. More is slower, but can be better. algorithm : "auto" Currently uses full EM, but will support others later. n_gpu : int (default = 1) - Number of GPUs to use. If -1 is used, uses all GPUs. More usage is faster. - gpu_id : int (default = 0) - Which GPU to use if n_gpu == 1. + Number of GPUs to use. Currently uses single GPU, but will support multiple GPUs later. Attributes @@ -192,25 +239,26 @@ class KMeans: cluster_centers_ : array The coordinates of the final clusters. This represents of "mean" of each data cluster. labels_ : array - Which cluster each datapoint belongs to. + Which cluster each datapoint belongs to. Notes ------ KMeans requires n_clusters to be specified. This means one needs to approximately guess or know - how many clusters a dataset has. If one is not sure, one can start with a small number of clusters, and + how many clusters a dataset has. If one is not sure, one can start with a small number of clusters, and visualize the resulting clusters with PCA, UMAP or T-SNE, and verify that they look appropriate. - + **Applications of KMeans** - + The biggest advantage of KMeans is its speed and simplicity. That is why KMeans is many practitioner's first choice of a clustering algorithm. KMeans has been extensively used when the number of clusters is approximately known, such as in big data clustering tasks, image segmentation and medical clustering. - - + + For additional docs, see `scikitlearn's Kmeans `_. """ - def __init__(self, n_clusters=8, max_iter=300, tol=1e-4, verbose=0, random_state=1, precompute_distances='auto', init='kmeans++', n_init=1, algorithm='auto', n_gpu=1, gpu_id=0): + def __init__(self, handle=None, n_clusters=8, max_iter=300, tol=1e-4, verbose=0, random_state=1, precompute_distances='auto', init='scalable-k-means++', n_init=1, algorithm='auto', n_gpu=1): + super(KMeans, self).__init__(handle, verbose); self.n_clusters = n_clusters self.verbose = verbose self.random_state = random_state @@ -225,9 +273,6 @@ class KMeans: self.labels_ = None self.cluster_centers_ = None self.n_gpu = n_gpu - self.gpu_id = gpu_id - warnings.warn("This version of K-means will be depricated in 0.7 for stability reasons. A new version based on our lessons learned will be available in 0.7. The current version will get no new bug fixes or improvements. The new version will follow the same API.", - FutureWarning) def _get_ctype_ptr(self, obj): # The manner to access the pointers in the gdf's might change, so @@ -268,53 +313,71 @@ class KMeans: input_ptr = self._get_ctype_ptr(X_m) + cdef cumlHandle* handle_ = self.handle.getHandle() self.labels_ = cudf.Series(np.zeros(self.n_rows, dtype=np.int32)) cdef uintptr_t labels_ptr = self._get_column_ptr(self.labels_) - self.cluster_centers_ = cuda.to_device(np.zeros(self.n_clusters* self.n_cols, dtype=self.gdf_datatype)) + if (isinstance(self.init, cudf.DataFrame)): + if(len(self.init) != self.n_clusters): + raise ValueError('The shape of the initial centers (%s) ' + 'does not match the number of clusters %i' + % (self.init.shape, self.n_clusters)) + init_value = Array + self.cluster_centers_ = cuda.device_array(self.n_clusters * self.n_cols, dtype=self.gdf_datatype) + self.cluster_centers_.copy_to_device(numba_utils.row_matrix(self.init)) + elif (isinstance(self.init, np.ndarray)): + if(self.init.shape[0] != self.n_clusters): + raise ValueError('The shape of the initial centers (%s) ' + 'does not match the number of clusters %i' + % (self.init.shape, self.n_clusters)) + init_value = Array; + self.cluster_centers_ = cuda.to_device(self.init.flatten()) + elif (self.init in ['scalable-k-means++', 'k-means||']): + init_value = KMeansPlusPlus + self.cluster_centers_ = cuda.to_device(np.zeros(self.n_clusters* self.n_cols, dtype=self.gdf_datatype)) + elif (self.init == 'random'): + init_value = Random + self.cluster_centers_ = cuda.to_device(np.zeros(self.n_clusters* self.n_cols, dtype=self.gdf_datatype)) + else: + raise TypeError('initialization method not supported') + cdef uintptr_t cluster_centers_ptr = self._get_ctype_ptr(self.cluster_centers_) + if self.gdf_datatype.type == np.float32: - make_ptr_kmeans( - 0, # dopredict - self.verbose, # verbose - self.random_state, # seed - self.gpu_id, # gpu_id - self.n_gpu, # n_gpu - self.n_rows, # mTrain (rows) - self.n_cols, # n (cols) - b'r', # ord - self.n_clusters, # k - self.n_clusters, # k_max - self.max_iter, # max_iterations - 1, # init_from_data TODO: can use kmeans++ - self.tol, # threshold - input_ptr, # srcdata - # ptr2, # srcdata - 0, # centroids - cluster_centers_ptr, # pred_centroids - # 0, # pred_centroids - labels_ptr) # pred_labels + fit_predict( + handle_[0], + self.n_clusters, # n_clusters + 0, # distance metric as squared L2: @todo - support other metrics + init_value, # init method + self.max_iter, # max_iterations + self.tol, # threshold + self.random_state, # seed + input_ptr, # srcdata + self.n_rows, # n_samples (rows) + self.n_cols, # n_features (cols) + cluster_centers_ptr, # pred_centroids); + labels_ptr, # pred_labels + self.verbose) + elif self.gdf_datatype.type == np.float64: + fit_predict( + handle_[0], + self.n_clusters, # n_clusters + 0, # distance metric as squared L2: @todo - support other metrics + init_value, # init method + self.max_iter, # max_iterations + self.tol, # threshold + self.random_state, # seed + input_ptr, # srcdata + self.n_rows, # n_samples (rows) + self.n_cols, # n_features (cols) + cluster_centers_ptr, # pred_centroids); + labels_ptr, # pred_labels + self.verbose) else: - make_ptr_kmeans( - 0, # dopredict - self.verbose, # verbose - self.random_state, # seed - self.gpu_id, # gpu_id - self.n_gpu, # n_gpu - self.n_rows, # mTrain (rows) - self.n_cols, # n (cols) - b'r', # ord - self.n_clusters, # k - self.n_clusters, # k_max - self.max_iter, # max_iterations - 1, # init_from_data TODO: can use kmeans++ - self.tol, # threshold - input_ptr, # srcdata - 0, # centroids - cluster_centers_ptr, # pred_centroids - labels_ptr) # pred_labels + raise TypeError("supports only float32 and float64 input, but input of type '%s' passed." % (str(self.gdf_datatype.type))) + self.handle.sync() cluster_centers_gdf = cudf.DataFrame() for i in range(0, self.n_cols): cluster_centers_gdf[str(i)] = self.cluster_centers_[i:self.n_clusters*self.n_cols:self.n_cols] @@ -366,6 +429,7 @@ class KMeans: input_ptr = self._get_ctype_ptr(X_m) + cdef cumlHandle* handle_ = self.handle.getHandle() clust_mat = numba_utils.row_matrix(self.cluster_centers_) cdef uintptr_t cluster_centers_ptr = self._get_ctype_ptr(clust_mat) @@ -373,46 +437,31 @@ class KMeans: cdef uintptr_t labels_ptr = self._get_column_ptr(self.labels_) if self.gdf_datatype.type == np.float32: - make_ptr_kmeans( - 1, # dopredict - self.verbose, # verbose - self.random_state, # seed - self.gpu_id, # gpu_id - self.n_gpu, # n_gpu - self.n_rows, # mTrain (rows) - self.n_cols, # n (cols) - b'r', # ord - self.n_clusters, # k - self.n_clusters, # k_max - self.max_iter, # max_iterations - 0, # init_from_data TODO: can use kmeans++ - self.tol, # threshold - # input_ptr, # srcdata - input_ptr, # srcdata - # ptr2, # srcdata - cluster_centers_ptr, # centroids - 0, # pred_centroids - labels_ptr) # pred_labels + predict( + handle_[0], + cluster_centers_ptr, # pred_centroids + self.n_clusters, # n_clusters + input_ptr, # srcdata + self.n_rows, # n_samples (rows) + self.n_cols, # n_features (cols) + 0, # distance metric as squared L2: @todo - support other metrics + labels_ptr, # pred_labels + self.verbose) + elif self.gdf_datatype.type == np.float64: + predict( + handle_[0], + cluster_centers_ptr, # pred_centroids + self.n_clusters, # n_clusters + input_ptr, # srcdata + self.n_rows, # n_samples (rows) + self.n_cols, # n_features (cols) + 0, # distance metric as squared L2: @todo - support other metrics + labels_ptr, # pred_labels + self.verbose) else: - make_ptr_kmeans( - 1, # dopredict - self.verbose, # verbose - self.random_state, # seed - self.gpu_id, # gpu_id - self.n_gpu, # n_gpu - self.n_rows, # mTrain (rows) - self.n_cols, # n (cols) - b'r', # ord - self.n_clusters, # k - self.n_clusters, # k_max - self.max_iter, # max_iterations - 0, # init_from_data TODO: can use kmeans++ - self.tol, # threshold - input_ptr, # srcdata - cluster_centers_ptr, # centroids - 0, # pred_centroids - labels_ptr) # pred_labels + raise TypeError("supports only float32 and float64 input, but input of type '%s' passed." % (str(self.gdf_datatype.type))) + self.handle.sync() del(X_m) del(clust_mat) return self.labels_ @@ -447,6 +496,7 @@ class KMeans: input_ptr = self._get_ctype_ptr(X_m) + cdef cumlHandle* handle_ = self.handle.getHandle() clust_mat = numba_utils.row_matrix(self.cluster_centers_) cdef uintptr_t cluster_centers_ptr = self._get_ctype_ptr(clust_mat) @@ -456,33 +506,34 @@ class KMeans: cdef uintptr_t preds_ptr = self._get_ctype_ptr(preds_data) if self.gdf_datatype.type == np.float32: - kmeans_transform( - self.verbose, # verbose - self.gpu_id, # gpu_id - self.n_gpu, # n_gpu - self.n_rows, # mTrain (rows) - self.n_cols, # n (cols) - b'r', # ord - self.n_clusters, # k - input_ptr, # srcdata - cluster_centers_ptr, # centroids - preds_ptr) # preds + transform( + handle_[0], + cluster_centers_ptr, # centroids + self.n_clusters, # n_clusters + input_ptr, # srcdata + self.n_rows, # n_samples (rows) + self.n_cols, # n_features (cols) + 1, # distance metric as L2-norm/euclidean distance: @todo - support other metrics + preds_ptr, # transformed output + self.verbose) + elif self.gdf_datatype.type == np.float64: + transform( + handle_[0], + cluster_centers_ptr, # centroids + self.n_clusters, # n_clusters + input_ptr, # srcdata + self.n_rows, # n_samples (rows) + self.n_cols, # n_features (cols) + 1, # distance metric as L2-norm/euclidean distance: @todo - support other metrics + preds_ptr, # transformed output + self.verbose) else: - kmeans_transform( - self.verbose, # verbose - self.gpu_id, # gpu_id - self.n_gpu, # n_gpu - self.n_rows, # mTrain (rows) - self.n_cols, # n (cols) - b'r', # ord - self.n_clusters, # k - input_ptr, # srcdata - cluster_centers_ptr, # centroids - preds_ptr) # preds + raise TypeError("supports only float32 and float64 input, but input of type '%s' passed." % (str(self.gdf_datatype.type))) + self.handle.sync() preds_gdf = cudf.DataFrame() for i in range(0, self.n_clusters): - preds_gdf[str(i)] = preds_data[i*self.n_rows:(i+1)*self.n_rows] + preds_gdf[str(i)] = preds_data[i:self.n_rows * self.n_clusters:self.n_clusters] del(X_m) del(clust_mat) diff --git a/python/cuml/linear_model/ridge.pyx b/python/cuml/linear_model/ridge.pyx index fa5ae1bb18..e44d28ce33 100644 --- a/python/cuml/linear_model/ridge.pyx +++ b/python/cuml/linear_model/ridge.pyx @@ -29,6 +29,8 @@ from libcpp cimport bool from libc.stdint cimport uintptr_t from libc.stdlib cimport calloc, malloc, free +from cuml.metrics.base import RegressorMixin + cdef extern from "glm/glm.hpp" namespace "ML::GLM": @@ -71,7 +73,7 @@ cdef extern from "glm/glm.hpp" namespace "ML::GLM": double *preds) -class Ridge: +class Ridge(RegressorMixin): """ Ridge extends LinearRegression by providing L2 regularization on the coefficients when @@ -155,15 +157,15 @@ class Ridge: The estimated coefficients for the linear regression model. intercept_ : array The independent term. If fit_intercept_ is False, will be 0. - + Notes ------ Ridge provides L2 regularization. This means that the coefficients can shrink to become very very small, but not zero. This can cause issues of interpretabiliy on the coefficients. Consider using Lasso, or thresholding small coefficients to zero. - + **Applications of Ridge** - + Ridge Regression is used in the same way as LinearRegression, but is used more frequently as it does not suffer from multicollinearity issues. Ridge is used in insurance premium prediction, stock market analysis and much more. diff --git a/python/cuml/metrics/__init__.py b/python/cuml/metrics/__init__.py new file mode 100644 index 0000000000..c2c1f4f134 --- /dev/null +++ b/python/cuml/metrics/__init__.py @@ -0,0 +1,17 @@ +# +# Copyright (c) 2019, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from cuml.metrics.regression import r2_score diff --git a/python/cuml/metrics/base.py b/python/cuml/metrics/base.py new file mode 100644 index 0000000000..a31a3b0b30 --- /dev/null +++ b/python/cuml/metrics/base.py @@ -0,0 +1,40 @@ +# +# Copyright (c) 2018-2019, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +class RegressorMixin: + """Mixin class for regression estimators in""" + _estimator_type = "regressor" + + def score(self, X, y, **kwargs): + """Scoring function for linear classifiers + + Returns the coefficient of determination R^2 of the prediction. + + Parameters + ---------- + X : [cudf.DataFrame] + Test samples on which we predict + y : [cudf.Series] + True values for predict(X) + + Returns + ------- + score : float + R^2 of self.predict(X) wrt. y. + """ + from cuml.metrics.regression import r2_score + return r2_score(y.to_gpu_array(), self.predict(X).to_gpu_array()) diff --git a/python/cuml/metrics/regression.pxd b/python/cuml/metrics/regression.pxd new file mode 100644 index 0000000000..41aa7092de --- /dev/null +++ b/python/cuml/metrics/regression.pxd @@ -0,0 +1,34 @@ +# +# Copyright (c) 2019, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +from cuml.common.handle cimport cumlHandle + +cdef extern from "metrics/metrics.hpp" namespace "ML::Metrics": + + float r2_score_py(const cumlHandle& handle, + float *y, + float *y_hat, + int n) + + double r2_score_py(const cumlHandle& handle, + double *y, + double *y_hat, + int n) diff --git a/python/cuml/metrics/regression.pyx b/python/cuml/metrics/regression.pyx new file mode 100644 index 0000000000..69165f469f --- /dev/null +++ b/python/cuml/metrics/regression.pyx @@ -0,0 +1,65 @@ +# +# Copyright (c) 2019, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +from cuml.common.handle cimport cumlHandle +import cuml.common.handle +from libc.stdint cimport uintptr_t + +from cuml.metrics cimport regression + + +def r2_score(y, y_hat, handle=None): + handle = cuml.common.handle.Handle() if handle is None else handle + cdef cumlHandle* handle_ = handle.getHandle() + + cdef uintptr_t y_ptr = y.device_ctypes_pointer.value + cdef uintptr_t y_hat_ptr = y_hat.device_ctypes_pointer.value + + cdef float result_f32 + cdef double result_f64 + + n = len(y) + + if y.dtype == 'float32': + + result_f32 = regression.r2_score_py(handle_[0], + y_ptr, + y_hat_ptr, + n) + + result = result_f32 + + else: + result_f64 = regression.r2_score_py(handle_[0], + y_ptr, + y_hat_ptr, + n) + + result = result_f64 + + + return result + + + + + + diff --git a/python/cuml/test/test_dbscan.py b/python/cuml/test/test_dbscan.py index 3128c0da79..ee428d7770 100644 --- a/python/cuml/test/test_dbscan.py +++ b/python/cuml/test/test_dbscan.py @@ -28,10 +28,15 @@ 'noisy_circles', 'no_structure'] +@pytest.mark.parametrize('max_bytes_per_batch', [10, 200, 2e6]) @pytest.mark.parametrize('datatype', [np.float32, np.float64]) @pytest.mark.parametrize('input_type', ['dataframe', 'ndarray']) -def test_dbscan_predict(datatype, input_type, run_stress, - run_quality): + +@pytest.mark.parametrize('use_handle', [True, False]) +def test_dbscan_predict(datatype, input_type, use_handle, max_bytes_per_batch + run_stress, run_quality): + + # max_bytes_per_batch sizes: 10=6 batches, 200=2 batches, 2e6=1 batch n_samples = 10000 n_feats = 50 if run_stress: @@ -40,13 +45,18 @@ def test_dbscan_predict(datatype, input_type, run_stress, elif run_quality: X, y = make_blobs(n_samples=n_samples, n_features=n_feats, random_state=0) - + else: X = np.array([[1, 2], [2, 2], [2, 3], [8, 7], [8, 8], [25, 80]], dtype=datatype) skdbscan = skDBSCAN(eps=3, min_samples=10) sk_labels = skdbscan.fit_predict(X) - cudbscan = cuDBSCAN(eps=3, min_samples=10) + + + handle, stream = get_handle(use_handle) + cudbscan = cuDBSCAN(handle=handle, eps=3, min_samples=2, + max_bytes_per_batch=max_bytes_per_batch) + if input_type == 'dataframe': X = pd.DataFrame( {'fea%d' % i: X[0:, i] for i in range(X.shape[1])}) diff --git a/python/cuml/test/test_kmeans.py b/python/cuml/test/test_kmeans.py index ee6943d447..7d8d0c4745 100644 --- a/python/cuml/test/test_kmeans.py +++ b/python/cuml/test/test_kmeans.py @@ -20,7 +20,10 @@ from sklearn.preprocessing import StandardScaler from cuml.test.utils import fit_predict, get_pattern, clusters_equal -dataset_names = ['noisy_moons', 'varied', 'aniso', 'noisy_circles', 'blobs'] + +dataset_names = ['blobs', 'noisy_circles'] + \ + [pytest.param(ds, marks=pytest.mark.xfail) + for ds in ['noisy_moons', 'varied', 'aniso']] @pytest.mark.parametrize('name', dataset_names) @@ -31,19 +34,9 @@ def test_kmeans_sklearn_comparison(name, run_stress, run_quality): 'damping': .9, 'preference': -200, 'n_neighbors': 10, - 'n_clusters': 20} - n_samples = 10000 - if run_stress: - pat = get_pattern(name, n_samples*50) - params = default_base.copy() - params.update(pat[1]) - X, y = pat[0] + 'n_clusters': 3} - elif run_quality: - pat = get_pattern(name, n_samples) - params = default_base.copy() - params.update(pat[1]) - X, y = pat[0] + pat = get_pattern(name, 10000) else: pat = get_pattern(name, np.int32(n_samples/2)) @@ -68,7 +61,7 @@ def test_kmeans_sklearn_comparison(name, run_stress, run_quality): clustering_algorithms[1][0], X) if name == 'noisy_circles': - assert (np.sum(sk_y_pred) - np.sum(cu_y_pred))/len(sk_y_pred) < 1e-10 + assert (np.sum(sk_y_pred) - np.sum(cu_y_pred))/len(sk_y_pred) < 2e-3 else: - clusters_equal(sk_y_pred, cu_y_pred, params['n_clusters']) + assert clusters_equal(sk_y_pred, cu_y_pred, params['n_clusters']) diff --git a/python/cuml/test/test_metrics.py b/python/cuml/test/test_metrics.py new file mode 100644 index 0000000000..25417f61db --- /dev/null +++ b/python/cuml/test/test_metrics.py @@ -0,0 +1,74 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import cuml +import numpy as np +import pytest + +from cuml.test.utils import get_handle + +from numba import cuda + + +@pytest.mark.parametrize('datatype', [np.float32, np.float64]) +@pytest.mark.parametrize('use_handle', [True, False]) +def test_r2_score(datatype, use_handle): + a = np.array([0.1, 0.2, 0.3, 0.4, 0.5], dtype=datatype) + b = np.array([0.12, 0.22, 0.32, 0.42, 0.52], dtype=datatype) + + a_dev = cuda.to_device(a) + b_dev = cuda.to_device(b) + + handle, stream = get_handle(use_handle) + + score = cuml.metrics.r2_score(a_dev, b_dev, handle=handle) + + np.testing.assert_almost_equal(score, 0.98, decimal=7) + + +@pytest.mark.skip(reason="Debugging NN test core dump") +def test_sklearn_search(): + """Test ensures scoring function works with sklearn machinery + """ + import numpy as np + from cuml import Ridge as cumlRidge + import cudf + from sklearn import datasets + from sklearn.model_selection import train_test_split, GridSearchCV + diabetes = datasets.load_diabetes() + X_train, X_test, y_train, y_test = train_test_split(diabetes.data, + diabetes.target, + test_size=0.2, + shuffle=False, + random_state=1) + + alpha = np.array([1.0]) + fit_intercept = True + normalize = False + + params = {'alpha': np.logspace(-3, -1, 10)} + cu_clf = cumlRidge(alpha=alpha, fit_intercept=fit_intercept, + normalize=normalize, solver="eig") + + assert getattr(cu_clf, 'score', False) + sk_cu_grid = GridSearchCV(cu_clf, params, cv=5, iid=False) + + record_data = (('fea%d' % i, X_train[:, i]) for i in + range(X_train.shape[1])) + gdf_data = cudf.DataFrame(record_data) + gdf_train = cudf.DataFrame(dict(train=y_train)) + + sk_cu_grid.fit(gdf_data, gdf_train.train) + assert sk_cu_grid.best_params_ == {'alpha': 0.1} diff --git a/python/cuml/test/test_umap.py b/python/cuml/test/test_umap.py index fa7f25114b..7fb0c5ee67 100644 --- a/python/cuml/test/test_umap.py +++ b/python/cuml/test/test_umap.py @@ -34,6 +34,15 @@ dataset_names = ['iris', 'digits', 'wine', 'blobs'] +def test_blobs_cluster(): + data, labels = datasets.make_blobs( + n_samples=500, n_features=10, centers=5) + embedding = UMAP(verbose=True).fit_transform(data) + score = adjusted_rand_score(labels, + KMeans(5).fit_predict(embedding)) + assert score == 1.0 + + def test_umap_fit_transform_score(run_stress, run_quality): if run_stress: @@ -68,9 +77,8 @@ def test_umap_fit_transform_score(run_stress, run_quality): def test_supervised_umap_trustworthiness_on_iris(): iris = datasets.load_iris() data = iris.data - embedding = UMAP(n_neighbors=10, min_dist=0.01).fit_transform( - data, iris.target - ) + embedding = UMAP(n_neighbors=10, min_dist=0.01, + verbose=True).fit_transform(data, iris.target) trust = trustworthiness(iris.data, embedding, 10) assert trust >= 0.97 @@ -80,9 +88,8 @@ def test_semisupervised_umap_trustworthiness_on_iris(): data = iris.data target = iris.target.copy() target[25:75] = -1 - embedding = UMAP(n_neighbors=10, min_dist=0.01).fit_transform( - data, target - ) + embedding = UMAP(n_neighbors=10, min_dist=0.01, + verbose=True).fit_transform(data, target) trust = trustworthiness(iris.data, embedding, 10) assert trust >= 0.97 @@ -90,7 +97,8 @@ def test_semisupervised_umap_trustworthiness_on_iris(): def test_umap_trustworthiness_on_iris(): iris = datasets.load_iris() data = iris.data - embedding = UMAP(n_neighbors=10, min_dist=0.01).fit_transform(data) + embedding = UMAP(n_neighbors=10, min_dist=0.01, + verbose=True).fit_transform(data) trust = trustworthiness(iris.data, embedding, 10) @@ -141,14 +149,13 @@ def test_umap_data_formats(input_type, should_downcast, elif run_quality: X, y = datasets.make_blobs(n_samples=int(n_samples/10), n_features=n_feats, random_state=0) - else: # For now, FAISS based nearest_neighbors only supports single precision digits = datasets.load_digits(n_class=9) X = digits["data"].astype(dtype) umap = UMAP_cuml(n_neighbors=3, n_components=2, - should_downcast=should_downcast) + should_downcast=should_downcast, verbose=True) if input_type == 'dataframe': X_pd = pd.DataFrame( @@ -179,7 +186,8 @@ def test_umap_downcast_fails(input_type, run_stress, run_quality): dtype=np.float64) # Test fit() fails with double precision when should_downcast set to False - umap = UMAP_cuml(should_downcast=False) + umap = UMAP_cuml(should_downcast=False, verbose=True) + if input_type == 'dataframe': X = cudf.DataFrame.from_pandas(pd.DataFrame(X))