diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index 874d140e05a7..0646bc9a9fbc 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -41,7 +41,7 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p proxy, [](auto const &value) { return value.NumCols(); }); }; - for (auto const &page : this->GetRowBatches()) { + for (auto const &page : this->GetRowBatchesImpl()) { this->info_.Extend(std::move(proxy->Info()), false, false); n_features = std::max(n_features, num_cols()); n_samples += num_rows(); @@ -77,12 +77,16 @@ void SparsePageDMatrix::InitializeSparsePage() { this->n_batches_, cache_info_.at(id)); } -BatchSet SparsePageDMatrix::GetRowBatches() { +BatchSet SparsePageDMatrix::GetRowBatchesImpl() { this->InitializeSparsePage(); auto begin_iter = BatchIterator(sparse_page_source_); return BatchSet(BatchIterator(begin_iter)); } +BatchSet SparsePageDMatrix::GetRowBatches() { + return this->GetRowBatchesImpl(); +} + BatchSet SparsePageDMatrix::GetColumnBatches() { auto id = MakeCache(this, ".col.page", cache_prefix_, &cache_info_); CHECK_NE(this->Info().num_col_, 0); diff --git a/src/data/sparse_page_dmatrix.h b/src/data/sparse_page_dmatrix.h index f40f311fa96b..a364e7816804 100644 --- a/src/data/sparse_page_dmatrix.h +++ b/src/data/sparse_page_dmatrix.h @@ -39,6 +39,8 @@ class SparsePageDMatrix : public DMatrix { size_t n_batches_ {0}; void InitializeSparsePage(); + // Non virtual version that can be used in constructor + BatchSet GetRowBatchesImpl(); public: explicit SparsePageDMatrix(DataIterHandle iter, DMatrixHandle proxy,