Skip to content

Commit

Permalink
Use adapter directly.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed May 14, 2020
1 parent d700ecc commit 155be2b
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 18 deletions.
1 change: 1 addition & 0 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,7 @@ def set_data_from_cuda_interface(self, data):
)

def set_data_from_cuda_columnar(self, data):
'''Set data from CUDA columnar format.1'''
interfaces_str = _cudf_array_interfaces(data)
_check_call(
_LIB.XGDMatrixSetDataCudaColumnar(
Expand Down
24 changes: 12 additions & 12 deletions src/data/iterative_device_dmatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,28 @@ namespace data {

#define DISPATCH_MEM(__Proxy, __Fn) \
[](DMatrixProxy const* proxy) -> decltype( \
(dmlc::get<CupyAdapterBatch>(proxy->Value())).__Fn()) { \
if (proxy->Value().type() == typeid(CupyAdapterBatch)) { \
return (dmlc::get<CupyAdapterBatch>(proxy->Value())).__Fn(); \
} else if (proxy->Value().type() == typeid(CudfAdapterBatch)) { \
return (dmlc::get<CudfAdapterBatch>(proxy->Value())).__Fn(); \
(dmlc::get<CupyAdapter>(proxy->Adapter()).Value()).__Fn()) { \
if (proxy->Adapter().type() == typeid(CupyAdapter)) { \
return (dmlc::get<CupyAdapter>(proxy->Adapter()).Value()).__Fn(); \
} else if (proxy->Adapter().type() == typeid(CudfAdapter)) { \
return (dmlc::get<CudfAdapter>(proxy->Adapter()).Value()).__Fn(); \
} else { \
LOG(FATAL) << "Unknown type"; \
} \
return 0; \
}(__Proxy)

#define DISPATCH_FN(__Proxy, __Fn, ...) \
[&](DMatrixProxy const* proxy) { \
if (proxy->Value().type() == typeid(CupyAdapterBatch)) { \
return __Fn((dmlc::get<CupyAdapterBatch>(proxy->Value())), \
if (proxy->Adapter().type() == typeid(CupyAdapter)) { \
return __Fn((dmlc::get<CupyAdapter>(proxy->Adapter()).Value()), \
__VA_ARGS__); \
} else if (proxy->Value().type() == typeid(CudfAdapterBatch)) { \
return __Fn((dmlc::get<CudfAdapterBatch>(proxy->Value())), \
} else if (proxy->Adapter().type() == typeid(CudfAdapter)) { \
return __Fn((dmlc::get<CudfAdapter>(proxy->Adapter()).Value()), \
__VA_ARGS__); \
} else { \
LOG(FATAL) << "Unknown type"; \
return __Fn((dmlc::get<CudfAdapterBatch>(proxy->Value())), \
return __Fn((dmlc::get<CudfAdapter>(proxy->Adapter()).Value()), \
__VA_ARGS__); \
} \
}(__Proxy)
Expand Down Expand Up @@ -63,7 +65,6 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
common::SketchContainer sketch_container(p_.max_bin, cols, rows);
auto device = 0;
while (iter.Next()) {
auto batch = proxy->Value();
DISPATCH_FN(proxy, common::AdapterDeviceSketch,
p_.max_bin, missing, rows, cols, device,
&sketch_container);
Expand All @@ -74,7 +75,6 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
pages_.reset(new std::vector<EllpackPage>);
iter.Reset();
while (iter.Next()) {
auto batch = proxy->Value();
auto rows = DISPATCH_MEM(proxy, NumRows);
dh::caching_device_vector<size_t> row_counts(rows + 1, 0);
common::Span<size_t> row_counts_span(row_counts.data().get(),
Expand Down
21 changes: 21 additions & 0 deletions src/data/proxy_dmatrix.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*!
* Copyright 2020 XGBoost contributors
*/
#include "proxy_dmatrix.h"
#include "device_adapter.cuh"

namespace xgboost {
namespace data {

void DMatrixProxy::FromCudaColumnar(std::string interface_str) {
data::CudfAdapter adapter(interface_str);
this->batch_ = adapter;
}

void DMatrixProxy::FromCudaArray(std::string interface_str) {
data::CupyAdapter adapter(interface_str);
this->batch_ = adapter;
}

} // namespace data
} // namespace xgboost
16 changes: 12 additions & 4 deletions src/data/proxy_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,24 @@ class DMatrixProxy : public DMatrix {
void SetInfo(const char* key, std::string const& interface_str) override {
this->Info().SetInfo(key, interface_str);
}

#if defined(XGBOOST_USE_CUDA)
void FromCudaColumnar(std::string interface_str);
void FromCudaArray(std::string interface_str);
#endif // defined(XGBOOST_USE_CUDA)

void SetData(char const* c_interface) {
common::AssertGPUSupport();
#if defined(XGBOOST_USE_CUDA)
std::string interface_str = c_interface;
Json json_array_interface =
Json::Load({interface_str.c_str(), interface_str.size()});
if (IsA<Array>(json_array_interface)) {
LOG(FATAL) << "Not implemented";
this->FromCudaColumnar(interface_str);
} else {
ArrayInterface interface{get<Object const>(json_array_interface)};
this->batch_ = CupyAdapterBatch{interface, &this->info_};
this->FromCudaArray(interface_str);
}
#endif // defined(XGBOOST_USE_CUDA)
}

MetaInfo& Info() override { return info_; }
Expand Down Expand Up @@ -91,7 +99,7 @@ class DMatrixProxy : public DMatrix {
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(nullptr));
}

dmlc::any Value() const {
dmlc::any Adapter() const {
return batch_;
}
};
Expand Down
2 changes: 1 addition & 1 deletion src/data/simple_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
if (batch.BaseMargin() != nullptr) {
auto& base_margin = info_.base_margin_.HostVector();
base_margin.insert(base_margin.end(), batch.BaseMargin(),
batch.BaseMargin() + batch.Size());
batch.BaseMargin() + batch.Size());
}
if (batch.Qid() != nullptr) {
qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size());
Expand Down
2 changes: 1 addition & 1 deletion tests/python-gpu/test_from_cudf.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def as_array(self):

def as_array_labels(self):
import cupy
return cupy.concat(self._labels)
return cupy.concatenate(self._labels)

def data(self):
'''Utility function for obtaining current batch of data.'''
Expand Down

0 comments on commit 155be2b

Please sign in to comment.