-
Notifications
You must be signed in to change notification settings - Fork 527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unify template parameter dispatch for FIL inference and shared memory footprint estimation #4013
Changes from 7 commits
53e0683
31b885f
589f9ad
26480b0
a037f16
2a1d622
bd3c505
5cf38d3
7959f4d
258e674
e0f53ea
ace36c0
98509b0
bcd05bc
36b27f1
06650f8
4dd4f8a
1931928
ca473db
971b19a
12dfaba
d8505df
1a0014d
5f2458f
1b87aa0
5a03998
293ec2a
dbcaa7b
798171f
7e78203
6ae0914
0978fc0
46eb819
4614ca2
9528148
0b6cea8
a5a6ca6
2b4555f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -130,16 +130,16 @@ struct shmem_size_params { | |
/// are the input columns are prefetched into shared | ||
/// memory before inferring the row in question | ||
bool cols_in_shmem = true; | ||
// are there categorical inner nodes? doesnt' currently affect shared memory size, | ||
// are there categorical inner nodes? doesn't currently affect shared memory size, | ||
// but participates in template dispatch and may affect it later | ||
bool cats_present; | ||
bool cats_present = false; | ||
/// log2_threads_per_tree determines how many threads work on a single tree | ||
/// at once inside a block (sharing trees means splitting input rows) | ||
int log2_threads_per_tree = 0; | ||
/// n_items is how many input samples (items) any thread processes. If 0 is given, | ||
/// choose the reasonable most (<=4) that fit into shared memory. See init_n_items() | ||
/// choose the reasonable most (<= MAX_N_ITEMS) that fit into shared memory. See init_n_items() | ||
int n_items = 0; | ||
// block_dim_x is the CUDA block size | ||
// block_dim_x is the CUDA block size. Set by dispatch_on_leaf_algo(...) | ||
int block_dim_x = 0; | ||
levsnv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/// shm_sz is the associated shared memory footprint | ||
int shm_sz = INT_MAX; | ||
|
@@ -177,105 +177,95 @@ struct predict_params : shmem_size_params { | |
int num_blocks; | ||
}; | ||
|
||
template <bool COLS_IN_SHMEM = false, | ||
bool CATS_SUPPORTED = false, | ||
int LEAF_ALGO = 0, | ||
int N_ITEMS = 1> | ||
struct KernelTemplateParameters { | ||
static const bool cols_in_shmem = COLS_IN_SHMEM; | ||
static const bool cats_supported = CATS_SUPPORTED; | ||
static const leaf_algo_t leaf_algo = static_cast<leaf_algo_t>(LEAF_ALGO); | ||
static const int n_items = N_ITEMS; | ||
template <bool COLS_IN_SHMEM_ = false, | ||
bool CATS_SUPPORTED_ = false, | ||
int LEAF_ALGO_ = 0, | ||
levsnv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
int N_ITEMS_ = 1> | ||
struct KernelTemplateParams { | ||
static const bool COLS_IN_SHMEM = COLS_IN_SHMEM_; | ||
static const bool CATS_SUPPORTED = CATS_SUPPORTED_; | ||
static const leaf_algo_t LEAF_ALGO = static_cast<leaf_algo_t>(LEAF_ALGO_); | ||
static const int N_ITEMS = N_ITEMS_; | ||
|
||
template <bool _cats_supported> | ||
using replace_cats_supported = | ||
KernelTemplateParameters<cols_in_shmem, _cats_supported, leaf_algo, n_items>; | ||
using inc_leaf_algo = | ||
KernelTemplateParameters<cols_in_shmem, cats_supported, leaf_algo + 1, n_items>; | ||
using ReplaceCatsSupported = | ||
KernelTemplateParams<COLS_IN_SHMEM, _cats_supported, LEAF_ALGO, N_ITEMS>; | ||
using NextLeafAlgo = KernelTemplateParams<COLS_IN_SHMEM, CATS_SUPPORTED, LEAF_ALGO + 1, N_ITEMS>; | ||
levsnv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
template <int _leaf_algo> | ||
levsnv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
using replace_leaf_algo = | ||
KernelTemplateParameters<cols_in_shmem, cats_supported, _leaf_algo, n_items>; | ||
using inc_n_items = | ||
KernelTemplateParameters<cols_in_shmem, cats_supported, leaf_algo, n_items + 1>; | ||
using ReplaceLeafAlgo = KernelTemplateParams<COLS_IN_SHMEM, CATS_SUPPORTED, _leaf_algo, N_ITEMS>; | ||
using IncNItems = KernelTemplateParams<COLS_IN_SHMEM, CATS_SUPPORTED, LEAF_ALGO, N_ITEMS + 1>; | ||
}; | ||
|
||
namespace dispatch { | ||
|
||
template <class KernelParams, class Func> | ||
void dispatch_on_n_items(Func func, predict_params& params) | ||
auto dispatch_on_n_items(Func func, predict_params params) -> decltype(func.run(params)) | ||
levsnv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
if (params.n_items == KernelParams::n_items) { | ||
func.template run<KernelParams>(params); | ||
} else if constexpr (KernelParams::n_items < 4) { | ||
dispatch_on_n_items<class KernelParams::inc_n_items>(func, params); | ||
if (params.n_items == KernelParams::N_ITEMS) { | ||
return func.template run<KernelParams>(params); | ||
} else if constexpr (KernelParams::N_ITEMS < MAX_N_ITEMS) { | ||
return dispatch_on_n_items<class KernelParams::IncNItems>(func, params); | ||
} else { | ||
ASSERT(false, "internal error: n_items > 4 or < 1"); | ||
ASSERT(false, "internal error: n_items > %d or < 1", MAX_N_ITEMS); | ||
} | ||
return func.run(params); // appeasing the compiler | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a better way of appeasing the compiler? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I could remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need In any case, even with |
||
} | ||
|
||
template <class KernelParams, class Func> | ||
void dispatch_on_leaf_algo(Func func, predict_params& params) | ||
auto dispatch_on_leaf_algo(Func func, predict_params params) -> decltype(func.run(params)) | ||
{ | ||
if (params.leaf_algo == KernelParams::leaf_algo) { | ||
if constexpr (KernelParams::leaf_algo == GROVE_PER_CLASS) { | ||
if (params.leaf_algo == KernelParams::LEAF_ALGO) { | ||
if constexpr (KernelParams::LEAF_ALGO == GROVE_PER_CLASS) { | ||
if (params.num_classes <= FIL_TPB) { | ||
params.block_dim_x = FIL_TPB - FIL_TPB % params.num_classes; | ||
using Next = typename KernelParams::replace_leaf_algo<GROVE_PER_CLASS_FEW_CLASSES>; | ||
dispatch_on_n_items<Next>(func, params); | ||
using Next = typename KernelParams::ReplaceLeafAlgo<GROVE_PER_CLASS_FEW_CLASSES>; | ||
return dispatch_on_n_items<Next>(func, params); | ||
} else { | ||
params.block_dim_x = FIL_TPB; | ||
using Next = typename KernelParams::replace_leaf_algo<GROVE_PER_CLASS_MANY_CLASSES>; | ||
dispatch_on_n_items<Next>(func, params); | ||
using Next = typename KernelParams::ReplaceLeafAlgo<GROVE_PER_CLASS_MANY_CLASSES>; | ||
return dispatch_on_n_items<Next>(func, params); | ||
} | ||
} else { | ||
params.block_dim_x = FIL_TPB; | ||
dispatch_on_n_items<KernelParams>(func, params); | ||
return dispatch_on_n_items<KernelParams>(func, params); | ||
} | ||
} else if constexpr (KernelParams::leaf_algo + 1 < static_cast<int>(LEAF_ALGO_INVALID)) { | ||
dispatch_on_leaf_algo<class KernelParams::inc_leaf_algo>(func, params); | ||
} else if constexpr (KernelParams::NextLeafAlgo::LEAF_ALGO < LEAF_ALGO_INVALID) { | ||
levsnv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return dispatch_on_leaf_algo<class KernelParams::NextLeafAlgo>(func, params); | ||
} else { | ||
ASSERT(false, "internal error: dispatch: invalid leaf_algo %d", params.leaf_algo); | ||
} | ||
return func.run(params); // appeasing the compiler | ||
levsnv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
template <class KernelParams, class Func> | ||
void dispatch_on_cats_supported(Func func, predict_params& params) | ||
auto dispatch_on_cats_supported(Func func, predict_params params) -> decltype(func.run(params)) | ||
{ | ||
if (params.cats_present) | ||
dispatch_on_leaf_algo<typename KernelParams::replace_cats_supported<true>>(func, params); | ||
else | ||
dispatch_on_leaf_algo<typename KernelParams::replace_cats_supported<false>>(func, params); | ||
return params.cats_present | ||
? dispatch_on_leaf_algo<typename KernelParams::ReplaceCatsSupported<true>>(func, params) | ||
: dispatch_on_leaf_algo<typename KernelParams::ReplaceCatsSupported<false>>(func, | ||
params); | ||
} | ||
|
||
template <class Func> | ||
void dispatch_on_cols_in_shmem(Func func, predict_params& params) | ||
auto dispatch_on_cols_in_shmem(Func func, predict_params params) -> decltype(func.run(params)) | ||
{ | ||
if (params.cols_in_shmem) | ||
dispatch_on_cats_supported<KernelTemplateParameters<true>>(func, params); | ||
else | ||
dispatch_on_cats_supported<KernelTemplateParameters<false>>(func, params); | ||
return params.cols_in_shmem | ||
? dispatch_on_cats_supported<KernelTemplateParams<true>>(func, params) | ||
: dispatch_on_cats_supported<KernelTemplateParams<false>>(func, params); | ||
} | ||
|
||
} // namespace dispatch | ||
|
||
template <class Func> | ||
void dispatch_on_fil_template_params(Func func, predict_params& params) | ||
auto dispatch_on_fil_template_params(Func func, predict_params params) -> decltype(func.run(params)) | ||
{ | ||
dispatch::dispatch_on_cols_in_shmem(func, params); | ||
return dispatch::dispatch_on_cols_in_shmem(func, params); | ||
} | ||
|
||
/* For an example of Func, see this: | ||
* | ||
* We need to instantiate all get_smem_footprint instantiations in infer.cu. | ||
* The only guarantee is by instantiating | ||
* dispatch_on_FIL_template<compute_smem_footprint... in infer.cu. This | ||
* requires a declaration of this struct with the declaration of the `run` template | ||
* (i.e. all but one line) visible from infer.cu, as well as this full | ||
* definition visible from fil.cu. We'll just define it in common.cuh. | ||
*/ | ||
// For an example of Func, see this: | ||
struct compute_smem_footprint { | ||
template <class KernelParams> | ||
void run(predict_params& ssp); | ||
template <class KernelParams = KernelTemplateParams<>> | ||
int run(predict_params ssp); | ||
}; | ||
|
||
// infer() calls the inference kernel with the parameters on the stream | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it needed externally? Or can it be moved to
internal.cuh
orcommon.cuh
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd keep it externally to easily check the maximum number of
n_items
that one can provide in C++ (and python API doesn't check such ranges).Actually, the n_items assert is not an internal error, it's an API error. I fixed it.