Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dask device dmatrix #5868

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 75 additions & 9 deletions demo/dask/gpu_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,102 @@
from dask.distributed import Client
from dask import array as da
import xgboost as xgb
from xgboost.dask import DaskDMatrix
import numpy as np
from xgboost.dask import DaskDMatrix, DaskDeviceQuantileDMatrix
# We use sklearn instead of dask-ml for demo, but for real world use dask-ml
# should be preferred.
from sklearn.metrics import mean_squared_error
from time import time


def main(client):
# generate some random data for demonstration
ROUNDS = 100


def generate_random(chunks=100):
m = 100000
n = 100
X = da.random.random(size=(m, n), chunks=100)
y = da.random.random(size=(m, ), chunks=100)
X = da.random.random(size=(m, n), chunks=chunks)
y = da.random.random(size=(m, ), chunks=chunks)
return X, y


def assert_non_decreasing(L, tolerance=1e-4):
assert all((y - x) < tolerance for x, y in zip(L, L[1:]))


def train_with_dask_dmatrix(client):
# generate some random data for demonstration
X, y = generate_random()

# DaskDMatrix acts like normal DMatrix, works as a proxy for local
# DMatrix scatter around workers.
start = time()
dtrain = DaskDMatrix(client, X, y)
end = time()
print('Constructing DMatrix:', end - start)

# Use train method from xgboost.dask instead of xgboost. This
# distributed version of train returns a dictionary containing the
# resulting booster and evaluation history obtained from
# evaluation metrics.
start = time()
output = xgb.dask.train(client,
{'verbosity': 2,
# Golden line for GPU training
'tree_method': 'gpu_hist'},
dtrain,
num_boost_round=ROUNDS, evals=[(dtrain, 'train')])
end = time()
print('Training:', end - start)

bst = output['booster']
history = output['history']

# you can pass output directly into `predict` too.
prediction = xgb.dask.predict(client, bst, dtrain)
mse = mean_squared_error(y_pred=prediction.compute(), y_true=y.compute())
print('Evaluation history:', history)
return mse


def train_with_dask_device_dmatrix(client):
import cupy
# generate some random data for demonstration
X, y = generate_random(10000)

X = X.map_blocks(cupy.array)
y = y.map_blocks(cupy.array)

# DaskDeviceQuantileDMatrix helps reducing memory when input is from device
# diectly.
start = time()
dtrain = DaskDeviceQuantileDMatrix(client, X, y)
end = time()
print('Constructing DaskDeviceQuantileDMatrix:', end - start)

# Use train method from xgboost.dask instead of xgboost. This
# distributed version of train returns a dictionary containing the
# resulting booster and evaluation history obtained from
# evaluation metrics.
start = time()
output = xgb.dask.train(client,
{'verbosity': 2,
# Golden line for GPU training
'tree_method': 'gpu_hist'},
dtrain,
num_boost_round=4, evals=[(dtrain, 'train')])
num_boost_round=ROUNDS, evals=[(dtrain, 'train')])
end = time()
print('Training:', end - start)
bst = output['booster']
history = output['history']
assert_non_decreasing(history['train']['rmse'])

# you can pass output directly into `predict` too.
prediction = xgb.dask.predict(client, bst, dtrain)
prediction = prediction.compute()
mse = mean_squared_error(y_pred=prediction.compute(),
y_true=y.map_blocks(cupy.asnumpy).compute())
print('Evaluation history:', history)
return prediction
return mse


if __name__ == '__main__':
Expand All @@ -42,4 +106,6 @@ def main(client):
# process.
with LocalCUDACluster(n_workers=2, threads_per_worker=4) as cluster:
with Client(cluster) as client:
main(client)
mse_dmatrix = train_with_dask_dmatrix(client)
mse_iter = train_with_dask_device_dmatrix(client)
assert np.isclose(mse_iter, mse_dmatrix, atol=1e-3)
29 changes: 20 additions & 9 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
"""Core XGBoost Library."""
import collections
# pylint: disable=no-name-in-module,import-error
from collections.abc import Mapping # Python 3
from collections.abc import Mapping
# pylint: enable=no-name-in-module,import-error
import ctypes
import os
import re
import sys
import json
import warnings

import numpy as np
import scipy.sparse
Expand Down Expand Up @@ -267,7 +268,6 @@ def _convert_unknown_data(data, meta=None, meta_type=None):
raise TypeError('Can not handle data from {}'.format(
type(data).__name__)) from e
else:
import warnings
warnings.warn(
'Unknown data type: ' + str(type(data)) +
', coverting it to csr_matrix')
Expand Down Expand Up @@ -331,7 +331,7 @@ def next_wrapper(self, this): # pylint: disable=unused-argument
'''A wrapper for user defined `next` function.

`this` is not used in Python. ctypes can handle `self` of a Python
member function automatically when converting a it to c function
member function automatically when converting it to c function
pointer.

'''
Expand All @@ -340,7 +340,8 @@ def next_wrapper(self, this): # pylint: disable=unused-argument

def data_handle(data, label=None, weight=None, base_margin=None,
group=None,
label_lower_bound=None, label_upper_bound=None):
label_lower_bound=None, label_upper_bound=None,
feature_names=None, feature_types=None):
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
# pylint: disable=protected-access
self.proxy._set_data_from_cuda_columnar(data)
Expand All @@ -358,14 +359,18 @@ def data_handle(data, label=None, weight=None, base_margin=None,
base_margin=base_margin,
group=group,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound)
label_upper_bound=label_upper_bound,
feature_names=feature_names,
feature_types=feature_types)
try:
# Deffer the exception in order to return 0 and stop the iteration.
# Exception inside a ctype callback function has no effect except
# for printing to stderr (doesn't stop the execution).
ret = self.next(data_handle) # pylint: disable=not-callable
except Exception as e: # pylint: disable=broad-except
tb = sys.exc_info()[2]
# On dask the worker is restarted and somehow the information is
# lost.
self.exception = e.with_traceback(tb)
return 0
return ret
Expand Down Expand Up @@ -469,7 +474,7 @@ def __init__(self, data, label=None, weight=None, base_margin=None,
assert self.handle, 'Failed to construct a DMatrix.'

if not can_handle_meta:
self.set_info(label, weight, base_margin)
self.set_info(label=label, weight=weight, base_margin=base_margin)

self.feature_names = feature_names
self.feature_types = feature_types
Expand Down Expand Up @@ -497,7 +502,9 @@ def set_info(self,
label=None, weight=None, base_margin=None,
group=None,
label_lower_bound=None,
label_upper_bound=None):
label_upper_bound=None,
feature_names=None,
feature_types=None):
'''Set meta info for DMatrix.'''
if label is not None:
self.set_label(label)
Expand All @@ -511,6 +518,10 @@ def set_info(self,
self.set_float_info('label_lower_bound', label_lower_bound)
if label_upper_bound is not None:
self.set_float_info('label_upper_bound', label_upper_bound)
if feature_names is not None:
self.feature_names = feature_names
if feature_types is not None:
self.feature_types = feature_types

def get_float_info(self, field):
"""Get float property from the DMatrix.
Expand Down Expand Up @@ -830,8 +841,8 @@ def feature_names(self, feature_names):

if len(feature_names) != len(set(feature_names)):
raise ValueError('feature_names must be unique')
if len(feature_names) != self.num_col():
msg = 'feature_names must have the same length as data'
if len(feature_names) != self.num_col() and self.num_col() != 0:
msg = f'feature_names must have the same length as data'
raise ValueError(msg)
# prohibit to use symbols may affect to parse. e.g. []<
if not all(isinstance(f, STRING_TYPES) and
Expand Down
143 changes: 142 additions & 1 deletion python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import platform
import logging
from collections import defaultdict
from collections.abc import Sequence
from threading import Thread

import numpy
Expand All @@ -28,7 +29,7 @@
from .compat import CUDF_concat
from .compat import lazy_isinstance

from .core import DMatrix, Booster, _expect
from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter
from .training import train as worker_train
from .tracker import RabitTracker
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
Expand Down Expand Up @@ -357,6 +358,146 @@ def get_worker_data_shape(self, worker):
return (rows, cols)


class DaskPartitionIter(DataIter):
'''A data iterator for `DaskDeviceQuantileDMatrix`.
'''
def __init__(self, data, label=None, weight=None, base_margin=None,
label_lower_bound=None, label_upper_bound=None,
feature_names=None, feature_types=None):
self._data = data
self._labels = label
self._weights = weight
self._base_margin = base_margin
self._label_lower_bound = label_lower_bound
self._label_upper_bound = label_upper_bound
self._feature_names = feature_names
self._feature_types = feature_types

assert isinstance(self._data, Sequence)

types = (Sequence, type(None))
assert isinstance(self._labels, types)
assert isinstance(self._weights, types)
assert isinstance(self._base_margin, types)
assert isinstance(self._label_lower_bound, types)
assert isinstance(self._label_upper_bound, types)

self._iter = 0 # set iterator to 0
super().__init__()

def data(self):
'''Utility function for obtaining current batch of data.'''
return self._data[self._iter]

def labels(self):
'''Utility function for obtaining current batch of label.'''
if self._labels is not None:
return self._labels[self._iter]
return None

def weights(self):
'''Utility function for obtaining current batch of label.'''
if self._weights is not None:
return self._weights[self._iter]
return None

def base_margins(self):
'''Utility function for obtaining current batch of base_margin.'''
if self._base_margin is not None:
return self._base_margin[self._iter]
return None

def label_lower_bounds(self):
'''Utility function for obtaining current batch of label_lower_bound.
'''
if self._label_lower_bound is not None:
return self._label_lower_bound[self._iter]
return None

def label_upper_bounds(self):
'''Utility function for obtaining current batch of label_upper_bound.
'''
if self._label_upper_bound is not None:
return self._label_upper_bound[self._iter]
return None

def reset(self):
'''Reset the iterator'''
self._iter = 0

def next(self, input_data):
'''Yield next batch of data'''
if self._iter == len(self._data):
# Return 0 when there's no more batch.
return 0
if self._feature_names:
feature_names = self._feature_names
else:
if hasattr(self.data(), 'columns'):
feature_names = self.data().columns.format()
else:
feature_names = None
input_data(data=self.data(), label=self.labels(),
weight=self.weights(), group=None,
label_lower_bound=self.label_lower_bounds(),
label_upper_bound=self.label_upper_bounds(),
feature_names=feature_names,
feature_types=self._feature_types)
self._iter += 1
return 1


class DaskDeviceQuantileDMatrix(DaskDMatrix):
'''Specialized data type for `gpu_hist` tree method. This class is
used to reduce the memory usage by eliminating data copies.
Internally the data is merged by weighted GK sketching. So the
number of partitions from dask may affect training accuracy as GK
generates error for each merge.

.. versionadded:: 1.2.0

Parameters
----------
max_bin: Number of bins for histogram construction.

'''
def __init__(self, client, data, label=None, weight=None,
missing=None,
feature_names=None,
feature_types=None,
max_bin=256):
super().__init__(client=client, data=data, label=label, weight=weight,
missing=missing,
feature_names=feature_names,
feature_types=feature_types)
self.max_bin = max_bin

def get_worker_data(self, worker):
if worker.address not in set(self.worker_map.keys()):
msg = 'worker {address} has an empty DMatrix. ' \
'All workers associated with this DMatrix: {workers}'.format(
address=worker.address,
workers=set(self.worker_map.keys()))
LOGGER.warning(msg)
import cupy # pylint: disable=import-error
d = DeviceQuantileDMatrix(cupy.zeros((0, 0)),
feature_names=self.feature_names,
feature_types=self.feature_types,
max_bin=self.max_bin)
return d

data, labels, weights = self.get_worker_parts(worker)
it = DaskPartitionIter(data=data, label=labels, weight=weights)

dmatrix = DeviceQuantileDMatrix(it,
missing=self.missing,
feature_names=self.feature_names,
feature_types=self.feature_types,
nthread=worker.nthreads,
max_bin=self.max_bin)
return dmatrix


def _get_rabit_args(worker_map, client):
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
host = distributed_comm.get_address_host(client.scheduler.address)
Expand Down
Loading