Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use context in SetInfo. #7687

Merged
merged 1 commit into from
Mar 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,13 @@ class MetaInfo {
* \param dtype The type of the source data.
* \param num Number of elements in the source array.
*/
void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num);
void SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype, size_t num);
/*!
* \brief Set information in the meta info with array interface.
* \param key The key of the information.
* \param interface_str String representation of json format array interface.
*/
void SetInfo(StringView key, StringView interface_str);
void SetInfo(Context const& ctx, StringView key, StringView interface_str);

void GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
const void** out_dptr) const;
Expand All @@ -176,8 +176,8 @@ class MetaInfo {
void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column);

private:
void SetInfoFromHost(StringView key, Json arr);
void SetInfoFromCUDA(StringView key, Json arr);
void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);

/*! \brief argsort of labels */
mutable std::vector<size_t> label_order_cache_;
Expand Down Expand Up @@ -478,12 +478,13 @@ class DMatrix {
DMatrix() = default;
/*! \brief meta information of the dataset */
virtual MetaInfo& Info() = 0;
virtual void SetInfo(const char *key, const void *dptr, DataType dtype,
size_t num) {
this->Info().SetInfo(key, dptr, dtype, num);
virtual void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
auto const& ctx = *this->Ctx();
this->Info().SetInfo(ctx, key, dptr, dtype, num);
}
virtual void SetInfo(const char* key, std::string const& interface_str) {
this->Info().SetInfo(key, StringView{interface_str});
auto const& ctx = *this->Ctx();
this->Info().SetInfo(ctx, key, StringView{interface_str});
}
/*! \brief meta information of the dataset */
virtual const MetaInfo& Info() const = 0;
Expand All @@ -494,7 +495,7 @@ class DMatrix {
* \brief Get the context object of this DMatrix. The context is created during construction of
* DMatrix with user specified `nthread` parameter.
*/
virtual GenericParameter const* Ctx() const = 0;
virtual Context const* Ctx() const = 0;

/**
* \brief Gets batches. Use range based for loop over BatchSet to access individual batches.
Expand Down
2 changes: 2 additions & 0 deletions include/xgboost/generic_parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
.describe("Enable checking whether parameters are used or not.");
}
};

using Context = GenericParameter;
} // namespace xgboost

#endif // XGBOOST_GENERIC_PARAMETERS_H_
42 changes: 17 additions & 25 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -485,35 +485,30 @@ XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle, const char* fname,
API_END();
}

XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
const char* field,
const bst_float* info,
XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, const char *field, const bst_float *info,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo(field, info, xgboost::DataType::kFloat32, len);
auto const& p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
p_fmat->SetInfo(field, info, xgboost::DataType::kFloat32, len);
API_END();
}

XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle,
char const* field,
char const* interface_c_str) {
XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle, char const *field,
char const *interface_c_str) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo(field, interface_c_str);
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
p_fmat->SetInfo(field, interface_c_str);
API_END();
}

XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
const char* field,
const unsigned* info,
XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const unsigned *info,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo(field, info, xgboost::DataType::kUInt32, len);
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
p_fmat->SetInfo(field, info, xgboost::DataType::kUInt32, len);
API_END();
}

Expand Down Expand Up @@ -549,25 +544,22 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
API_END();
}

XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field,
void const *data, xgboost::bst_ulong size,
int type) {
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void const *data,
xgboost::bst_ulong size, int type) {
API_BEGIN();
CHECK_HANDLE();
auto &info = static_cast<std::shared_ptr<DMatrix> *>(handle)->get()->Info();
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
CHECK(type >= 1 && type <= 4);
info.SetInfo(field, data, static_cast<DataType>(type), size);
p_fmat->SetInfo(field, data, static_cast<DataType>(type), size);
API_END();
}

XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
const unsigned* group,
xgboost::bst_ulong len) {
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle, const unsigned *group, xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
LOG(WARNING) << "XGDMatrixSetGroup is deprecated, use `XGDMatrixSetUIntInfo` instead.";
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo("group", group, xgboost::DataType::kUInt32, len);
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
p_fmat->SetInfo("group", group, xgboost::DataType::kUInt32, len);
API_END();
}

Expand Down
51 changes: 26 additions & 25 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ inline bool MetaTryLoadFloatInfo(const std::string& fname,

namespace {
template <int32_t D, typename T>
void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor<T, D>* p_out) {
void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T, D>* p_out) {
ArrayInterface<D> array{arr_interface};
if (array.n == 0) {
p_out->Reshape(array.shape);
Expand All @@ -428,16 +428,15 @@ void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor<T, D>* p_out) {
return;
}
p_out->Reshape(array.shape);
auto t = p_out->View(GenericParameter::kCpuId);
auto t = p_out->View(Context::kCpuId);
CHECK(t.CContiguous());
// FIXME(jiamingy): Remove the use of this default thread.
linalg::ElementWiseTransformHost(t, common::OmpGetNumThreads(0), [&](auto i, auto) {
linalg::ElementWiseTransformHost(t, ctx.Threads(), [&](auto i, auto) {
return linalg::detail::Apply(TypedIndex<T, D>{array}, linalg::UnravelIndex<D>(i, t.Shape()));
});
}
} // namespace

void MetaInfo::SetInfo(StringView key, StringView interface_str) {
void MetaInfo::SetInfo(Context const& ctx, StringView key, StringView interface_str) {
Json j_interface = Json::Load(interface_str);
bool is_cuda{false};
if (IsA<Array>(j_interface)) {
Expand All @@ -454,16 +453,16 @@ void MetaInfo::SetInfo(StringView key, StringView interface_str) {
}

if (is_cuda) {
this->SetInfoFromCUDA(key, j_interface);
this->SetInfoFromCUDA(ctx, key, j_interface);
} else {
this->SetInfoFromHost(key, j_interface);
this->SetInfoFromHost(ctx, key, j_interface);
}
}

void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) {
// multi-dim float info
if (key == "base_margin") {
CopyTensorInfoImpl(arr, &this->base_margin_);
CopyTensorInfoImpl(ctx, arr, &this->base_margin_);
// FIXME(jiamingy): Remove the deprecated API and let all language bindings aware of
// input shape. This issue is CPU only since CUDA uses array interface from day 1.
//
Expand All @@ -477,7 +476,7 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
}
return;
} else if (key == "label") {
CopyTensorInfoImpl(arr, &this->labels);
CopyTensorInfoImpl(ctx, arr, &this->labels);
if (this->num_row_ != 0 && this->labels.Shape(0) != this->num_row_) {
CHECK_EQ(this->labels.Size() % this->num_row_, 0) << "Incorrect size for labels.";
size_t n_targets = this->labels.Size() / this->num_row_;
Expand All @@ -491,7 +490,7 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
// uint info
if (key == "group") {
linalg::Tensor<bst_group_t, 1> t;
CopyTensorInfoImpl(arr, &t);
CopyTensorInfoImpl(ctx, arr, &t);
auto const& h_groups = t.Data()->HostVector();
group_ptr_.clear();
group_ptr_.resize(h_groups.size() + 1, 0);
Expand All @@ -501,7 +500,7 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
return;
} else if (key == "qid") {
linalg::Tensor<bst_group_t, 1> t;
CopyTensorInfoImpl(arr, &t);
CopyTensorInfoImpl(ctx, arr, &t);
bool non_dec = true;
auto const& query_ids = t.Data()->HostVector();
for (size_t i = 1; i < query_ids.size(); ++i) {
Expand All @@ -526,7 +525,7 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
}
// float info
linalg::Tensor<float, 1> t;
CopyTensorInfoImpl<1>(arr, &t);
CopyTensorInfoImpl<1>(ctx, arr, &t);
if (key == "weight") {
this->weights_ = std::move(*t.Data());
auto const& h_weights = this->weights_.ConstHostVector();
Expand All @@ -548,36 +547,38 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
}
}

void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
void MetaInfo::SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype,
size_t num) {
auto proc = [&](auto cast_d_ptr) {
using T = std::remove_pointer_t<decltype(cast_d_ptr)>;
auto t =
linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, GenericParameter::kCpuId);
auto t = linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, Context::kCpuId);
CHECK(t.CContiguous());
Json interface { linalg::ArrayInterface(t) };
Json interface {
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
linalg::ArrayInterface(t)
};
assert(ArrayInterface<1>{interface}.is_contiguous);
return interface;
};
// Legacy code using XGBoost dtype, which is a small subset of array interface types.
switch (dtype) {
case xgboost::DataType::kFloat32: {
auto cast_ptr = reinterpret_cast<const float*>(dptr);
this->SetInfoFromHost(key, proc(cast_ptr));
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break;
}
case xgboost::DataType::kDouble: {
auto cast_ptr = reinterpret_cast<const double*>(dptr);
this->SetInfoFromHost(key, proc(cast_ptr));
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break;
}
case xgboost::DataType::kUInt32: {
auto cast_ptr = reinterpret_cast<const uint32_t*>(dptr);
this->SetInfoFromHost(key, proc(cast_ptr));
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break;
}
case xgboost::DataType::kUInt64: {
auto cast_ptr = reinterpret_cast<const uint64_t*>(dptr);
this->SetInfoFromHost(key, proc(cast_ptr));
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break;
}
default:
Expand Down Expand Up @@ -724,9 +725,7 @@ void MetaInfo::Validate(int32_t device) const {
"doesn't equal to actual number of rows given by data.";
}
auto check_device = [device](HostDeviceVector<float> const& v) {
CHECK(v.DeviceIdx() == GenericParameter::kCpuId ||
device == GenericParameter::kCpuId ||
v.DeviceIdx() == device)
CHECK(v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device)
<< "Data is resided on a different device than `gpu_id`. "
<< "Device that data is on: " << v.DeviceIdx() << ", "
<< "`gpu_id` for XGBoost: " << device;
Expand Down Expand Up @@ -769,7 +768,9 @@ void MetaInfo::Validate(int32_t device) const {
}

#if !defined(XGBOOST_USE_CUDA)
void MetaInfo::SetInfoFromCUDA(StringView key, Json arr) { common::AssertGPUSupport(); }
void MetaInfo::SetInfoFromCUDA(Context const& ctx, StringView key, Json arr) {
common::AssertGPUSupport();
}
#endif // !defined(XGBOOST_USE_CUDA)

using DMatrixThreadLocal =
Expand Down
3 changes: 2 additions & 1 deletion src/data/data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ void CopyQidImpl(ArrayInterface<1> array_interface, std::vector<bst_group_t>* p_
}
} // namespace

void MetaInfo::SetInfoFromCUDA(StringView key, Json array) {
// Context is not used until we have CUDA stream.
void MetaInfo::SetInfoFromCUDA(Context const&, StringView key, Json array) {
// multi-dim float info
if (key == "base_margin") {
CopyTensorInfoImpl(array, &base_margin_);
Expand Down
10 changes: 5 additions & 5 deletions src/data/iterative_device_dmatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,18 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
size_t batches = 0;
size_t accumulated_rows = 0;
bst_feature_t cols = 0;
int32_t device = GenericParameter::kCpuId;

int32_t current_device;
dh::safe_cuda(cudaGetDevice(&current_device));
auto get_device = [&]() -> int32_t {
int32_t d = (device == GenericParameter::kCpuId) ? current_device : device;
CHECK_NE(d, GenericParameter::kCpuId);
int32_t d = (ctx_.gpu_id == Context::kCpuId) ? current_device : ctx_.gpu_id;
CHECK_NE(d, Context::kCpuId);
return d;
};

while (iter.Next()) {
device = proxy->DeviceIdx();
CHECK_LT(device, common::AllVisibleGPUs());
ctx_.gpu_id = proxy->DeviceIdx();
CHECK_LT(ctx_.gpu_id, common::AllVisibleGPUs());
dh::safe_cuda(cudaSetDevice(get_device()));
if (cols == 0) {
cols = num_cols();
Expand Down
6 changes: 2 additions & 4 deletions src/data/iterative_device_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace data {

class IterativeDeviceDMatrix : public DMatrix {
MetaInfo info_;
Context ctx_;
BatchParam batch_param_;
std::shared_ptr<EllpackPage> page_;

Expand Down Expand Up @@ -72,10 +73,7 @@ class IterativeDeviceDMatrix : public DMatrix {
MetaInfo &Info() override { return info_; }
MetaInfo const &Info() const override { return info_; }

GenericParameter const *Ctx() const override {
LOG(FATAL) << "`IterativeDMatrix` doesn't have context.";
return nullptr;
}
Context const *Ctx() const override { return &ctx_; }
};

#if !defined(XGBOOST_USE_CUDA)
Expand Down
11 changes: 5 additions & 6 deletions src/data/proxy_dmatrix.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2020 XGBoost contributors
* Copyright 2020-2022, XGBoost contributors
*/
#include "proxy_dmatrix.h"
#include "device_adapter.cuh"
Expand All @@ -11,24 +11,23 @@ void DMatrixProxy::FromCudaColumnar(std::string interface_str) {
std::shared_ptr<data::CudfAdapter> adapter {new data::CudfAdapter(interface_str)};
auto const& value = adapter->Value();
this->batch_ = adapter;
device_ = adapter->DeviceIdx();
ctx_.gpu_id = adapter->DeviceIdx();
this->Info().num_col_ = adapter->NumColumns();
this->Info().num_row_ = adapter->NumRows();
if (device_ < 0) {
if (ctx_.gpu_id < 0) {
CHECK_EQ(this->Info().num_row_, 0);
}
}

void DMatrixProxy::FromCudaArray(std::string interface_str) {
std::shared_ptr<CupyAdapter> adapter(new CupyAdapter(interface_str));
this->batch_ = adapter;
device_ = adapter->DeviceIdx();
ctx_.gpu_id = adapter->DeviceIdx();
this->Info().num_col_ = adapter->NumColumns();
this->Info().num_row_ = adapter->NumRows();
if (device_ < 0) {
if (ctx_.gpu_id < 0) {
CHECK_EQ(this->Info().num_row_, 0);
}
}

} // namespace data
} // namespace xgboost
Loading