Skip to content

Commit

Permalink
cleanup.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Mar 26, 2020
1 parent 179fa88 commit 0d46ffd
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions src/predictor/cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ bst_float PredValue(const SparsePage::Inst &inst,
return psum;
}

template <size_t kUnrollLen = 8>
struct SparsePageView {
SparsePage const* page;
bst_row_t base_rowid;
static size_t constexpr kUnroll = kUnrollLen;

explicit SparsePageView(SparsePage const *p)
: page{p}, base_rowid{page->base_rowid} {
Expand All @@ -56,13 +58,16 @@ struct SparsePageView {
size_t Size() const { return page->Size(); }
};

template <typename Adapter, size_t kUnroll = 8>
template <typename Adapter, size_t kUnrollLen = 8>
class AdapterView {
Adapter* adapter_;
float missing_;
common::Span<Entry> workspace_;
std::vector<size_t> current_unroll_;

public:
static size_t constexpr kUnroll = kUnrollLen;

public:
explicit AdapterView(Adapter *adapter, float missing,
common::Span<Entry> workplace)
Expand Down Expand Up @@ -108,8 +113,8 @@ void PredictBatchKernel(DataView batch, std::vector<bst_float> *out_preds,
CHECK_EQ(model.param.size_leaf_vector, 0)
<< "size_leaf_vector is enforced to 0 so far";
// parallel over local batch
constexpr int kUnroll = 8;
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
auto constexpr kUnroll = DataView::kUnroll;
const bst_omp_uint rest = nsize % kUnroll;
if (nsize >= kUnroll) {
#pragma omp parallel for schedule(static)
Expand All @@ -118,13 +123,13 @@ void PredictBatchKernel(DataView batch, std::vector<bst_float> *out_preds,
RegTree::FVec &feats = thread_temp[tid];
int64_t ridx[kUnroll];
SparsePage::Inst inst[kUnroll];
for (int k = 0; k < kUnroll; ++k) {
for (size_t k = 0; k < kUnroll; ++k) {
ridx[k] = static_cast<int64_t>(batch.base_rowid + i + k);
}
for (int k = 0; k < kUnroll; ++k) {
for (size_t k = 0; k < kUnroll; ++k) {
inst[k] = batch[i + k];
}
for (int k = 0; k < kUnroll; ++k) {
for (size_t k = 0; k < kUnroll; ++k) {
for (int gid = 0; gid < num_group; ++gid) {
const size_t offset = ridx[k] * num_group + gid;
preds[offset] += PredValue(inst[k], model.trees, model.tree_info, gid,
Expand Down Expand Up @@ -167,7 +172,8 @@ class CPUPredictor : public Predictor {
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
CHECK_EQ(out_preds->size(),
p_fmat->Info().num_row_ * model.learner_model_param_->num_output_group);
PredictBatchKernel(SparsePageView{&batch}, out_preds, model, tree_begin,
size_t constexpr kUnroll = 8;
PredictBatchKernel(SparsePageView<kUnroll>{&batch}, out_preds, model, tree_begin,
tree_end, &thread_temp_);
}
}
Expand Down Expand Up @@ -265,7 +271,6 @@ class CPUPredictor : public Predictor {
PredictionCacheEntry *out_preds,
uint32_t tree_begin, uint32_t tree_end) const {
auto threads = omp_get_max_threads();
size_t constexpr kUnroll = 8;
auto m = dmlc::get<Adapter>(x);
CHECK_EQ(m.NumColumns(), model.learner_model_param_->num_feature)
<< "Number of columns in data must equal to trained model.";
Expand All @@ -277,6 +282,7 @@ class CPUPredictor : public Predictor {
auto &predictions = out_preds->predictions.HostVector();
std::vector<RegTree::FVec> thread_temp;
InitThreadTemp(threads, model.learner_model_param_->num_feature, &thread_temp);
size_t constexpr kUnroll = 8;
PredictBatchKernel(AdapterView<Adapter, kUnroll>(
&m, missing, common::Span<Entry>{workspace}),
&predictions, model, tree_begin, tree_end, &thread_temp);
Expand Down

0 comments on commit 0d46ffd

Please sign in to comment.