Skip to content

Commit

Permalink
Fix slice and get info.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Apr 17, 2020
1 parent c245eb8 commit e0a87e9
Show file tree
Hide file tree
Showing 14 changed files with 180 additions and 166 deletions.
5 changes: 3 additions & 2 deletions R-package/R/xgb.DMatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,10 @@ getinfo <- function(object, ...) UseMethod("getinfo")
getinfo.xgb.DMatrix <- function(object, name, ...) {
if (typeof(name) != "character" ||
length(name) != 1 ||
!name %in% c('label', 'weight', 'base_margin', 'nrow')) {
!name %in% c('label', 'weight', 'base_margin', 'nrow',
'label_lower_bound', 'label_upper_bound')) {
stop("getinfo: name must be one of the following\n",
" 'label', 'weight', 'base_margin', 'nrow'")
" 'label', 'weight', 'base_margin', 'nrow', 'label_lower_bound', 'label_upper_bound'")
}
if (name != "nrow"){
ret <- .Call(XGDMatrixGetInfo_R, object, name)
Expand Down
8 changes: 7 additions & 1 deletion R-package/tests/testthat/test_dmatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ test_that("xgb.DMatrix: getinfo & setinfo", {
labels <- getinfo(dtest, 'label')
expect_equal(test_label, getinfo(dtest, 'label'))

expect_true(setinfo(dtest, 'label_lower_bound', test_label))
expect_equal(test_label, getinfo(dtest, 'label_lower_bound'))

expect_true(setinfo(dtest, 'label_upper_bound', test_label))
expect_equal(test_label, getinfo(dtest, 'label_upper_bound'))

expect_true(length(getinfo(dtest, 'weight')) == 0)
expect_true(length(getinfo(dtest, 'base_margin')) == 0)

Expand All @@ -59,7 +65,7 @@ test_that("xgb.DMatrix: getinfo & setinfo", {
expect_error(setinfo(dtest, 'group', test_label))

# providing character values will give a warning
expect_warning( setinfo(dtest, 'weight', rep('a', nrow(test_data))) )
expect_warning(setinfo(dtest, 'weight', rep('a', nrow(test_data))))

# any other label should error
expect_error(setinfo(dtest, 'asdf', test_label))
Expand Down
6 changes: 5 additions & 1 deletion include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class MetaInfo {

/*! \brief default constructor */
MetaInfo() = default;
MetaInfo(MetaInfo&& that) = default;
MetaInfo& operator=(MetaInfo&& that) = default;
MetaInfo& operator=(MetaInfo const& that) {
this->num_row_ = that.num_row_;
this->num_col_ = that.num_col_;
Expand All @@ -89,6 +91,8 @@ class MetaInfo {
this->base_margin_.Copy(that.base_margin_);
return *this;
}

MetaInfo Slice(common::Span<int32_t const> ridxs) const;
/*!
* \brief Get weight of each instances.
* \param i Instance index.
Expand Down Expand Up @@ -491,7 +495,7 @@ class DMatrix {
const std::string& cache_prefix = "",
size_t page_size = kPageSize);


virtual DMatrix* Slice(common::Span<int32_t const> ridxs) = 0;
/*! \brief page size 32 MB */
static const size_t kPageSize = 32UL << 20UL;

Expand Down
6 changes: 1 addition & 5 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,7 @@ XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle,
<< "slice does not support group structure";
}
DMatrix* dmat = static_cast<std::shared_ptr<DMatrix>*>(handle)->get();
CHECK(dynamic_cast<data::SimpleDMatrix*>(dmat))
<< "Slice only supported for SimpleDMatrix currently.";
data::DMatrixSliceAdapter adapter(dmat, {idxset, static_cast<size_t>(len)});
*out = new std::shared_ptr<DMatrix>(
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1));
*out = new std::shared_ptr<DMatrix>(dmat->Slice({idxset, len}));
API_END();
}

Expand Down
87 changes: 0 additions & 87 deletions src/data/adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,93 +599,6 @@ class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
dmlc::RowBlock<uint32_t> block_;
std::unique_ptr<FileAdapterBatch> batch_;
};

class DMatrixSliceAdapterBatch {
public:
// Fetch metainfo values according to sliced rows
template <typename T>
std::vector<T> Gather(const std::vector<T>& in) {
if (in.empty()) return {};

std::vector<T> out(this->Size());
for (auto i = 0ull; i < this->Size(); i++) {
out[i] = in[ridx_set[i]];
}
return out;
}
DMatrixSliceAdapterBatch(const SparsePage& batch, DMatrix* dmat,
common::Span<const int> ridx_set)
: batch(batch), ridx_set(ridx_set) {
batch_labels = this->Gather(dmat->Info().labels_.HostVector());
batch_weights = this->Gather(dmat->Info().weights_.HostVector());
batch_base_margin = this->Gather(dmat->Info().base_margin_.HostVector());
}

class Line {
public:
Line(const SparsePage::Inst& inst, size_t row_idx)
: inst_(inst), row_idx_(row_idx) {}

size_t Size() { return inst_.size(); }
COOTuple GetElement(size_t idx) {
return COOTuple{row_idx_, inst_[idx].index, inst_[idx].fvalue};
}

private:
SparsePage::Inst inst_;
size_t row_idx_;
};
Line GetLine(size_t idx) const { return Line(batch[ridx_set[idx]], idx); }
const float* Labels() const {
if (batch_labels.empty()) {
return nullptr;
}
return batch_labels.data();
}
const float* Weights() const {
if (batch_weights.empty()) {
return nullptr;
}
return batch_weights.data();
}
const uint64_t* Qid() const { return nullptr; }
const float* BaseMargin() const {
if (batch_base_margin.empty()) {
return nullptr;
}
return batch_base_margin.data();
}

size_t Size() const { return ridx_set.size(); }
const SparsePage& batch;
common::Span<const int> ridx_set;
std::vector<float> batch_labels;
std::vector<float> batch_weights;
std::vector<float> batch_base_margin;
};

// Group pointer is not exposed
// This is because external bindings currently manipulate the group values
// manually when slicing This could potentially be moved to internal C++ code if
// needed

class DMatrixSliceAdapter
: public detail::SingleBatchDataIter<DMatrixSliceAdapterBatch> {
public:
DMatrixSliceAdapter(DMatrix* dmat, common::Span<const int> ridx_set)
: dmat_(dmat),
ridx_set_(ridx_set),
batch_(*dmat_->GetBatches<SparsePage>().begin(), dmat_, ridx_set) {}
const DMatrixSliceAdapterBatch& Value() const override { return batch_; }
// Indicates a number of rows/columns must be inferred
size_t NumRows() const { return ridx_set_.size(); }
size_t NumColumns() const { return dmat_->Info().num_col_; }

private:
DMatrix* dmat_;
common::Span<const int> ridx_set_;
DMatrixSliceAdapterBatch batch_;
};
}; // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_ADAPTER_H_
50 changes: 47 additions & 3 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,53 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) {
LoadVectorField(fi, u8"labels_upper_bound", DataType::kFloat32, &labels_upper_bound_);
}

template <typename T>
std::vector<T> Gather(const std::vector<T> &in, common::Span<int const> ridxs, size_t stride = 1) {
if (in.empty()) {
return {};
}
auto size = ridxs.size();
std::vector<T> out(size * stride);
for (auto i = 0ull; i < size; i++) {
auto ridx = ridxs[i];
for (size_t j = 0; j < stride; ++j) {
out[i * stride +j] = in[ridx * stride + j];
}
}
return out;
}

MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
MetaInfo out;
out.num_row_ = ridxs.size();
out.num_col_ = this->num_col_;
// Groups is maintained by a higher level Python function. We should aim at deprecating
// the slice function.
out.labels_.HostVector() = Gather(this->labels_.HostVector(), ridxs);
out.labels_upper_bound_.HostVector() =
Gather(this->labels_upper_bound_.HostVector(), ridxs);
out.labels_lower_bound_.HostVector() =
Gather(this->labels_lower_bound_.HostVector(), ridxs);
// weights
if (this->weights_.Size() + 1 == this->group_ptr_.size()) {
auto& h_weights = out.weights_.HostVector();
// Assuming all groups are available.
out.weights_.HostVector() = h_weights;
} else {
out.weights_.HostVector() = Gather(this->weights_.HostVector(), ridxs);
}

if (this->base_margin_.Size() != this->num_row_) {
CHECK_EQ(this->base_margin_.Size() % this->num_row_, 0)
<< "Incorrect size of base margin vector.";
size_t stride = this->base_margin_.Size() / this->num_row_;
out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs, stride);
} else {
out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs);
}
return out;
}

// try to load group information from file, if exists
inline bool MetaTryLoadGroup(const std::string& fname,
std::vector<unsigned>* group) {
Expand Down Expand Up @@ -459,9 +506,6 @@ template DMatrix* DMatrix::Create<data::DataTableAdapter>(
template DMatrix* DMatrix::Create<data::FileAdapter>(
data::FileAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
template DMatrix* DMatrix::Create<data::DMatrixSliceAdapter>(
data::DMatrixSliceAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
template DMatrix* DMatrix::Create<data::IteratorAdapter>(
data::IteratorAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
Expand Down
4 changes: 4 additions & 0 deletions src/data/device_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ class DeviceDMatrix : public DMatrix {

bool EllpackExists() const override { return true; }
bool SparsePageExists() const override { return false; }
DMatrix *Slice(common::Span<int32_t const> ridxs) override {
LOG(FATAL) << "Slicing DMatrix is not supported for Device DMatrix.";
return nullptr;
}

private:
BatchSet<SparsePage> GetRowBatches() override {
Expand Down
23 changes: 21 additions & 2 deletions src/data/simple_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,27 @@ MetaInfo& SimpleDMatrix::Info() { return info_; }

const MetaInfo& SimpleDMatrix::Info() const { return info_; }

DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
auto out = new SimpleDMatrix;
SparsePage& out_page = out->sparse_page_;
for (auto const &page : this->GetBatches<SparsePage>()) {
page.data.HostVector();
page.offset.HostVector();
auto& h_data = out_page.data.HostVector();
auto& h_offset = out_page.offset.HostVector();
size_t rptr{0};
for (auto ridx : ridxs) {
auto inst = page[ridx];
rptr += inst.size();
std::copy(inst.begin(), inst.end(), std::back_inserter(h_data));
h_offset.emplace_back(rptr);
}
out->Info() = this->Info().Slice(ridxs);
out->Info().num_nonzero_ = h_offset.back();
}
return out;
}

BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
// since csr is the default data structure so `source_` is always available.
auto begin_iter = BatchIterator<SparsePage>(
Expand Down Expand Up @@ -174,8 +195,6 @@ template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(DMatrixSliceAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(IteratorAdapter* adapter, float missing,
int nthread);
} // namespace data
Expand Down
2 changes: 2 additions & 0 deletions src/data/simple_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace data {
// Used for single batch data.
class SimpleDMatrix : public DMatrix {
public:
SimpleDMatrix() = default;
template <typename AdapterT>
explicit SimpleDMatrix(AdapterT* adapter, float missing, int nthread);

Expand All @@ -32,6 +33,7 @@ class SimpleDMatrix : public DMatrix {
const MetaInfo& Info() const override;

bool SingleColBlock() const override { return true; }
DMatrix* Slice(common::Span<int32_t const> ridxs) override;

/*! \brief magic number used to identify SimpleDMatrix binary files */
static const int kMagic = 0xffffab01;
Expand Down
4 changes: 4 additions & 0 deletions src/data/sparse_page_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ class SparsePageDMatrix : public DMatrix {
const MetaInfo& Info() const override;

bool SingleColBlock() const override { return false; }
DMatrix *Slice(common::Span<int32_t const> ridxs) override {
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
return nullptr;
}

private:
BatchSet<SparsePage> GetRowBatches() override;
Expand Down
25 changes: 0 additions & 25 deletions tests/cpp/data/test_adapter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,31 +67,6 @@ TEST(Adapter, CSCAdapterColsMoreThanRows) {
EXPECT_EQ(inst[3].index, 3);
}

TEST(CAPI, DMatrixSliceAdapterFromSimpleDMatrix) {
auto p_dmat = RandomDataGenerator(6, 2, 1.0).GenerateDMatrix();

std::vector<int> ridx_set = {1, 3, 5};
data::DMatrixSliceAdapter adapter(p_dmat.get(),
{ridx_set.data(), ridx_set.size()});
EXPECT_EQ(adapter.NumRows(), ridx_set.size());

adapter.BeforeFirst();
for (auto &batch : p_dmat->GetBatches<SparsePage>()) {
adapter.Next();
auto &adapter_batch = adapter.Value();
for (auto i = 0ull; i < adapter_batch.Size(); i++) {
auto inst = batch[ridx_set[i]];
auto line = adapter_batch.GetLine(i);
ASSERT_EQ(inst.size(), line.Size());
for (auto j = 0ull; j < line.Size(); j++) {
EXPECT_EQ(inst[j].fvalue, line.GetElement(j).value);
EXPECT_EQ(inst[j].index, line.GetElement(j).column_idx);
EXPECT_EQ(i, line.GetElement(j).row_idx);
}
}
}
}

// A mock for JVM data iterator.
class DataIterForTest {
std::vector<float> data_ {1, 2, 3, 4, 5};
Expand Down
1 change: 0 additions & 1 deletion tests/cpp/data/test_data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,4 @@ TEST(DMatrix, Uri) {
ASSERT_EQ(dmat->Info().num_col_, kCols);
ASSERT_EQ(dmat->Info().num_row_, kRows);
}

} // namespace xgboost
Loading

0 comments on commit e0a87e9

Please sign in to comment.