From 9559f81377a5470be128465938a284ce1cdf9396 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Thu, 9 Jan 2020 14:20:13 +1300 Subject: [PATCH] Move SimpleDMatrix constructor to .cc file (#5188) --- src/data/simple_dmatrix.cc | 91 ++++++++++++++++++++++++++++++++++++++ src/data/simple_dmatrix.h | 77 +------------------------------- 2 files changed, 92 insertions(+), 76 deletions(-) diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 3ee7eda19e3e..bb83400a0757 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -75,5 +75,96 @@ BatchSet SimpleDMatrix::GetEllpackBatches(const BatchParam& param) } bool SimpleDMatrix::SingleColBlock() const { return true; } + +template +SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { + // Set number of threads but keep old value so we can reset it after + const int nthreadmax = omp_get_max_threads(); + if (nthread <= 0) nthread = nthreadmax; + int nthread_original = omp_get_max_threads(); + omp_set_num_threads(nthread); + + source_.reset(new SimpleCSRSource()); + SimpleCSRSource& mat = *reinterpret_cast(source_.get()); + std::vector qids; + uint64_t default_max = std::numeric_limits::max(); + uint64_t last_group_id = default_max; + bst_uint group_size = 0; + auto& offset_vec = mat.page_.offset.HostVector(); + auto& data_vec = mat.page_.data.HostVector(); + uint64_t inferred_num_columns = 0; + + adapter->BeforeFirst(); + // Iterate over batches of input data + while (adapter->Next()) { + auto& batch = adapter->Value(); + auto batch_max_columns = mat.page_.Push(batch, missing, nthread); + inferred_num_columns = std::max(batch_max_columns, inferred_num_columns); + // Append meta information if available + if (batch.Labels() != nullptr) { + auto& labels = mat.info.labels_.HostVector(); + labels.insert(labels.end(), batch.Labels(), + batch.Labels() + batch.Size()); + } + if (batch.Weights() != nullptr) { + auto& weights = mat.info.weights_.HostVector(); + weights.insert(weights.end(), batch.Weights(), + batch.Weights() + batch.Size()); + } + if (batch.Qid() != nullptr) { + qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size()); + // get group + for (size_t i = 0; i < batch.Size(); ++i) { + const uint64_t cur_group_id = batch.Qid()[i]; + if (last_group_id == default_max || last_group_id != cur_group_id) { + mat.info.group_ptr_.push_back(group_size); + } + last_group_id = cur_group_id; + ++group_size; + } + } + } + + if (last_group_id != default_max) { + if (group_size > mat.info.group_ptr_.back()) { + mat.info.group_ptr_.push_back(group_size); + } + } + + // Deal with empty rows/columns if necessary + if (adapter->NumColumns() == kAdapterUnknownSize) { + mat.info.num_col_ = inferred_num_columns; + } else { + mat.info.num_col_ = adapter->NumColumns(); + } + // Synchronise worker columns + rabit::Allreduce(&mat.info.num_col_, 1); + + if (adapter->NumRows() == kAdapterUnknownSize) { + mat.info.num_row_ = offset_vec.size() - 1; + } else { + if (offset_vec.empty()) { + offset_vec.emplace_back(0); + } + + while (offset_vec.size() - 1 < adapter->NumRows()) { + offset_vec.emplace_back(offset_vec.back()); + } + mat.info.num_row_ = adapter->NumRows(); + } + mat.info.num_nonzero_ = data_vec.size(); + omp_set_num_threads(nthread_original); +} + +template SimpleDMatrix::SimpleDMatrix(DenseAdapter* adapter, float missing, + int nthread); +template SimpleDMatrix::SimpleDMatrix(CSRAdapter* adapter, float missing, + int nthread); +template SimpleDMatrix::SimpleDMatrix(CSCAdapter* adapter, float missing, + int nthread); +template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing, + int nthread); +template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing, + int nthread); } // namespace data } // namespace xgboost diff --git a/src/data/simple_dmatrix.h b/src/data/simple_dmatrix.h index 46849e84be17..65f525f33a57 100644 --- a/src/data/simple_dmatrix.h +++ b/src/data/simple_dmatrix.h @@ -30,82 +30,7 @@ class SimpleDMatrix : public DMatrix { : source_(std::move(source)) {} template - explicit SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { - // Set number of threads but keep old value so we can reset it after - const int nthreadmax = omp_get_max_threads(); - if (nthread <= 0) nthread = nthreadmax; - int nthread_original = omp_get_max_threads(); - omp_set_num_threads(nthread); - - source_.reset(new SimpleCSRSource()); - SimpleCSRSource& mat = *reinterpret_cast(source_.get()); - std::vector qids; - uint64_t default_max = std::numeric_limits::max(); - uint64_t last_group_id = default_max; - bst_uint group_size = 0; - auto& offset_vec = mat.page_.offset.HostVector(); - auto& data_vec = mat.page_.data.HostVector(); - uint64_t inferred_num_columns = 0; - - adapter->BeforeFirst(); - // Iterate over batches of input data - while (adapter->Next()) { - auto& batch = adapter->Value(); - auto batch_max_columns = mat.page_.Push(batch, missing, nthread); - inferred_num_columns = std::max(batch_max_columns, inferred_num_columns); - // Append meta information if available - if (batch.Labels() != nullptr) { - auto& labels = mat.info.labels_.HostVector(); - labels.insert(labels.end(), batch.Labels(), batch.Labels() + batch.Size()); - } - if (batch.Weights() != nullptr) { - auto& weights = mat.info.weights_.HostVector(); - weights.insert(weights.end(), batch.Weights(), batch.Weights() + batch.Size()); - } - if (batch.Qid() != nullptr) { - qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size()); - // get group - for (size_t i = 0; i < batch.Size(); ++i) { - const uint64_t cur_group_id = batch.Qid()[i]; - if (last_group_id == default_max || last_group_id != cur_group_id) { - mat.info.group_ptr_.push_back(group_size); - } - last_group_id = cur_group_id; - ++group_size; - } - } - } - - if (last_group_id != default_max) { - if (group_size > mat.info.group_ptr_.back()) { - mat.info.group_ptr_.push_back(group_size); - } - } - - // Deal with empty rows/columns if necessary - if (adapter->NumColumns() == kAdapterUnknownSize) { - mat.info.num_col_ = inferred_num_columns; - } else { - mat.info.num_col_ = adapter->NumColumns(); - } - // Synchronise worker columns - rabit::Allreduce(&mat.info.num_col_, 1); - - if (adapter->NumRows() == kAdapterUnknownSize) { - mat.info.num_row_ = offset_vec.size() - 1; - } else { - if (offset_vec.empty()) { - offset_vec.emplace_back(0); - } - - while (offset_vec.size() - 1 < adapter->NumRows()) { - offset_vec.emplace_back(offset_vec.back()); - } - mat.info.num_row_ = adapter->NumRows(); - } - mat.info.num_nonzero_ = data_vec.size(); - omp_set_num_threads(nthread_original); - } + explicit SimpleDMatrix(AdapterT* adapter, float missing, int nthread); MetaInfo& Info() override;