Skip to content

Commit

Permalink
Empty dataset handling.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jul 20, 2021
1 parent 8629f51 commit d8059f4
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 11 deletions.
14 changes: 12 additions & 2 deletions src/data/data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@ void CopyInfoImpl(ArrayInterface column, HostDeviceVector<float>* out) {
cudaPointerAttributes attr;
dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr));
int32_t ptr_device = attr.device;
dh::safe_cuda(cudaSetDevice(ptr_device));
if (ptr_device >= 0) {
dh::safe_cuda(cudaSetDevice(ptr_device));
}
return ptr_device;
};
auto ptr_device = SetDeviceToPtr(column.data);

if (column.num_rows == 0) {
return;
}
out->SetDevice(ptr_device);
out->Resize(column.num_rows);

Expand Down Expand Up @@ -123,7 +128,12 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
<< "MetaInfo: " << c_key << ". " << ArrayInterfaceErrors::Dimension(1);
ArrayInterface array_interface(interface_str);
std::string key{c_key};
array_interface.AsColumnVector();
if (!((array_interface.num_cols == 1 && array_interface.num_rows == 0) ||
(array_interface.num_cols == 0 && array_interface.num_rows == 1))) {
// Not an empty column, transform it.
array_interface.AsColumnVector();
}

CHECK(!array_interface.valid.Data())
<< "Meta info " << key << " should be dense, found validity mask";
if (array_interface.num_rows == 0) {
Expand Down
6 changes: 3 additions & 3 deletions src/data/device_adapter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {

size_t NumRows() const { return num_rows_; }
size_t NumColumns() const { return columns_.size(); }
size_t DeviceIdx() const { return device_idx_; }
int32_t DeviceIdx() const { return device_idx_; }

private:
CudfAdapterBatch batch_;
Expand Down Expand Up @@ -202,12 +202,12 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {

size_t NumRows() const { return array_interface_.num_rows; }
size_t NumColumns() const { return array_interface_.num_cols; }
size_t DeviceIdx() const { return device_idx_; }
int32_t DeviceIdx() const { return device_idx_; }

private:
ArrayInterface array_interface_;
CupyAdapterBatch batch_;
int device_idx_;
int32_t device_idx_ {-1};
};

// Returns maximum row length
Expand Down
1 change: 1 addition & 0 deletions src/data/ellpack_page_source.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
namespace xgboost {
namespace data {
void EllpackPageSource::Fetch() {
dh::safe_cuda(cudaSetDevice(param_.gpu_id));
if (!this->ReadCache()) {
auto const &csr = source_->Page();
this->page_.reset(new EllpackPage{});
Expand Down
6 changes: 6 additions & 0 deletions src/data/proxy_dmatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ void DMatrixProxy::FromCudaColumnar(std::string interface_str) {
device_ = adapter->DeviceIdx();
this->Info().num_col_ = adapter->NumColumns();
this->Info().num_row_ = adapter->NumRows();
if (device_ < 0) {
CHECK_EQ(this->Info().num_row_, 0);
}
}

void DMatrixProxy::FromCudaArray(std::string interface_str) {
Expand All @@ -22,6 +25,9 @@ void DMatrixProxy::FromCudaArray(std::string interface_str) {
device_ = adapter->DeviceIdx();
this->Info().num_col_ = adapter->NumColumns();
this->Info().num_row_ = adapter->NumRows();
if (device_ < 0) {
CHECK_EQ(this->Info().num_row_, 0);
}
}

} // namespace data
Expand Down
9 changes: 6 additions & 3 deletions src/data/simple_dmatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ namespace data {
// be supported in future. Does not currently support inferring row/column size
template <typename AdapterT>
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
dh::safe_cuda(cudaSetDevice(adapter->DeviceIdx()));
auto device =
adapter->DeviceIdx() < 0 ? dh::CurrentDevice() : adapter->DeviceIdx();
CHECK_GE(device, 0);
dh::safe_cuda(cudaSetDevice(device));

CHECK(adapter->NumRows() != kAdapterUnknownSize);
CHECK(adapter->NumColumns() != kAdapterUnknownSize);
Expand All @@ -27,8 +30,8 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
// Enforce single batch
CHECK(!adapter->Next());

info_.num_nonzero_ = CopyToSparsePage(adapter->Value(), adapter->DeviceIdx(),
missing, sparse_page_.get());
info_.num_nonzero_ =
CopyToSparsePage(adapter->Value(), device, missing, sparse_page_.get());
info_.num_col_ = adapter->NumColumns();
info_.num_row_ = adapter->NumRows();
// Synchronise worker columns
Expand Down
5 changes: 5 additions & 0 deletions src/data/sparse_page_source.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ size_t NFeaturesDevice(DMatrixProxy *proxy) {

void DevicePush(DMatrixProxy* proxy, float missing, SparsePage* page) {
auto device = proxy->DeviceIdx();
if (device < 0) {
device = dh::CurrentDevice();
}
CHECK_GE(device, 0);

Dispatch(proxy, [&](auto const &value) {
CopyToSparsePage(value, device, missing, page);
});
Expand Down
6 changes: 3 additions & 3 deletions tests/python/test_data_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ def run_data_iterator(
n_features: int,
n_batches: int,
tree_method: str,
cupy: bool,
use_cupy: bool,
) -> None:
n_rounds = 2

it = IteratorForTest(
*make_batches(n_samples_per_batch, n_features, n_batches, cupy)
*make_batches(n_samples_per_batch, n_features, n_batches, use_cupy)
)
if n_batches == 0:
with pytest.raises(ValueError, match="1 batch"):
Expand Down Expand Up @@ -103,8 +103,8 @@ def run_data_iterator(
if tree_method != "gpu_hist":
rtol = 1e-1 # flaky
else:
np.testing.assert_allclose(it_predt, arr_predt, rtol=1e-3)
rtol = 1e-6
np.testing.assert_allclose(it_predt, arr_predt)

np.testing.assert_allclose(
results_from_it["Train"]["rmse"],
Expand Down

0 comments on commit d8059f4

Please sign in to comment.