Skip to content

Commit

Permalink
Fix base row id.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 6, 2022
1 parent 53f562b commit a83ba8d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
17 changes: 9 additions & 8 deletions src/common/quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,27 +100,28 @@ void SketchContainerImpl<WQSketch>::PushRowPage(SparsePage const &page, MetaInfo
}

auto batch = data::SparsePageAdapterBatch{page.GetView()};
this->PushRowPageImpl(batch, OptionalWeights{weights}, info.num_nonzero_, info.num_col_, is_dense,
[](auto) { return true; });
this->PushRowPageImpl(batch, page.base_rowid, OptionalWeights{weights}, page.data.Size(),
info.num_col_, is_dense, [](auto) { return true; });
monitor_.Stop(__func__);
}

template <typename Batch>
void HostSketchContainer::PushAdapterBatch(Batch const &batch, MetaInfo const &info, size_t nnz,
float missing) {
void HostSketchContainer::PushAdapterBatch(Batch const &batch, size_t base_rowid,
MetaInfo const &info, size_t nnz, float missing) {
auto const &h_weights = (use_group_ind_ ? UnrollGroupWeights(info) : info.weights_.HostVector());

auto is_valid = data::IsValidFunctor{missing};
auto weights = OptionalWeights{Span<float const>{h_weights}};
// the nnz from info is not reliable as sketching might be the first place to go through
// the data.
auto is_dense = nnz == info.num_col_ * info.num_row_;
this->PushRowPageImpl(batch, weights, nnz, info.num_col_, is_dense, is_valid);
this->PushRowPageImpl(batch, base_rowid, weights, nnz, info.num_col_, is_dense, is_valid);
}

#define INSTANTIATE(_type) \
template void HostSketchContainer::PushAdapterBatch<data::_type>( \
data::_type const &batch, MetaInfo const &info, size_t nnz, float missing);
#define INSTANTIATE(_type) \
template void HostSketchContainer::PushAdapterBatch<data::_type>( \
data::_type const &batch, size_t base_rowid, MetaInfo const &info, size_t nnz, \
float missing);

INSTANTIATE(ArrayAdapterBatch)
INSTANTIATE(CSRArrayAdapterBatch)
Expand Down
9 changes: 5 additions & 4 deletions src/common/quantile.h
Original file line number Diff line number Diff line change
Expand Up @@ -821,8 +821,8 @@ class SketchContainerImpl {
std::vector<int32_t> *p_num_cuts);

template <typename Batch, typename IsValid>
void PushRowPageImpl(Batch const &batch, OptionalWeights weights, size_t nnz, size_t n_features,
bool is_dense, IsValid is_valid) {
void PushRowPageImpl(Batch const &batch, size_t base_rowid, OptionalWeights weights, size_t nnz,
size_t n_features, bool is_dense, IsValid is_valid) {
auto thread_columns_ptr = LoadBalance(batch, nnz, n_features, n_threads_, is_valid);

dmlc::OMPException exc;
Expand All @@ -837,7 +837,7 @@ class SketchContainerImpl {
if (begin < end && end <= n_features) {
for (size_t ridx = 0; ridx < batch.Size(); ++ridx) {
auto const &line = batch.GetLine(ridx);
auto w = weights[ridx];
auto w = weights[ridx + base_rowid];
if (is_dense) {
for (size_t ii = begin; ii < end; ii++) {
auto elem = line.GetElement(ii);
Expand Down Expand Up @@ -883,7 +883,8 @@ class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, fl
bool use_group, Span<float const> hessian, int32_t n_threads);

template <typename Batch>
void PushAdapterBatch(Batch const &batch, MetaInfo const& info, size_t nnz, float missing);
void PushAdapterBatch(Batch const &batch, size_t base_rowid, MetaInfo const &info, size_t nnz,
float missing);
};

/**
Expand Down

0 comments on commit a83ba8d

Please sign in to comment.