Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify template parameter dispatch for FIL inference and shared memory footprint estimation #4013

Merged
merged 38 commits into from
Oct 27, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
53e0683
try 1
levsnv Apr 17, 2021
31b885f
draft of set-and-launch
levsnv May 25, 2021
589f9ad
Merge remote-tracking branch 'rapidsai/branch-21.06' into extra-share…
levsnv May 25, 2021
26480b0
set carveout and occupancy-affecting preferred cache config before ev…
levsnv May 26, 2021
a037f16
other review comments
levsnv May 26, 2021
2a1d622
DRY: rewrote in terms of dispatch_on_FIL_template_params<func, storag…
levsnv Jun 12, 2021
bd3c505
Merge branch 'branch-21.08' of github.com:rapidsai/cuml into extra-sh…
levsnv Jun 12, 2021
5cf38d3
style, clean up diff
levsnv Jun 12, 2021
7959f4d
Merge branch 'branch-21.08' of github.com:rapidsai/cuml into extra-sh…
levsnv Jun 14, 2021
258e674
fixed bugs and linker issues
levsnv Jun 15, 2021
e0f53ea
removed unnecessary specialization in dispatch
levsnv Jun 15, 2021
ace36c0
simplified code to template-based dispatch
levsnv Jun 24, 2021
98509b0
Merge branch 'branch-21.08' of github.com:rapidsai/cuml into dispatch
levsnv Jun 24, 2021
bcd05bc
reverted max_shm changes
levsnv Jun 24, 2021
36b27f1
simplified templates, fixed bug
levsnv Jun 24, 2021
06650f8
remove unnecessary forward declarations
levsnv Jun 25, 2021
4dd4f8a
halfway change
levsnv Jun 29, 2021
1931928
recording template lambda-based attempt
levsnv Jun 29, 2021
ca473db
Revert "recording template lambda-based attempt"
levsnv Jun 29, 2021
971b19a
wrapped template params into a struct
levsnv Jul 2, 2021
12dfaba
Merge branch 'branch-21.08' of github.com:rapidsai/cuml into dispatch
levsnv Jul 2, 2021
d8505df
moved runtime args to constructor, separated ::run(...), added defaul…
levsnv Jul 3, 2021
1a0014d
Merge remote-tracking branch 'rapidsai/branch-21.10' into dispatch
levsnv Sep 29, 2021
5f2458f
KernelParams::inc_*
levsnv Sep 29, 2021
1b87aa0
dispatch_on_cats_present
levsnv Sep 30, 2021
5a03998
fix several issues
levsnv Oct 7, 2021
293ec2a
Merge branch 'branch-21.12' of github.com:rapidsai/cuml into dispatch
levsnv Oct 7, 2021
dbcaa7b
extern template void dispatch_on_fil_template_params(compute_smem_foo…
levsnv Oct 8, 2021
798171f
change-by-reference into accept-and-return-by-value
levsnv Oct 9, 2021
7e78203
variable renames
levsnv Oct 9, 2021
6ae0914
unnecessary changes
levsnv Oct 9, 2021
0978fc0
finish case adjustments
levsnv Oct 9, 2021
46eb819
NextLeafAlgo
levsnv Oct 9, 2021
4614ca2
MAX_N_ITEMS
levsnv Oct 9, 2021
9528148
stray changes
levsnv Oct 9, 2021
0b6cea8
removed decltype
levsnv Oct 20, 2021
a5a6ca6
next_leaf_algo, LEAF_ALGO_INVALID->MAX_LEAF_ALGO, ...
levsnv Oct 20, 2021
2b4555f
misc
levsnv Oct 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cpp/include/cuml/fil/fil.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ struct forest;
/** forest_t is the predictor handle */
typedef forest* forest_t;

constexpr int MAX_N_ITEMS = 4;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it needed externally? Or can it be moved to internal.cuh or common.cuh?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd keep it externally to easily check the maximum number of n_items that one can provide in C++ (and python API doesn't check such ranges).
Actually, the n_items assert is not an internal error, it's an API error. I fixed it.

levsnv marked this conversation as resolved.
Show resolved Hide resolved

/** treelite_params_t are parameters for importing treelite models */
struct treelite_params_t {
// algo is the inference algorithm
Expand All @@ -94,7 +96,7 @@ struct treelite_params_t {
// can only be a power of 2
int threads_per_tree;
// n_items is how many input samples (items) any thread processes. If 0 is given,
// choose most (up to 4) that fit into shared memory.
// choose most (up to MAX_N_ITEMS) that fit into shared memory.
int n_items;
// if non-nullptr, *pforest_shape_str will be set to caller-owned string that
// contains forest shape
Expand Down
110 changes: 50 additions & 60 deletions cpp/src/fil/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,16 @@ struct shmem_size_params {
/// are the input columns are prefetched into shared
/// memory before inferring the row in question
bool cols_in_shmem = true;
// are there categorical inner nodes? doesnt' currently affect shared memory size,
// are there categorical inner nodes? doesn't currently affect shared memory size,
// but participates in template dispatch and may affect it later
bool cats_present;
bool cats_present = false;
/// log2_threads_per_tree determines how many threads work on a single tree
/// at once inside a block (sharing trees means splitting input rows)
int log2_threads_per_tree = 0;
/// n_items is how many input samples (items) any thread processes. If 0 is given,
/// choose the reasonable most (<=4) that fit into shared memory. See init_n_items()
/// choose the reasonable most (<= MAX_N_ITEMS) that fit into shared memory. See init_n_items()
int n_items = 0;
// block_dim_x is the CUDA block size
// block_dim_x is the CUDA block size. Set by dispatch_on_leaf_algo(...)
int block_dim_x = 0;
levsnv marked this conversation as resolved.
Show resolved Hide resolved
/// shm_sz is the associated shared memory footprint
int shm_sz = INT_MAX;
Expand Down Expand Up @@ -177,105 +177,95 @@ struct predict_params : shmem_size_params {
int num_blocks;
};

template <bool COLS_IN_SHMEM = false,
bool CATS_SUPPORTED = false,
int LEAF_ALGO = 0,
int N_ITEMS = 1>
struct KernelTemplateParameters {
static const bool cols_in_shmem = COLS_IN_SHMEM;
static const bool cats_supported = CATS_SUPPORTED;
static const leaf_algo_t leaf_algo = static_cast<leaf_algo_t>(LEAF_ALGO);
static const int n_items = N_ITEMS;
template <bool COLS_IN_SHMEM_ = false,
bool CATS_SUPPORTED_ = false,
int LEAF_ALGO_ = 0,
levsnv marked this conversation as resolved.
Show resolved Hide resolved
int N_ITEMS_ = 1>
struct KernelTemplateParams {
static const bool COLS_IN_SHMEM = COLS_IN_SHMEM_;
static const bool CATS_SUPPORTED = CATS_SUPPORTED_;
static const leaf_algo_t LEAF_ALGO = static_cast<leaf_algo_t>(LEAF_ALGO_);
static const int N_ITEMS = N_ITEMS_;

template <bool _cats_supported>
using replace_cats_supported =
KernelTemplateParameters<cols_in_shmem, _cats_supported, leaf_algo, n_items>;
using inc_leaf_algo =
KernelTemplateParameters<cols_in_shmem, cats_supported, leaf_algo + 1, n_items>;
using ReplaceCatsSupported =
KernelTemplateParams<COLS_IN_SHMEM, _cats_supported, LEAF_ALGO, N_ITEMS>;
using NextLeafAlgo = KernelTemplateParams<COLS_IN_SHMEM, CATS_SUPPORTED, LEAF_ALGO + 1, N_ITEMS>;
levsnv marked this conversation as resolved.
Show resolved Hide resolved
template <int _leaf_algo>
levsnv marked this conversation as resolved.
Show resolved Hide resolved
using replace_leaf_algo =
KernelTemplateParameters<cols_in_shmem, cats_supported, _leaf_algo, n_items>;
using inc_n_items =
KernelTemplateParameters<cols_in_shmem, cats_supported, leaf_algo, n_items + 1>;
using ReplaceLeafAlgo = KernelTemplateParams<COLS_IN_SHMEM, CATS_SUPPORTED, _leaf_algo, N_ITEMS>;
using IncNItems = KernelTemplateParams<COLS_IN_SHMEM, CATS_SUPPORTED, LEAF_ALGO, N_ITEMS + 1>;
};

namespace dispatch {

template <class KernelParams, class Func>
void dispatch_on_n_items(Func func, predict_params& params)
auto dispatch_on_n_items(Func func, predict_params params) -> decltype(func.run(params))
levsnv marked this conversation as resolved.
Show resolved Hide resolved
{
if (params.n_items == KernelParams::n_items) {
func.template run<KernelParams>(params);
} else if constexpr (KernelParams::n_items < 4) {
dispatch_on_n_items<class KernelParams::inc_n_items>(func, params);
if (params.n_items == KernelParams::N_ITEMS) {
return func.template run<KernelParams>(params);
} else if constexpr (KernelParams::N_ITEMS < MAX_N_ITEMS) {
return dispatch_on_n_items<class KernelParams::IncNItems>(func, params);
} else {
ASSERT(false, "internal error: n_items > 4 or < 1");
ASSERT(false, "internal error: n_items > %d or < 1", MAX_N_ITEMS);
}
return func.run(params); // appeasing the compiler
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better way of appeasing the compiler?

Copy link
Contributor Author

@levsnv levsnv Oct 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could remove elses, since the branches return, and flip the if constexpr condition. But I thought you prefer else after return?
Alternatively, return dispatch_on_n_items<KernelParams>(func, params) might work.
Last, perhaps template <class KernelParams, class Func, class T = decltype(declval<Func>().run(params))> might resolve both this and the signature decltype question?

Copy link
Contributor

@canonizer canonizer Oct 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need T to be in the template parameters? I guess yes if we want to use it in the return type, and no if we only need it in the function body.

In any case, even with T in the template parameters, it looks better if return T() or return T{} works. Perhaps you could add a documentation comment as to why you're doing this.

}

template <class KernelParams, class Func>
void dispatch_on_leaf_algo(Func func, predict_params& params)
auto dispatch_on_leaf_algo(Func func, predict_params params) -> decltype(func.run(params))
{
if (params.leaf_algo == KernelParams::leaf_algo) {
if constexpr (KernelParams::leaf_algo == GROVE_PER_CLASS) {
if (params.leaf_algo == KernelParams::LEAF_ALGO) {
if constexpr (KernelParams::LEAF_ALGO == GROVE_PER_CLASS) {
if (params.num_classes <= FIL_TPB) {
params.block_dim_x = FIL_TPB - FIL_TPB % params.num_classes;
using Next = typename KernelParams::replace_leaf_algo<GROVE_PER_CLASS_FEW_CLASSES>;
dispatch_on_n_items<Next>(func, params);
using Next = typename KernelParams::ReplaceLeafAlgo<GROVE_PER_CLASS_FEW_CLASSES>;
return dispatch_on_n_items<Next>(func, params);
} else {
params.block_dim_x = FIL_TPB;
using Next = typename KernelParams::replace_leaf_algo<GROVE_PER_CLASS_MANY_CLASSES>;
dispatch_on_n_items<Next>(func, params);
using Next = typename KernelParams::ReplaceLeafAlgo<GROVE_PER_CLASS_MANY_CLASSES>;
return dispatch_on_n_items<Next>(func, params);
}
} else {
params.block_dim_x = FIL_TPB;
dispatch_on_n_items<KernelParams>(func, params);
return dispatch_on_n_items<KernelParams>(func, params);
}
} else if constexpr (KernelParams::leaf_algo + 1 < static_cast<int>(LEAF_ALGO_INVALID)) {
dispatch_on_leaf_algo<class KernelParams::inc_leaf_algo>(func, params);
} else if constexpr (KernelParams::NextLeafAlgo::LEAF_ALGO < LEAF_ALGO_INVALID) {
levsnv marked this conversation as resolved.
Show resolved Hide resolved
return dispatch_on_leaf_algo<class KernelParams::NextLeafAlgo>(func, params);
} else {
ASSERT(false, "internal error: dispatch: invalid leaf_algo %d", params.leaf_algo);
}
return func.run(params); // appeasing the compiler
levsnv marked this conversation as resolved.
Show resolved Hide resolved
}

template <class KernelParams, class Func>
void dispatch_on_cats_supported(Func func, predict_params& params)
auto dispatch_on_cats_supported(Func func, predict_params params) -> decltype(func.run(params))
{
if (params.cats_present)
dispatch_on_leaf_algo<typename KernelParams::replace_cats_supported<true>>(func, params);
else
dispatch_on_leaf_algo<typename KernelParams::replace_cats_supported<false>>(func, params);
return params.cats_present
? dispatch_on_leaf_algo<typename KernelParams::ReplaceCatsSupported<true>>(func, params)
: dispatch_on_leaf_algo<typename KernelParams::ReplaceCatsSupported<false>>(func,
params);
}

template <class Func>
void dispatch_on_cols_in_shmem(Func func, predict_params& params)
auto dispatch_on_cols_in_shmem(Func func, predict_params params) -> decltype(func.run(params))
{
if (params.cols_in_shmem)
dispatch_on_cats_supported<KernelTemplateParameters<true>>(func, params);
else
dispatch_on_cats_supported<KernelTemplateParameters<false>>(func, params);
return params.cols_in_shmem
? dispatch_on_cats_supported<KernelTemplateParams<true>>(func, params)
: dispatch_on_cats_supported<KernelTemplateParams<false>>(func, params);
}

} // namespace dispatch

template <class Func>
void dispatch_on_fil_template_params(Func func, predict_params& params)
auto dispatch_on_fil_template_params(Func func, predict_params params) -> decltype(func.run(params))
{
dispatch::dispatch_on_cols_in_shmem(func, params);
return dispatch::dispatch_on_cols_in_shmem(func, params);
}

/* For an example of Func, see this:
*
* We need to instantiate all get_smem_footprint instantiations in infer.cu.
* The only guarantee is by instantiating
* dispatch_on_FIL_template<compute_smem_footprint... in infer.cu. This
* requires a declaration of this struct with the declaration of the `run` template
* (i.e. all but one line) visible from infer.cu, as well as this full
* definition visible from fil.cu. We'll just define it in common.cuh.
*/
// For an example of Func, see this:
struct compute_smem_footprint {
template <class KernelParams>
void run(predict_params& ssp);
template <class KernelParams = KernelTemplateParams<>>
int run(predict_params ssp);
};

// infer() calls the inference kernel with the parameters on the stream
Expand Down
25 changes: 12 additions & 13 deletions cpp/src/fil/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ __global__ void transform_k(float* preds,
}

// needed to avoid expanding the dispatch template into unresolved
// compute_smem_footprint::run<KernelParams>() calls. In infer.cu, we don't export those symbols,
// compute_smem_footprint::run() calls. In infer.cu, we don't export those symbols,
// but rather one symbol for the whole template specialization, as below.
extern template void dispatch_on_fil_template_params(compute_smem_footprint, predict_params&);
extern template int dispatch_on_fil_template_params(compute_smem_footprint, predict_params);

struct forest {
forest(const raft::handle_t& h) : vector_leaf_(0, h.get_stream()), cat_sets_(h.get_stream()) {}
Expand Down Expand Up @@ -130,16 +130,15 @@ struct forest {
shmem_size_params& ssp_ = predict_proba ? proba_ssp_ : class_ssp_;
ssp_.predict_proba = predict_proba;
shmem_size_params ssp = ssp_;
// if n_items was not provided, try from 1 to 4. Otherwise, use as-is.
// if n_items was not provided, try from 1 to MAX_N_ITEMS. Otherwise, use as-is.
int min_n_items = ssp.n_items == 0 ? 1 : ssp.n_items;
int max_n_items =
ssp.n_items == 0 ? (algo_ == algo_t::BATCH_TREE_REORG ? 4 : 1) : ssp.n_items;
ssp.n_items == 0 ? (algo_ == algo_t::BATCH_TREE_REORG ? MAX_N_ITEMS : 1) : ssp.n_items;
for (bool cols_in_shmem : {false, true}) {
ssp.cols_in_shmem = cols_in_shmem;
for (ssp.n_items = min_n_items; ssp.n_items <= max_n_items; ++ssp.n_items) {
predict_params pp = ssp;
dispatch_on_fil_template_params(compute_smem_footprint(), pp);
if (pp.shm_sz < max_shm) ssp_ = pp;
ssp.shm_sz = dispatch_on_fil_template_params(compute_smem_footprint(), ssp);
if (ssp.shm_sz < max_shm) ssp_ = ssp;
}
}
ASSERT(max_shm >= ssp_.shm_sz,
Expand All @@ -163,11 +162,6 @@ struct forest {
const std::vector<float>& vector_leaf,
const forest_params_t* params)
{
int device = h.get_device();
cudaStream_t stream = h.get_stream();
// categorical features
cat_sets_ = cat_sets_device_owner(cat_sets, stream);

depth_ = params->depth;
num_trees_ = params->num_trees;
algo_ = params->algo;
Expand All @@ -179,9 +173,11 @@ struct forest {
proba_ssp_.leaf_algo = params->leaf_algo;
proba_ssp_.num_cols = params->num_cols;
proba_ssp_.num_classes = params->num_classes;
proba_ssp_.cats_present = cat_sets_.accessor().cats_present();
proba_ssp_.cats_present = cat_sets.cats_present();
class_ssp_ = proba_ssp_;

int device = h.get_device();
cudaStream_t stream = h.get_stream();
init_n_items(device); // n_items takes priority over blocks_per_sm
init_fixed_block_count(device, params->blocks_per_sm);

Expand All @@ -195,6 +191,9 @@ struct forest {
cudaMemcpyHostToDevice,
stream));
}

// categorical features
cat_sets_ = cat_sets_device_owner(cat_sets, stream);
}

virtual void infer(predict_params params, cudaStream_t stream) = 0;
Expand Down
27 changes: 12 additions & 15 deletions cpp/src/fil/infer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
#endif // __CUDA_ARCH__
#endif // CUDA_PRAGMA_UNROLL

#define INLINE_CONFIG __noinline__
#define INLINE_CONFIG __forceinline__

namespace ML {
namespace fil {
Expand Down Expand Up @@ -786,7 +786,7 @@ __device__ INLINE_CONFIG void load_data(float* sdata,
template <int NITEMS,
leaf_algo_t leaf_algo,
bool cols_in_shmem,
bool cats_supported,
bool CATS_SUPPORTED,
class storage_type>
__global__ void infer_k(storage_type forest, predict_params params)
{
Expand Down Expand Up @@ -823,7 +823,7 @@ __global__ void infer_k(storage_type forest, predict_params params)
typedef typename leaf_output_t<leaf_algo>::T pred_t;
vec<NITEMS, pred_t> prediction;
if (tree < forest.num_trees() && thread_num_rows != 0) {
prediction = infer_one_tree<NITEMS, cats_supported, pred_t>(
prediction = infer_one_tree<NITEMS, CATS_SUPPORTED, pred_t>(
forest[tree],
cols_in_shmem ? sdata + thread_row0 * sdata_stride : block_input + thread_row0 * num_cols,
cols_in_shmem ? sdata_stride : num_cols,
Expand Down Expand Up @@ -855,32 +855,29 @@ size_t shmem_size_params::get_smem_footprint()
}

template <class KernelParams>
void compute_smem_footprint::run(predict_params& ssp)
int compute_smem_footprint::run(predict_params ssp)
{
// need GROVE_PER_CLASS_*_CLASSES
if constexpr (KernelParams::leaf_algo != GROVE_PER_CLASS) {
ssp.shm_sz = ssp.template get_smem_footprint<KernelParams::n_items, KernelParams::leaf_algo>();
}
return ssp.template get_smem_footprint<KernelParams::N_ITEMS, KernelParams::LEAF_ALGO>();
}

// make sure to instantiate all possible get_smem_footprint instantiations
template void dispatch_on_fil_template_params(compute_smem_footprint, predict_params&);
template int dispatch_on_fil_template_params(compute_smem_footprint, predict_params);

template <typename storage_type>
struct infer_k_storage_template {
storage_type forest;
cudaStream_t stream;

template <class KernelParams>
void run(predict_params& params)
template <class KernelParams = KernelTemplateParams<>>
void run(predict_params params)
{
params.num_blocks = params.num_blocks != 0
? params.num_blocks
: raft::ceildiv(int(params.num_rows), params.n_items);
infer_k<KernelParams::n_items,
KernelParams::leaf_algo,
KernelParams::cols_in_shmem,
KernelParams::cats_supported>
infer_k<KernelParams::N_ITEMS,
KernelParams::LEAF_ALGO,
KernelParams::COLS_IN_SHMEM,
KernelParams::CATS_SUPPORTED>
<<<params.num_blocks, params.block_dim_x, params.shm_sz, stream>>>(forest, params);
CUDA_CHECK(cudaPeekAtLastError());
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/fil/internal.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ struct forest_params_t {
// at once inside a block (sharing trees means splitting input rows)
int threads_per_tree;
// n_items is how many input samples (items) any thread processes. If 0 is given,
// choose most (up to 4) that fit into shared memory.
// choose most (up to MAX_N_ITEMS) that fit into shared memory.
int n_items;
};

Expand Down