Skip to content

Commit

Permalink
Initial support for external memory in gradient index. (#7183)
Browse files Browse the repository at this point in the history
* Add hessian to batch param in preparation of new approx impl.
* Extract a push method for gradient index matrix.
* Use span instead of vector ref for hessian in sketching.
* Create a binary format for gradient index.
trivialfis authored Sep 13, 2021

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent a0dcf6f commit 3515931
Showing 26 changed files with 546 additions and 171 deletions.
2 changes: 2 additions & 0 deletions amalgamation/xgboost-all0.cc
Original file line number Diff line number Diff line change
@@ -38,6 +38,8 @@
#include "../src/data/sparse_page_raw_format.cc"
#include "../src/data/ellpack_page.cc"
#include "../src/data/gradient_index.cc"
#include "../src/data/gradient_index_page_source.cc"
#include "../src/data/gradient_index_format.cc"
#include "../src/data/sparse_page_dmatrix.cc"
#include "../src/data/proxy_dmatrix.cc"

19 changes: 17 additions & 2 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright (c) 2015 by Contributors
* Copyright (c) 2015-2021 by Contributors
* \file data.h
* \brief The input data structure of xgboost.
* \author Tianqi Chen
@@ -214,12 +214,27 @@ struct BatchParam {
int gpu_id;
/*! \brief Maximum number of bins per feature for histograms. */
int max_bin{0};
/*! \brief Hessian, used for sketching with future approx implementation. */
common::Span<float> hess;
/*! \brief Whether should DMatrix regenerate the batch. Only used for GHistIndex. */
bool regen {false};

BatchParam() = default;
BatchParam(int32_t device, int32_t max_bin)
: gpu_id{device}, max_bin{max_bin} {}
/**
* \brief Get batch with sketch weighted by hessian. The batch will be regenerated if
* the span is changed, so caller should keep the span for each iteration.
*/
BatchParam(int32_t device, int32_t max_bin, common::Span<float> hessian,
bool regenerate = false)
: gpu_id{device}, max_bin{max_bin}, hess{hessian}, regen{regenerate} {}

bool operator!=(const BatchParam& other) const {
return gpu_id != other.gpu_id || max_bin != other.max_bin;
if (hess.empty() && other.hess.empty()) {
return gpu_id != other.gpu_id || max_bin != other.max_bin;
}
return gpu_id != other.gpu_id || max_bin != other.max_bin || hess.data() != other.hess.data();
}
};

11 changes: 9 additions & 2 deletions src/common/hist_util.h
Original file line number Diff line number Diff line change
@@ -111,7 +111,7 @@ class HistogramCuts {
};

inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins,
std::vector<float> const &hessian = {}) {
Span<float> const hessian = {}) {
HistogramCuts out;
auto const& info = m->Info();
const auto threads = omp_get_max_threads();
@@ -136,7 +136,7 @@ inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins,
return out;
}

enum BinTypeSize {
enum BinTypeSize : uint32_t {
kUint8BinsTypeSize = 1,
kUint16BinsTypeSize = 2,
kUint32BinsTypeSize = 4
@@ -207,6 +207,13 @@ struct Index {
return data_.end();
}

std::vector<uint8_t>::iterator begin() { // NOLINT
return data_.begin();
}
std::vector<uint8_t>::iterator end() { // NOLINT
return data_.end();
}

private:
static uint32_t GetValueFromUint8(void *t, size_t i) {
return reinterpret_cast<uint8_t*>(t)[i];
12 changes: 6 additions & 6 deletions src/common/quantile.cc
Original file line number Diff line number Diff line change
@@ -94,26 +94,26 @@ std::vector<bst_feature_t> HostSketchContainer::LoadBalance(
namespace {
// Function to merge hessian and sample weights
std::vector<float> MergeWeights(MetaInfo const &info,
std::vector<float> const &hessian,
Span<float> const hessian,
bool use_group, int32_t n_threads) {
CHECK_EQ(hessian.size(), info.num_row_);
std::vector<float> results(hessian.size());
auto const &group_ptr = info.group_ptr_;
auto const& weights = info.weights_.HostVector();
auto get_weight = [&](size_t i) { return weights.empty() ? 1.0f : weights[i]; };
if (use_group) {
auto const &group_weights = info.weights_.HostVector();
CHECK_GE(group_ptr.size(), 2);
CHECK_EQ(group_ptr.back(), hessian.size());
size_t cur_group = 0;
for (size_t i = 0; i < hessian.size(); ++i) {
results[i] = hessian[i] * group_weights[cur_group];
results[i] = hessian[i] * get_weight(cur_group);
if (i == group_ptr[cur_group + 1]) {
cur_group++;
}
}
} else {
auto const &sample_weights = info.weights_.HostVector();
ParallelFor(hessian.size(), n_threads, Sched::Auto(),
[&](auto i) { results[i] = hessian[i] * sample_weights[i]; });
[&](auto i) { results[i] = hessian[i] * get_weight(i); });
}
return results;
}
@@ -141,7 +141,7 @@ std::vector<float> UnrollGroupWeights(MetaInfo const &info) {
} // anonymous namespace

void HostSketchContainer::PushRowPage(
SparsePage const &page, MetaInfo const &info, std::vector<float> const &hessian) {
SparsePage const &page, MetaInfo const &info, Span<float> hessian) {
monitor_.Start(__func__);
bst_feature_t n_columns = info.num_col_;
auto is_dense = info.num_nonzero_ == info.num_col_ * info.num_row_;
2 changes: 1 addition & 1 deletion src/common/quantile.h
Original file line number Diff line number Diff line change
@@ -760,7 +760,7 @@ class HostSketchContainer {

/* \brief Push a CSR matrix. */
void PushRowPage(SparsePage const &page, MetaInfo const &info,
std::vector<float> const &hessian = {});
Span<float> const hessian = {});

void MakeCuts(HistogramCuts* cuts);
};
2 changes: 2 additions & 0 deletions src/data/data.cc
Original file line number Diff line number Diff line change
@@ -32,6 +32,7 @@ DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SparsePage>
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::CSCPage>);
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SortedCSCPage>);
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::EllpackPage>);
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::GHistIndexMatrix>);
} // namespace dmlc

namespace {
@@ -1089,5 +1090,6 @@ namespace data {

// List of files that will be force linked in static links.
DMLC_REGISTRY_LINK_TAG(sparse_page_raw_format);
DMLC_REGISTRY_LINK_TAG(gradient_index_format);
} // namespace data
} // namespace xgboost
18 changes: 7 additions & 11 deletions src/data/ellpack_page_raw_format.cu
Original file line number Diff line number Diff line change
@@ -4,8 +4,9 @@
#include <xgboost/data.h>
#include <dmlc/registry.h>

#include "./ellpack_page.cuh"
#include "./sparse_page_writer.h"
#include "ellpack_page.cuh"
#include "sparse_page_writer.h"
#include "histogram_cut_format.h"

namespace xgboost {
namespace data {
@@ -17,9 +18,9 @@ class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
public:
bool Read(EllpackPage* page, dmlc::SeekStream* fi) override {
auto* impl = page->Impl();
fi->Read(&impl->Cuts().cut_values_.HostVector());
fi->Read(&impl->Cuts().cut_ptrs_.HostVector());
fi->Read(&impl->Cuts().min_vals_.HostVector());
if (!ReadHistogramCuts(&impl->Cuts(), fi)) {
return false;
}
fi->Read(&impl->n_rows);
fi->Read(&impl->is_dense);
fi->Read(&impl->row_stride);
@@ -33,12 +34,7 @@ class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
size_t Write(const EllpackPage& page, dmlc::Stream* fo) override {
size_t bytes = 0;
auto* impl = page.Impl();
fo->Write(impl->Cuts().cut_values_.ConstHostVector());
bytes += impl->Cuts().cut_values_.ConstHostSpan().size_bytes() + sizeof(uint64_t);
fo->Write(impl->Cuts().cut_ptrs_.ConstHostVector());
bytes += impl->Cuts().cut_ptrs_.ConstHostSpan().size_bytes() + sizeof(uint64_t);
fo->Write(impl->Cuts().min_vals_.ConstHostVector());
bytes += impl->Cuts().min_vals_.ConstHostSpan().size_bytes() + sizeof(uint64_t);
bytes += WriteHistogramCuts(impl->Cuts(), fo);
fo->Write(impl->n_rows);
bytes += sizeof(impl->n_rows);
fo->Write(impl->is_dense);
2 changes: 1 addition & 1 deletion src/data/ellpack_page_source.h
Original file line number Diff line number Diff line change
@@ -32,7 +32,7 @@ class EllpackPageSource : public PageSourceIncMixIn<EllpackPage> {
size_t row_stride, common::Span<FeatureType const> feature_types,
std::shared_ptr<SparsePageSource> source)
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache),
is_dense_{is_dense}, row_stride_{row_stride}, param_{param},
is_dense_{is_dense}, row_stride_{row_stride}, param_{std::move(param)},
feature_types_{feature_types}, cuts_{std::move(cuts)} {
this->source_ = source;
this->Fetch();
Loading

0 comments on commit 3515931

Please sign in to comment.