Skip to content

Commit

Permalink
Unify template parameter dispatch for FIL inference and shared memory…
Browse files Browse the repository at this point in the history
… footprint estimation (#4013)

Authors:
  - Levs Dolgovs (https://github.com/levsnv)

Approvers:
  - Andy Adinets (https://github.com/canonizer)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #4013
  • Loading branch information
levsnv authored Oct 27, 2021
1 parent 4775124 commit 92a40e0
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 104 deletions.
5 changes: 4 additions & 1 deletion cpp/include/cuml/fil/fil.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ struct forest;
/** forest_t is the predictor handle */
typedef forest* forest_t;

/** MAX_N_ITEMS determines the maximum allowed value for tl_params::n_items */
constexpr int MAX_N_ITEMS = 4;

/** treelite_params_t are parameters for importing treelite models */
struct treelite_params_t {
// algo is the inference algorithm
Expand All @@ -94,7 +97,7 @@ struct treelite_params_t {
// can only be a power of 2
int threads_per_tree;
// n_items is how many input samples (items) any thread processes. If 0 is given,
// choose most (up to 4) that fit into shared memory.
// choose most (up to MAX_N_ITEMS) that fit into shared memory.
int n_items;
// if non-nullptr, *pforest_shape_str will be set to caller-owned string that
// contains forest shape
Expand Down
118 changes: 114 additions & 4 deletions cpp/src/fil/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,17 @@ struct shmem_size_params {
/// are the input columns are prefetched into shared
/// memory before inferring the row in question
bool cols_in_shmem = true;
// are there categorical inner nodes? doesn't currently affect shared memory size,
// but participates in template dispatch and may affect it later
bool cats_present = false;
/// log2_threads_per_tree determines how many threads work on a single tree
/// at once inside a block (sharing trees means splitting input rows)
int log2_threads_per_tree = 0;
/// n_items is how many input samples (items) any thread processes. If 0 is given,
/// choose the reasonable most (<=4) that fit into shared memory. See init_n_items()
/// choose the reasonable most (<= MAX_N_ITEMS) that fit into shared memory. See init_n_items()
int n_items = 0;
// block_dim_x is the CUDA block size. Set by dispatch_on_leaf_algo(...)
int block_dim_x = 0;
/// shm_sz is the associated shared memory footprint
int shm_sz = INT_MAX;

Expand All @@ -147,9 +152,6 @@ struct shmem_size_params {
{
return cols_in_shmem ? sizeof(float) * sdata_stride() * n_items << log2_threads_per_tree : 0;
}
void compute_smem_footprint();
template <int NITEMS>
size_t get_smem_footprint();
template <int NITEMS, leaf_algo_t leaf_algo>
size_t get_smem_footprint();
};
Expand All @@ -175,6 +177,114 @@ struct predict_params : shmem_size_params {
int num_blocks;
};

constexpr leaf_algo_t next_leaf_algo(leaf_algo_t algo)
{
return static_cast<leaf_algo_t>(algo + 1);
}

template <bool COLS_IN_SHMEM_ = false,
bool CATS_SUPPORTED_ = false,
leaf_algo_t LEAF_ALGO_ = MIN_LEAF_ALGO,
int N_ITEMS_ = 1>
struct KernelTemplateParams {
static const bool COLS_IN_SHMEM = COLS_IN_SHMEM_;
static const bool CATS_SUPPORTED = CATS_SUPPORTED_;
static const leaf_algo_t LEAF_ALGO = LEAF_ALGO_;
static const int N_ITEMS = N_ITEMS_;

template <bool _cats_supported>
using ReplaceCatsSupported =
KernelTemplateParams<COLS_IN_SHMEM, _cats_supported, LEAF_ALGO, N_ITEMS>;
using NextLeafAlgo =
KernelTemplateParams<COLS_IN_SHMEM, CATS_SUPPORTED, next_leaf_algo(LEAF_ALGO), N_ITEMS>;
template <leaf_algo_t NEW_LEAF_ALGO>
using ReplaceLeafAlgo =
KernelTemplateParams<COLS_IN_SHMEM, CATS_SUPPORTED, NEW_LEAF_ALGO, N_ITEMS>;
using IncNItems = KernelTemplateParams<COLS_IN_SHMEM, CATS_SUPPORTED, LEAF_ALGO, N_ITEMS + 1>;
};

// inherit from this struct to pass the functor to dispatch_on_fil_template_params()
// compiler will prevent defining a .run() method with a different output type
template <typename T>
struct dispatch_functor {
typedef T return_t;
template <class KernelParams = KernelTemplateParams<>>
T run(predict_params);
};

namespace dispatch {

template <class KernelParams, class Func, class T = typename Func::return_t>
T dispatch_on_n_items(Func func, predict_params params)
{
if (params.n_items == KernelParams::N_ITEMS) {
return func.template run<KernelParams>(params);
} else if constexpr (KernelParams::N_ITEMS < MAX_N_ITEMS) {
return dispatch_on_n_items<class KernelParams::IncNItems>(func, params);
} else {
ASSERT(false, "n_items > %d or < 1", MAX_N_ITEMS);
}
return T(); // appeasing the compiler
}

template <class KernelParams, class Func, class T = typename Func::return_t>
T dispatch_on_leaf_algo(Func func, predict_params params)
{
if (params.leaf_algo == KernelParams::LEAF_ALGO) {
if constexpr (KernelParams::LEAF_ALGO == GROVE_PER_CLASS) {
if (params.num_classes <= FIL_TPB) {
params.block_dim_x = FIL_TPB - FIL_TPB % params.num_classes;
using Next = typename KernelParams::ReplaceLeafAlgo<GROVE_PER_CLASS_FEW_CLASSES>;
return dispatch_on_n_items<Next>(func, params);
} else {
params.block_dim_x = FIL_TPB;
using Next = typename KernelParams::ReplaceLeafAlgo<GROVE_PER_CLASS_MANY_CLASSES>;
return dispatch_on_n_items<Next>(func, params);
}
} else {
params.block_dim_x = FIL_TPB;
return dispatch_on_n_items<KernelParams>(func, params);
}
} else if constexpr (next_leaf_algo(KernelParams::LEAF_ALGO) <= MAX_LEAF_ALGO) {
return dispatch_on_leaf_algo<class KernelParams::NextLeafAlgo>(func, params);
} else {
ASSERT(false, "internal error: dispatch: invalid leaf_algo %d", params.leaf_algo);
}
return T(); // appeasing the compiler
}

template <class KernelParams, class Func, class T = typename Func::return_t>
T dispatch_on_cats_supported(Func func, predict_params params)
{
return params.cats_present
? dispatch_on_leaf_algo<typename KernelParams::ReplaceCatsSupported<true>>(func, params)
: dispatch_on_leaf_algo<typename KernelParams::ReplaceCatsSupported<false>>(func,
params);
}

template <class Func, class T = typename Func::return_t>
T dispatch_on_cols_in_shmem(Func func, predict_params params)
{
return params.cols_in_shmem
? dispatch_on_cats_supported<KernelTemplateParams<true>>(func, params)
: dispatch_on_cats_supported<KernelTemplateParams<false>>(func, params);
}

} // namespace dispatch

template <class Func, class T = typename Func::return_t>
T dispatch_on_fil_template_params(Func func, predict_params params)
{
return dispatch::dispatch_on_cols_in_shmem(func, params);
}

// For an example of Func declaration, see this.
// the .run(predict_params) method will be defined in infer.cu
struct compute_smem_footprint : dispatch_functor<int> {
template <class KernelParams = KernelTemplateParams<>>
int run(predict_params);
};

// infer() calls the inference kernel with the parameters on the stream
template <typename storage_type>
void infer(storage_type forest, predict_params params, cudaStream_t stream);
Expand Down
14 changes: 10 additions & 4 deletions cpp/src/fil/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ __global__ void transform_k(float* preds,
preds[i] = result;
}

// needed to avoid expanding the dispatch template into unresolved
// compute_smem_footprint::run() calls. In infer.cu, we don't export those symbols,
// but rather one symbol for the whole template specialization, as below.
extern template int dispatch_on_fil_template_params(compute_smem_footprint, predict_params);

struct forest {
forest(const raft::handle_t& h) : vector_leaf_(0, h.get_stream()), cat_sets_(h.get_stream()) {}

Expand Down Expand Up @@ -125,14 +130,14 @@ struct forest {
shmem_size_params& ssp_ = predict_proba ? proba_ssp_ : class_ssp_;
ssp_.predict_proba = predict_proba;
shmem_size_params ssp = ssp_;
// if n_items was not provided, try from 1 to 4. Otherwise, use as-is.
// if n_items was not provided, try from 1 to MAX_N_ITEMS. Otherwise, use as-is.
int min_n_items = ssp.n_items == 0 ? 1 : ssp.n_items;
int max_n_items =
ssp.n_items == 0 ? (algo_ == algo_t::BATCH_TREE_REORG ? 4 : 1) : ssp.n_items;
ssp.n_items == 0 ? (algo_ == algo_t::BATCH_TREE_REORG ? MAX_N_ITEMS : 1) : ssp.n_items;
for (bool cols_in_shmem : {false, true}) {
ssp.cols_in_shmem = cols_in_shmem;
for (ssp.n_items = min_n_items; ssp.n_items <= max_n_items; ++ssp.n_items) {
ssp.compute_smem_footprint();
ssp.shm_sz = dispatch_on_fil_template_params(compute_smem_footprint(), ssp);
if (ssp.shm_sz < max_shm) ssp_ = ssp;
}
}
Expand Down Expand Up @@ -168,6 +173,7 @@ struct forest {
proba_ssp_.leaf_algo = params->leaf_algo;
proba_ssp_.num_cols = params->num_cols;
proba_ssp_.num_classes = params->num_classes;
proba_ssp_.cats_present = cat_sets.cats_present();
class_ssp_ = proba_ssp_;

int device = h.get_device();
Expand Down Expand Up @@ -301,7 +307,7 @@ struct forest {
params.num_outputs = params.num_classes;
do_transform = (ot != output_t::RAW && ot != output_t::SOFTMAX) || global_bias != 0.0f;
break;
default: ASSERT(false, "internal error: invalid leaf_algo_");
default: ASSERT(false, "internal error: predict: invalid leaf_algo %d", params.leaf_algo);
}
} else {
if (params.leaf_algo == leaf_algo_t::FLOAT_UNARY_BINARY) {
Expand Down
120 changes: 26 additions & 94 deletions cpp/src/fil/infer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -851,114 +851,46 @@ size_t shmem_size_params::get_smem_footprint()
size_t accumulate_footprint =
tree_aggregator_t<NITEMS, leaf_algo>::smem_accumulate_footprint(num_classes) +
cols_shmem_size();

return std::max(accumulate_footprint, finalize_footprint);
}

template <int NITEMS>
size_t shmem_size_params::get_smem_footprint()
{
switch (leaf_algo) {
case FLOAT_UNARY_BINARY: return get_smem_footprint<NITEMS, FLOAT_UNARY_BINARY>();
case CATEGORICAL_LEAF: return get_smem_footprint<NITEMS, CATEGORICAL_LEAF>();
case GROVE_PER_CLASS:
if (num_classes > FIL_TPB) return get_smem_footprint<NITEMS, GROVE_PER_CLASS_MANY_CLASSES>();
return get_smem_footprint<NITEMS, GROVE_PER_CLASS_FEW_CLASSES>();
case VECTOR_LEAF: return get_smem_footprint<NITEMS, VECTOR_LEAF>();
default: ASSERT(false, "internal error: unexpected leaf_algo_t");
}
}

void shmem_size_params::compute_smem_footprint()
template <class KernelParams>
int compute_smem_footprint::run(predict_params ssp)
{
switch (n_items) {
case 1: shm_sz = get_smem_footprint<1>(); break;
case 2: shm_sz = get_smem_footprint<2>(); break;
case 3: shm_sz = get_smem_footprint<3>(); break;
case 4: shm_sz = get_smem_footprint<4>(); break;
default: ASSERT(false, "internal error: n_items > 4");
}
return ssp.template get_smem_footprint<KernelParams::N_ITEMS, KernelParams::LEAF_ALGO>();
}

template <leaf_algo_t leaf_algo, bool COLS_IN_SHMEM, bool CATS_SUPPORTED, typename storage_type>
void infer_k_nitems_launcher(storage_type forest,
predict_params params,
cudaStream_t stream,
int block_dim_x)
{
switch (params.n_items) {
case 1:
infer_k<1, leaf_algo, COLS_IN_SHMEM, CATS_SUPPORTED>
<<<params.num_blocks, block_dim_x, params.shm_sz, stream>>>(forest, params);
break;
case 2:
infer_k<2, leaf_algo, COLS_IN_SHMEM, CATS_SUPPORTED>
<<<params.num_blocks, block_dim_x, params.shm_sz, stream>>>(forest, params);
break;
case 3:
infer_k<3, leaf_algo, COLS_IN_SHMEM, CATS_SUPPORTED>
<<<params.num_blocks, block_dim_x, params.shm_sz, stream>>>(forest, params);
break;
case 4:
infer_k<4, leaf_algo, COLS_IN_SHMEM, CATS_SUPPORTED>
<<<params.num_blocks, block_dim_x, params.shm_sz, stream>>>(forest, params);
break;
default: ASSERT(false, "internal error: nitems > 4");
}
CUDA_CHECK(cudaPeekAtLastError());
}
// make sure to instantiate all possible get_smem_footprint instantiations
template int dispatch_on_fil_template_params(compute_smem_footprint, predict_params);

template <leaf_algo_t leaf_algo, bool COLS_IN_SHMEM, typename storage_type>
void infer_k_categorical_launcher(storage_type forest,
predict_params params,
cudaStream_t stream,
int blockdim_x)
{
if (forest.cats_present()) {
infer_k_nitems_launcher<leaf_algo, COLS_IN_SHMEM, true>(forest, params, stream, blockdim_x);
} else {
infer_k_nitems_launcher<leaf_algo, COLS_IN_SHMEM, false>(forest, params, stream, blockdim_x);
template <typename storage_type>
struct infer_k_storage_template : dispatch_functor<void> {
storage_type forest;
cudaStream_t stream;
infer_k_storage_template(storage_type forest_, cudaStream_t stream_)
: forest(forest_), stream(stream_)
{
}
}

template <leaf_algo_t leaf_algo, typename storage_type>
void infer_k_cols_launcher(storage_type forest,
predict_params params,
cudaStream_t stream,
int blockdim_x)
{
params.num_blocks = params.num_blocks != 0 ? params.num_blocks
: raft::ceildiv(int(params.num_rows), params.n_items);
if (params.cols_in_shmem) {
infer_k_categorical_launcher<leaf_algo, true>(forest, params, stream, blockdim_x);
} else {
infer_k_categorical_launcher<leaf_algo, false>(forest, params, stream, blockdim_x);
template <class KernelParams = KernelTemplateParams<>>
void run(predict_params params)
{
params.num_blocks = params.num_blocks != 0
? params.num_blocks
: raft::ceildiv(int(params.num_rows), params.n_items);
infer_k<KernelParams::N_ITEMS,
KernelParams::LEAF_ALGO,
KernelParams::COLS_IN_SHMEM,
KernelParams::CATS_SUPPORTED>
<<<params.num_blocks, params.block_dim_x, params.shm_sz, stream>>>(forest, params);
CUDA_CHECK(cudaPeekAtLastError());
}
}
};

template <typename storage_type>
void infer(storage_type forest, predict_params params, cudaStream_t stream)
{
switch (params.leaf_algo) {
case FLOAT_UNARY_BINARY:
infer_k_cols_launcher<FLOAT_UNARY_BINARY>(forest, params, stream, FIL_TPB);
break;
case GROVE_PER_CLASS:
if (params.num_classes > FIL_TPB) {
params.leaf_algo = GROVE_PER_CLASS_MANY_CLASSES;
infer_k_cols_launcher<GROVE_PER_CLASS_MANY_CLASSES>(forest, params, stream, FIL_TPB);
} else {
params.leaf_algo = GROVE_PER_CLASS_FEW_CLASSES;
infer_k_cols_launcher<GROVE_PER_CLASS_FEW_CLASSES>(
forest, params, stream, FIL_TPB - FIL_TPB % params.num_classes);
}
break;
case CATEGORICAL_LEAF:
infer_k_cols_launcher<CATEGORICAL_LEAF>(forest, params, stream, FIL_TPB);
break;
case VECTOR_LEAF: infer_k_cols_launcher<VECTOR_LEAF>(forest, params, stream, FIL_TPB); break;
default: ASSERT(false, "internal error: invalid leaf_algo");
}
dispatch_on_fil_template_params(infer_k_storage_template<storage_type>(forest, stream), params);
}

template void infer<dense_storage>(dense_storage forest,
Expand Down
5 changes: 4 additions & 1 deletion cpp/src/fil/internal.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ struct alignas(8) sparse_node8 : base_node {
and how FIL aggregates them into class margins/regression result/best class
**/
enum leaf_algo_t {
/** For iteration purposes */
MIN_LEAF_ALGO = 0,
/** storing a class probability or regression summand. We add all margins
together and determine regression result or use threshold to determine
one of the two classes. **/
Expand Down Expand Up @@ -239,6 +241,7 @@ enum leaf_algo_t {
/** Leaf contains an index into a vector of class probabilities. **/
VECTOR_LEAF = 5,
// to be extended
MAX_LEAF_ALGO = 5
};

template <leaf_algo_t leaf_algo>
Expand Down Expand Up @@ -300,7 +303,7 @@ struct forest_params_t {
// at once inside a block (sharing trees means splitting input rows)
int threads_per_tree;
// n_items is how many input samples (items) any thread processes. If 0 is given,
// choose most (up to 4) that fit into shared memory.
// choose most (up to MAX_N_ITEMS) that fit into shared memory.
int n_items;
};

Expand Down

0 comments on commit 92a40e0

Please sign in to comment.