Skip to content

Commit

Permalink
Use context in SetInfo. (#7687)
Browse files Browse the repository at this point in the history
* Use the name `Context`.
* Pass a context object into `SetInfo`.
* Add context to proxy matrix.
* Add context to iterative DMatrix.

This is to remove the use of the default number of threads during `SetInfo` as a follow-up on
removing the global omp variable while preparing for CUDA stream semantic.  Currently, XGBoost
uses the legacy CUDA stream, we will gradually remove them in the future in favor of non-blocking streams.
  • Loading branch information
trivialfis authored Mar 24, 2022
1 parent f5b2028 commit 6457559
Show file tree
Hide file tree
Showing 19 changed files with 142 additions and 142 deletions.
19 changes: 10 additions & 9 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,13 @@ class MetaInfo {
* \param dtype The type of the source data.
* \param num Number of elements in the source array.
*/
void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num);
void SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype, size_t num);
/*!
* \brief Set information in the meta info with array interface.
* \param key The key of the information.
* \param interface_str String representation of json format array interface.
*/
void SetInfo(StringView key, StringView interface_str);
void SetInfo(Context const& ctx, StringView key, StringView interface_str);

void GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
const void** out_dptr) const;
Expand All @@ -176,8 +176,8 @@ class MetaInfo {
void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column);

private:
void SetInfoFromHost(StringView key, Json arr);
void SetInfoFromCUDA(StringView key, Json arr);
void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);

/*! \brief argsort of labels */
mutable std::vector<size_t> label_order_cache_;
Expand Down Expand Up @@ -478,12 +478,13 @@ class DMatrix {
DMatrix() = default;
/*! \brief meta information of the dataset */
virtual MetaInfo& Info() = 0;
virtual void SetInfo(const char *key, const void *dptr, DataType dtype,
size_t num) {
this->Info().SetInfo(key, dptr, dtype, num);
virtual void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
auto const& ctx = *this->Ctx();
this->Info().SetInfo(ctx, key, dptr, dtype, num);
}
virtual void SetInfo(const char* key, std::string const& interface_str) {
this->Info().SetInfo(key, StringView{interface_str});
auto const& ctx = *this->Ctx();
this->Info().SetInfo(ctx, key, StringView{interface_str});
}
/*! \brief meta information of the dataset */
virtual const MetaInfo& Info() const = 0;
Expand All @@ -494,7 +495,7 @@ class DMatrix {
* \brief Get the context object of this DMatrix. The context is created during construction of
* DMatrix with user specified `nthread` parameter.
*/
virtual GenericParameter const* Ctx() const = 0;
virtual Context const* Ctx() const = 0;

/**
* \brief Gets batches. Use range based for loop over BatchSet to access individual batches.
Expand Down
2 changes: 2 additions & 0 deletions include/xgboost/generic_parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
.describe("Enable checking whether parameters are used or not.");
}
};

using Context = GenericParameter;
} // namespace xgboost

#endif // XGBOOST_GENERIC_PARAMETERS_H_
42 changes: 17 additions & 25 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -485,35 +485,30 @@ XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle, const char* fname,
API_END();
}

XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
const char* field,
const bst_float* info,
XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, const char *field, const bst_float *info,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo(field, info, xgboost::DataType::kFloat32, len);
auto const& p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
p_fmat->SetInfo(field, info, xgboost::DataType::kFloat32, len);
API_END();
}

XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle,
char const* field,
char const* interface_c_str) {
XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle, char const *field,
char const *interface_c_str) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo(field, interface_c_str);
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
p_fmat->SetInfo(field, interface_c_str);
API_END();
}

XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
const char* field,
const unsigned* info,
XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const unsigned *info,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo(field, info, xgboost::DataType::kUInt32, len);
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
p_fmat->SetInfo(field, info, xgboost::DataType::kUInt32, len);
API_END();
}

Expand Down Expand Up @@ -549,25 +544,22 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
API_END();
}

XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field,
void const *data, xgboost::bst_ulong size,
int type) {
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void const *data,
xgboost::bst_ulong size, int type) {
API_BEGIN();
CHECK_HANDLE();
auto &info = static_cast<std::shared_ptr<DMatrix> *>(handle)->get()->Info();
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
CHECK(type >= 1 && type <= 4);
info.SetInfo(field, data, static_cast<DataType>(type), size);
p_fmat->SetInfo(field, data, static_cast<DataType>(type), size);
API_END();
}

XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
const unsigned* group,
xgboost::bst_ulong len) {
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle, const unsigned *group, xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
LOG(WARNING) << "XGDMatrixSetGroup is deprecated, use `XGDMatrixSetUIntInfo` instead.";
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo("group", group, xgboost::DataType::kUInt32, len);
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
p_fmat->SetInfo("group", group, xgboost::DataType::kUInt32, len);
API_END();
}

Expand Down
51 changes: 26 additions & 25 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ inline bool MetaTryLoadFloatInfo(const std::string& fname,

namespace {
template <int32_t D, typename T>
void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor<T, D>* p_out) {
void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T, D>* p_out) {
ArrayInterface<D> array{arr_interface};
if (array.n == 0) {
p_out->Reshape(array.shape);
Expand All @@ -428,16 +428,15 @@ void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor<T, D>* p_out) {
return;
}
p_out->Reshape(array.shape);
auto t = p_out->View(GenericParameter::kCpuId);
auto t = p_out->View(Context::kCpuId);
CHECK(t.CContiguous());
// FIXME(jiamingy): Remove the use of this default thread.
linalg::ElementWiseTransformHost(t, common::OmpGetNumThreads(0), [&](auto i, auto) {
linalg::ElementWiseTransformHost(t, ctx.Threads(), [&](auto i, auto) {
return linalg::detail::Apply(TypedIndex<T, D>{array}, linalg::UnravelIndex<D>(i, t.Shape()));
});
}
} // namespace

void MetaInfo::SetInfo(StringView key, StringView interface_str) {
void MetaInfo::SetInfo(Context const& ctx, StringView key, StringView interface_str) {
Json j_interface = Json::Load(interface_str);
bool is_cuda{false};
if (IsA<Array>(j_interface)) {
Expand All @@ -454,16 +453,16 @@ void MetaInfo::SetInfo(StringView key, StringView interface_str) {
}

if (is_cuda) {
this->SetInfoFromCUDA(key, j_interface);
this->SetInfoFromCUDA(ctx, key, j_interface);
} else {
this->SetInfoFromHost(key, j_interface);
this->SetInfoFromHost(ctx, key, j_interface);
}
}

void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) {
// multi-dim float info
if (key == "base_margin") {
CopyTensorInfoImpl(arr, &this->base_margin_);
CopyTensorInfoImpl(ctx, arr, &this->base_margin_);
// FIXME(jiamingy): Remove the deprecated API and let all language bindings aware of
// input shape. This issue is CPU only since CUDA uses array interface from day 1.
//
Expand All @@ -477,7 +476,7 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
}
return;
} else if (key == "label") {
CopyTensorInfoImpl(arr, &this->labels);
CopyTensorInfoImpl(ctx, arr, &this->labels);
if (this->num_row_ != 0 && this->labels.Shape(0) != this->num_row_) {
CHECK_EQ(this->labels.Size() % this->num_row_, 0) << "Incorrect size for labels.";
size_t n_targets = this->labels.Size() / this->num_row_;
Expand All @@ -491,7 +490,7 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
// uint info
if (key == "group") {
linalg::Tensor<bst_group_t, 1> t;
CopyTensorInfoImpl(arr, &t);
CopyTensorInfoImpl(ctx, arr, &t);
auto const& h_groups = t.Data()->HostVector();
group_ptr_.clear();
group_ptr_.resize(h_groups.size() + 1, 0);
Expand All @@ -501,7 +500,7 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
return;
} else if (key == "qid") {
linalg::Tensor<bst_group_t, 1> t;
CopyTensorInfoImpl(arr, &t);
CopyTensorInfoImpl(ctx, arr, &t);
bool non_dec = true;
auto const& query_ids = t.Data()->HostVector();
for (size_t i = 1; i < query_ids.size(); ++i) {
Expand All @@ -526,7 +525,7 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
}
// float info
linalg::Tensor<float, 1> t;
CopyTensorInfoImpl<1>(arr, &t);
CopyTensorInfoImpl<1>(ctx, arr, &t);
if (key == "weight") {
this->weights_ = std::move(*t.Data());
auto const& h_weights = this->weights_.ConstHostVector();
Expand All @@ -548,36 +547,38 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
}
}

void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
void MetaInfo::SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype,
size_t num) {
auto proc = [&](auto cast_d_ptr) {
using T = std::remove_pointer_t<decltype(cast_d_ptr)>;
auto t =
linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, GenericParameter::kCpuId);
auto t = linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, Context::kCpuId);
CHECK(t.CContiguous());
Json interface { linalg::ArrayInterface(t) };
Json interface {
linalg::ArrayInterface(t)
};
assert(ArrayInterface<1>{interface}.is_contiguous);
return interface;
};
// Legacy code using XGBoost dtype, which is a small subset of array interface types.
switch (dtype) {
case xgboost::DataType::kFloat32: {
auto cast_ptr = reinterpret_cast<const float*>(dptr);
this->SetInfoFromHost(key, proc(cast_ptr));
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break;
}
case xgboost::DataType::kDouble: {
auto cast_ptr = reinterpret_cast<const double*>(dptr);
this->SetInfoFromHost(key, proc(cast_ptr));
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break;
}
case xgboost::DataType::kUInt32: {
auto cast_ptr = reinterpret_cast<const uint32_t*>(dptr);
this->SetInfoFromHost(key, proc(cast_ptr));
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break;
}
case xgboost::DataType::kUInt64: {
auto cast_ptr = reinterpret_cast<const uint64_t*>(dptr);
this->SetInfoFromHost(key, proc(cast_ptr));
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break;
}
default:
Expand Down Expand Up @@ -724,9 +725,7 @@ void MetaInfo::Validate(int32_t device) const {
"doesn't equal to actual number of rows given by data.";
}
auto check_device = [device](HostDeviceVector<float> const& v) {
CHECK(v.DeviceIdx() == GenericParameter::kCpuId ||
device == GenericParameter::kCpuId ||
v.DeviceIdx() == device)
CHECK(v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device)
<< "Data is resided on a different device than `gpu_id`. "
<< "Device that data is on: " << v.DeviceIdx() << ", "
<< "`gpu_id` for XGBoost: " << device;
Expand Down Expand Up @@ -769,7 +768,9 @@ void MetaInfo::Validate(int32_t device) const {
}

#if !defined(XGBOOST_USE_CUDA)
void MetaInfo::SetInfoFromCUDA(StringView key, Json arr) { common::AssertGPUSupport(); }
void MetaInfo::SetInfoFromCUDA(Context const& ctx, StringView key, Json arr) {
common::AssertGPUSupport();
}
#endif // !defined(XGBOOST_USE_CUDA)

using DMatrixThreadLocal =
Expand Down
3 changes: 2 additions & 1 deletion src/data/data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ void CopyQidImpl(ArrayInterface<1> array_interface, std::vector<bst_group_t>* p_
}
} // namespace

void MetaInfo::SetInfoFromCUDA(StringView key, Json array) {
// Context is not used until we have CUDA stream.
void MetaInfo::SetInfoFromCUDA(Context const&, StringView key, Json array) {
// multi-dim float info
if (key == "base_margin") {
CopyTensorInfoImpl(array, &base_margin_);
Expand Down
10 changes: 5 additions & 5 deletions src/data/iterative_device_dmatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,18 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
size_t batches = 0;
size_t accumulated_rows = 0;
bst_feature_t cols = 0;
int32_t device = GenericParameter::kCpuId;

int32_t current_device;
dh::safe_cuda(cudaGetDevice(&current_device));
auto get_device = [&]() -> int32_t {
int32_t d = (device == GenericParameter::kCpuId) ? current_device : device;
CHECK_NE(d, GenericParameter::kCpuId);
int32_t d = (ctx_.gpu_id == Context::kCpuId) ? current_device : ctx_.gpu_id;
CHECK_NE(d, Context::kCpuId);
return d;
};

while (iter.Next()) {
device = proxy->DeviceIdx();
CHECK_LT(device, common::AllVisibleGPUs());
ctx_.gpu_id = proxy->DeviceIdx();
CHECK_LT(ctx_.gpu_id, common::AllVisibleGPUs());
dh::safe_cuda(cudaSetDevice(get_device()));
if (cols == 0) {
cols = num_cols();
Expand Down
6 changes: 2 additions & 4 deletions src/data/iterative_device_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace data {

class IterativeDeviceDMatrix : public DMatrix {
MetaInfo info_;
Context ctx_;
BatchParam batch_param_;
std::shared_ptr<EllpackPage> page_;

Expand Down Expand Up @@ -72,10 +73,7 @@ class IterativeDeviceDMatrix : public DMatrix {
MetaInfo &Info() override { return info_; }
MetaInfo const &Info() const override { return info_; }

GenericParameter const *Ctx() const override {
LOG(FATAL) << "`IterativeDMatrix` doesn't have context.";
return nullptr;
}
Context const *Ctx() const override { return &ctx_; }
};

#if !defined(XGBOOST_USE_CUDA)
Expand Down
11 changes: 5 additions & 6 deletions src/data/proxy_dmatrix.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2020 XGBoost contributors
* Copyright 2020-2022, XGBoost contributors
*/
#include "proxy_dmatrix.h"
#include "device_adapter.cuh"
Expand All @@ -11,24 +11,23 @@ void DMatrixProxy::FromCudaColumnar(std::string interface_str) {
std::shared_ptr<data::CudfAdapter> adapter {new data::CudfAdapter(interface_str)};
auto const& value = adapter->Value();
this->batch_ = adapter;
device_ = adapter->DeviceIdx();
ctx_.gpu_id = adapter->DeviceIdx();
this->Info().num_col_ = adapter->NumColumns();
this->Info().num_row_ = adapter->NumRows();
if (device_ < 0) {
if (ctx_.gpu_id < 0) {
CHECK_EQ(this->Info().num_row_, 0);
}
}

void DMatrixProxy::FromCudaArray(std::string interface_str) {
std::shared_ptr<CupyAdapter> adapter(new CupyAdapter(interface_str));
this->batch_ = adapter;
device_ = adapter->DeviceIdx();
ctx_.gpu_id = adapter->DeviceIdx();
this->Info().num_col_ = adapter->NumColumns();
this->Info().num_row_ = adapter->NumRows();
if (device_ < 0) {
if (ctx_.gpu_id < 0) {
CHECK_EQ(this->Info().num_row_, 0);
}
}

} // namespace data
} // namespace xgboost
Loading

0 comments on commit 6457559

Please sign in to comment.