Skip to content

Commit

Permalink
Move SimpleDMatrix constructor to .cc file (#5188)
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell authored Jan 9, 2020
1 parent 9049c7c commit 9559f81
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 76 deletions.
91 changes: 91 additions & 0 deletions src/data/simple_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,96 @@ BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param)
}

bool SimpleDMatrix::SingleColBlock() const { return true; }

template <typename AdapterT>
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
// Set number of threads but keep old value so we can reset it after
const int nthreadmax = omp_get_max_threads();
if (nthread <= 0) nthread = nthreadmax;
int nthread_original = omp_get_max_threads();
omp_set_num_threads(nthread);

source_.reset(new SimpleCSRSource());
SimpleCSRSource& mat = *reinterpret_cast<SimpleCSRSource*>(source_.get());
std::vector<uint64_t> qids;
uint64_t default_max = std::numeric_limits<uint64_t>::max();
uint64_t last_group_id = default_max;
bst_uint group_size = 0;
auto& offset_vec = mat.page_.offset.HostVector();
auto& data_vec = mat.page_.data.HostVector();
uint64_t inferred_num_columns = 0;

adapter->BeforeFirst();
// Iterate over batches of input data
while (adapter->Next()) {
auto& batch = adapter->Value();
auto batch_max_columns = mat.page_.Push(batch, missing, nthread);
inferred_num_columns = std::max(batch_max_columns, inferred_num_columns);
// Append meta information if available
if (batch.Labels() != nullptr) {
auto& labels = mat.info.labels_.HostVector();
labels.insert(labels.end(), batch.Labels(),
batch.Labels() + batch.Size());
}
if (batch.Weights() != nullptr) {
auto& weights = mat.info.weights_.HostVector();
weights.insert(weights.end(), batch.Weights(),
batch.Weights() + batch.Size());
}
if (batch.Qid() != nullptr) {
qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size());
// get group
for (size_t i = 0; i < batch.Size(); ++i) {
const uint64_t cur_group_id = batch.Qid()[i];
if (last_group_id == default_max || last_group_id != cur_group_id) {
mat.info.group_ptr_.push_back(group_size);
}
last_group_id = cur_group_id;
++group_size;
}
}
}

if (last_group_id != default_max) {
if (group_size > mat.info.group_ptr_.back()) {
mat.info.group_ptr_.push_back(group_size);
}
}

// Deal with empty rows/columns if necessary
if (adapter->NumColumns() == kAdapterUnknownSize) {
mat.info.num_col_ = inferred_num_columns;
} else {
mat.info.num_col_ = adapter->NumColumns();
}
// Synchronise worker columns
rabit::Allreduce<rabit::op::Max>(&mat.info.num_col_, 1);

if (adapter->NumRows() == kAdapterUnknownSize) {
mat.info.num_row_ = offset_vec.size() - 1;
} else {
if (offset_vec.empty()) {
offset_vec.emplace_back(0);
}

while (offset_vec.size() - 1 < adapter->NumRows()) {
offset_vec.emplace_back(offset_vec.back());
}
mat.info.num_row_ = adapter->NumRows();
}
mat.info.num_nonzero_ = data_vec.size();
omp_set_num_threads(nthread_original);
}

template SimpleDMatrix::SimpleDMatrix(DenseAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(CSRAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(CSCAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing,
int nthread);
} // namespace data
} // namespace xgboost
77 changes: 1 addition & 76 deletions src/data/simple_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,82 +30,7 @@ class SimpleDMatrix : public DMatrix {
: source_(std::move(source)) {}

template <typename AdapterT>
explicit SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
// Set number of threads but keep old value so we can reset it after
const int nthreadmax = omp_get_max_threads();
if (nthread <= 0) nthread = nthreadmax;
int nthread_original = omp_get_max_threads();
omp_set_num_threads(nthread);

source_.reset(new SimpleCSRSource());
SimpleCSRSource& mat = *reinterpret_cast<SimpleCSRSource*>(source_.get());
std::vector<uint64_t> qids;
uint64_t default_max = std::numeric_limits<uint64_t>::max();
uint64_t last_group_id = default_max;
bst_uint group_size = 0;
auto& offset_vec = mat.page_.offset.HostVector();
auto& data_vec = mat.page_.data.HostVector();
uint64_t inferred_num_columns = 0;

adapter->BeforeFirst();
// Iterate over batches of input data
while (adapter->Next()) {
auto& batch = adapter->Value();
auto batch_max_columns = mat.page_.Push(batch, missing, nthread);
inferred_num_columns = std::max(batch_max_columns, inferred_num_columns);
// Append meta information if available
if (batch.Labels() != nullptr) {
auto& labels = mat.info.labels_.HostVector();
labels.insert(labels.end(), batch.Labels(), batch.Labels() + batch.Size());
}
if (batch.Weights() != nullptr) {
auto& weights = mat.info.weights_.HostVector();
weights.insert(weights.end(), batch.Weights(), batch.Weights() + batch.Size());
}
if (batch.Qid() != nullptr) {
qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size());
// get group
for (size_t i = 0; i < batch.Size(); ++i) {
const uint64_t cur_group_id = batch.Qid()[i];
if (last_group_id == default_max || last_group_id != cur_group_id) {
mat.info.group_ptr_.push_back(group_size);
}
last_group_id = cur_group_id;
++group_size;
}
}
}

if (last_group_id != default_max) {
if (group_size > mat.info.group_ptr_.back()) {
mat.info.group_ptr_.push_back(group_size);
}
}

// Deal with empty rows/columns if necessary
if (adapter->NumColumns() == kAdapterUnknownSize) {
mat.info.num_col_ = inferred_num_columns;
} else {
mat.info.num_col_ = adapter->NumColumns();
}
// Synchronise worker columns
rabit::Allreduce<rabit::op::Max>(&mat.info.num_col_, 1);

if (adapter->NumRows() == kAdapterUnknownSize) {
mat.info.num_row_ = offset_vec.size() - 1;
} else {
if (offset_vec.empty()) {
offset_vec.emplace_back(0);
}

while (offset_vec.size() - 1 < adapter->NumRows()) {
offset_vec.emplace_back(offset_vec.back());
}
mat.info.num_row_ = adapter->NumRows();
}
mat.info.num_nonzero_ = data_vec.size();
omp_set_num_threads(nthread_original);
}
explicit SimpleDMatrix(AdapterT* adapter, float missing, int nthread);

MetaInfo& Info() override;

Expand Down

0 comments on commit 9559f81

Please sign in to comment.