Skip to content

Commit

Permalink
broadcast subsampled feature correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
Nan Zhu committed Jan 9, 2019
1 parent 00d8a50 commit 853a758
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions src/common/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <numeric>
#include <random>

#include "io.h"

namespace xgboost {
namespace common {
/*!
Expand Down Expand Up @@ -93,15 +95,24 @@ class ColumnSampler {
if (colsample == 1.0f) return p_features;
const auto& features = *p_features;
CHECK_GT(features.size(), 0);
int n = std::max(1, static_cast<int>(colsample * features.size()));
unsigned long n = std::max(1, static_cast<int>(colsample * features.size()));
auto p_new_features = std::make_shared<std::vector<int>>();
auto& new_features = *p_new_features;
new_features.resize(features.size());
std::copy(features.begin(), features.end(), new_features.begin());
std::shuffle(new_features.begin(), new_features.end(), common::GlobalRandom());
new_features.resize(n);
std::sort(new_features.begin(), new_features.end());

new_features.resize(static_cast<unsigned long>(n));
// std::sort(new_features.begin(), new_features.end());

// sync the subsampled columns
std::string s_cache;
common::MemoryBufferStream fc(&s_cache);
dmlc::Stream& fs = fc;
if (rabit::GetRank() == 0) {
fs.Write(new_features);
}
rabit::Broadcast(&s_cache, 0);
fs.Read(&new_features);
// ensure that new_features are the same across ranks
rabit::Broadcast(&new_features, 0);

Expand Down

0 comments on commit 853a758

Please sign in to comment.