diff --git a/R-package/R/xgb.DMatrix.R b/R-package/R/xgb.DMatrix.R index 4c125832d2ea..4201a830276c 100644 --- a/R-package/R/xgb.DMatrix.R +++ b/R-package/R/xgb.DMatrix.R @@ -188,9 +188,10 @@ getinfo <- function(object, ...) UseMethod("getinfo") getinfo.xgb.DMatrix <- function(object, name, ...) { if (typeof(name) != "character" || length(name) != 1 || - !name %in% c('label', 'weight', 'base_margin', 'nrow')) { + !name %in% c('label', 'weight', 'base_margin', 'nrow', + 'label_lower_bound', 'label_upper_bound')) { stop("getinfo: name must be one of the following\n", - " 'label', 'weight', 'base_margin', 'nrow'") + " 'label', 'weight', 'base_margin', 'nrow', 'label_lower_bound', 'label_upper_bound'") } if (name != "nrow"){ ret <- .Call(XGDMatrixGetInfo_R, object, name) diff --git a/R-package/tests/testthat/test_dmatrix.R b/R-package/tests/testthat/test_dmatrix.R index 27889165cf2a..c06358962959 100644 --- a/R-package/tests/testthat/test_dmatrix.R +++ b/R-package/tests/testthat/test_dmatrix.R @@ -50,6 +50,12 @@ test_that("xgb.DMatrix: getinfo & setinfo", { labels <- getinfo(dtest, 'label') expect_equal(test_label, getinfo(dtest, 'label')) + expect_true(setinfo(dtest, 'label_lower_bound', test_label)) + expect_equal(test_label, getinfo(dtest, 'label_lower_bound')) + + expect_true(setinfo(dtest, 'label_upper_bound', test_label)) + expect_equal(test_label, getinfo(dtest, 'label_upper_bound')) + expect_true(length(getinfo(dtest, 'weight')) == 0) expect_true(length(getinfo(dtest, 'base_margin')) == 0) @@ -59,7 +65,7 @@ test_that("xgb.DMatrix: getinfo & setinfo", { expect_error(setinfo(dtest, 'group', test_label)) # providing character values will give a warning - expect_warning( setinfo(dtest, 'weight', rep('a', nrow(test_data))) ) + expect_warning(setinfo(dtest, 'weight', rep('a', nrow(test_data)))) # any other label should error expect_error(setinfo(dtest, 'asdf', test_label)) diff --git a/include/xgboost/data.h b/include/xgboost/data.h index c66f60aff9bc..c2a80576c395 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -73,6 +73,8 @@ class MetaInfo { /*! \brief default constructor */ MetaInfo() = default; + MetaInfo(MetaInfo&& that) = default; + MetaInfo& operator=(MetaInfo&& that) = default; MetaInfo& operator=(MetaInfo const& that) { this->num_row_ = that.num_row_; this->num_col_ = that.num_col_; @@ -89,6 +91,8 @@ class MetaInfo { this->base_margin_.Copy(that.base_margin_); return *this; } + + MetaInfo Slice(common::Span ridxs) const; /*! * \brief Get weight of each instances. * \param i Instance index. @@ -491,7 +495,7 @@ class DMatrix { const std::string& cache_prefix = "", size_t page_size = kPageSize); - + virtual DMatrix* Slice(common::Span ridxs) = 0; /*! \brief page size 32 MB */ static const size_t kPageSize = 32UL << 20UL; diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 78ed862361ea..754e27ef02b3 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -181,11 +181,7 @@ XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle, << "slice does not support group structure"; } DMatrix* dmat = static_cast*>(handle)->get(); - CHECK(dynamic_cast(dmat)) - << "Slice only supported for SimpleDMatrix currently."; - data::DMatrixSliceAdapter adapter(dmat, {idxset, static_cast(len)}); - *out = new std::shared_ptr( - DMatrix::Create(&adapter, std::numeric_limits::quiet_NaN(), 1)); + *out = new std::shared_ptr(dmat->Slice({idxset, len})); API_END(); } diff --git a/src/data/adapter.h b/src/data/adapter.h index 9d26d5e14126..e252e1da7a1f 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -599,93 +599,6 @@ class IteratorAdapter : public dmlc::DataIter { dmlc::RowBlock block_; std::unique_ptr batch_; }; - -class DMatrixSliceAdapterBatch { - public: - // Fetch metainfo values according to sliced rows - template - std::vector Gather(const std::vector& in) { - if (in.empty()) return {}; - - std::vector out(this->Size()); - for (auto i = 0ull; i < this->Size(); i++) { - out[i] = in[ridx_set[i]]; - } - return out; - } - DMatrixSliceAdapterBatch(const SparsePage& batch, DMatrix* dmat, - common::Span ridx_set) - : batch(batch), ridx_set(ridx_set) { - batch_labels = this->Gather(dmat->Info().labels_.HostVector()); - batch_weights = this->Gather(dmat->Info().weights_.HostVector()); - batch_base_margin = this->Gather(dmat->Info().base_margin_.HostVector()); - } - - class Line { - public: - Line(const SparsePage::Inst& inst, size_t row_idx) - : inst_(inst), row_idx_(row_idx) {} - - size_t Size() { return inst_.size(); } - COOTuple GetElement(size_t idx) { - return COOTuple{row_idx_, inst_[idx].index, inst_[idx].fvalue}; - } - - private: - SparsePage::Inst inst_; - size_t row_idx_; - }; - Line GetLine(size_t idx) const { return Line(batch[ridx_set[idx]], idx); } - const float* Labels() const { - if (batch_labels.empty()) { - return nullptr; - } - return batch_labels.data(); - } - const float* Weights() const { - if (batch_weights.empty()) { - return nullptr; - } - return batch_weights.data(); - } - const uint64_t* Qid() const { return nullptr; } - const float* BaseMargin() const { - if (batch_base_margin.empty()) { - return nullptr; - } - return batch_base_margin.data(); - } - - size_t Size() const { return ridx_set.size(); } - const SparsePage& batch; - common::Span ridx_set; - std::vector batch_labels; - std::vector batch_weights; - std::vector batch_base_margin; -}; - -// Group pointer is not exposed -// This is because external bindings currently manipulate the group values -// manually when slicing This could potentially be moved to internal C++ code if -// needed - -class DMatrixSliceAdapter - : public detail::SingleBatchDataIter { - public: - DMatrixSliceAdapter(DMatrix* dmat, common::Span ridx_set) - : dmat_(dmat), - ridx_set_(ridx_set), - batch_(*dmat_->GetBatches().begin(), dmat_, ridx_set) {} - const DMatrixSliceAdapterBatch& Value() const override { return batch_; } - // Indicates a number of rows/columns must be inferred - size_t NumRows() const { return ridx_set_.size(); } - size_t NumColumns() const { return dmat_->Info().num_col_; } - - private: - DMatrix* dmat_; - common::Span ridx_set_; - DMatrixSliceAdapterBatch batch_; -}; }; // namespace data } // namespace xgboost #endif // XGBOOST_DATA_ADAPTER_H_ diff --git a/src/data/data.cc b/src/data/data.cc index 405b88c7d248..3af47e2b6600 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -205,6 +205,53 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) { LoadVectorField(fi, u8"labels_upper_bound", DataType::kFloat32, &labels_upper_bound_); } +template +std::vector Gather(const std::vector &in, common::Span ridxs, size_t stride = 1) { + if (in.empty()) { + return {}; + } + auto size = ridxs.size(); + std::vector out(size * stride); + for (auto i = 0ull; i < size; i++) { + auto ridx = ridxs[i]; + for (size_t j = 0; j < stride; ++j) { + out[i * stride +j] = in[ridx * stride + j]; + } + } + return out; +} + +MetaInfo MetaInfo::Slice(common::Span ridxs) const { + MetaInfo out; + out.num_row_ = ridxs.size(); + out.num_col_ = this->num_col_; + // Groups is maintained by a higher level Python function. We should aim at deprecating + // the slice function. + out.labels_.HostVector() = Gather(this->labels_.HostVector(), ridxs); + out.labels_upper_bound_.HostVector() = + Gather(this->labels_upper_bound_.HostVector(), ridxs); + out.labels_lower_bound_.HostVector() = + Gather(this->labels_lower_bound_.HostVector(), ridxs); + // weights + if (this->weights_.Size() + 1 == this->group_ptr_.size()) { + auto& h_weights = out.weights_.HostVector(); + // Assuming all groups are available. + out.weights_.HostVector() = h_weights; + } else { + out.weights_.HostVector() = Gather(this->weights_.HostVector(), ridxs); + } + + if (this->base_margin_.Size() != this->num_row_) { + CHECK_EQ(this->base_margin_.Size() % this->num_row_, 0) + << "Incorrect size of base margin vector."; + size_t stride = this->base_margin_.Size() / this->num_row_; + out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs, stride); + } else { + out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs); + } + return out; +} + // try to load group information from file, if exists inline bool MetaTryLoadGroup(const std::string& fname, std::vector* group) { @@ -459,9 +506,6 @@ template DMatrix* DMatrix::Create( template DMatrix* DMatrix::Create( data::FileAdapter* adapter, float missing, int nthread, const std::string& cache_prefix, size_t page_size); -template DMatrix* DMatrix::Create( - data::DMatrixSliceAdapter* adapter, float missing, int nthread, - const std::string& cache_prefix, size_t page_size); template DMatrix* DMatrix::Create( data::IteratorAdapter* adapter, float missing, int nthread, const std::string& cache_prefix, size_t page_size); diff --git a/src/data/device_dmatrix.h b/src/data/device_dmatrix.h index 2442175a713d..781461baaa9e 100644 --- a/src/data/device_dmatrix.h +++ b/src/data/device_dmatrix.h @@ -31,6 +31,10 @@ class DeviceDMatrix : public DMatrix { bool EllpackExists() const override { return true; } bool SparsePageExists() const override { return false; } + DMatrix *Slice(common::Span ridxs) override { + LOG(FATAL) << "Slicing DMatrix is not supported for Device DMatrix."; + return nullptr; + } private: BatchSet GetRowBatches() override { diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index d0c1396dd2a6..ca65eb4cc930 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -16,6 +16,27 @@ MetaInfo& SimpleDMatrix::Info() { return info_; } const MetaInfo& SimpleDMatrix::Info() const { return info_; } +DMatrix* SimpleDMatrix::Slice(common::Span ridxs) { + auto out = new SimpleDMatrix; + SparsePage& out_page = out->sparse_page_; + for (auto const &page : this->GetBatches()) { + page.data.HostVector(); + page.offset.HostVector(); + auto& h_data = out_page.data.HostVector(); + auto& h_offset = out_page.offset.HostVector(); + size_t rptr{0}; + for (auto ridx : ridxs) { + auto inst = page[ridx]; + rptr += inst.size(); + std::copy(inst.begin(), inst.end(), std::back_inserter(h_data)); + h_offset.emplace_back(rptr); + } + out->Info() = this->Info().Slice(ridxs); + out->Info().num_nonzero_ = h_offset.back(); + } + return out; +} + BatchSet SimpleDMatrix::GetRowBatches() { // since csr is the default data structure so `source_` is always available. auto begin_iter = BatchIterator( @@ -174,8 +195,6 @@ template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing, int nthread); template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing, int nthread); -template SimpleDMatrix::SimpleDMatrix(DMatrixSliceAdapter* adapter, float missing, - int nthread); template SimpleDMatrix::SimpleDMatrix(IteratorAdapter* adapter, float missing, int nthread); } // namespace data diff --git a/src/data/simple_dmatrix.h b/src/data/simple_dmatrix.h index e0b41a4ef914..9d2130b4195e 100644 --- a/src/data/simple_dmatrix.h +++ b/src/data/simple_dmatrix.h @@ -19,6 +19,7 @@ namespace data { // Used for single batch data. class SimpleDMatrix : public DMatrix { public: + SimpleDMatrix() = default; template explicit SimpleDMatrix(AdapterT* adapter, float missing, int nthread); @@ -32,6 +33,7 @@ class SimpleDMatrix : public DMatrix { const MetaInfo& Info() const override; bool SingleColBlock() const override { return true; } + DMatrix* Slice(common::Span ridxs) override; /*! \brief magic number used to identify SimpleDMatrix binary files */ static const int kMagic = 0xffffab01; diff --git a/src/data/sparse_page_dmatrix.h b/src/data/sparse_page_dmatrix.h index 05d318313cc8..393172658695 100644 --- a/src/data/sparse_page_dmatrix.h +++ b/src/data/sparse_page_dmatrix.h @@ -37,6 +37,10 @@ class SparsePageDMatrix : public DMatrix { const MetaInfo& Info() const override; bool SingleColBlock() const override { return false; } + DMatrix *Slice(common::Span ridxs) override { + LOG(FATAL) << "Slicing DMatrix is not supported for external memory."; + return nullptr; + } private: BatchSet GetRowBatches() override; diff --git a/tests/cpp/data/test_adapter.cc b/tests/cpp/data/test_adapter.cc index 2d5e33ac0084..89bd3e7869b2 100644 --- a/tests/cpp/data/test_adapter.cc +++ b/tests/cpp/data/test_adapter.cc @@ -67,31 +67,6 @@ TEST(Adapter, CSCAdapterColsMoreThanRows) { EXPECT_EQ(inst[3].index, 3); } -TEST(CAPI, DMatrixSliceAdapterFromSimpleDMatrix) { - auto p_dmat = RandomDataGenerator(6, 2, 1.0).GenerateDMatrix(); - - std::vector ridx_set = {1, 3, 5}; - data::DMatrixSliceAdapter adapter(p_dmat.get(), - {ridx_set.data(), ridx_set.size()}); - EXPECT_EQ(adapter.NumRows(), ridx_set.size()); - - adapter.BeforeFirst(); - for (auto &batch : p_dmat->GetBatches()) { - adapter.Next(); - auto &adapter_batch = adapter.Value(); - for (auto i = 0ull; i < adapter_batch.Size(); i++) { - auto inst = batch[ridx_set[i]]; - auto line = adapter_batch.GetLine(i); - ASSERT_EQ(inst.size(), line.Size()); - for (auto j = 0ull; j < line.Size(); j++) { - EXPECT_EQ(inst[j].fvalue, line.GetElement(j).value); - EXPECT_EQ(inst[j].index, line.GetElement(j).column_idx); - EXPECT_EQ(i, line.GetElement(j).row_idx); - } - } - } -} - // A mock for JVM data iterator. class DataIterForTest { std::vector data_ {1, 2, 3, 4, 5}; diff --git a/tests/cpp/data/test_data.cc b/tests/cpp/data/test_data.cc index 041cd9dbc4b1..d01da568c421 100644 --- a/tests/cpp/data/test_data.cc +++ b/tests/cpp/data/test_data.cc @@ -125,5 +125,4 @@ TEST(DMatrix, Uri) { ASSERT_EQ(dmat->Info().num_col_, kCols); ASSERT_EQ(dmat->Info().num_row_, kRows); } - } // namespace xgboost diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index ebcd6678c666..691dc8545eca 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -1,11 +1,12 @@ // Copyright by Contributors #include #include -#include "../../../src/data/simple_dmatrix.h" +#include +#include "xgboost/base.h" +#include "../../../src/data/simple_dmatrix.h" #include "../../../src/data/adapter.h" #include "../helpers.h" -#include "xgboost/base.h" using namespace xgboost; // NOLINT @@ -218,45 +219,64 @@ TEST(SimpleDMatrix, FromFile) { } TEST(SimpleDMatrix, Slice) { - const int kRows = 6; - const int kCols = 2; - auto p_dmat = RandomDataGenerator(kRows, kCols, 1.0).GenerateDMatrix(); - auto &labels = p_dmat->Info().labels_.HostVector(); - auto &weights = p_dmat->Info().weights_.HostVector(); - auto &base_margin = p_dmat->Info().base_margin_.HostVector(); + size_t constexpr kRows {16}; + size_t constexpr kCols {8}; + size_t constexpr kClasses {3}; + auto p_m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true); + auto& weights = p_m->Info().weights_.HostVector(); weights.resize(kRows); - labels.resize(kRows); - base_margin.resize(kRows); - std::iota(labels.begin(), labels.end(), 0); - std::iota(weights.begin(), weights.end(), 0); - std::iota(base_margin.begin(), base_margin.end(), 0); - - std::vector ridx_set = {1, 3, 5}; - data::DMatrixSliceAdapter adapter(p_dmat.get(), - {ridx_set.data(), ridx_set.size()}); - EXPECT_EQ(adapter.NumRows(), ridx_set.size()); - data::SimpleDMatrix new_dmat(&adapter, - std::numeric_limits::quiet_NaN(), 1); - - EXPECT_EQ(new_dmat.Info().num_row_, ridx_set.size()); - - auto &old_batch = *p_dmat->GetBatches().begin(); - auto &new_batch = *new_dmat.GetBatches().begin(); - for (auto i = 0ull; i < ridx_set.size(); i++) { - EXPECT_EQ(new_dmat.Info().labels_.HostVector()[i], - p_dmat->Info().labels_.HostVector()[ridx_set[i]]); - EXPECT_EQ(new_dmat.Info().weights_.HostVector()[i], - p_dmat->Info().weights_.HostVector()[ridx_set[i]]); - EXPECT_EQ(new_dmat.Info().base_margin_.HostVector()[i], - p_dmat->Info().base_margin_.HostVector()[ridx_set[i]]); - auto old_inst = old_batch[ridx_set[i]]; - auto new_inst = new_batch[i]; - ASSERT_EQ(old_inst.size(), new_inst.size()); - for (auto j = 0ull; j < old_inst.size(); j++) { - EXPECT_EQ(old_inst[j], new_inst[j]); + std::iota(weights.begin(), weights.end(), 0.0f); + + auto& lower = p_m->Info().labels_lower_bound_.HostVector(); + auto& upper = p_m->Info().labels_upper_bound_.HostVector(); + lower.resize(kRows); + upper.resize(kRows); + + std::iota(lower.begin(), lower.end(), 0.0f); + std::iota(upper.begin(), upper.end(), 1.0f); + + auto& margin = p_m->Info().base_margin_.HostVector(); + margin.resize(kRows * kClasses); + + std::array ridxs {1, 3, 5}; + std::unique_ptr out { p_m->Slice(ridxs) }; + ASSERT_EQ(out->Info().labels_.Size(), ridxs.size()); + ASSERT_EQ(out->Info().labels_lower_bound_.Size(), ridxs.size()); + ASSERT_EQ(out->Info().labels_upper_bound_.Size(), ridxs.size()); + ASSERT_EQ(out->Info().base_margin_.Size(), ridxs.size() * kClasses); + + for (auto const& in_page : p_m->GetBatches()) { + for (auto const &out_page : out->GetBatches()) { + for (size_t i = 0; i < ridxs.size(); ++i) { + auto ridx = ridxs[i]; + auto out_inst = out_page[i]; + auto in_inst = in_page[ridx]; + ASSERT_EQ(out_inst.size(), in_inst.size()) << i; + for (size_t j = 0; j < in_inst.size(); ++j) { + ASSERT_EQ(in_inst[j].fvalue, out_inst[j].fvalue); + ASSERT_EQ(in_inst[j].index, out_inst[j].index); + } + + ASSERT_EQ(p_m->Info().labels_lower_bound_.HostVector().at(ridx), + out->Info().labels_lower_bound_.HostVector().at(i)); + ASSERT_EQ(p_m->Info().labels_upper_bound_.HostVector().at(ridx), + out->Info().labels_upper_bound_.HostVector().at(i)); + ASSERT_EQ(p_m->Info().weights_.HostVector().at(ridx), + out->Info().weights_.HostVector().at(i)); + + auto& out_margin = out->Info().base_margin_.HostVector(); + for (size_t j = 0; j < kClasses; ++j) { + auto in_beg = ridx * kClasses; + ASSERT_EQ(out_margin.at(i * kClasses + j), margin.at(in_beg + j)); + } + } } } -}; + + ASSERT_EQ(out->Info().num_col_, out->Info().num_col_); + ASSERT_EQ(out->Info().num_row_, ridxs.size()); + ASSERT_EQ(out->Info().num_nonzero_, ridxs.size() * kCols); // dense +} TEST(SimpleDMatrix, SaveLoadBinary) { dmlc::TemporaryDirectory tempdir; diff --git a/tests/python/test_dmatrix.py b/tests/python/test_dmatrix.py index 2becb674f7c9..c1640d4e3b45 100644 --- a/tests/python/test_dmatrix.py +++ b/tests/python/test_dmatrix.py @@ -71,7 +71,34 @@ def test_np_view(self): assert (from_view.shape == from_array.shape) assert (from_view == from_array).all() - def test_feature_names(self): + def test_slice(self): + X = rng.randn(100, 100) + y = rng.randint(low=0, high=3, size=100) + d = xgb.DMatrix(X, y) + eval_res_0 = {} + booster = xgb.train( + {'num_class': 3, 'objective': 'multi:softprob'}, d, + num_boost_round=2, evals=[(d, 'd')], evals_result=eval_res_0) + + predt = booster.predict(d) + predt = predt.reshape(100 * 3, 1) + d.set_base_margin(predt) + + ridxs = [1, 2, 3, 4, 5, 6] + d = d.slice(ridxs) + sliced_margin = d.get_float_info('base_margin') + assert sliced_margin.shape[0] == len(ridxs) * 3 + + eval_res_1 = {} + xgb.train({'num_class': 3, 'objective': 'multi:softprob'}, d, + num_boost_round=2, evals=[(d, 'd')], evals_result=eval_res_1) + + eval_res_0 = eval_res_0['d']['merror'] + eval_res_1 = eval_res_1['d']['merror'] + for i in range(len(eval_res_0)): + assert abs(eval_res_0[i] - eval_res_1[i]) < 0.02 + + def test_feature_names_slice(self): data = np.random.randn(5, 5) # different length