Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify template parameter dispatch for FIL inference and shared memory footprint estimation #4013

Merged
merged 38 commits into from
Oct 27, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
53e0683
try 1
levsnv Apr 17, 2021
31b885f
draft of set-and-launch
levsnv May 25, 2021
589f9ad
Merge remote-tracking branch 'rapidsai/branch-21.06' into extra-share…
levsnv May 25, 2021
26480b0
set carveout and occupancy-affecting preferred cache config before ev…
levsnv May 26, 2021
a037f16
other review comments
levsnv May 26, 2021
2a1d622
DRY: rewrote in terms of dispatch_on_FIL_template_params<func, storag…
levsnv Jun 12, 2021
bd3c505
Merge branch 'branch-21.08' of github.com:rapidsai/cuml into extra-sh…
levsnv Jun 12, 2021
5cf38d3
style, clean up diff
levsnv Jun 12, 2021
7959f4d
Merge branch 'branch-21.08' of github.com:rapidsai/cuml into extra-sh…
levsnv Jun 14, 2021
258e674
fixed bugs and linker issues
levsnv Jun 15, 2021
e0f53ea
removed unnecessary specialization in dispatch
levsnv Jun 15, 2021
ace36c0
simplified code to template-based dispatch
levsnv Jun 24, 2021
98509b0
Merge branch 'branch-21.08' of github.com:rapidsai/cuml into dispatch
levsnv Jun 24, 2021
bcd05bc
reverted max_shm changes
levsnv Jun 24, 2021
36b27f1
simplified templates, fixed bug
levsnv Jun 24, 2021
06650f8
remove unnecessary forward declarations
levsnv Jun 25, 2021
4dd4f8a
halfway change
levsnv Jun 29, 2021
1931928
recording template lambda-based attempt
levsnv Jun 29, 2021
ca473db
Revert "recording template lambda-based attempt"
levsnv Jun 29, 2021
971b19a
wrapped template params into a struct
levsnv Jul 2, 2021
12dfaba
Merge branch 'branch-21.08' of github.com:rapidsai/cuml into dispatch
levsnv Jul 2, 2021
d8505df
moved runtime args to constructor, separated ::run(...), added defaul…
levsnv Jul 3, 2021
1a0014d
Merge remote-tracking branch 'rapidsai/branch-21.10' into dispatch
levsnv Sep 29, 2021
5f2458f
KernelParams::inc_*
levsnv Sep 29, 2021
1b87aa0
dispatch_on_cats_present
levsnv Sep 30, 2021
5a03998
fix several issues
levsnv Oct 7, 2021
293ec2a
Merge branch 'branch-21.12' of github.com:rapidsai/cuml into dispatch
levsnv Oct 7, 2021
dbcaa7b
extern template void dispatch_on_fil_template_params(compute_smem_foo…
levsnv Oct 8, 2021
798171f
change-by-reference into accept-and-return-by-value
levsnv Oct 9, 2021
7e78203
variable renames
levsnv Oct 9, 2021
6ae0914
unnecessary changes
levsnv Oct 9, 2021
0978fc0
finish case adjustments
levsnv Oct 9, 2021
46eb819
NextLeafAlgo
levsnv Oct 9, 2021
4614ca2
MAX_N_ITEMS
levsnv Oct 9, 2021
9528148
stray changes
levsnv Oct 9, 2021
0b6cea8
removed decltype
levsnv Oct 20, 2021
a5a6ca6
next_leaf_algo, LEAF_ALGO_INVALID->MAX_LEAF_ALGO, ...
levsnv Oct 20, 2021
2b4555f
misc
levsnv Oct 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 85 additions & 3 deletions cpp/src/fil/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ struct shmem_size_params {
/// n_items is how many input samples (items) any thread processes. If 0 is given,
/// choose the reasonable most (<=4) that fit into shared memory. See init_n_items()
int n_items = 0;
// blockdim_x is the CUDA block size
int blockdim_x = 0;
levsnv marked this conversation as resolved.
Show resolved Hide resolved
/// shm_sz is the associated shared memory footprint
int shm_sz = INT_MAX;

Expand All @@ -146,9 +148,6 @@ struct shmem_size_params {
? sizeof(float) * sdata_stride() * n_items << log2_threads_per_tree
: 0;
}
void compute_smem_footprint();
template <int NITEMS>
size_t get_smem_footprint();
template <int NITEMS, leaf_algo_t leaf_algo>
size_t get_smem_footprint();
};
Expand All @@ -173,6 +172,89 @@ struct predict_params : shmem_size_params {
int num_blocks;
};

namespace dispatch {

template <template <bool, leaf_algo_t, int> class Func, typename storage_type,
levsnv marked this conversation as resolved.
Show resolved Hide resolved
bool cols_in_shmem, leaf_algo_t leaf_algo, int n_items,
levsnv marked this conversation as resolved.
Show resolved Hide resolved
typename... Args>
void dispatch_on_n_items(predict_params& params, Args... args) {
ASSERT(params.n_items <= 4, "internal error: n_items > 4");
levsnv marked this conversation as resolved.
Show resolved Hide resolved
if (params.n_items == n_items) {
Func<cols_in_shmem, leaf_algo, n_items>::template run<storage_type>(
params, args...);
} else if constexpr (n_items < 4) {
dispatch_on_n_items<Func, storage_type, cols_in_shmem, leaf_algo,
n_items + 1>(params, args...);
}
}

template <template <bool, leaf_algo_t, int> class Func, typename storage_type,
levsnv marked this conversation as resolved.
Show resolved Hide resolved
bool cols_in_shmem, typename... Args>
void dispatch_on_leaf_algo(predict_params& params, Args... args) {
switch (params.leaf_algo) {
case FLOAT_UNARY_BINARY:
levsnv marked this conversation as resolved.
Show resolved Hide resolved
params.blockdim_x = FIL_TPB;
dispatch_on_n_items<Func, storage_type, cols_in_shmem, FLOAT_UNARY_BINARY,
1>(params, args...);
break;
case GROVE_PER_CLASS:
if (params.num_classes > FIL_TPB) {
params.blockdim_x = FIL_TPB;
dispatch_on_n_items<Func, storage_type, cols_in_shmem,
GROVE_PER_CLASS_MANY_CLASSES, 1>(params, args...);
} else {
params.blockdim_x = FIL_TPB - FIL_TPB % params.num_classes;
dispatch_on_n_items<Func, storage_type, cols_in_shmem,
GROVE_PER_CLASS_FEW_CLASSES, 1>(params, args...);
}
break;
case CATEGORICAL_LEAF:
params.blockdim_x = FIL_TPB;
dispatch_on_n_items<Func, storage_type, cols_in_shmem, CATEGORICAL_LEAF,
1>(params, args...);
break;
case VECTOR_LEAF:
params.blockdim_x = FIL_TPB;
dispatch_on_n_items<Func, storage_type, cols_in_shmem, VECTOR_LEAF, 1>(
params, args...);
break;
default:
ASSERT(false, "internal error: dispatch: invalid leaf_algo %d",
params.leaf_algo);
}
}

template <template <bool, leaf_algo_t, int> class Func, typename storage_type,
levsnv marked this conversation as resolved.
Show resolved Hide resolved
typename... Args>
void dispatch_on_cols_in_shmem(predict_params& params, Args... args) {
if (params.cols_in_shmem)
dispatch_on_leaf_algo<Func, storage_type, true>(params, args...);
else
dispatch_on_leaf_algo<Func, storage_type, false>(params, args...);
}

} // namespace dispatch

template <template <bool, leaf_algo_t, int> class Func, typename storage_type,
typename... Args>
void dispatch_on_fil_template_params(predict_params& params, Args... args) {
dispatch::dispatch_on_cols_in_shmem<Func, storage_type>(params, args...);
}

// we need to instantiate all get_smem_footprint instantiations in infer.cu.
// The only guarantee is by instantiating
// dispatch_on_FIL_template<compute_smem_footprint... in infer.cu. This
// requires a declaration of this struct with the declaration of the run method
// (i.e. all but one line) visible from infer.cu, as well as this full
// definition visible from fil.cu. We'll just define it in common.cuh.
template <bool cols_in_shmem, leaf_algo_t leaf_algo, int n_items>
struct compute_smem_footprint {
template <typename storage_type>
static void run(predict_params& ssp) {
ssp.shm_sz = ssp.get_smem_footprint<n_items, leaf_algo>();
}
};

// infer() calls the inference kernel with the parameters on the stream
template <typename storage_type>
void infer(storage_type forest, predict_params params, cudaStream_t stream);
Expand Down
11 changes: 8 additions & 3 deletions cpp/src/fil/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ __global__ void transform_k(float* preds, size_t n, output_t output,
preds[i] = result;
}

extern template void dispatch_on_fil_template_params<
levsnv marked this conversation as resolved.
Show resolved Hide resolved
compute_smem_footprint, dense_storage>(predict_params&);

struct forest {
void init_n_items(int device) {
int max_shm_std = 48 * 1024; // 48 KiB
Expand All @@ -99,7 +102,7 @@ struct forest {
for (bool predict_proba : {false, true}) {
shmem_size_params& ssp_ = predict_proba ? proba_ssp_ : class_ssp_;
ssp_.predict_proba = predict_proba;
shmem_size_params ssp = ssp_;
predict_params ssp = ssp_;
// if n_items was not provided, try from 1 to 4. Otherwise, use as-is.
int min_n_items = ssp.n_items == 0 ? 1 : ssp.n_items;
int max_n_items = ssp.n_items == 0
Expand All @@ -109,7 +112,8 @@ struct forest {
ssp.cols_in_shmem = cols_in_shmem;
for (ssp.n_items = min_n_items; ssp.n_items <= max_n_items;
++ssp.n_items) {
ssp.compute_smem_footprint();
dispatch_on_fil_template_params<compute_smem_footprint,
dense_storage>(ssp);
if (ssp.shm_sz < max_shm) ssp_ = ssp;
}
}
Expand Down Expand Up @@ -276,7 +280,8 @@ struct forest {
global_bias != 0.0f;
break;
default:
ASSERT(false, "internal error: invalid leaf_algo_");
ASSERT(false, "internal error: predict: invalid leaf_algo %d",
params.leaf_algo);
}
} else {
if (params.leaf_algo == leaf_algo_t::FLOAT_UNARY_BINARY) {
Expand Down
124 changes: 19 additions & 105 deletions cpp/src/fil/infer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -750,118 +750,32 @@ size_t shmem_size_params::get_smem_footprint() {
tree_aggregator_t<NITEMS, leaf_algo>::smem_accumulate_footprint(
num_classes) +
cols_shmem_size();

return std::max(accumulate_footprint, finalize_footprint);
}

template <int NITEMS>
size_t shmem_size_params::get_smem_footprint() {
switch (leaf_algo) {
case FLOAT_UNARY_BINARY:
return get_smem_footprint<NITEMS, FLOAT_UNARY_BINARY>();
case CATEGORICAL_LEAF:
return get_smem_footprint<NITEMS, CATEGORICAL_LEAF>();
case GROVE_PER_CLASS:
if (num_classes > FIL_TPB)
return get_smem_footprint<NITEMS, GROVE_PER_CLASS_MANY_CLASSES>();
return get_smem_footprint<NITEMS, GROVE_PER_CLASS_FEW_CLASSES>();
case VECTOR_LEAF:
return get_smem_footprint<NITEMS, VECTOR_LEAF>();
default:
ASSERT(false, "internal error: unexpected leaf_algo_t");
// make sure to instantiate all possible get_smem_footprint instantiations
template void dispatch_on_fil_template_params<compute_smem_footprint,
dense_storage>(predict_params&);

template <bool cols_in_shmem, leaf_algo_t leaf_algo, int n_items>
struct infer_k_launcher {
template <typename storage_type>
static void run(predict_params& params, storage_type forest,
cudaStream_t stream) {
params.num_blocks = params.num_blocks != 0
? params.num_blocks
: raft::ceildiv(int(params.num_rows), params.n_items);
infer_k<n_items, leaf_algo, cols_in_shmem, storage_type>
<<<params.num_blocks, params.blockdim_x, params.shm_sz, stream>>>(forest,
params);
CUDA_CHECK(cudaPeekAtLastError());
}
}

void shmem_size_params::compute_smem_footprint() {
switch (n_items) {
case 1:
shm_sz = get_smem_footprint<1>();
break;
case 2:
shm_sz = get_smem_footprint<2>();
break;
case 3:
shm_sz = get_smem_footprint<3>();
break;
case 4:
shm_sz = get_smem_footprint<4>();
break;
default:
ASSERT(false, "internal error: n_items > 4");
}
}

template <leaf_algo_t leaf_algo, bool cols_in_shmem, typename storage_type>
void infer_k_nitems_launcher(storage_type forest, predict_params params,
cudaStream_t stream, int block_dim_x) {
switch (params.n_items) {
case 1:
infer_k<1, leaf_algo, cols_in_shmem>
<<<params.num_blocks, block_dim_x, params.shm_sz, stream>>>(forest,
params);
break;
case 2:
infer_k<2, leaf_algo, cols_in_shmem>
<<<params.num_blocks, block_dim_x, params.shm_sz, stream>>>(forest,
params);
break;
case 3:
infer_k<3, leaf_algo, cols_in_shmem>
<<<params.num_blocks, block_dim_x, params.shm_sz, stream>>>(forest,
params);
break;
case 4:
infer_k<4, leaf_algo, cols_in_shmem>
<<<params.num_blocks, block_dim_x, params.shm_sz, stream>>>(forest,
params);
break;
default:
ASSERT(false, "internal error: nitems > 4");
}
CUDA_CHECK(cudaPeekAtLastError());
}

template <leaf_algo_t leaf_algo, typename storage_type>
void infer_k_launcher(storage_type forest, predict_params params,
cudaStream_t stream, int blockdim_x) {
params.num_blocks = params.num_blocks != 0
? params.num_blocks
: raft::ceildiv(int(params.num_rows), params.n_items);
if (params.cols_in_shmem) {
infer_k_nitems_launcher<leaf_algo, true>(forest, params, stream,
blockdim_x);
} else {
infer_k_nitems_launcher<leaf_algo, false>(forest, params, stream,
blockdim_x);
}
}
};

template <typename storage_type>
void infer(storage_type forest, predict_params params, cudaStream_t stream) {
switch (params.leaf_algo) {
case FLOAT_UNARY_BINARY:
infer_k_launcher<FLOAT_UNARY_BINARY>(forest, params, stream, FIL_TPB);
break;
case GROVE_PER_CLASS:
if (params.num_classes > FIL_TPB) {
params.leaf_algo = GROVE_PER_CLASS_MANY_CLASSES;
infer_k_launcher<GROVE_PER_CLASS_MANY_CLASSES>(forest, params, stream,
FIL_TPB);
} else {
params.leaf_algo = GROVE_PER_CLASS_FEW_CLASSES;
infer_k_launcher<GROVE_PER_CLASS_FEW_CLASSES>(
forest, params, stream, FIL_TPB - FIL_TPB % params.num_classes);
}
break;
case CATEGORICAL_LEAF:
infer_k_launcher<CATEGORICAL_LEAF>(forest, params, stream, FIL_TPB);
break;
case VECTOR_LEAF:
infer_k_launcher<VECTOR_LEAF>(forest, params, stream, FIL_TPB);
break;
default:
ASSERT(false, "internal error: invalid leaf_algo");
}
dispatch_on_fil_template_params<infer_k_launcher, storage_type>(
levsnv marked this conversation as resolved.
Show resolved Hide resolved
params, forest, stream);
}

template void infer<dense_storage>(dense_storage forest, predict_params params,
Expand Down