Skip to content

Commit

Permalink
Make `HistCutMatrix::Init' be aware of groups.
Browse files Browse the repository at this point in the history
* Add checks for group size.
  • Loading branch information
trivialfis committed Feb 9, 2019
1 parent 3320a52 commit 40c086f
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 9 deletions.
32 changes: 25 additions & 7 deletions src/common/hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,30 +35,48 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {

const int nthread = omp_get_max_threads();

auto nstep = static_cast<unsigned>((info.num_col_ + nthread - 1) / nthread);
auto ncol = static_cast<unsigned>(info.num_col_);
unsigned const nstep =
static_cast<unsigned>((info.num_col_ + nthread - 1) / nthread);
unsigned const ncol = static_cast<unsigned>(info.num_col_);
sketchs.resize(info.num_col_);
for (auto& s : sketchs) {
s.Init(info.num_row_, 1.0 / (max_num_bins * kFactor));
}

const auto& weights = info.weights_.HostVector();

for (const auto &batch : p_fmat->GetRowBatches()) {
#pragma omp parallel num_threads(nthread)
#pragma omp parallel num_threads(nthread)
{
CHECK_EQ(nthread, omp_get_num_threads());
auto tid = static_cast<unsigned>(omp_get_thread_num());
unsigned begin = std::min(nstep * tid, ncol);
unsigned end = std::min(nstep * (tid + 1), ncol);

// Data groups, used in ranking.
std::vector<bst_uint> const& groups = info.group_ptr_;
size_t const num_groups = groups.size() == 0 ? 0 : groups.size() - 1;
size_t group_ind = // index into groups
tid * static_cast<unsigned>((num_groups + nthread - 1) / nthread);
// Do we need to use this index?
bool const use_group_ind = num_groups != 0 && weights.size() != info.num_row_;
// do not iterate if no columns are assigned to the thread
if (begin < end && end <= ncol) {
for (size_t i = 0; i < batch.Size(); ++i) { // NOLINT(*)
size_t ridx = batch.base_rowid + i;
SparsePage::Inst inst = batch[i];
for (auto& ins : inst) {
if (ins.index >= begin && ins.index < end) {
sketchs[ins.index].Push(ins.fvalue,
weights.size() > 0 ? weights[ridx] : 1.0f);
if (use_group_ind &&
groups[group_ind] == ridx &&
// maximum equals to weights.size() - 1
group_ind < num_groups - 1) {
// move to next group
group_ind++;
}
for (auto& entry : inst) {
if (entry.index >= begin && entry.index < end) {
size_t w_idx = use_group_ind ? group_ind : ridx;
sketchs[entry.index].Push(entry.fvalue,
weights.size() > 0 ? weights.at(w_idx) : 1.0f);
}
}
}
Expand Down
22 changes: 22 additions & 0 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,12 +474,16 @@ class LearnerImpl : public Learner {

void UpdateOneIter(int iter, DMatrix* train) override {
monitor_.Start("UpdateOneIter");

// TODO(trivialfis): Merge the duplicated code with BoostOneIter
CHECK(ModelInitialized())
<< "Always call InitModel or LoadModel before update";
if (tparam_.seed_per_iteration || rabit::IsDistributed()) {
common::GlobalRandom().seed(tparam_.seed * kRandSeedMagic + iter);
}
this->ValidateDMatrix(train);
this->PerformTreeMethodHeuristic(train);

monitor_.Start("PredictRaw");
this->PredictRaw(train, &preds_);
monitor_.Stop("PredictRaw");
Expand All @@ -493,10 +497,15 @@ class LearnerImpl : public Learner {
void BoostOneIter(int iter, DMatrix* train,
HostDeviceVector<GradientPair>* in_gpair) override {
monitor_.Start("BoostOneIter");

CHECK(ModelInitialized())
<< "Always call InitModel or LoadModel before update";
if (tparam_.seed_per_iteration || rabit::IsDistributed()) {
common::GlobalRandom().seed(tparam_.seed * kRandSeedMagic + iter);
}
this->ValidateDMatrix(train);
this->PerformTreeMethodHeuristic(train);

gbm_->DoBoost(train, in_gpair);
monitor_.Stop("BoostOneIter");
}
Expand Down Expand Up @@ -736,6 +745,19 @@ class LearnerImpl : public Learner {
gbm_->PredictBatch(data, out_preds, ntree_limit);
}

void ValidateDMatrix(DMatrix* p_fmat) {
MetaInfo const& info = p_fmat->Info();
auto const& weights = info.weights_.HostVector();
if (info.group_ptr_.size() != 0 && weights.size() != 0) {
CHECK(weights.size() == info.group_ptr_.size() - 1)
<< "\n"
<< "weights size: " << weights.size() << ", "
<< "groups size: " << info.group_ptr_.size() -1 << ", "
<< "num rows: " << p_fmat->Info().num_row_ << "\n"
<< "Number of weights should be equal to number of groups.";
}
}

// model parameter
LearnerModelParam mparam_;
// training parameter
Expand Down
45 changes: 43 additions & 2 deletions tests/cpp/test_learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

namespace xgboost {

TEST(learner, Test) {
TEST(Learner, Basic) {
typedef std::pair<std::string, std::string> arg;
auto args = {arg("tree_method", "exact")};
auto mat_ptr = CreateDMatrix(10, 10, 0);
Expand All @@ -17,7 +17,7 @@ TEST(learner, Test) {
delete mat_ptr;
}

TEST(learner, SelectTreeMethod) {
TEST(Learner, SelectTreeMethod) {
using arg = std::pair<std::string, std::string>;
auto mat_ptr = CreateDMatrix(10, 10, 0);
std::vector<std::shared_ptr<xgboost::DMatrix>> mat = {*mat_ptr};
Expand Down Expand Up @@ -51,4 +51,45 @@ TEST(learner, SelectTreeMethod) {
delete mat_ptr;
}

TEST(Learner, CheckGroup) {
using arg = std::pair<std::string, std::string>;
size_t constexpr kNumGroups = 4;
size_t constexpr kNumRows = 17;
size_t constexpr kNumCols = 15;

auto pp_mat = CreateDMatrix(kNumRows, kNumCols, 0);
auto& p_mat = *pp_mat;
std::vector<bst_float> weight(kNumGroups);
std::vector<bst_int> group(kNumGroups);
group[0] = 2;
group[1] = 3;
group[2] = 7;
group[3] = 5;
std::vector<bst_float> labels (kNumRows);
for (size_t i = 0; i < kNumRows; ++i) {
labels[i] = i % 2;
}

p_mat->Info().SetInfo(
"weight", static_cast<void*>(weight.data()), DataType::kFloat32, kNumGroups);
p_mat->Info().SetInfo(
"group", group.data(), DataType::kUInt32, kNumGroups);
p_mat->Info().SetInfo("label", labels.data(), DataType::kFloat32, kNumRows);

std::vector<std::shared_ptr<xgboost::DMatrix>> mat = {p_mat};
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
learner->Configure({arg{"objective", "rank:pairwise"}});
learner->InitModel();

EXPECT_NO_THROW(learner->UpdateOneIter(0, p_mat.get()));

group.resize(kNumGroups+1);
group[3] = 4;
group[4] = 1;
p_mat->Info().SetInfo("group", group.data(), DataType::kUInt32, kNumGroups+1);
EXPECT_ANY_THROW(learner->UpdateOneIter(0, p_mat.get()));

delete pp_mat;
}

} // namespace xgboost

0 comments on commit 40c086f

Please sign in to comment.