Skip to content

Commit

Permalink
Dask device dmatrix (#5901)
Browse files Browse the repository at this point in the history
* Fix softprob with empty dmatrix.
  • Loading branch information
trivialfis authored Jul 17, 2020
1 parent e471056 commit 7c26861
Show file tree
Hide file tree
Showing 12 changed files with 392 additions and 149 deletions.
143 changes: 142 additions & 1 deletion python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import platform
import logging
from collections import defaultdict
from collections.abc import Sequence
from threading import Thread

import numpy
Expand All @@ -28,7 +29,7 @@
from .compat import CUDF_concat
from .compat import lazy_isinstance

from .core import DMatrix, Booster, _expect
from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter
from .training import train as worker_train
from .tracker import RabitTracker
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
Expand Down Expand Up @@ -357,6 +358,146 @@ def get_worker_data_shape(self, worker):
return (rows, cols)


class DaskPartitionIter(DataIter): # pylint: disable=R0902
'''A data iterator for `DaskDeviceQuantileDMatrix`.
'''
def __init__(self, data, label=None, weight=None, base_margin=None,
label_lower_bound=None, label_upper_bound=None,
feature_names=None, feature_types=None):
self._data = data
self._labels = label
self._weights = weight
self._base_margin = base_margin
self._label_lower_bound = label_lower_bound
self._label_upper_bound = label_upper_bound
self._feature_names = feature_names
self._feature_types = feature_types

assert isinstance(self._data, Sequence)

types = (Sequence, type(None))
assert isinstance(self._labels, types)
assert isinstance(self._weights, types)
assert isinstance(self._base_margin, types)
assert isinstance(self._label_lower_bound, types)
assert isinstance(self._label_upper_bound, types)

self._iter = 0 # set iterator to 0
super().__init__()

def data(self):
'''Utility function for obtaining current batch of data.'''
return self._data[self._iter]

def labels(self):
'''Utility function for obtaining current batch of label.'''
if self._labels is not None:
return self._labels[self._iter]
return None

def weights(self):
'''Utility function for obtaining current batch of label.'''
if self._weights is not None:
return self._weights[self._iter]
return None

def base_margins(self):
'''Utility function for obtaining current batch of base_margin.'''
if self._base_margin is not None:
return self._base_margin[self._iter]
return None

def label_lower_bounds(self):
'''Utility function for obtaining current batch of label_lower_bound.
'''
if self._label_lower_bound is not None:
return self._label_lower_bound[self._iter]
return None

def label_upper_bounds(self):
'''Utility function for obtaining current batch of label_upper_bound.
'''
if self._label_upper_bound is not None:
return self._label_upper_bound[self._iter]
return None

def reset(self):
'''Reset the iterator'''
self._iter = 0

def next(self, input_data):
'''Yield next batch of data'''
if self._iter == len(self._data):
# Return 0 when there's no more batch.
return 0
if self._feature_names:
feature_names = self._feature_names
else:
if hasattr(self.data(), 'columns'):
feature_names = self.data().columns.format()
else:
feature_names = None
input_data(data=self.data(), label=self.labels(),
weight=self.weights(), group=None,
label_lower_bound=self.label_lower_bounds(),
label_upper_bound=self.label_upper_bounds(),
feature_names=feature_names,
feature_types=self._feature_types)
self._iter += 1
return 1


class DaskDeviceQuantileDMatrix(DaskDMatrix):
'''Specialized data type for `gpu_hist` tree method. This class is
used to reduce the memory usage by eliminating data copies.
Internally the data is merged by weighted GK sketching. So the
number of partitions from dask may affect training accuracy as GK
generates error for each merge.
.. versionadded:: 1.2.0
Parameters
----------
max_bin: Number of bins for histogram construction.
'''
def __init__(self, client, data, label=None, weight=None,
missing=None,
feature_names=None,
feature_types=None,
max_bin=256):
super().__init__(client=client, data=data, label=label, weight=weight,
missing=missing,
feature_names=feature_names,
feature_types=feature_types)
self.max_bin = max_bin

def get_worker_data(self, worker):
if worker.address not in set(self.worker_map.keys()):
msg = 'worker {address} has an empty DMatrix. ' \
'All workers associated with this DMatrix: {workers}'.format(
address=worker.address,
workers=set(self.worker_map.keys()))
LOGGER.warning(msg)
import cupy # pylint: disable=import-error
d = DeviceQuantileDMatrix(cupy.zeros((0, 0)),
feature_names=self.feature_names,
feature_types=self.feature_types,
max_bin=self.max_bin)
return d

data, labels, weights = self.get_worker_parts(worker)
it = DaskPartitionIter(data=data, label=labels, weight=weights)

dmatrix = DeviceQuantileDMatrix(it,
missing=self.missing,
feature_names=self.feature_names,
feature_types=self.feature_types,
nthread=worker.nthreads,
max_bin=self.max_bin)
return dmatrix


def _get_rabit_args(worker_map, client):
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
host = distributed_comm.get_address_host(client.scheduler.address)
Expand Down
2 changes: 1 addition & 1 deletion python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


def _warn_unused_missing(data, missing):
if not (np.isnan(missing) or None):
if (not np.isnan(missing)) or (missing is None):
warnings.warn(
'`missing` is not used for current input data type:' +
str(type(data)))
Expand Down
1 change: 1 addition & 0 deletions src/common/quantile.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class SketchContainer {
// Initialize Sketches for this dmatrix
this->columns_ptr_.SetDevice(device_);
this->columns_ptr_.Resize(num_columns + 1);
CHECK_GE(device, 0);
timer_.Init(__func__);
}
/* \brief Return GPU ID for this container. */
Expand Down
5 changes: 3 additions & 2 deletions src/data/array_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ class ArrayInterfaceHandler {
get<Array const>(
obj.at("data"))
.at(0))));
CHECK(p_data);
return p_data;
}

Expand Down Expand Up @@ -224,6 +223,9 @@ class ArrayInterfaceHandler {
auto shape = ExtractShape(column);

T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column);
if (!p_data) {
CHECK_EQ(shape.first * shape.second, 0) << "Empty data with non-zero shape.";
}
return common::Span<T>{p_data, shape.first * shape.second};
}
};
Expand All @@ -234,7 +236,6 @@ class ArrayInterface {
bool allow_mask = true) {
ArrayInterfaceHandler::Validate(column);
data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
CHECK(data) << "Column is null";
auto shape = ArrayInterfaceHandler::ExtractShape(column);
num_rows = shape.first;
num_cols = shape.second;
Expand Down
9 changes: 9 additions & 0 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,15 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows) {
this->group_ptr_.insert(this->group_ptr_.end(), group_ptr.begin() + 1,
group_ptr.end());
}

if (!that.feature_names.empty()) {
this->feature_names = that.feature_names;
}
if (!that.feature_type_names.empty()) {
this->feature_type_names = that.feature_type_names;
auto &h_feature_types = feature_types.HostVector();
LoadFeatureType(this->feature_type_names, &h_feature_types);
}
}

void MetaInfo::Validate(int32_t device) const {
Expand Down
3 changes: 3 additions & 0 deletions src/data/data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
<< "Meta info " << key << " should be dense, found validity mask";
CHECK_EQ(array_interface.num_cols, 1)
<< "Meta info should be a single column.";
if (array_interface.num_rows == 0) {
return;
}

if (key == "label") {
CopyInfoImpl(array_interface, &labels_);
Expand Down
11 changes: 9 additions & 2 deletions src/data/device_adapter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,14 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
CHECK_NE(typestr.front(), '>') << ArrayInterfaceErrors::BigEndian();
std::vector<ArrayInterface> columns;
auto first_column = ArrayInterface(get<Object const>(json_columns[0]));
num_rows_ = first_column.num_rows;
if (num_rows_ == 0) {
return;
}

device_idx_ = dh::CudaGetPointerDevice(first_column.data);
CHECK_NE(device_idx_, -1);
dh::safe_cuda(cudaSetDevice(device_idx_));
num_rows_ = first_column.num_rows;
for (auto& json_col : json_columns) {
auto column = ArrayInterface(get<Object const>(json_col));
columns.push_back(column);
Expand Down Expand Up @@ -183,9 +187,12 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
Json json_array_interface =
Json::Load({cuda_interface_str.c_str(), cuda_interface_str.size()});
array_interface_ = ArrayInterface(get<Object const>(json_array_interface), false);
batch_ = CupyAdapterBatch(array_interface_);
if (array_interface_.num_rows == 0) {
return;
}
device_idx_ = dh::CudaGetPointerDevice(array_interface_.data);
CHECK_NE(device_idx_, -1);
batch_ = CupyAdapterBatch(array_interface_);
}
const CupyAdapterBatch& Value() const override { return batch_; }

Expand Down
48 changes: 25 additions & 23 deletions src/data/iterative_device_dmatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -62,43 +62,46 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
size_t batches = 0;
size_t accumulated_rows = 0;
bst_feature_t cols = 0;
int32_t device = -1;
int32_t device = GenericParameter::kCpuId;
int32_t current_device_;
dh::safe_cuda(cudaGetDevice(&current_device_));
auto get_device = [&]() -> int32_t {
int32_t d = GenericParameter::kCpuId ? current_device_ : device;
return d;
};

while (iter.Next()) {
device = proxy->DeviceIdx();
dh::safe_cuda(cudaSetDevice(device));
dh::safe_cuda(cudaSetDevice(get_device()));
if (cols == 0) {
cols = num_cols();
rabit::Allreduce<rabit::op::Max>(&cols, 1);
} else {
CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns.";
}
sketch_containers.emplace_back(batch_param_.max_bin, num_cols(), num_rows(), device);
sketch_containers.emplace_back(batch_param_.max_bin, cols, num_rows(), get_device());
auto* p_sketch = &sketch_containers.back();
proxy->Info().weights_.SetDevice(device);
proxy->Info().weights_.SetDevice(get_device());
Dispatch(proxy, [&](auto const &value) {
common::AdapterDeviceSketch(value, batch_param_.max_bin,
proxy->Info(), missing, p_sketch);
});

auto batch_rows = num_rows();
accumulated_rows += batch_rows;
dh::caching_device_vector<size_t> row_counts(batch_rows + 1, 0);
common::Span<size_t> row_counts_span(row_counts.data().get(),
row_counts.size());
row_stride = std::max(row_stride, Dispatch(proxy, [=](auto const &value) {
return GetRowCounts(value, row_counts_span,
device, missing);
get_device(), missing);
}));
nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(),
row_counts.end());
batches++;
}

if (device < 0) { // error or empty
this->page_.reset(new EllpackPage);
return;
}

common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, device);
iter.Reset();
dh::safe_cuda(cudaSetDevice(get_device()));
common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, get_device());
for (auto const& sketch : sketch_containers) {
final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data());
final_sketch.FixError();
Expand All @@ -113,14 +116,14 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
this->info_.num_row_ = accumulated_rows;
this->info_.num_nonzero_ = nnz;

auto init_page = [this, &proxy, &cuts, row_stride, accumulated_rows]() {
auto init_page = [this, &proxy, &cuts, row_stride, accumulated_rows,
get_device]() {
if (!page_) {
// Should be put inside the while loop to protect against empty batch. In
// that case device id is invalid.
page_.reset(new EllpackPage);
*(page_->Impl()) =
EllpackPageImpl(proxy->DeviceIdx(), cuts, this->IsDense(), row_stride,
accumulated_rows);
*(page_->Impl()) = EllpackPageImpl(get_device(), cuts, this->IsDense(),
row_stride, accumulated_rows);
}
};

Expand All @@ -130,21 +133,20 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
size_t n_batches_for_verification = 0;
while (iter.Next()) {
init_page();
auto device = proxy->DeviceIdx();
dh::safe_cuda(cudaSetDevice(device));
dh::safe_cuda(cudaSetDevice(get_device()));
auto rows = num_rows();
dh::caching_device_vector<size_t> row_counts(rows + 1, 0);
common::Span<size_t> row_counts_span(row_counts.data().get(),
row_counts.size());
Dispatch(proxy, [=](auto const& value) {
return GetRowCounts(value, row_counts_span, device, missing);
return GetRowCounts(value, row_counts_span, get_device(), missing);
});
auto is_dense = this->IsDense();
auto new_impl = Dispatch(proxy, [&](auto const &value) {
return EllpackPageImpl(value, missing, device, is_dense, nthread,
row_counts_span, row_stride, rows, cols, cuts);
return EllpackPageImpl(value, missing, get_device(), is_dense, nthread,
row_counts_span, row_stride, rows, cols, cuts);
});
size_t num_elements = page_->Impl()->Copy(device, &new_impl, offset);
size_t num_elements = page_->Impl()->Copy(get_device(), &new_impl, offset);
offset += num_elements;

proxy->Info().num_row_ = num_rows();
Expand Down
Loading

0 comments on commit 7c26861

Please sign in to comment.