Skip to content

Commit

Permalink
Remove the need to load sparse page.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Dec 29, 2021
1 parent 34084af commit 8d2f0b5
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 10 deletions.
4 changes: 4 additions & 0 deletions src/data/gradient_index_page_source.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ namespace xgboost {
namespace data {
void GradientIndexPageSource::Fetch() {
if (!this->ReadCache()) {
if (count_ != 0) {
++(*source_);
}
CHECK_EQ(count_, source_->Iter());
auto const& csr = source_->Page();
this->page_.reset(new GHistIndexMatrix());
CHECK_NE(cuts_.Values().size(), 0);
Expand Down
9 changes: 5 additions & 4 deletions src/data/gradient_index_page_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
#include <memory>
#include <utility>

#include "sparse_page_source.h"
#include "gradient_index.h"
#include "sparse_page_source.h"

namespace xgboost {
namespace data {
Expand All @@ -25,7 +25,8 @@ class GradientIndexPageSource : public PageSourceIncMixIn<GHistIndexMatrix> {
common::HistogramCuts cuts, bool is_dense, int32_t max_bin_per_feat,
common::Span<FeatureType const> feature_types, float sparse_thresh,
std::shared_ptr<SparsePageSource> source)
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache),
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache,
!std::isnan(sparse_thresh)),
cuts_{std::move(cuts)},
is_dense_{is_dense},
max_bin_per_feat_{max_bin_per_feat},
Expand All @@ -37,6 +38,6 @@ class GradientIndexPageSource : public PageSourceIncMixIn<GHistIndexMatrix> {

void Fetch() final;
};
} // namespace data
} // namespace xgboost
} // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_
29 changes: 23 additions & 6 deletions src/data/sparse_page_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,12 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
size_t n_prefetch_batches = std::min(kPreFetch, n_batches_);
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
size_t fetch_it = count_;

for (size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
fetch_it %= n_batches_; // ring
if (ring_->at(fetch_it).valid()) { continue; }
if (ring_->at(fetch_it).valid()) {
continue;
}
auto const *self = this; // make sure it's const
CHECK_LT(fetch_it, cache_info_->offset.size());
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self]() {
Expand All @@ -139,8 +142,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
return page;
});
}
CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(),
[](auto const &f) { return f.valid(); }),
CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }),
n_prefetch_batches)
<< "Sparse DMatrix assumes forward iteration.";
page_ = (*ring_)[count_].get();
Expand Down Expand Up @@ -289,15 +291,28 @@ template <typename S>
class PageSourceIncMixIn : public SparsePageSourceImpl<S> {
protected:
std::shared_ptr<SparsePageSource> source_;
using Super = SparsePageSourceImpl<S>;
bool sync_{true}; // synchronize the row page.

public:
using SparsePageSourceImpl<S>::SparsePageSourceImpl;
PageSourceIncMixIn(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
std::shared_ptr<Cache> cache, bool sync)
: Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache}, sync_{sync} {}

PageSourceIncMixIn& operator++() final {
TryLockGuard guard{this->single_threaded_};
++(*source_);
if (sync_) {
++(*source_);
}

++this->count_;
this->at_end_ = source_->AtEnd();
this->at_end_ = this->count_ == this->n_batches_;
if (this->at_end_) {
CHECK_EQ(this->count_, this->n_batches_);
} else {
CHECK_LT(this->count_, this->n_batches_);
}

if (this->at_end_) {
this->cache_info_->Commit();
Expand All @@ -308,7 +323,9 @@ class PageSourceIncMixIn : public SparsePageSourceImpl<S> {
} else {
this->Fetch();
}
CHECK_EQ(source_->Iter(), this->count_);
if (sync_) {
CHECK_EQ(source_->Iter(), this->count_);
}
return *this;
}
};
Expand Down

0 comments on commit 8d2f0b5

Please sign in to comment.