diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 1655f9d0ec89..42c299dd2703 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -148,13 +148,13 @@ class MetaInfo { * \param dtype The type of the source data. * \param num Number of elements in the source array. */ - void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num); + void SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype, size_t num); /*! * \brief Set information in the meta info with array interface. * \param key The key of the information. * \param interface_str String representation of json format array interface. */ - void SetInfo(StringView key, StringView interface_str); + void SetInfo(Context const& ctx, StringView key, StringView interface_str); void GetInfo(char const* key, bst_ulong* out_len, DataType dtype, const void** out_dptr) const; @@ -176,8 +176,8 @@ class MetaInfo { void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column); private: - void SetInfoFromHost(StringView key, Json arr); - void SetInfoFromCUDA(StringView key, Json arr); + void SetInfoFromHost(Context const& ctx, StringView key, Json arr); + void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr); /*! \brief argsort of labels */ mutable std::vector label_order_cache_; @@ -478,12 +478,13 @@ class DMatrix { DMatrix() = default; /*! \brief meta information of the dataset */ virtual MetaInfo& Info() = 0; - virtual void SetInfo(const char *key, const void *dptr, DataType dtype, - size_t num) { - this->Info().SetInfo(key, dptr, dtype, num); + virtual void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) { + auto const& ctx = *this->Ctx(); + this->Info().SetInfo(ctx, key, dptr, dtype, num); } virtual void SetInfo(const char* key, std::string const& interface_str) { - this->Info().SetInfo(key, StringView{interface_str}); + auto const& ctx = *this->Ctx(); + this->Info().SetInfo(ctx, key, StringView{interface_str}); } /*! \brief meta information of the dataset */ virtual const MetaInfo& Info() const = 0; @@ -494,7 +495,7 @@ class DMatrix { * \brief Get the context object of this DMatrix. The context is created during construction of * DMatrix with user specified `nthread` parameter. */ - virtual GenericParameter const* Ctx() const = 0; + virtual Context const* Ctx() const = 0; /** * \brief Gets batches. Use range based for loop over BatchSet to access individual batches. diff --git a/include/xgboost/generic_parameters.h b/include/xgboost/generic_parameters.h index 8f8cd09124d5..0375ecfafdc2 100644 --- a/include/xgboost/generic_parameters.h +++ b/include/xgboost/generic_parameters.h @@ -75,6 +75,8 @@ struct GenericParameter : public XGBoostParameter { .describe("Enable checking whether parameters are used or not."); } }; + +using Context = GenericParameter; } // namespace xgboost #endif // XGBOOST_GENERIC_PARAMETERS_H_ diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 86d763a6af1e..a11602a56610 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -485,35 +485,30 @@ XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle, const char* fname, API_END(); } -XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, - const char* field, - const bst_float* info, +XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, const char *field, const bst_float *info, xgboost::bst_ulong len) { API_BEGIN(); CHECK_HANDLE(); - static_cast*>(handle) - ->get()->Info().SetInfo(field, info, xgboost::DataType::kFloat32, len); + auto const& p_fmat = *static_cast *>(handle); + p_fmat->SetInfo(field, info, xgboost::DataType::kFloat32, len); API_END(); } -XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle, - char const* field, - char const* interface_c_str) { +XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle, char const *field, + char const *interface_c_str) { API_BEGIN(); CHECK_HANDLE(); - static_cast*>(handle) - ->get()->Info().SetInfo(field, interface_c_str); + auto const &p_fmat = *static_cast *>(handle); + p_fmat->SetInfo(field, interface_c_str); API_END(); } -XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, - const char* field, - const unsigned* info, +XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const unsigned *info, xgboost::bst_ulong len) { API_BEGIN(); CHECK_HANDLE(); - static_cast*>(handle) - ->get()->Info().SetInfo(field, info, xgboost::DataType::kUInt32, len); + auto const &p_fmat = *static_cast *>(handle); + p_fmat->SetInfo(field, info, xgboost::DataType::kUInt32, len); API_END(); } @@ -549,25 +544,22 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field, API_END(); } -XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, - void const *data, xgboost::bst_ulong size, - int type) { +XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void const *data, + xgboost::bst_ulong size, int type) { API_BEGIN(); CHECK_HANDLE(); - auto &info = static_cast *>(handle)->get()->Info(); + auto const &p_fmat = *static_cast *>(handle); CHECK(type >= 1 && type <= 4); - info.SetInfo(field, data, static_cast(type), size); + p_fmat->SetInfo(field, data, static_cast(type), size); API_END(); } -XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle, - const unsigned* group, - xgboost::bst_ulong len) { +XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle, const unsigned *group, xgboost::bst_ulong len) { API_BEGIN(); CHECK_HANDLE(); LOG(WARNING) << "XGDMatrixSetGroup is deprecated, use `XGDMatrixSetUIntInfo` instead."; - static_cast*>(handle) - ->get()->Info().SetInfo("group", group, xgboost::DataType::kUInt32, len); + auto const &p_fmat = *static_cast *>(handle); + p_fmat->SetInfo("group", group, xgboost::DataType::kUInt32, len); API_END(); } diff --git a/src/data/data.cc b/src/data/data.cc index 57940199f0f4..ff5b2c8030cb 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -409,7 +409,7 @@ inline bool MetaTryLoadFloatInfo(const std::string& fname, namespace { template -void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor* p_out) { +void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor* p_out) { ArrayInterface array{arr_interface}; if (array.n == 0) { p_out->Reshape(array.shape); @@ -428,16 +428,15 @@ void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor* p_out) { return; } p_out->Reshape(array.shape); - auto t = p_out->View(GenericParameter::kCpuId); + auto t = p_out->View(Context::kCpuId); CHECK(t.CContiguous()); - // FIXME(jiamingy): Remove the use of this default thread. - linalg::ElementWiseTransformHost(t, common::OmpGetNumThreads(0), [&](auto i, auto) { + linalg::ElementWiseTransformHost(t, ctx.Threads(), [&](auto i, auto) { return linalg::detail::Apply(TypedIndex{array}, linalg::UnravelIndex(i, t.Shape())); }); } } // namespace -void MetaInfo::SetInfo(StringView key, StringView interface_str) { +void MetaInfo::SetInfo(Context const& ctx, StringView key, StringView interface_str) { Json j_interface = Json::Load(interface_str); bool is_cuda{false}; if (IsA(j_interface)) { @@ -454,16 +453,16 @@ void MetaInfo::SetInfo(StringView key, StringView interface_str) { } if (is_cuda) { - this->SetInfoFromCUDA(key, j_interface); + this->SetInfoFromCUDA(ctx, key, j_interface); } else { - this->SetInfoFromHost(key, j_interface); + this->SetInfoFromHost(ctx, key, j_interface); } } -void MetaInfo::SetInfoFromHost(StringView key, Json arr) { +void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) { // multi-dim float info if (key == "base_margin") { - CopyTensorInfoImpl(arr, &this->base_margin_); + CopyTensorInfoImpl(ctx, arr, &this->base_margin_); // FIXME(jiamingy): Remove the deprecated API and let all language bindings aware of // input shape. This issue is CPU only since CUDA uses array interface from day 1. // @@ -477,7 +476,7 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) { } return; } else if (key == "label") { - CopyTensorInfoImpl(arr, &this->labels); + CopyTensorInfoImpl(ctx, arr, &this->labels); if (this->num_row_ != 0 && this->labels.Shape(0) != this->num_row_) { CHECK_EQ(this->labels.Size() % this->num_row_, 0) << "Incorrect size for labels."; size_t n_targets = this->labels.Size() / this->num_row_; @@ -491,7 +490,7 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) { // uint info if (key == "group") { linalg::Tensor t; - CopyTensorInfoImpl(arr, &t); + CopyTensorInfoImpl(ctx, arr, &t); auto const& h_groups = t.Data()->HostVector(); group_ptr_.clear(); group_ptr_.resize(h_groups.size() + 1, 0); @@ -501,7 +500,7 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) { return; } else if (key == "qid") { linalg::Tensor t; - CopyTensorInfoImpl(arr, &t); + CopyTensorInfoImpl(ctx, arr, &t); bool non_dec = true; auto const& query_ids = t.Data()->HostVector(); for (size_t i = 1; i < query_ids.size(); ++i) { @@ -526,7 +525,7 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) { } // float info linalg::Tensor t; - CopyTensorInfoImpl<1>(arr, &t); + CopyTensorInfoImpl<1>(ctx, arr, &t); if (key == "weight") { this->weights_ = std::move(*t.Data()); auto const& h_weights = this->weights_.ConstHostVector(); @@ -548,13 +547,15 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) { } } -void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) { +void MetaInfo::SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype, + size_t num) { auto proc = [&](auto cast_d_ptr) { using T = std::remove_pointer_t; - auto t = - linalg::TensorView(common::Span{cast_d_ptr, num}, {num}, GenericParameter::kCpuId); + auto t = linalg::TensorView(common::Span{cast_d_ptr, num}, {num}, Context::kCpuId); CHECK(t.CContiguous()); - Json interface { linalg::ArrayInterface(t) }; + Json interface { + linalg::ArrayInterface(t) + }; assert(ArrayInterface<1>{interface}.is_contiguous); return interface; }; @@ -562,22 +563,22 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t switch (dtype) { case xgboost::DataType::kFloat32: { auto cast_ptr = reinterpret_cast(dptr); - this->SetInfoFromHost(key, proc(cast_ptr)); + this->SetInfoFromHost(ctx, key, proc(cast_ptr)); break; } case xgboost::DataType::kDouble: { auto cast_ptr = reinterpret_cast(dptr); - this->SetInfoFromHost(key, proc(cast_ptr)); + this->SetInfoFromHost(ctx, key, proc(cast_ptr)); break; } case xgboost::DataType::kUInt32: { auto cast_ptr = reinterpret_cast(dptr); - this->SetInfoFromHost(key, proc(cast_ptr)); + this->SetInfoFromHost(ctx, key, proc(cast_ptr)); break; } case xgboost::DataType::kUInt64: { auto cast_ptr = reinterpret_cast(dptr); - this->SetInfoFromHost(key, proc(cast_ptr)); + this->SetInfoFromHost(ctx, key, proc(cast_ptr)); break; } default: @@ -724,9 +725,7 @@ void MetaInfo::Validate(int32_t device) const { "doesn't equal to actual number of rows given by data."; } auto check_device = [device](HostDeviceVector const& v) { - CHECK(v.DeviceIdx() == GenericParameter::kCpuId || - device == GenericParameter::kCpuId || - v.DeviceIdx() == device) + CHECK(v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device) << "Data is resided on a different device than `gpu_id`. " << "Device that data is on: " << v.DeviceIdx() << ", " << "`gpu_id` for XGBoost: " << device; @@ -769,7 +768,9 @@ void MetaInfo::Validate(int32_t device) const { } #if !defined(XGBOOST_USE_CUDA) -void MetaInfo::SetInfoFromCUDA(StringView key, Json arr) { common::AssertGPUSupport(); } +void MetaInfo::SetInfoFromCUDA(Context const& ctx, StringView key, Json arr) { + common::AssertGPUSupport(); +} #endif // !defined(XGBOOST_USE_CUDA) using DMatrixThreadLocal = diff --git a/src/data/data.cu b/src/data/data.cu index 55c1c80d0e23..cf574dd50f1e 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -115,7 +115,8 @@ void CopyQidImpl(ArrayInterface<1> array_interface, std::vector* p_ } } // namespace -void MetaInfo::SetInfoFromCUDA(StringView key, Json array) { +// Context is not used until we have CUDA stream. +void MetaInfo::SetInfoFromCUDA(Context const&, StringView key, Json array) { // multi-dim float info if (key == "base_margin") { CopyTensorInfoImpl(array, &base_margin_); diff --git a/src/data/iterative_device_dmatrix.cu b/src/data/iterative_device_dmatrix.cu index 0f7b6d790492..5c8612d63214 100644 --- a/src/data/iterative_device_dmatrix.cu +++ b/src/data/iterative_device_dmatrix.cu @@ -43,18 +43,18 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin size_t batches = 0; size_t accumulated_rows = 0; bst_feature_t cols = 0; - int32_t device = GenericParameter::kCpuId; + int32_t current_device; dh::safe_cuda(cudaGetDevice(¤t_device)); auto get_device = [&]() -> int32_t { - int32_t d = (device == GenericParameter::kCpuId) ? current_device : device; - CHECK_NE(d, GenericParameter::kCpuId); + int32_t d = (ctx_.gpu_id == Context::kCpuId) ? current_device : ctx_.gpu_id; + CHECK_NE(d, Context::kCpuId); return d; }; while (iter.Next()) { - device = proxy->DeviceIdx(); - CHECK_LT(device, common::AllVisibleGPUs()); + ctx_.gpu_id = proxy->DeviceIdx(); + CHECK_LT(ctx_.gpu_id, common::AllVisibleGPUs()); dh::safe_cuda(cudaSetDevice(get_device())); if (cols == 0) { cols = num_cols(); diff --git a/src/data/iterative_device_dmatrix.h b/src/data/iterative_device_dmatrix.h index 69b2793be02e..ba2d4a92f9da 100644 --- a/src/data/iterative_device_dmatrix.h +++ b/src/data/iterative_device_dmatrix.h @@ -21,6 +21,7 @@ namespace data { class IterativeDeviceDMatrix : public DMatrix { MetaInfo info_; + Context ctx_; BatchParam batch_param_; std::shared_ptr page_; @@ -72,10 +73,7 @@ class IterativeDeviceDMatrix : public DMatrix { MetaInfo &Info() override { return info_; } MetaInfo const &Info() const override { return info_; } - GenericParameter const *Ctx() const override { - LOG(FATAL) << "`IterativeDMatrix` doesn't have context."; - return nullptr; - } + Context const *Ctx() const override { return &ctx_; } }; #if !defined(XGBOOST_USE_CUDA) diff --git a/src/data/proxy_dmatrix.cu b/src/data/proxy_dmatrix.cu index 6fbd721007d0..84f1fcb0d527 100644 --- a/src/data/proxy_dmatrix.cu +++ b/src/data/proxy_dmatrix.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2020 XGBoost contributors + * Copyright 2020-2022, XGBoost contributors */ #include "proxy_dmatrix.h" #include "device_adapter.cuh" @@ -11,10 +11,10 @@ void DMatrixProxy::FromCudaColumnar(std::string interface_str) { std::shared_ptr adapter {new data::CudfAdapter(interface_str)}; auto const& value = adapter->Value(); this->batch_ = adapter; - device_ = adapter->DeviceIdx(); + ctx_.gpu_id = adapter->DeviceIdx(); this->Info().num_col_ = adapter->NumColumns(); this->Info().num_row_ = adapter->NumRows(); - if (device_ < 0) { + if (ctx_.gpu_id < 0) { CHECK_EQ(this->Info().num_row_, 0); } } @@ -22,13 +22,12 @@ void DMatrixProxy::FromCudaColumnar(std::string interface_str) { void DMatrixProxy::FromCudaArray(std::string interface_str) { std::shared_ptr adapter(new CupyAdapter(interface_str)); this->batch_ = adapter; - device_ = adapter->DeviceIdx(); + ctx_.gpu_id = adapter->DeviceIdx(); this->Info().num_col_ = adapter->NumColumns(); this->Info().num_row_ = adapter->NumRows(); - if (device_ < 0) { + if (ctx_.gpu_id < 0) { CHECK_EQ(this->Info().num_row_, 0); } } - } // namespace data } // namespace xgboost diff --git a/src/data/proxy_dmatrix.h b/src/data/proxy_dmatrix.h index fdf274980f9c..8a6f67f144d0 100644 --- a/src/data/proxy_dmatrix.h +++ b/src/data/proxy_dmatrix.h @@ -1,5 +1,5 @@ /*! - * Copyright 2020-2021 XGBoost contributors + * Copyright 2020-2022, XGBoost contributors */ #ifndef XGBOOST_DATA_PROXY_DMATRIX_H_ #define XGBOOST_DATA_PROXY_DMATRIX_H_ @@ -45,7 +45,7 @@ class DataIterProxy { class DMatrixProxy : public DMatrix { MetaInfo info_; dmlc::any batch_; - int32_t device_ { xgboost::GenericParameter::kCpuId }; + Context ctx_; #if defined(XGBOOST_USE_CUDA) void FromCudaColumnar(std::string interface_str); @@ -53,7 +53,7 @@ class DMatrixProxy : public DMatrix { #endif // defined(XGBOOST_USE_CUDA) public: - int DeviceIdx() const { return device_; } + int DeviceIdx() const { return ctx_.gpu_id; } void SetData(char const* c_interface) { common::AssertGPUSupport(); @@ -67,7 +67,7 @@ class DMatrixProxy : public DMatrix { this->FromCudaArray(interface_str); } if (this->info_.num_row_ == 0) { - this->device_ = GenericParameter::kCpuId; + this->ctx_.gpu_id = Context::kCpuId; } #endif // defined(XGBOOST_USE_CUDA) } @@ -79,10 +79,7 @@ class DMatrixProxy : public DMatrix { MetaInfo& Info() override { return info_; } MetaInfo const& Info() const override { return info_; } - GenericParameter const* Ctx() const override { - LOG(FATAL) << "`ProxyDMatrix` doesn't have context."; - return nullptr; - } + Context const* Ctx() const override { return &ctx_; } bool SingleColBlock() const override { return true; } bool EllpackExists() const override { return true; } diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 8d42a4220116..a373ff0196e5 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -149,10 +149,8 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { weights.insert(weights.end(), batch.Weights(), batch.Weights() + batch.Size()); } if (batch.BaseMargin() != nullptr) { - info_.base_margin_ = decltype(info_.base_margin_){batch.BaseMargin(), - batch.BaseMargin() + batch.Size(), - {batch.Size()}, - GenericParameter::kCpuId}; + info_.base_margin_ = decltype(info_.base_margin_){ + batch.BaseMargin(), batch.BaseMargin() + batch.Size(), {batch.Size()}, Context::kCpuId}; } if (batch.Qid() != nullptr) { qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size()); diff --git a/src/data/simple_dmatrix.h b/src/data/simple_dmatrix.h index 8bb438481e80..25546f96469d 100644 --- a/src/data/simple_dmatrix.h +++ b/src/data/simple_dmatrix.h @@ -31,7 +31,7 @@ class SimpleDMatrix : public DMatrix { MetaInfo& Info() override; const MetaInfo& Info() const override; - GenericParameter const* Ctx() const override { return &ctx_; } + Context const* Ctx() const override { return &ctx_; } bool SingleColBlock() const override { return true; } DMatrix* Slice(common::Span ridxs) override; @@ -63,7 +63,7 @@ class SimpleDMatrix : public DMatrix { } private: - GenericParameter ctx_; + Context ctx_; }; } // namespace data } // namespace xgboost diff --git a/src/data/sparse_page_dmatrix.h b/src/data/sparse_page_dmatrix.h index 245032009897..797910836a73 100644 --- a/src/data/sparse_page_dmatrix.h +++ b/src/data/sparse_page_dmatrix.h @@ -69,7 +69,7 @@ class SparsePageDMatrix : public DMatrix { XGDMatrixCallbackNext *next_; float missing_; - GenericParameter ctx_; + Context ctx_; std::string cache_prefix_; uint32_t n_batches_ {0}; // sparse page is the source to other page types, we make a special member function. @@ -100,7 +100,7 @@ class SparsePageDMatrix : public DMatrix { MetaInfo& Info() override; const MetaInfo& Info() const override; - GenericParameter const* Ctx() const override { return &ctx_; } + Context const* Ctx() const override { return &ctx_; } bool SingleColBlock() const override { return false; } DMatrix *Slice(common::Span) override { diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index 719425dee84a..9c48096bf251 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -149,8 +149,7 @@ TEST(CutsBuilder, SearchGroupInd) { group[2] = 7; group[3] = 5; - p_mat->Info().SetInfo( - "group", group.data(), DataType::kUInt32, kNumGroups); + p_mat->SetInfo("group", group.data(), DataType::kUInt32, kNumGroups); HistogramCuts hmat; @@ -350,6 +349,7 @@ void TestSketchFromWeights(bool with_group) { common::HistogramCuts cuts = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0)); MetaInfo info; + Context ctx; auto& h_weights = info.weights_.HostVector(); if (with_group) { h_weights.resize(kGroups); @@ -363,7 +363,7 @@ void TestSketchFromWeights(bool with_group) { for (size_t i = 0; i < kGroups; ++i) { groups[i] = kRows / kGroups; } - info.SetInfo("group", groups.data(), DataType::kUInt32, kGroups); + info.SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups); } info.num_row_ = kRows; @@ -371,10 +371,10 @@ void TestSketchFromWeights(bool with_group) { // Assign weights. if (with_group) { - m->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups); + m->SetInfo("group", groups.data(), DataType::kUInt32, kGroups); } - m->Info().SetInfo("weight", h_weights.data(), DataType::kFloat32, h_weights.size()); + m->SetInfo("weight", h_weights.data(), DataType::kFloat32, h_weights.size()); m->Info().num_col_ = kCols; m->Info().num_row_ = kRows; ASSERT_EQ(cuts.Ptrs().size(), kCols + 1); diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index 4ab9b2c9ebcf..f02bff547c5a 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -520,7 +520,7 @@ TEST(HistUtil, DeviceSketchFromGroupWeights) { for (size_t i = 0; i < kGroups; ++i) { groups[i] = kRows / kGroups; } - m->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups); + m->SetInfo("group", groups.data(), DataType::kUInt32, kGroups); HistogramCuts weighted_cuts = DeviceSketch(0, m.get(), kBins, 0); h_weights.clear(); @@ -550,6 +550,7 @@ void TestAdapterSketchFromWeights(bool with_group) { RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateArrayInterface( &storage); MetaInfo info; + Context ctx; auto& h_weights = info.weights_.HostVector(); if (with_group) { h_weights.resize(kGroups); @@ -563,7 +564,7 @@ void TestAdapterSketchFromWeights(bool with_group) { for (size_t i = 0; i < kGroups; ++i) { groups[i] = kRows / kGroups; } - info.SetInfo("group", groups.data(), DataType::kUInt32, kGroups); + info.SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups); } info.weights_.SetDevice(0); @@ -582,10 +583,10 @@ void TestAdapterSketchFromWeights(bool with_group) { auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols); if (with_group) { - dmat->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups); + dmat->Info().SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups); } - dmat->Info().SetInfo("weight", h_weights.data(), DataType::kFloat32, h_weights.size()); + dmat->Info().SetInfo(ctx, "weight", h_weights.data(), DataType::kFloat32, h_weights.size()); dmat->Info().num_col_ = kCols; dmat->Info().num_row_ = kRows; ASSERT_EQ(cuts.Ptrs().size(), kCols + 1); diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index 2f17f6bfe376..62146b571fb8 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -12,28 +12,29 @@ #include "xgboost/base.h" TEST(MetaInfo, GetSet) { + xgboost::Context ctx; xgboost::MetaInfo info; double double2[2] = {1.0, 2.0}; EXPECT_EQ(info.labels.Size(), 0); - info.SetInfo("label", double2, xgboost::DataType::kFloat32, 2); + info.SetInfo(ctx, "label", double2, xgboost::DataType::kFloat32, 2); EXPECT_EQ(info.labels.Size(), 2); float float2[2] = {1.0f, 2.0f}; EXPECT_EQ(info.GetWeight(1), 1.0f) << "When no weights are given, was expecting default value 1"; - info.SetInfo("weight", float2, xgboost::DataType::kFloat32, 2); + info.SetInfo(ctx, "weight", float2, xgboost::DataType::kFloat32, 2); EXPECT_EQ(info.GetWeight(1), 2.0f); uint32_t uint32_t2[2] = {1U, 2U}; EXPECT_EQ(info.base_margin_.Size(), 0); - info.SetInfo("base_margin", uint32_t2, xgboost::DataType::kUInt32, 2); + info.SetInfo(ctx, "base_margin", uint32_t2, xgboost::DataType::kUInt32, 2); EXPECT_EQ(info.base_margin_.Size(), 2); uint64_t uint64_t2[2] = {1U, 2U}; EXPECT_EQ(info.group_ptr_.size(), 0); - info.SetInfo("group", uint64_t2, xgboost::DataType::kUInt64, 2); + info.SetInfo(ctx, "group", uint64_t2, xgboost::DataType::kUInt64, 2); ASSERT_EQ(info.group_ptr_.size(), 3); EXPECT_EQ(info.group_ptr_[2], 3); @@ -73,6 +74,8 @@ TEST(MetaInfo, GetSetFeature) { TEST(MetaInfo, SaveLoadBinary) { xgboost::MetaInfo info; + xgboost::Context ctx; + uint64_t constexpr kRows { 64 }, kCols { 32 }; auto generator = []() { static float f = 0; @@ -80,9 +83,9 @@ TEST(MetaInfo, SaveLoadBinary) { }; std::vector values (kRows); std::generate(values.begin(), values.end(), generator); - info.SetInfo("label", values.data(), xgboost::DataType::kFloat32, kRows); - info.SetInfo("weight", values.data(), xgboost::DataType::kFloat32, kRows); - info.SetInfo("base_margin", values.data(), xgboost::DataType::kFloat32, kRows); + info.SetInfo(ctx, "label", values.data(), xgboost::DataType::kFloat32, kRows); + info.SetInfo(ctx, "weight", values.data(), xgboost::DataType::kFloat32, kRows); + info.SetInfo(ctx, "base_margin", values.data(), xgboost::DataType::kFloat32, kRows); info.num_row_ = kRows; info.num_col_ = kCols; @@ -210,13 +213,14 @@ TEST(MetaInfo, LoadQid) { TEST(MetaInfo, CPUQid) { xgboost::MetaInfo info; + xgboost::Context ctx; info.num_row_ = 100; std::vector qid(info.num_row_, 0); for (size_t i = 0; i < qid.size(); ++i) { qid[i] = i; } - info.SetInfo("qid", qid.data(), xgboost::DataType::kUInt32, info.num_row_); + info.SetInfo(ctx, "qid", qid.data(), xgboost::DataType::kUInt32, info.num_row_); ASSERT_EQ(info.group_ptr_.size(), info.num_row_ + 1); ASSERT_EQ(info.group_ptr_.front(), 0); ASSERT_EQ(info.group_ptr_.back(), info.num_row_); @@ -232,12 +236,15 @@ TEST(MetaInfo, Validate) { info.num_nonzero_ = 12; info.num_col_ = 3; std::vector groups (11); - info.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, 11); + xgboost::Context ctx; + info.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, 11); EXPECT_THROW(info.Validate(0), dmlc::Error); std::vector labels(info.num_row_ + 1); EXPECT_THROW( - { info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_ + 1); }, + { + info.SetInfo(ctx, "label", labels.data(), xgboost::DataType::kFloat32, info.num_row_ + 1); + }, dmlc::Error); // Make overflow data, which can happen when users pass group structure as int @@ -247,14 +254,13 @@ TEST(MetaInfo, Validate) { groups.push_back(1562500); } groups.push_back(static_cast(-1)); - EXPECT_THROW(info.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, - groups.size()), + EXPECT_THROW(info.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, groups.size()), dmlc::Error); #if defined(XGBOOST_USE_CUDA) info.group_ptr_.clear(); labels.resize(info.num_row_); - info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_); + info.SetInfo(ctx, "label", labels.data(), xgboost::DataType::kFloat32, info.num_row_); info.labels.SetDevice(0); EXPECT_THROW(info.Validate(1), dmlc::Error); @@ -263,12 +269,13 @@ TEST(MetaInfo, Validate) { d_groups.DevicePointer(); // pull to device std::string arr_interface_str{ArrayInterfaceStr( xgboost::linalg::MakeVec(d_groups.ConstDevicePointer(), d_groups.Size(), 0))}; - EXPECT_THROW(info.SetInfo("group", xgboost::StringView{arr_interface_str}), dmlc::Error); + EXPECT_THROW(info.SetInfo(ctx, "group", xgboost::StringView{arr_interface_str}), dmlc::Error); #endif // defined(XGBOOST_USE_CUDA) } TEST(MetaInfo, HostExtend) { xgboost::MetaInfo lhs, rhs; + xgboost::Context ctx; size_t const kRows = 100; lhs.labels.Reshape(kRows); lhs.num_row_ = kRows; @@ -282,8 +289,8 @@ TEST(MetaInfo, HostExtend) { for (size_t g = 0; g < kRows / per_group; ++g) { groups.emplace_back(per_group); } - lhs.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, groups.size()); - rhs.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, groups.size()); + lhs.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, groups.size()); + rhs.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, groups.size()); lhs.Extend(rhs, true, true); ASSERT_EQ(lhs.num_row_, kRows * 2); @@ -300,5 +307,5 @@ TEST(MetaInfo, HostExtend) { } namespace xgboost { -TEST(MetaInfo, CPUStridedData) { TestMetaInfoStridedData(GenericParameter::kCpuId); } +TEST(MetaInfo, CPUStridedData) { TestMetaInfoStridedData(Context::kCpuId); } } // namespace xgboost diff --git a/tests/cpp/data/test_metainfo.cu b/tests/cpp/data/test_metainfo.cu index c02597eef1fd..434b63f64299 100644 --- a/tests/cpp/data/test_metainfo.cu +++ b/tests/cpp/data/test_metainfo.cu @@ -25,14 +25,13 @@ std::string PrepareData(std::string typestr, thrust::device_vector* out, cons std::vector j_shape {Json(Integer(static_cast(kRows)))}; column["shape"] = Array(j_shape); - column["strides"] = Array(std::vector{Json(Integer(static_cast(sizeof(T))))}); + column["strides"] = Array(std::vector{Json(Integer{static_cast(sizeof(T))})}); column["version"] = 3; column["typestr"] = String(typestr); auto p_d_data = d_data.data().get(); - std::vector j_data { - Json(Integer(reinterpret_cast(p_d_data))), - Json(Boolean(false))}; + std::vector j_data{Json(Integer{reinterpret_cast(p_d_data)}), + Json(Boolean(false))}; column["data"] = j_data; column["stream"] = nullptr; Json array(std::vector{column}); @@ -45,12 +44,13 @@ std::string PrepareData(std::string typestr, thrust::device_vector* out, cons TEST(MetaInfo, FromInterface) { cudaSetDevice(0); + Context ctx; thrust::device_vector d_data; std::string str = PrepareData(" expected_group_ptr = {0, 4, 7, 9, 10}; EXPECT_EQ(info.group_ptr_, expected_group_ptr); } @@ -89,10 +89,11 @@ TEST(MetaInfo, GPUStridedData) { TEST(MetaInfo, Group) { cudaSetDevice(0); MetaInfo info; + Context ctx; thrust::device_vector d_uint; std::string uint_str = PrepareData(" d_int64; std::string int_str = PrepareData(" d_float; std::string float_str = PrepareData(" qid(info.num_row_, 0); for (size_t i = 0; i < qid.size(); ++i) { @@ -127,7 +129,7 @@ TEST(MetaInfo, GPUQid) { Json array{std::vector{column}}; std::string array_str; Json::Dump(array, &array_str); - info.SetInfo("qid", array_str.c_str()); + info.SetInfo(ctx, "qid", array_str.c_str()); ASSERT_EQ(info.group_ptr_.size(), info.num_row_ + 1); ASSERT_EQ(info.group_ptr_.front(), 0); ASSERT_EQ(info.group_ptr_.back(), info.num_row_); @@ -142,11 +144,12 @@ TEST(MetaInfo, DeviceExtend) { dh::safe_cuda(cudaSetDevice(0)); size_t const kRows = 100; MetaInfo lhs, rhs; + Context ctx; thrust::device_vector d_data; std::string str = PrepareData("HostCanRead()); lhs.num_row_ = kRows; rhs.num_row_ = kRows; diff --git a/tests/cpp/data/test_metainfo.h b/tests/cpp/data/test_metainfo.h index bb86e16eaefa..6e45b5062b2b 100644 --- a/tests/cpp/data/test_metainfo.h +++ b/tests/cpp/data/test_metainfo.h @@ -16,6 +16,8 @@ namespace xgboost { inline void TestMetaInfoStridedData(int32_t device) { MetaInfo info; + Context ctx; + ctx.UpdateAllowUnknown(Args{{"gpu_id", std::to_string(device)}}); { // labels linalg::Tensor labels; @@ -25,7 +27,7 @@ inline void TestMetaInfoStridedData(int32_t device) { auto t_labels = labels.View(device).Slice(linalg::All(), 0, linalg::All()); ASSERT_EQ(t_labels.Shape().size(), 2); - info.SetInfo("label", StringView{ArrayInterfaceStr(t_labels)}); + info.SetInfo(ctx, "label", StringView{ArrayInterfaceStr(t_labels)}); auto const& h_result = info.labels.View(-1); ASSERT_EQ(h_result.Shape().size(), 2); auto in_labels = labels.View(-1); @@ -46,7 +48,7 @@ inline void TestMetaInfoStridedData(int32_t device) { std::iota(h_qid.begin(), h_qid.end(), 0); auto s = qid.View(device).Slice(linalg::All(), 0); auto str = ArrayInterfaceStr(s); - info.SetInfo("qid", StringView{str}); + info.SetInfo(ctx, "qid", StringView{str}); auto const& h_result = info.group_ptr_; ASSERT_EQ(h_result.size(), s.Size() + 1); } @@ -59,7 +61,7 @@ inline void TestMetaInfoStridedData(int32_t device) { auto t_margin = base_margin.View(device).Slice(linalg::All(), 0, linalg::All()); ASSERT_EQ(t_margin.Shape().size(), 2); - info.SetInfo("base_margin", StringView{ArrayInterfaceStr(t_margin)}); + info.SetInfo(ctx, "base_margin", StringView{ArrayInterfaceStr(t_margin)}); auto const& h_result = info.base_margin_.View(-1); ASSERT_EQ(h_result.Shape().size(), 2); auto in_margin = base_margin.View(-1); diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index c337312a1154..ed2f86c6c123 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -257,7 +257,7 @@ TEST(Dart, Prediction) { for (size_t i = 0; i < kRows; ++i) { labels[i] = i % 2; } - p_mat->Info().SetInfo("label", labels.data(), DataType::kFloat32, kRows); + p_mat->SetInfo("label", labels.data(), DataType::kFloat32, kRows); auto learner = std::unique_ptr(Learner::Create({p_mat})); learner->SetParam("booster", "dart"); diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index eaba41b6aa48..987626df86bf 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -74,11 +74,9 @@ TEST(Learner, CheckGroup) { labels[i] = i % 2; } - p_mat->Info().SetInfo( - "weight", static_cast(weight.data()), DataType::kFloat32, kNumGroups); - p_mat->Info().SetInfo( - "group", group.data(), DataType::kUInt32, kNumGroups); - p_mat->Info().SetInfo("label", labels.data(), DataType::kFloat32, kNumRows); + p_mat->SetInfo("weight", static_cast(weight.data()), DataType::kFloat32, kNumGroups); + p_mat->SetInfo("group", group.data(), DataType::kUInt32, kNumGroups); + p_mat->SetInfo("label", labels.data(), DataType::kFloat32, kNumRows); std::vector> mat = {p_mat}; auto learner = std::unique_ptr(Learner::Create(mat)); @@ -88,7 +86,7 @@ TEST(Learner, CheckGroup) { group.resize(kNumGroups+1); group[3] = 4; group[4] = 1; - p_mat->Info().SetInfo("group", group.data(), DataType::kUInt32, kNumGroups+1); + p_mat->SetInfo("group", group.data(), DataType::kUInt32, kNumGroups+1); EXPECT_ANY_THROW(learner->UpdateOneIter(0, p_mat)); } @@ -105,7 +103,7 @@ TEST(Learner, SLOW_CheckMultiBatch) { // NOLINT for (size_t i = 0; i < num_row; ++i) { labels[i] = i % 2; } - dmat->Info().SetInfo("label", labels.data(), DataType::kFloat32, num_row); + dmat->SetInfo("label", labels.data(), DataType::kFloat32, num_row); std::vector> mat{dmat}; auto learner = std::unique_ptr(Learner::Create(mat)); learner->SetParams(Args{{"objective", "binary:logistic"}});