From 53e0683155a4b0707fc36c104168ada5c12da2e4 Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Fri, 16 Apr 2021 23:29:06 -0700 Subject: [PATCH 01/30] try 1 --- cpp/src/fil/common.cuh | 4 ++-- cpp/src/fil/fil.cu | 5 +---- cpp/src/fil/infer.cu | 42 ++++++++++++++++++++++++++++++++++++----- cpp/test/sg/fil_test.cu | 2 ++ 4 files changed, 42 insertions(+), 11 deletions(-) diff --git a/cpp/src/fil/common.cuh b/cpp/src/fil/common.cuh index 0beca695fc..3a487f050a 100644 --- a/cpp/src/fil/common.cuh +++ b/cpp/src/fil/common.cuh @@ -131,9 +131,9 @@ struct shmem_size_params { } void compute_smem_footprint(); template - size_t get_smem_footprint(); + int get_smem_footprint(); template - size_t get_smem_footprint(); + int get_smem_footprint(); }; // predict_params are parameters for prediction diff --git a/cpp/src/fil/fil.cu b/cpp/src/fil/fil.cu index 4db300d6fb..639509a59d 100644 --- a/cpp/src/fil/fil.cu +++ b/cpp/src/fil/fil.cu @@ -72,13 +72,10 @@ __global__ void transform_k(float* preds, size_t n, output_t output, struct forest { void init_n_items(int device) { - int max_shm_std = 48 * 1024; // 48 KiB /// the most shared memory a kernel can request on the GPU in question int max_shm = 0; CUDA_CHECK(cudaDeviceGetAttribute( &max_shm, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); - // TODO(canonizer): use >48KiB shared memory if available - max_shm = std::min(max_shm, max_shm_std); // searching for the most items per block while respecting the shared // memory limits creates a full linear programming problem. @@ -93,7 +90,7 @@ struct forest { ssp.n_items <= (algo_ == algo_t::BATCH_TREE_REORG ? 4 : 1); ++ssp.n_items) { ssp.compute_smem_footprint(); - if (ssp.shm_sz < max_shm) ssp_ = ssp; + if (ssp.shm_sz <= max_shm) ssp_ = ssp; } } ASSERT(max_shm >= ssp_.shm_sz, diff --git a/cpp/src/fil/infer.cu b/cpp/src/fil/infer.cu index edf8337a1e..019b4f6c29 100644 --- a/cpp/src/fil/infer.cu +++ b/cpp/src/fil/infer.cu @@ -576,21 +576,53 @@ __global__ void infer_k(storage_type forest, predict_params params) { } } +void set_carveout(void* kernel, int footprint, int max_shm) { + CUDA_CHECK( + cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, + // footprint in % of max_shm, rounding up + (100 * footprint + max_shm - 1) / max_shm)); + CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, footprint)); +} + +template +void set_carveouts(int footprint) { + int device = 0; + CUDA_CHECK(cudaGetDevice(&device)); + int max_shm = 0; + CUDA_CHECK(cudaDeviceGetAttribute( + &max_shm, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); + if (footprint > max_shm) return; + + set_carveout((void*)infer_k, footprint, max_shm); + set_carveout((void*)infer_k, footprint, max_shm); + set_carveout((void*)infer_k, footprint, max_shm); +} + template -size_t shmem_size_params::get_smem_footprint() { - size_t finalize_footprint = +int shmem_size_params::get_smem_footprint() { + int finalize_footprint = tree_aggregator_t::smem_finalize_footprint( cols_shmem_size(), num_classes, predict_proba); - size_t accumulate_footprint = + int accumulate_footprint = tree_aggregator_t::smem_accumulate_footprint( num_classes) + cols_shmem_size(); - return std::max(accumulate_footprint, finalize_footprint); + int footprint = std::max(accumulate_footprint, finalize_footprint); + int max_shm_std = 48 * 1024; // 48 KiB available on any architecture + if (footprint > max_shm_std) { + // for no cols_in_shmem, it is a matter of supporting this config at all + set_carveouts(footprint); + // for cols_in_shmem, it will accelerate performance + set_carveouts(footprint); + // This much may not suffice, in which case set_carveouts will do nothing. + } + return footprint; } template -size_t shmem_size_params::get_smem_footprint() { +int shmem_size_params::get_smem_footprint() { switch (leaf_algo) { case FLOAT_UNARY_BINARY: return get_smem_footprint(); diff --git a/cpp/test/sg/fil_test.cu b/cpp/test/sg/fil_test.cu index 1c32e3224e..215deacd77 100644 --- a/cpp/test/sg/fil_test.cu +++ b/cpp/test/sg/fil_test.cu @@ -749,6 +749,8 @@ std::vector predict_dense_inputs = { FIL_TEST_PARAMS(num_rows = 103, num_cols = 100'000, depth = 5, num_trees = 1, algo = BATCH_TREE_REORG, leaf_algo = CATEGORICAL_LEAF, num_classes = 3), + // use shared memory opt-in carveout if available, or infer out of L1 cache + FIL_TEST_PARAMS(num_cols = ((48 + 1) * 1024) / sizeof(float), algo = NAIVE), }; TEST_P(PredictDenseFilTest, Predict) { compare(); } From 31b885f7c3244456a778ff91ab0b2c2ca1c837ba Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Tue, 25 May 2021 16:15:42 -0700 Subject: [PATCH 02/30] draft of set-and-launch --- cpp/src/fil/fil.cu | 9 ++++++++ cpp/src/fil/infer.cu | 49 +++++++++++--------------------------------- 2 files changed, 21 insertions(+), 37 deletions(-) diff --git a/cpp/src/fil/fil.cu b/cpp/src/fil/fil.cu index 639509a59d..47c2701e8c 100644 --- a/cpp/src/fil/fil.cu +++ b/cpp/src/fil/fil.cu @@ -112,6 +112,11 @@ struct forest { fixed_block_count_ = blocks_per_sm * sm_count; } + void init_max_shm(int device) { + CUDA_CHECK(cudaDeviceGetAttribute( + &max_shm_, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); + } + void init_common(const raft::handle_t& h, const forest_params_t* params) { depth_ = params->depth; num_trees_ = params->num_trees; @@ -127,6 +132,7 @@ struct forest { int device = h.get_device(); init_n_items(device); // n_items takes priority over blocks_per_sm init_fixed_block_count(device, params->blocks_per_sm); + init_max_shm(device); } virtual void infer(predict_params params, cudaStream_t stream) = 0; @@ -250,6 +256,8 @@ struct forest { } } + int max_shm() { return max_shm_; } + virtual void free(const raft::handle_t& h) = 0; virtual ~forest() {} @@ -261,6 +269,7 @@ struct forest { float global_bias_ = 0; shmem_size_params class_ssp_, proba_ssp_; int fixed_block_count_ = 0; + int max_shm_ = 0; }; struct dense_forest : forest { diff --git a/cpp/src/fil/infer.cu b/cpp/src/fil/infer.cu index 019b4f6c29..f00665bb08 100644 --- a/cpp/src/fil/infer.cu +++ b/cpp/src/fil/infer.cu @@ -585,20 +585,6 @@ void set_carveout(void* kernel, int footprint, int max_shm) { kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, footprint)); } -template -void set_carveouts(int footprint) { - int device = 0; - CUDA_CHECK(cudaGetDevice(&device)); - int max_shm = 0; - CUDA_CHECK(cudaDeviceGetAttribute( - &max_shm, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); - if (footprint > max_shm) return; - - set_carveout((void*)infer_k, footprint, max_shm); - set_carveout((void*)infer_k, footprint, max_shm); - set_carveout((void*)infer_k, footprint, max_shm); -} - template int shmem_size_params::get_smem_footprint() { int finalize_footprint = @@ -608,17 +594,7 @@ int shmem_size_params::get_smem_footprint() { tree_aggregator_t::smem_accumulate_footprint( num_classes) + cols_shmem_size(); - - int footprint = std::max(accumulate_footprint, finalize_footprint); - int max_shm_std = 48 * 1024; // 48 KiB available on any architecture - if (footprint > max_shm_std) { - // for no cols_in_shmem, it is a matter of supporting this config at all - set_carveouts(footprint); - // for cols_in_shmem, it will accelerate performance - set_carveouts(footprint); - // This much may not suffice, in which case set_carveouts will do nothing. - } - return footprint; + return std::max(accumulate_footprint, finalize_footprint); } template @@ -659,30 +635,29 @@ void shmem_size_params::compute_smem_footprint() { template void infer_k_nitems_launcher(storage_type forest, predict_params params, cudaStream_t stream, int block_dim_x) { + void (*kernel)(storage_type, predict_params); switch (params.n_items) { case 1: - infer_k<1, leaf_algo, cols_in_shmem> - <<>>(forest, - params); + kernel = infer_k<1, leaf_algo, cols_in_shmem>; break; case 2: - infer_k<2, leaf_algo, cols_in_shmem> - <<>>(forest, - params); + kernel = infer_k<2, leaf_algo, cols_in_shmem>; break; case 3: - infer_k<3, leaf_algo, cols_in_shmem> - <<>>(forest, - params); + kernel = infer_k<3, leaf_algo, cols_in_shmem>; break; case 4: - infer_k<4, leaf_algo, cols_in_shmem> - <<>>(forest, - params); + kernel = infer_k<4, leaf_algo, cols_in_shmem>; break; default: ASSERT(false, "internal error: nitems > 4"); } + // Two forests might be using the same handle, so + // large batch will run fastest if we set just before launching. + // This will not cause a race condition between setting and launching. + set_carveout((void*)kernel, params.shm_sz, forest.max_shm()); + kernel<<>>(forest, + params); CUDA_CHECK(cudaPeekAtLastError()); } From 26480b07b6f6e468d7fce6c92f9e4da189df5927 Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Tue, 25 May 2021 19:03:40 -0700 Subject: [PATCH 03/30] set carveout and occupancy-affecting preferred cache config before every inference --- cpp/src/fil/common.cuh | 2 ++ cpp/src/fil/fil.cu | 20 ++++++----------- cpp/src/fil/infer.cu | 49 +++++++++++++++++++++++------------------- 3 files changed, 36 insertions(+), 35 deletions(-) diff --git a/cpp/src/fil/common.cuh b/cpp/src/fil/common.cuh index 3a487f050a..61ebb048e2 100644 --- a/cpp/src/fil/common.cuh +++ b/cpp/src/fil/common.cuh @@ -123,6 +123,8 @@ struct shmem_size_params { bool cols_in_shmem = true; /// n_items is the most items per thread that fit into shared memory int n_items = 0; + /// max_shm is the maximum opt-in shared memory on the device + int max_shm = 0; /// shm_sz is the associated shared memory footprint int shm_sz = INT_MAX; diff --git a/cpp/src/fil/fil.cu b/cpp/src/fil/fil.cu index aa626eb3eb..a48fa34953 100644 --- a/cpp/src/fil/fil.cu +++ b/cpp/src/fil/fil.cu @@ -76,11 +76,6 @@ __global__ void transform_k(float* preds, size_t n, output_t output, struct forest { void init_n_items(int device) { - /// the most shared memory a kernel can request on the GPU in question - int max_shm = 0; - CUDA_CHECK(cudaDeviceGetAttribute( - &max_shm, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); - // searching for the most items per block while respecting the shared // memory limits creates a full linear programming problem. // solving it in a single equation looks less tractable than this @@ -94,10 +89,10 @@ struct forest { ssp.n_items <= (algo_ == algo_t::BATCH_TREE_REORG ? 4 : 1); ++ssp.n_items) { ssp.compute_smem_footprint(); - if (ssp.shm_sz <= max_shm) ssp_ = ssp; + if (ssp.shm_sz <= ssp.max_shm) ssp_ = ssp; } } - ASSERT(max_shm >= ssp_.shm_sz, + ASSERT(ssp_.max_shm >= ssp_.shm_sz, "FIL out of shared memory. Perhaps the maximum number of \n" "supported classes is exceeded? 5'000 would still be safe."); } @@ -116,12 +111,13 @@ struct forest { fixed_block_count_ = blocks_per_sm * sm_count; } - void init_max_shm(int device) { - CUDA_CHECK(cudaDeviceGetAttribute( - &max_shm_, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); - } + void init_max_shm(int device) {} void init_common(const raft::handle_t& h, const forest_params_t* params) { + int device = h.get_device(); + CUDA_CHECK(cudaDeviceGetAttribute( + &proba_ssp_.max_shm, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); + depth_ = params->depth; num_trees_ = params->num_trees; algo_ = params->algo; @@ -133,10 +129,8 @@ struct forest { proba_ssp_.num_classes = params->num_classes; class_ssp_ = proba_ssp_; - int device = h.get_device(); init_n_items(device); // n_items takes priority over blocks_per_sm init_fixed_block_count(device, params->blocks_per_sm); - init_max_shm(device); } virtual void infer(predict_params params, cudaStream_t stream) = 0; diff --git a/cpp/src/fil/infer.cu b/cpp/src/fil/infer.cu index 2c401a865e..d4015fe00c 100644 --- a/cpp/src/fil/infer.cu +++ b/cpp/src/fil/infer.cu @@ -16,6 +16,7 @@ #include #include +#include #include #include @@ -24,6 +25,8 @@ namespace ML { namespace fil { +std::mutex shmem_carveout_mutex; + // vec wraps float[N] for cub::BlockReduce template struct vec; @@ -575,11 +578,13 @@ __global__ void infer_k(storage_type forest, predict_params params) { } void set_carveout(void* kernel, int footprint, int max_shm) { - CUDA_CHECK( + // ensure optimal occupancy in case default allows less blocks/SM + CUDA_CHECK_NO_THROW( cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, // footprint in % of max_shm, rounding up (100 * footprint + max_shm - 1) / max_shm)); - CUDA_CHECK(cudaFuncSetAttribute( + // even if the footprint < 48'000, ensure that we reset after previous forest + CUDA_CHECK_NO_THROW(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, footprint)); } @@ -633,30 +638,30 @@ void shmem_size_params::compute_smem_footprint() { template void infer_k_nitems_launcher(storage_type forest, predict_params params, cudaStream_t stream, int block_dim_x) { - void (*kernel)(storage_type, predict_params); - switch (params.n_items) { - case 1: - kernel = infer_k<1, leaf_algo, cols_in_shmem>; - break; - case 2: - kernel = infer_k<2, leaf_algo, cols_in_shmem>; - break; - case 3: - kernel = infer_k<3, leaf_algo, cols_in_shmem>; - break; - case 4: - kernel = infer_k<4, leaf_algo, cols_in_shmem>; - break; - default: - ASSERT(false, "internal error: nitems > 4"); - } + void (*kernels[])(storage_type, predict_params) = { + nullptr, + infer_k<1, leaf_algo, cols_in_shmem, storage_type>, + infer_k<2, leaf_algo, cols_in_shmem, storage_type>, + infer_k<3, leaf_algo, cols_in_shmem, storage_type>, + infer_k<4, leaf_algo, cols_in_shmem, storage_type>, + }; + ASSERT(params.n_items <= 4, "internal error: nitems > 4"); + void (*kernel)(storage_type, predict_params) = kernels[params.n_items]; // Two forests might be using the same handle, so // large batch will run fastest if we set just before launching. - // This will not cause a race condition between setting and launching. - set_carveout((void*)kernel, params.shm_sz, forest.max_shm()); + // This will not cause a race condition between setting and launching despite + // CPU-GPU asynchronicity. + shmem_carveout_mutex.lock(); + set_carveout((void*)kernel, params.shm_sz, params.max_shm); kernel<<>>(forest, params); - CUDA_CHECK(cudaPeekAtLastError()); + CUDA_CHECK_NO_THROW(cudaPeekAtLastError()); + shmem_carveout_mutex.unlock(); // a CUDA error should not hang other threads + if (cudaPeekAtLastError() != cudaSuccess) { + // a wrong thread might throw, it's OK + throw raft::cuda_error( + "CUDA error in ML::fil::predict() (see stdout for details)"); + } } template From a037f169da84d08cf8475430f726177d1d24d427 Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Tue, 25 May 2021 19:13:01 -0700 Subject: [PATCH 04/30] other review comments --- cpp/src/fil/common.cuh | 6 +++--- cpp/src/fil/infer.cu | 12 ++++++------ cpp/test/sg/fil_test.cu | 3 ++- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/cpp/src/fil/common.cuh b/cpp/src/fil/common.cuh index 61ebb048e2..53ebe7be30 100644 --- a/cpp/src/fil/common.cuh +++ b/cpp/src/fil/common.cuh @@ -126,16 +126,16 @@ struct shmem_size_params { /// max_shm is the maximum opt-in shared memory on the device int max_shm = 0; /// shm_sz is the associated shared memory footprint - int shm_sz = INT_MAX; + size_t shm_sz = INT_MAX; __host__ __device__ size_t cols_shmem_size() { return cols_in_shmem ? sizeof(float) * num_cols * n_items : 0; } void compute_smem_footprint(); template - int get_smem_footprint(); + size_t get_smem_footprint(); template - int get_smem_footprint(); + size_t get_smem_footprint(); }; // predict_params are parameters for prediction diff --git a/cpp/src/fil/infer.cu b/cpp/src/fil/infer.cu index d4015fe00c..5ef32f19b5 100644 --- a/cpp/src/fil/infer.cu +++ b/cpp/src/fil/infer.cu @@ -577,23 +577,23 @@ __global__ void infer_k(storage_type forest, predict_params params) { } } -void set_carveout(void* kernel, int footprint, int max_shm) { +void set_carveout(void* kernel, size_t footprint, int max_shm) { // ensure optimal occupancy in case default allows less blocks/SM CUDA_CHECK_NO_THROW( cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, // footprint in % of max_shm, rounding up (100 * footprint + max_shm - 1) / max_shm)); - // even if the footprint < 48'000, ensure that we reset after previous forest + // even if the footprint < 48 * 1024, ensure that we reset after previous forest CUDA_CHECK_NO_THROW(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, footprint)); } template -int shmem_size_params::get_smem_footprint() { - int finalize_footprint = +size_t shmem_size_params::get_smem_footprint() { + size_t finalize_footprint = tree_aggregator_t::smem_finalize_footprint( cols_shmem_size(), num_classes, predict_proba); - int accumulate_footprint = + size_t accumulate_footprint = tree_aggregator_t::smem_accumulate_footprint( num_classes) + cols_shmem_size(); @@ -601,7 +601,7 @@ int shmem_size_params::get_smem_footprint() { } template -int shmem_size_params::get_smem_footprint() { +size_t shmem_size_params::get_smem_footprint() { switch (leaf_algo) { case FLOAT_UNARY_BINARY: return get_smem_footprint(); diff --git a/cpp/test/sg/fil_test.cu b/cpp/test/sg/fil_test.cu index 8d8b1a41c9..63c2eaeb32 100644 --- a/cpp/test/sg/fil_test.cu +++ b/cpp/test/sg/fil_test.cu @@ -763,7 +763,8 @@ std::vector predict_dense_inputs = { algo = BATCH_TREE_REORG, leaf_algo = CATEGORICAL_LEAF, num_classes = 3), // use shared memory opt-in carveout if available, or infer out of L1 cache - FIL_TEST_PARAMS(num_cols = ((48 + 1) * 1024) / sizeof(float), algo = NAIVE), + FIL_TEST_PARAMS(num_rows = 103, num_cols = ((48 + 1) * 1024) / sizeof(float), + algo = NAIVE), }; TEST_P(PredictDenseFilTest, Predict) { compare(); } From 2a1d622105aa094f9614bba1c7a585a4e890f1b7 Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Fri, 11 Jun 2021 23:30:59 -0700 Subject: [PATCH 05/30] DRY: rewrote in terms of dispatch_on_FIL_template_params(predict_params, ...) --- cpp/src/fil/common.cuh | 89 ++++++++++++++++++++++++++- cpp/src/fil/fil.cu | 36 ++++++++++- cpp/src/fil/infer.cu | 127 +++++---------------------------------- cpp/src/fil/internal.cuh | 6 ++ 4 files changed, 140 insertions(+), 118 deletions(-) diff --git a/cpp/src/fil/common.cuh b/cpp/src/fil/common.cuh index 53ebe7be30..65e66055ba 100644 --- a/cpp/src/fil/common.cuh +++ b/cpp/src/fil/common.cuh @@ -125,15 +125,14 @@ struct shmem_size_params { int n_items = 0; /// max_shm is the maximum opt-in shared memory on the device int max_shm = 0; + // blockdim_x is the CUDA block size + int blockdim_x = 0; /// shm_sz is the associated shared memory footprint size_t shm_sz = INT_MAX; __host__ __device__ size_t cols_shmem_size() { return cols_in_shmem ? sizeof(float) * num_cols * n_items : 0; } - void compute_smem_footprint(); - template - size_t get_smem_footprint(); template size_t get_smem_footprint(); }; @@ -158,6 +157,90 @@ struct predict_params : shmem_size_params { int num_blocks; }; +namespace dispatch { + +template