Skip to content

Commit

Permalink
Use Span<Entry const> instead of Inst in SparsePage.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Aug 5, 2018
1 parent bfb5371 commit 1d7fafc
Show file tree
Hide file tree
Showing 19 changed files with 76 additions and 85 deletions.
31 changes: 11 additions & 20 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <string>
#include <vector>
#include "./base.h"
#include "../../src/common/span.h"

namespace xgboost {
// forward declare learner.
Expand Down Expand Up @@ -155,24 +156,14 @@ class SparsePage {
std::vector<Entry> data;

size_t base_rowid;

/*! \brief an instance of sparse vector in the batch */
struct Inst {
/*! \brief pointer to the elements*/
const Entry *data{nullptr};
/*! \brief length of the instance */
bst_uint length{0};
/*! \brief constructor */
Inst() = default;
Inst(const Entry *data, bst_uint length) : data(data), length(length) {}
/*! \brief get i-th pair in the sparse vector*/
inline const Entry& operator[](size_t i) const {
return data[i];
}
};
using Inst = common::Span<Entry const>;

/*! \brief get i-th row from the batch */
inline Inst operator[](size_t i) const {
return {data.data() + offset[i], static_cast<bst_uint>(offset[i + 1] - offset[i])};
return {data.data() + offset[i],
static_cast<Inst::index_type>(offset[i + 1] - offset[i])};
}

/*! \brief constructor */
Expand Down Expand Up @@ -234,12 +225,12 @@ class SparsePage {
* \param inst an instance row
*/
inline void Push(const Inst &inst) {
offset.push_back(offset.back() + inst.length);
offset.push_back(offset.back() + inst.size());
size_t begin = data.size();
data.resize(begin + inst.length);
if (inst.length != 0) {
std::memcpy(dmlc::BeginPtr(data) + begin, inst.data,
sizeof(Entry) * inst.length);
data.resize(begin + inst.size());
if (inst.size() != 0) {
std::memcpy(dmlc::BeginPtr(data) + begin, inst.data(),
sizeof(Entry) * inst.size());
}
}

Expand Down Expand Up @@ -328,7 +319,7 @@ class DMatrix {
* \brief check if column access is supported, if not, initialize column access.
* \param max_row_perbatch auxiliary information, maximum row used in each column batch.
* this is a hint information that can be ignored by the implementation.
* \param sorted If column features should be in sorted order
* \param sorted If column features should be in sorted order
* \return Number of column blocks in the column access.
*/
virtual void InitColAccess(size_t max_row_perbatch, bool sorted) = 0;
Expand Down
4 changes: 2 additions & 2 deletions include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -574,14 +574,14 @@ inline void RegTree::FVec::Init(size_t size) {
}

inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
for (bst_uint i = 0; i < inst.length; ++i) {
for (bst_uint i = 0; i < inst.size(); ++i) {
if (inst[i].index >= data_.size()) continue;
data_[inst[i].index].fvalue = inst[i].fvalue;
}
}

inline void RegTree::FVec::Drop(const SparsePage::Inst& inst) {
for (bst_uint i = 0; i < inst.length; ++i) {
for (bst_uint i = 0; i < inst.size(); ++i) {
if (inst[i].index >= data_.size()) continue;
data_[inst[i].index].flag = -1;
}
Expand Down
8 changes: 4 additions & 4 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -682,10 +682,10 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
const int ridx = idxset[i];
auto inst = batch[ridx];
CHECK_LT(static_cast<xgboost::bst_ulong>(ridx), batch.Size());
ret.page_.data.insert(ret.page_.data.end(), inst.data,
inst.data + inst.length);
ret.page_.offset.push_back(ret.page_.offset.back() + inst.length);
ret.info.num_nonzero_ += inst.length;
ret.page_.data.insert(ret.page_.data.end(), inst.data(),
inst.data() + inst.size());
ret.page_.offset.push_back(ret.page_.offset.back() + inst.size());
ret.info.num_nonzero_ += inst.size();

if (src.info.labels_.size() != 0) {
ret.info.labels_.push_back(src.info.labels_[ridx]);
Expand Down
8 changes: 4 additions & 4 deletions src/common/hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
for (size_t i = 0; i < batch.Size(); ++i) { // NOLINT(*)
size_t ridx = batch.base_rowid + i;
SparsePage::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
for (bst_uint j = 0; j < inst.size(); ++j) {
if (inst[j].index >= begin && inst[j].index < end) {
sketchs[inst[j].index].Push(inst[j].fvalue, info.GetWeight(ridx));
}
Expand Down Expand Up @@ -129,7 +129,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) {
auto batch = iter->Value();
const size_t rbegin = row_ptr.size() - 1;
for (size_t i = 0; i < batch.Size(); ++i) {
row_ptr.push_back(batch[i].length + row_ptr.back());
row_ptr.push_back(batch[i].size() + row_ptr.back());
}
index.resize(row_ptr.back());

Expand All @@ -143,8 +143,8 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) {
size_t ibegin = row_ptr[rbegin + i];
size_t iend = row_ptr[rbegin + i + 1];
SparsePage::Inst inst = batch[i];
CHECK_EQ(ibegin + inst.length, iend);
for (bst_uint j = 0; j < inst.length; ++j) {
CHECK_EQ(ibegin + inst.size(), iend);
for (bst_uint j = 0; j < inst.size(); ++j) {
unsigned fid = inst[j].index;
auto cbegin = cut->cut.begin() + cut->row_ptr[fid];
auto cend = cut->cut.begin() + cut->row_ptr[fid + 1];
Expand Down
4 changes: 2 additions & 2 deletions src/data/simple_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ void SimpleDMatrix::MakeOneBatch(SparsePage* pcol, bool sorted) {
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
int tid = omp_get_thread_num();
auto inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
for (bst_uint j = 0; j < inst.size(); ++j) {
builder.AddBudget(inst[j].index, tid);
}
}
Expand All @@ -72,7 +72,7 @@ void SimpleDMatrix::MakeOneBatch(SparsePage* pcol, bool sorted) {
for (long i = 0; i < static_cast<long>(batch.Size()); ++i) { // NOLINT(*)
int tid = omp_get_thread_num();
auto inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
for (bst_uint j = 0; j < inst.size(); ++j) {
builder.Push(
inst[j].index,
Entry(static_cast<bst_uint>(batch.base_rowid + i), inst[j].fvalue),
Expand Down
2 changes: 1 addition & 1 deletion src/data/simple_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class SimpleDMatrix : public DMatrix {

size_t GetColSize(size_t cidx) const override {
auto& batch = *col_iter_.column_page_;
return batch[cidx].length;
return batch[cidx].size();
}

float GetColDensity(size_t cidx) const override {
Expand Down
4 changes: 2 additions & 2 deletions src/gbm/gblinear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class GBLinear : public GradientBooster {
for (int gid = 0; gid < ngroup; ++gid) {
bst_float *p_contribs = &contribs[(row_idx * ngroup + gid) * ncolumns];
// calculate linear terms' contributions
for (bst_uint c = 0; c < inst.length; ++c) {
for (bst_uint c = 0; c < inst.size(); ++c) {
if (inst[c].index >= model_.param.num_feature) continue;
p_contribs[inst[c].index] = inst[c].fvalue * model_[inst[c].index][gid];
}
Expand Down Expand Up @@ -268,7 +268,7 @@ class GBLinear : public GradientBooster {
inline void Pred(const SparsePage::Inst &inst, bst_float *preds, int gid,
bst_float base) {
bst_float psum = model_.bias()[gid] + base;
for (bst_uint i = 0; i < inst.length; ++i) {
for (bst_uint i = 0; i < inst.size(); ++i) {
if (inst[i].index >= model_.param.num_feature) continue;
psum += inst[i].fvalue * model_[inst[i].index][gid];
}
Expand Down
10 changes: 5 additions & 5 deletions src/linear/coordinate_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ inline std::pair<double, double> GetGradient(int group_idx, int num_group, int f
while (iter->Next()) {
auto batch = iter->Value();
auto col = batch[fidx];
const auto ndata = static_cast<bst_omp_uint>(col.length);
const auto ndata = static_cast<bst_omp_uint>(col.size());
for (bst_omp_uint j = 0; j < ndata; ++j) {
const bst_float v = col[j].fvalue;
auto &p = gpair[col[j].index * num_group + group_idx];
Expand Down Expand Up @@ -100,7 +100,7 @@ inline std::pair<double, double> GetGradientParallel(int group_idx, int num_grou
while (iter->Next()) {
auto batch = iter->Value();
auto col = batch[fidx];
const auto ndata = static_cast<bst_omp_uint>(col.length);
const auto ndata = static_cast<bst_omp_uint>(col.size());
#pragma omp parallel for schedule(static) reduction(+ : sum_grad, sum_hess)
for (bst_omp_uint j = 0; j < ndata; ++j) {
const bst_float v = col[j].fvalue;
Expand Down Expand Up @@ -159,7 +159,7 @@ inline void UpdateResidualParallel(int fidx, int group_idx, int num_group,
auto batch = iter->Value();
auto col = batch[fidx];
// update grad value
const auto num_row = static_cast<bst_omp_uint>(col.length);
const auto num_row = static_cast<bst_omp_uint>(col.size());
#pragma omp parallel for schedule(static)
for (bst_omp_uint j = 0; j < num_row; ++j) {
GradientPair &p = (*in_gpair)[col[j].index * num_group + group_idx];
Expand Down Expand Up @@ -331,7 +331,7 @@ class GreedyFeatureSelector : public FeatureSelector {
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nfeat; ++i) {
const auto col = batch[i];
const bst_uint ndata = col.length;
const bst_uint ndata = col.size();
auto &sums = gpair_sums_[group_idx * nfeat + i];
for (bst_uint j = 0u; j < ndata; ++j) {
const bst_float v = col[j].fvalue;
Expand Down Expand Up @@ -399,7 +399,7 @@ class ThriftyFeatureSelector : public FeatureSelector {
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nfeat; ++i) {
const auto col = batch[i];
const bst_uint ndata = col.length;
const bst_uint ndata = col.size();
for (bst_uint gid = 0u; gid < ngroup; ++gid) {
auto &sums = gpair_sums_[gid * nfeat + i];
for (bst_uint j = 0u; j < ndata; ++j) {
Expand Down
8 changes: 4 additions & 4 deletions src/linear/updater_gpu_coordinate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,13 @@ class DeviceShard {
return e1.index < e2.index;
};
auto column_begin =
std::lower_bound(col.data, col.data + col.length,
std::lower_bound(col.data(), col.data() + col.size(),
Entry(row_begin, 0.0f), cmp);
auto column_end =
std::upper_bound(col.data, col.data + col.length,
std::upper_bound(col.data(), col.data() + col.size(),
Entry(row_end, 0.0f), cmp);
column_segments.push_back(
std::make_pair(column_begin - col.data, column_end - col.data));
std::make_pair(column_begin - col.data(), column_end - col.data()));
row_ptr_.push_back(row_ptr_.back() + column_end - column_begin);
}
ba_.Allocate(device_idx, param.silent, &data_, row_ptr_.back(), &gpair_,
Expand All @@ -134,7 +134,7 @@ class DeviceShard {
auto col = batch[fidx];
auto seg = column_segments[fidx];
dh::safe_cuda(cudaMemcpy(
data_.Data() + row_ptr_[fidx], col.data + seg.first,
data_.Data() + row_ptr_[fidx], col.data() + seg.first,
sizeof(Entry) * (seg.second - seg.first), cudaMemcpyHostToDevice));
}
// Rescale indices with respect to current shard
Expand Down
4 changes: 2 additions & 2 deletions src/linear/updater_shotgun.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class ShotgunUpdater : public LinearUpdater {
auto col = batch[ii];
for (int gid = 0; gid < ngroup; ++gid) {
double sum_grad = 0.0, sum_hess = 0.0;
for (bst_uint j = 0; j < col.length; ++j) {
for (bst_uint j = 0; j < col.size(); ++j) {
GradientPair &p = gpair[col[j].index * ngroup + gid];
if (p.GetHess() < 0.0f) continue;
const bst_float v = col[j].fvalue;
Expand All @@ -107,7 +107,7 @@ class ShotgunUpdater : public LinearUpdater {
if (dw == 0.f) continue;
w += dw;
// update grad values
for (bst_uint j = 0; j < col.length; ++j) {
for (bst_uint j = 0; j < col.size(); ++j) {
GradientPair &p = gpair[col[j].index * ngroup + gid];
if (p.GetHess() < 0.0f) continue;
p += GradientPair(p.GetHess() * col[j].fvalue * dw, 0);
Expand Down
10 changes: 5 additions & 5 deletions src/tree/updater_basemaker-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ class BaseMaker: public TreeUpdater {
auto batch = iter->Value();
for (bst_uint fid = 0; fid < batch.Size(); ++fid) {
auto c = batch[fid];
if (c.length != 0) {
if (c.size() != 0) {
fminmax_[fid * 2 + 0] = std::max(-c[0].fvalue, fminmax_[fid * 2 + 0]);
fminmax_[fid * 2 + 1] = std::max(c[c.length - 1].fvalue, fminmax_[fid * 2 + 1]);
fminmax_[fid * 2 + 1] = std::max(c[c.size() - 1].fvalue, fminmax_[fid * 2 + 1]);
}
}
}
Expand Down Expand Up @@ -106,7 +106,7 @@ class BaseMaker: public TreeUpdater {
inline static int NextLevel(const SparsePage::Inst &inst, const RegTree &tree, int nid) {
const RegTree::Node &n = tree[nid];
bst_uint findex = n.SplitIndex();
for (unsigned i = 0; i < inst.length; ++i) {
for (unsigned i = 0; i < inst.size(); ++i) {
if (findex == inst[i].index) {
if (inst[i].fvalue < n.SplitCond()) {
return n.LeftChild();
Expand Down Expand Up @@ -250,7 +250,7 @@ class BaseMaker: public TreeUpdater {
auto it = std::lower_bound(sorted_split_set.begin(), sorted_split_set.end(), fid);

if (it != sorted_split_set.end() && *it == fid) {
const auto ndata = static_cast<bst_omp_uint>(col.length);
const auto ndata = static_cast<bst_omp_uint>(col.size());
#pragma omp parallel for schedule(static)
for (bst_omp_uint j = 0; j < ndata; ++j) {
const bst_uint ridx = col[j].index;
Expand Down Expand Up @@ -308,7 +308,7 @@ class BaseMaker: public TreeUpdater {
auto batch = iter->Value();
for (auto fid : fsplits) {
auto col = batch[fid];
const auto ndata = static_cast<bst_omp_uint>(col.length);
const auto ndata = static_cast<bst_omp_uint>(col.size());
#pragma omp parallel for schedule(static)
for (bst_omp_uint j = 0; j < ndata; ++j) {
const bst_uint ridx = col[j].index;
Expand Down
20 changes: 10 additions & 10 deletions src/tree/updater_colmaker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ class ColMaker: public TreeUpdater {
const std::vector<GradientPair> &gpair) {
// TODO(tqchen): double check stats order.
const MetaInfo& info = fmat.Info();
const bool ind = col.length != 0 && col.data[0].fvalue == col.data[col.length - 1].fvalue;
const bool ind = col.size() != 0 && col[0].fvalue == col[col.size() - 1].fvalue;
bool need_forward = param_.NeedForwardSearch(fmat.GetColDensity(fid), ind);
bool need_backward = param_.NeedBackwardSearch(fmat.GetColDensity(fid), ind);
const std::vector<int> &qexpand = qexpand_;
Expand All @@ -281,8 +281,8 @@ class ColMaker: public TreeUpdater {
for (int j : qexpand) {
temp[j].stats.Clear();
}
bst_uint step = (col.length + this->nthread_ - 1) / this->nthread_;
bst_uint end = std::min(col.length, step * (tid + 1));
bst_uint step = (col.size() + this->nthread_ - 1) / this->nthread_;
bst_uint end = std::min(static_cast<bst_uint>(col.size()), step * (tid + 1));
for (bst_uint i = tid * step; i < end; ++i) {
const bst_uint ridx = col[i].index;
const int nid = position_[ridx];
Expand Down Expand Up @@ -363,8 +363,8 @@ class ColMaker: public TreeUpdater {
GradStats c(param_), cright(param_);
const int tid = omp_get_thread_num();
std::vector<ThreadEntry> &temp = stemp_[tid];
bst_uint step = (col.length + this->nthread_ - 1) / this->nthread_;
bst_uint end = std::min(col.length, step * (tid + 1));
bst_uint step = (col.size() + this->nthread_ - 1) / this->nthread_;
bst_uint end = std::min(static_cast<bst_uint>(col.size()), step * (tid + 1));
for (bst_uint i = tid * step; i < end; ++i) {
const bst_uint ridx = col[i].index;
const int nid = position_[ridx];
Expand Down Expand Up @@ -620,13 +620,13 @@ class ColMaker: public TreeUpdater {
int fid = feat_set[i];
const int tid = omp_get_thread_num();
auto c = batch[fid];
const bool ind = c.length != 0 && c.data[0].fvalue == c.data[c.length - 1].fvalue;
const bool ind = c.size() != 0 && c[0].fvalue == c[c.size() - 1].fvalue;
if (param_.NeedForwardSearch(fmat.GetColDensity(fid), ind)) {
this->EnumerateSplit(c.data, c.data + c.length, +1,
this->EnumerateSplit(c.data(), c.data() + c.size(), +1,
fid, gpair, info, stemp_[tid]);
}
if (param_.NeedBackwardSearch(fmat.GetColDensity(fid), ind)) {
this->EnumerateSplit(c.data + c.length - 1, c.data - 1, -1,
this->EnumerateSplit(c.data() + c.size() - 1, c.data() - 1, -1,
fid, gpair, info, stemp_[tid]);
}
}
Expand Down Expand Up @@ -734,7 +734,7 @@ class ColMaker: public TreeUpdater {
auto batch = iter->Value();
for (auto fid : fsplits) {
auto col = batch[fid];
const auto ndata = static_cast<bst_omp_uint>(col.length);
const auto ndata = static_cast<bst_omp_uint>(col.size());
#pragma omp parallel for schedule(static)
for (bst_omp_uint j = 0; j < ndata; ++j) {
const bst_uint ridx = col[j].index;
Expand Down Expand Up @@ -865,7 +865,7 @@ class DistColMaker : public ColMaker {
auto batch = iter->Value();
for (auto fid : fsplits) {
auto col = batch[fid];
const auto ndata = static_cast<bst_omp_uint>(col.length);
const auto ndata = static_cast<bst_omp_uint>(col.size());
#pragma omp parallel for schedule(static)
for (bst_omp_uint j = 0; j < ndata; ++j) {
const bst_uint ridx = col[j].index;
Expand Down
2 changes: 1 addition & 1 deletion src/tree/updater_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ class GPUMaker : public TreeUpdater {
auto batch = iter->Value();
for (int i = 0; i < batch.Size(); i++) {
auto col = batch[i];
for (const Entry* it = col.data; it != col.data + col.length;
for (const Entry* it = col.data(); it != col.data() + col.size();
it++) {
int inst_id = static_cast<int>(it->index);
fval->push_back(it->fvalue);
Expand Down
Loading

0 comments on commit 1d7fafc

Please sign in to comment.