Skip to content

Commit

Permalink
fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Feb 7, 2022
1 parent 86715a5 commit 8392d80
Show file tree
Hide file tree
Showing 11 changed files with 34 additions and 26 deletions.
1 change: 1 addition & 0 deletions src/c_api/c_api_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ inline float GetMissing(Json const &config) {
class XGBoostAPIGuard {
#if defined(XGBOOST_USE_CUDA)
int32_t device_id_ {0};

void SetGPUAttribute();
void RestoreGPUAttribute();
#else
Expand Down
2 changes: 1 addition & 1 deletion src/common/column_matrix.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2017 by Contributors
* Copyright 2017-2022 by Contributors
* \file column_matrix.h
* \brief Utility for fast column-wise access
* \author Philip Cho
Expand Down
2 changes: 1 addition & 1 deletion src/common/partition_builder.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2021 by Contributors
* Copyright 2021-2022 by Contributors
* \file row_set.h
* \brief Quick Utility to compute subset of rows
* \author Philip Cho, Tianqi Chen
Expand Down
9 changes: 8 additions & 1 deletion src/data/ellpack_page_source.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2019-2021 XGBoost contributors
* Copyright 2019-2022 XGBoost contributors
*/
#include <memory>
#include <utility>
Expand All @@ -12,6 +12,13 @@ namespace data {
void EllpackPageSource::Fetch() {
dh::safe_cuda(cudaSetDevice(param_.gpu_id));
if (!this->ReadCache()) {
if (count_ != 0 && !sync_) {
// source is initialized to be the 0th page during construction, so when count_ is 0
// there's no need to increment the source.
++(*source_);
}
// This is not read from cache so we still need it to be synced with sparse page source.
CHECK_EQ(count_, source_->Iter());
auto const &csr = source_->Page();
this->page_.reset(new EllpackPage{});
auto *impl = this->page_->Impl();
Expand Down
22 changes: 12 additions & 10 deletions src/data/ellpack_page_source.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2019-2021 by XGBoost Contributors
* Copyright 2019-2022 by XGBoost Contributors
*/

#ifndef XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_
Expand All @@ -25,15 +25,17 @@ class EllpackPageSource : public PageSourceIncMixIn<EllpackPage> {
std::unique_ptr<common::HistogramCuts> cuts_;

public:
EllpackPageSource(
float missing, int nthreads, bst_feature_t n_features, size_t n_batches,
std::shared_ptr<Cache> cache, BatchParam param,
std::unique_ptr<common::HistogramCuts> cuts, bool is_dense,
size_t row_stride, common::Span<FeatureType const> feature_types,
std::shared_ptr<SparsePageSource> source)
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache),
is_dense_{is_dense}, row_stride_{row_stride}, param_{std::move(param)},
feature_types_{feature_types}, cuts_{std::move(cuts)} {
EllpackPageSource(float missing, int nthreads, bst_feature_t n_features, size_t n_batches,
std::shared_ptr<Cache> cache, BatchParam param,
std::unique_ptr<common::HistogramCuts> cuts, bool is_dense, size_t row_stride,
common::Span<FeatureType const> feature_types,
std::shared_ptr<SparsePageSource> source)
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, false),
is_dense_{is_dense},
row_stride_{row_stride},
param_{std::move(param)},
feature_types_{feature_types},
cuts_{std::move(cuts)} {
this->source_ = source;
this->Fetch();
}
Expand Down
2 changes: 1 addition & 1 deletion src/data/gradient_index_format.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2021 XGBoost contributors
* Copyright 2021-2022 XGBoost contributors
*/
#include "sparse_page_writer.h"
#include "gradient_index.h"
Expand Down
6 changes: 4 additions & 2 deletions src/data/gradient_index_page_source.cc
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
/*!
* Copyright 2021 by XGBoost Contributors
* Copyright 2021-2022 by XGBoost Contributors
*/
#include "gradient_index_page_source.h"

namespace xgboost {
namespace data {
void GradientIndexPageSource::Fetch() {
if (!this->ReadCache()) {
if (count_ != 0) {
if (count_ != 0 && !sync_) {
// source is initialized to be the 0th page during construction, so when count_ is 0
// there's no need to increment the source.
++(*source_);
}
// This is not read from cache so we still need it to be synced with sparse page source.
Expand Down
4 changes: 2 additions & 2 deletions src/data/gradient_index_page_source.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2021 by XGBoost Contributors
* Copyright 2021-2022 by XGBoost Contributors
*/
#ifndef XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_
#define XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_
Expand All @@ -26,7 +26,7 @@ class GradientIndexPageSource : public PageSourceIncMixIn<GHistIndexMatrix> {
common::Span<FeatureType const> feature_types, float sparse_thresh,
std::shared_ptr<SparsePageSource> source)
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache,
!std::isnan(sparse_thresh)),
std::isnan(sparse_thresh)),
cuts_{std::move(cuts)},
is_dense_{is_dense},
max_bin_per_feat_{max_bin_per_feat},
Expand Down
1 change: 0 additions & 1 deletion src/data/sparse_page_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam
auto sorted_sketch = param.regen;
auto cuts =
common::SketchOnDMatrix(this, param.max_bin, ctx_.Threads(), sorted_sketch, param.hess);
this->InitializeSparsePage(); // reset after use.

batch_param_ = param;
ghist_index_source_.reset();
Expand Down
9 changes: 3 additions & 6 deletions src/data/sparse_page_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,9 @@ class PageSourceIncMixIn : public SparsePageSourceImpl<S> {
protected:
std::shared_ptr<SparsePageSource> source_;
using Super = SparsePageSourceImpl<S>;
bool sync_{true}; // synchronize the row page.
// synchronize the row page, `hist` and `gpu_hist` don't need the original sparse page
// so we avoid fetching it.
bool sync_{true};

public:
PageSourceIncMixIn(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
Expand All @@ -307,11 +309,6 @@ class PageSourceIncMixIn : public SparsePageSourceImpl<S> {

++this->count_;
this->at_end_ = this->count_ == this->n_batches_;
if (this->at_end_) {
CHECK_EQ(this->count_, this->n_batches_);
} else {
CHECK_LT(this->count_, this->n_batches_);
}

if (this->at_end_) {
this->cache_info_->Commit();
Expand Down
2 changes: 1 addition & 1 deletion src/tree/hist/histogram.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2021 by XGBoost Contributors
* Copyright 2021-2022 by XGBoost Contributors
*/
#ifndef XGBOOST_TREE_HIST_HISTOGRAM_H_
#define XGBOOST_TREE_HIST_HISTOGRAM_H_
Expand Down

0 comments on commit 8392d80

Please sign in to comment.