Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix slice and get info. #5552

Merged
merged 1 commit into from
Apr 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions R-package/R/xgb.DMatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,10 @@ getinfo <- function(object, ...) UseMethod("getinfo")
getinfo.xgb.DMatrix <- function(object, name, ...) {
if (typeof(name) != "character" ||
length(name) != 1 ||
!name %in% c('label', 'weight', 'base_margin', 'nrow')) {
!name %in% c('label', 'weight', 'base_margin', 'nrow',
'label_lower_bound', 'label_upper_bound')) {
stop("getinfo: name must be one of the following\n",
" 'label', 'weight', 'base_margin', 'nrow'")
" 'label', 'weight', 'base_margin', 'nrow', 'label_lower_bound', 'label_upper_bound'")
}
if (name != "nrow"){
ret <- .Call(XGDMatrixGetInfo_R, object, name)
Expand Down
8 changes: 7 additions & 1 deletion R-package/tests/testthat/test_dmatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ test_that("xgb.DMatrix: getinfo & setinfo", {
labels <- getinfo(dtest, 'label')
expect_equal(test_label, getinfo(dtest, 'label'))

expect_true(setinfo(dtest, 'label_lower_bound', test_label))
expect_equal(test_label, getinfo(dtest, 'label_lower_bound'))

expect_true(setinfo(dtest, 'label_upper_bound', test_label))
expect_equal(test_label, getinfo(dtest, 'label_upper_bound'))

expect_true(length(getinfo(dtest, 'weight')) == 0)
expect_true(length(getinfo(dtest, 'base_margin')) == 0)

Expand All @@ -59,7 +65,7 @@ test_that("xgb.DMatrix: getinfo & setinfo", {
expect_error(setinfo(dtest, 'group', test_label))

# providing character values will give a warning
expect_warning( setinfo(dtest, 'weight', rep('a', nrow(test_data))) )
expect_warning(setinfo(dtest, 'weight', rep('a', nrow(test_data))))

# any other label should error
expect_error(setinfo(dtest, 'asdf', test_label))
Expand Down
6 changes: 5 additions & 1 deletion include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class MetaInfo {

/*! \brief default constructor */
MetaInfo() = default;
MetaInfo(MetaInfo&& that) = default;
MetaInfo& operator=(MetaInfo&& that) = default;
MetaInfo& operator=(MetaInfo const& that) {
this->num_row_ = that.num_row_;
this->num_col_ = that.num_col_;
Expand All @@ -89,6 +91,8 @@ class MetaInfo {
this->base_margin_.Copy(that.base_margin_);
return *this;
}

MetaInfo Slice(common::Span<int32_t const> ridxs) const;
/*!
* \brief Get weight of each instances.
* \param i Instance index.
Expand Down Expand Up @@ -491,7 +495,7 @@ class DMatrix {
const std::string& cache_prefix = "",
size_t page_size = kPageSize);


virtual DMatrix* Slice(common::Span<int32_t const> ridxs) = 0;
/*! \brief page size 32 MB */
static const size_t kPageSize = 32UL << 20UL;

Expand Down
6 changes: 1 addition & 5 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,7 @@ XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle,
<< "slice does not support group structure";
}
DMatrix* dmat = static_cast<std::shared_ptr<DMatrix>*>(handle)->get();
CHECK(dynamic_cast<data::SimpleDMatrix*>(dmat))
<< "Slice only supported for SimpleDMatrix currently.";
data::DMatrixSliceAdapter adapter(dmat, {idxset, static_cast<size_t>(len)});
*out = new std::shared_ptr<DMatrix>(
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1));
*out = new std::shared_ptr<DMatrix>(dmat->Slice({idxset, len}));
API_END();
}

Expand Down
87 changes: 0 additions & 87 deletions src/data/adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,93 +599,6 @@ class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
dmlc::RowBlock<uint32_t> block_;
std::unique_ptr<FileAdapterBatch> batch_;
};

class DMatrixSliceAdapterBatch {
public:
// Fetch metainfo values according to sliced rows
template <typename T>
std::vector<T> Gather(const std::vector<T>& in) {
if (in.empty()) return {};

std::vector<T> out(this->Size());
for (auto i = 0ull; i < this->Size(); i++) {
out[i] = in[ridx_set[i]];
}
return out;
}
DMatrixSliceAdapterBatch(const SparsePage& batch, DMatrix* dmat,
common::Span<const int> ridx_set)
: batch(batch), ridx_set(ridx_set) {
batch_labels = this->Gather(dmat->Info().labels_.HostVector());
batch_weights = this->Gather(dmat->Info().weights_.HostVector());
batch_base_margin = this->Gather(dmat->Info().base_margin_.HostVector());
}

class Line {
public:
Line(const SparsePage::Inst& inst, size_t row_idx)
: inst_(inst), row_idx_(row_idx) {}

size_t Size() { return inst_.size(); }
COOTuple GetElement(size_t idx) {
return COOTuple{row_idx_, inst_[idx].index, inst_[idx].fvalue};
}

private:
SparsePage::Inst inst_;
size_t row_idx_;
};
Line GetLine(size_t idx) const { return Line(batch[ridx_set[idx]], idx); }
const float* Labels() const {
if (batch_labels.empty()) {
return nullptr;
}
return batch_labels.data();
}
const float* Weights() const {
if (batch_weights.empty()) {
return nullptr;
}
return batch_weights.data();
}
const uint64_t* Qid() const { return nullptr; }
const float* BaseMargin() const {
if (batch_base_margin.empty()) {
return nullptr;
}
return batch_base_margin.data();
}

size_t Size() const { return ridx_set.size(); }
const SparsePage& batch;
common::Span<const int> ridx_set;
std::vector<float> batch_labels;
std::vector<float> batch_weights;
std::vector<float> batch_base_margin;
};

// Group pointer is not exposed
// This is because external bindings currently manipulate the group values
// manually when slicing This could potentially be moved to internal C++ code if
// needed

class DMatrixSliceAdapter
: public detail::SingleBatchDataIter<DMatrixSliceAdapterBatch> {
public:
DMatrixSliceAdapter(DMatrix* dmat, common::Span<const int> ridx_set)
: dmat_(dmat),
ridx_set_(ridx_set),
batch_(*dmat_->GetBatches<SparsePage>().begin(), dmat_, ridx_set) {}
const DMatrixSliceAdapterBatch& Value() const override { return batch_; }
// Indicates a number of rows/columns must be inferred
size_t NumRows() const { return ridx_set_.size(); }
size_t NumColumns() const { return dmat_->Info().num_col_; }

private:
DMatrix* dmat_;
common::Span<const int> ridx_set_;
DMatrixSliceAdapterBatch batch_;
};
}; // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_ADAPTER_H_
50 changes: 47 additions & 3 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,53 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) {
LoadVectorField(fi, u8"labels_upper_bound", DataType::kFloat32, &labels_upper_bound_);
}

template <typename T>
std::vector<T> Gather(const std::vector<T> &in, common::Span<int const> ridxs, size_t stride = 1) {
if (in.empty()) {
return {};
}
auto size = ridxs.size();
std::vector<T> out(size * stride);
for (auto i = 0ull; i < size; i++) {
auto ridx = ridxs[i];
for (size_t j = 0; j < stride; ++j) {
out[i * stride +j] = in[ridx * stride + j];
}
}
return out;
}

MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
MetaInfo out;
out.num_row_ = ridxs.size();
out.num_col_ = this->num_col_;
// Groups is maintained by a higher level Python function. We should aim at deprecating
// the slice function.
out.labels_.HostVector() = Gather(this->labels_.HostVector(), ridxs);
out.labels_upper_bound_.HostVector() =
Gather(this->labels_upper_bound_.HostVector(), ridxs);
out.labels_lower_bound_.HostVector() =
Gather(this->labels_lower_bound_.HostVector(), ridxs);
// weights
if (this->weights_.Size() + 1 == this->group_ptr_.size()) {
auto& h_weights = out.weights_.HostVector();
// Assuming all groups are available.
out.weights_.HostVector() = h_weights;
} else {
out.weights_.HostVector() = Gather(this->weights_.HostVector(), ridxs);
}

if (this->base_margin_.Size() != this->num_row_) {
CHECK_EQ(this->base_margin_.Size() % this->num_row_, 0)
<< "Incorrect size of base margin vector.";
size_t stride = this->base_margin_.Size() / this->num_row_;
out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs, stride);
} else {
out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs);
}
return out;
}

// try to load group information from file, if exists
inline bool MetaTryLoadGroup(const std::string& fname,
std::vector<unsigned>* group) {
Expand Down Expand Up @@ -459,9 +506,6 @@ template DMatrix* DMatrix::Create<data::DataTableAdapter>(
template DMatrix* DMatrix::Create<data::FileAdapter>(
data::FileAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
template DMatrix* DMatrix::Create<data::DMatrixSliceAdapter>(
data::DMatrixSliceAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
template DMatrix* DMatrix::Create<data::IteratorAdapter>(
data::IteratorAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
Expand Down
4 changes: 4 additions & 0 deletions src/data/device_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ class DeviceDMatrix : public DMatrix {

bool EllpackExists() const override { return true; }
bool SparsePageExists() const override { return false; }
DMatrix *Slice(common::Span<int32_t const> ridxs) override {
LOG(FATAL) << "Slicing DMatrix is not supported for Device DMatrix.";
return nullptr;
}

private:
BatchSet<SparsePage> GetRowBatches() override {
Expand Down
23 changes: 21 additions & 2 deletions src/data/simple_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,27 @@ MetaInfo& SimpleDMatrix::Info() { return info_; }

const MetaInfo& SimpleDMatrix::Info() const { return info_; }

DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
auto out = new SimpleDMatrix;
SparsePage& out_page = out->sparse_page_;
for (auto const &page : this->GetBatches<SparsePage>()) {
page.data.HostVector();
page.offset.HostVector();
auto& h_data = out_page.data.HostVector();
auto& h_offset = out_page.offset.HostVector();
size_t rptr{0};
for (auto ridx : ridxs) {
auto inst = page[ridx];
rptr += inst.size();
std::copy(inst.begin(), inst.end(), std::back_inserter(h_data));
h_offset.emplace_back(rptr);
}
out->Info() = this->Info().Slice(ridxs);
out->Info().num_nonzero_ = h_offset.back();
}
return out;
}

BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
// since csr is the default data structure so `source_` is always available.
auto begin_iter = BatchIterator<SparsePage>(
Expand Down Expand Up @@ -174,8 +195,6 @@ template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(DMatrixSliceAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(IteratorAdapter* adapter, float missing,
int nthread);
} // namespace data
Expand Down
2 changes: 2 additions & 0 deletions src/data/simple_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace data {
// Used for single batch data.
class SimpleDMatrix : public DMatrix {
public:
SimpleDMatrix() = default;
template <typename AdapterT>
explicit SimpleDMatrix(AdapterT* adapter, float missing, int nthread);

Expand All @@ -32,6 +33,7 @@ class SimpleDMatrix : public DMatrix {
const MetaInfo& Info() const override;

bool SingleColBlock() const override { return true; }
DMatrix* Slice(common::Span<int32_t const> ridxs) override;

/*! \brief magic number used to identify SimpleDMatrix binary files */
static const int kMagic = 0xffffab01;
Expand Down
4 changes: 4 additions & 0 deletions src/data/sparse_page_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ class SparsePageDMatrix : public DMatrix {
const MetaInfo& Info() const override;

bool SingleColBlock() const override { return false; }
DMatrix *Slice(common::Span<int32_t const> ridxs) override {
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
return nullptr;
}

private:
BatchSet<SparsePage> GetRowBatches() override;
Expand Down
25 changes: 0 additions & 25 deletions tests/cpp/data/test_adapter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,31 +67,6 @@ TEST(Adapter, CSCAdapterColsMoreThanRows) {
EXPECT_EQ(inst[3].index, 3);
}

TEST(CAPI, DMatrixSliceAdapterFromSimpleDMatrix) {
auto p_dmat = RandomDataGenerator(6, 2, 1.0).GenerateDMatrix();

std::vector<int> ridx_set = {1, 3, 5};
data::DMatrixSliceAdapter adapter(p_dmat.get(),
{ridx_set.data(), ridx_set.size()});
EXPECT_EQ(adapter.NumRows(), ridx_set.size());

adapter.BeforeFirst();
for (auto &batch : p_dmat->GetBatches<SparsePage>()) {
adapter.Next();
auto &adapter_batch = adapter.Value();
for (auto i = 0ull; i < adapter_batch.Size(); i++) {
auto inst = batch[ridx_set[i]];
auto line = adapter_batch.GetLine(i);
ASSERT_EQ(inst.size(), line.Size());
for (auto j = 0ull; j < line.Size(); j++) {
EXPECT_EQ(inst[j].fvalue, line.GetElement(j).value);
EXPECT_EQ(inst[j].index, line.GetElement(j).column_idx);
EXPECT_EQ(i, line.GetElement(j).row_idx);
}
}
}
}

// A mock for JVM data iterator.
class DataIterForTest {
std::vector<float> data_ {1, 2, 3, 4, 5};
Expand Down
1 change: 0 additions & 1 deletion tests/cpp/data/test_data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,4 @@ TEST(DMatrix, Uri) {
ASSERT_EQ(dmat->Info().num_col_, kCols);
ASSERT_EQ(dmat->Info().num_row_, kRows);
}

} // namespace xgboost
Loading