From 029a8b533f69dd74e571200e3c2b94d92fa9eab4 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 16 Jul 2020 15:17:31 +0800 Subject: [PATCH] Simplify the data backends. (#5893) --- python-package/xgboost/core.py | 210 +-- python-package/xgboost/data.py | 1313 +++++++++-------- src/data/proxy_dmatrix.cu | 4 + .../test_device_quantile_dmatrix.py | 7 +- tests/python-gpu/test_from_cudf.py | 16 +- tests/python-gpu/test_from_cupy.py | 8 +- tests/python/test_basic.py | 32 +- tests/python/test_with_pandas.py | 12 +- 8 files changed, 793 insertions(+), 809 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 7e14bba6ade1..15e7c6c08108 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -4,13 +4,14 @@ """Core XGBoost Library.""" import collections # pylint: disable=no-name-in-module,import-error -from collections.abc import Mapping # Python 3 +from collections.abc import Mapping # pylint: enable=no-name-in-module,import-error import ctypes import os import re import sys import json +import warnings import numpy as np import scipy.sparse @@ -267,7 +268,6 @@ def _convert_unknown_data(data, meta=None, meta_type=None): raise TypeError('Can not handle data from {}'.format( type(data).__name__)) from e else: - import warnings warnings.warn( 'Unknown data type: ' + str(type(data)) + ', coverting it to csr_matrix') @@ -279,27 +279,6 @@ def _convert_unknown_data(data, meta=None, meta_type=None): return data -# Either object has cuda array interface or contains columns with interfaces -def _has_cuda_array_interface(data): - return hasattr(data, '__cuda_array_interface__') or \ - lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame') - - -def _cudf_array_interfaces(df): - '''Extract CuDF __cuda_array_interface__''' - interfaces = [] - if lazy_isinstance(df, 'cudf.core.series', 'Series'): - interfaces.append(df.__cuda_array_interface__) - else: - for col in df: - interface = df[col].__cuda_array_interface__ - if 'mask' in interface: - interface['mask'] = interface['mask'].__cuda_array_interface__ - interfaces.append(interface) - interfaces_str = bytes(json.dumps(interfaces, indent=2), 'utf-8') - return interfaces_str - - class DataIter: '''The interface for user defined data iterator. Currently is only supported by Device DMatrix. @@ -331,7 +310,7 @@ def next_wrapper(self, this): # pylint: disable=unused-argument '''A wrapper for user defined `next` function. `this` is not used in Python. ctypes can handle `self` of a Python - member function automatically when converting a it to c function + member function automatically when converting it to c function pointer. ''' @@ -340,32 +319,30 @@ def next_wrapper(self, this): # pylint: disable=unused-argument def data_handle(data, label=None, weight=None, base_margin=None, group=None, - label_lower_bound=None, label_upper_bound=None): - if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'): - # pylint: disable=protected-access - self.proxy._set_data_from_cuda_columnar(data) - elif lazy_isinstance(data, 'cudf.core.series', 'Series'): - # pylint: disable=protected-access - self.proxy._set_data_from_cuda_columnar(data) - elif lazy_isinstance(data, 'cupy.core.core', 'ndarray'): - # pylint: disable=protected-access - self.proxy._set_data_from_cuda_interface(data) - else: - raise TypeError( - 'Value type is not supported for data iterator:' + - str(type(self._handle)), type(data)) + label_lower_bound=None, label_upper_bound=None, + feature_names=None, feature_types=None): + from .data import dispatch_device_quantile_dmatrix_set_data + from .data import _device_quantile_transform + data, feature_names, feature_types = _device_quantile_transform( + data, feature_names, feature_types + ) + dispatch_device_quantile_dmatrix_set_data(self.proxy, data) self.proxy.set_info(label=label, weight=weight, base_margin=base_margin, group=group, label_lower_bound=label_lower_bound, - label_upper_bound=label_upper_bound) + label_upper_bound=label_upper_bound, + feature_names=feature_names, + feature_types=feature_types) try: - # Deffer the exception in order to return 0 and stop the iteration. + # Differ the exception in order to return 0 and stop the iteration. # Exception inside a ctype callback function has no effect except # for printing to stderr (doesn't stop the execution). ret = self.next(data_handle) # pylint: disable=not-callable except Exception as e: # pylint: disable=broad-except tb = sys.exc_info()[2] + # On dask the worker is restarted and somehow the information is + # lost. self.exception = e.with_traceback(tb) return 0 return ret @@ -453,41 +430,20 @@ def __init__(self, data, label=None, weight=None, base_margin=None, self.handle = None return - handler = self._get_data_handler(data) - can_handle_meta = False - if handler is None: - data = _convert_unknown_data(data, None) - handler = self._get_data_handler(data) - try: - handler.handle_meta(label, weight, base_margin) - can_handle_meta = True - except NotImplementedError: - can_handle_meta = False - - self.handle, feature_names, feature_types = handler.handle_input( - data, feature_names, feature_types) - assert self.handle, 'Failed to construct a DMatrix.' + from .data import dispatch_data_backend + handle, feature_names, feature_types = dispatch_data_backend( + data, missing=self.missing, + threads=self.nthread, + feature_names=feature_names, + feature_types=feature_types) + assert handle is not None + self.handle = handle - if not can_handle_meta: - self.set_info(label, weight, base_margin) + self.set_info(label=label, weight=weight, base_margin=base_margin) self.feature_names = feature_names self.feature_types = feature_types - def _get_data_handler(self, data, meta=None, meta_type=None): - '''Get data handler for this DMatrix class.''' - from .data import get_dmatrix_data_handler - handler = get_dmatrix_data_handler( - data, self.missing, self.nthread, self.silent, meta, meta_type) - return handler - - # pylint: disable=no-self-use - def _get_meta_handler(self, data, meta, meta_type): - from .data import get_dmatrix_meta_handler - handler = get_dmatrix_meta_handler( - data, meta, meta_type) - return handler - def __del__(self): if hasattr(self, "handle") and self.handle: _check_call(_LIB.XGDMatrixFree(self.handle)) @@ -497,7 +453,9 @@ def set_info(self, label=None, weight=None, base_margin=None, group=None, label_lower_bound=None, - label_upper_bound=None): + label_upper_bound=None, + feature_names=None, + feature_types=None): '''Set meta info for DMatrix.''' if label is not None: self.set_label(label) @@ -511,6 +469,10 @@ def set_info(self, self.set_float_info('label_lower_bound', label_lower_bound) if label_upper_bound is not None: self.set_float_info('label_upper_bound', label_upper_bound) + if feature_names is not None: + self.feature_names = feature_names + if feature_types is not None: + self.feature_types = feature_types def get_float_info(self, field): """Get float property from the DMatrix. @@ -565,17 +527,8 @@ def set_float_info(self, field, data): data: numpy array The array of data to be set """ - if isinstance(data, np.ndarray): - self.set_float_info_npy2d(field, data) - return - handler = self._get_data_handler(data, field, np.float32) - assert handler - data, _, _ = handler.transform(data) - c_data = c_array(ctypes.c_float, data) - _check_call(_LIB.XGDMatrixSetFloatInfo(self.handle, - c_str(field), - c_data, - c_bst_ulong(len(data)))) + from .data import dispatch_meta_backend + dispatch_meta_backend(self, data, field, 'float') def set_float_info_npy2d(self, field, data): """Set float type property into the DMatrix @@ -589,13 +542,8 @@ def set_float_info_npy2d(self, field, data): data: numpy array The array of data to be set """ - data, _, _ = self._get_meta_handler( - data, field, np.float32).transform(data) - c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) - _check_call(_LIB.XGDMatrixSetFloatInfo(self.handle, - c_str(field), - c_data, - c_bst_ulong(len(data)))) + from .data import dispatch_meta_backend + dispatch_meta_backend(self, data, field, 'float') def set_uint_info(self, field, data): """Set uint type property into the DMatrix. @@ -608,27 +556,8 @@ def set_uint_info(self, field, data): data: numpy array The array of data to be set """ - data, _, _ = self._get_data_handler( - data, field, 'uint32').transform(data) - _check_call(_LIB.XGDMatrixSetUIntInfo(self.handle, - c_str(field), - c_array(ctypes.c_uint, data), - c_bst_ulong(len(data)))) - - def set_interface_info(self, field, data): - """Set info type property into DMatrix.""" - # If we are passed a dataframe, extract the series - if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'): - if len(data.columns) != 1: - raise ValueError( - 'Expecting meta-info to contain a single column') - data = data[data.columns[0]] - - interface = bytes(json.dumps([data.__cuda_array_interface__], - indent=2), 'utf-8') - _check_call(_LIB.XGDMatrixSetInfoFromInterface(self.handle, - c_str(field), - interface)) + from .data import dispatch_meta_backend + dispatch_meta_backend(self, data, field, 'uint32') def save_binary(self, fname, silent=True): """Save DMatrix to an XGBoost buffer. Saved binary can be later loaded @@ -653,10 +582,8 @@ def set_label(self, label): label: array like The label information to be set into DMatrix """ - if _has_cuda_array_interface(label): - self.set_interface_info('label', label) - else: - self.set_float_info('label', label) + from .data import dispatch_meta_backend + dispatch_meta_backend(self, label, 'label', 'float') def set_weight(self, weight): """Set weight of each instance. @@ -674,10 +601,8 @@ def set_weight(self, weight): sense to assign weights to individual data points. """ - if _has_cuda_array_interface(weight): - self.set_interface_info('weight', weight) - else: - self.set_float_info('weight', weight) + from .data import dispatch_meta_backend + dispatch_meta_backend(self, weight, 'weight', 'float') def set_base_margin(self, margin): """Set base margin of booster to start from. @@ -693,10 +618,8 @@ def set_base_margin(self, margin): Prediction margin of each datapoint """ - if _has_cuda_array_interface(margin): - self.set_interface_info('base_margin', margin) - else: - self.set_float_info('base_margin', margin) + from .data import dispatch_meta_backend + dispatch_meta_backend(self, margin, 'base_margin', 'float') def set_group(self, group): """Set group size of DMatrix (used for ranking). @@ -706,10 +629,8 @@ def set_group(self, group): group : array like Group size of each group """ - if _has_cuda_array_interface(group): - self.set_interface_info('group', group) - else: - self.set_uint_info('group', group) + from .data import dispatch_meta_backend + dispatch_meta_backend(self, group, 'group', 'uint32') def get_label(self): """Get the label of the DMatrix. @@ -830,7 +751,7 @@ def feature_names(self, feature_names): if len(feature_names) != len(set(feature_names)): raise ValueError('feature_names must be unique') - if len(feature_names) != self.num_col(): + if len(feature_names) != self.num_col() and self.num_col() != 0: msg = 'feature_names must have the same length as data' raise ValueError(msg) # prohibit to use symbols may affect to parse. e.g. []< @@ -935,28 +856,35 @@ class DeviceQuantileDMatrix(DMatrix): """ - def __init__(self, data, label=None, weight=None, base_margin=None, + def __init__(self, data, label=None, weight=None, # pylint: disable=W0231 + base_margin=None, missing=None, silent=False, feature_names=None, feature_types=None, nthread=None, max_bin=256): self.max_bin = max_bin + self.missing = missing if missing is not None else np.nan + self.nthread = nthread if nthread is not None else 1 + if isinstance(data, ctypes.c_void_p): self.handle = data return - super().__init__(data, label=label, weight=weight, - base_margin=base_margin, - missing=missing, - silent=silent, - feature_names=feature_names, - feature_types=feature_types, - nthread=nthread) - - def _get_data_handler(self, data, meta=None, meta_type=None): - from .data import get_device_quantile_dmatrix_data_handler - return get_device_quantile_dmatrix_data_handler( - data, self.max_bin, self.missing, self.nthread, self.silent) + from .data import init_device_quantile_dmatrix + handle, feature_names, feature_types = init_device_quantile_dmatrix( + data, missing=self.missing, threads=self.nthread, + max_bin=self.max_bin, + label=label, weight=weight, + base_margin=base_margin, + group=None, + label_lower_bound=None, + label_upper_bound=None, + feature_names=feature_names, + feature_types=feature_types) + self.handle = handle + + self.feature_names = feature_names + self.feature_types = feature_types def _set_data_from_cuda_interface(self, data): '''Set data from CUDA array interface.''' @@ -971,6 +899,7 @@ def _set_data_from_cuda_interface(self, data): def _set_data_from_cuda_columnar(self, data): '''Set data from CUDA columnar format.1''' + from .data import _cudf_array_interfaces interfaces_str = _cudf_array_interfaces(data) _check_call( _LIB.XGDeviceQuantileDMatrixSetDataCudaColumnar( @@ -1592,6 +1521,7 @@ def reshape_output(predt, rows): rows = data.shape[0] return reshape_output(mem, rows) if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'): + from .data import _cudf_array_interfaces interfaces_str = _cudf_array_interfaces(data) _check_call(_LIB.XGBoosterPredictFromArrayInterfaceColumns( self.handle, diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 955802ec857f..33064d4cd627 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -1,538 +1,642 @@ -# pylint: disable=too-many-arguments, no-self-use, too-many-instance-attributes +# pylint: disable=too-many-arguments, too-many-branches +# pylint: disable=too-many-return-statements, import-error '''Data dispatching for DMatrix.''' import ctypes -import abc import json import warnings import numpy as np -from .core import c_array, _LIB, _check_call, c_str, _cudf_array_interfaces -from .core import DataIter -from .compat import lazy_isinstance, STRING_TYPES, os_fspath, os_PathLike +from .core import c_array, _LIB, _check_call, c_str +from .core import DataIter, DeviceQuantileDMatrix, DMatrix +from .compat import lazy_isinstance, os_fspath, os_PathLike c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name -class DataHandler(abc.ABC): - '''Base class for various data handler.''' - def __init__(self, missing, nthread, silent, meta=None, meta_type=None): - self.missing = missing - self.nthread = nthread - self.silent = silent +def _warn_unused_missing(data, missing): + if not (np.isnan(missing) or None): + warnings.warn( + '`missing` is not used for current input data type:' + + str(type(data))) + + +def _check_complex(data): + '''Test whether data is complex using `dtype` attribute.''' + complex_dtypes = (np.complex128, np.complex64, + np.cfloat, np.cdouble, np.clongdouble) + if hasattr(data, 'dtype') and data.dtype in complex_dtypes: + raise ValueError('Complex data not supported') + + +def _is_scipy_csr(data): + try: + import scipy + except ImportError: + scipy = None + return False + return isinstance(data, scipy.sparse.csr_matrix) + + +def _from_scipy_csr(data, missing, feature_names, feature_types): + '''Initialize data from a CSR matrix.''' + if len(data.indices) != len(data.data): + raise ValueError('length mismatch: {} vs {}'.format( + len(data.indices), len(data.data))) + _warn_unused_missing(data, missing) + handle = ctypes.c_void_p() + _check_call(_LIB.XGDMatrixCreateFromCSREx( + c_array(ctypes.c_size_t, data.indptr), + c_array(ctypes.c_uint, data.indices), + c_array(ctypes.c_float, data.data), + ctypes.c_size_t(len(data.indptr)), + ctypes.c_size_t(len(data.data)), + ctypes.c_size_t(data.shape[1]), + ctypes.byref(handle))) + return handle, feature_names, feature_types + + +def _is_scipy_csc(data): + try: + import scipy + except ImportError: + scipy = None + return False + return isinstance(data, scipy.sparse.csc_matrix) + + +def _from_scipy_csc(data, missing, feature_names, feature_types): + if len(data.indices) != len(data.data): + raise ValueError('length mismatch: {} vs {}'.format( + len(data.indices), len(data.data))) + _warn_unused_missing(data, missing) + handle = ctypes.c_void_p() + _check_call(_LIB.XGDMatrixCreateFromCSCEx( + c_array(ctypes.c_size_t, data.indptr), + c_array(ctypes.c_uint, data.indices), + c_array(ctypes.c_float, data.data), + ctypes.c_size_t(len(data.indptr)), + ctypes.c_size_t(len(data.data)), + ctypes.c_size_t(data.shape[0]), + ctypes.byref(handle))) + return handle, feature_names, feature_types + + +def _is_numpy_array(data): + return isinstance(data, (np.ndarray, np.matrix)) + + +def _maybe_np_slice(data, dtype): + '''Handle numpy slice. This can be removed if we use __array_interface__. + ''' + try: + if not data.flags.c_contiguous: + warnings.warn( + "Use subset (sliced data) of np.ndarray is not recommended " + + "because it will generate extra copies and increase " + + "memory consumption") + data = np.array(data, copy=True, dtype=dtype) + else: + data = np.array(data, copy=False, dtype=dtype) + except AttributeError: + data = np.array(data, copy=False, dtype=dtype) + return data + + +def _transform_np_array(data: np.ndarray): + if not isinstance(data, np.ndarray) and hasattr(data, '__array__'): + data = np.array(data, copy=False) + if len(data.shape) != 2: + raise ValueError('Expecting 2 dimensional numpy.ndarray, got: ', + data.shape) + # flatten the array by rows and ensure it is float32. we try to avoid + # data copies if possible (reshape returns a view when possible and we + # explicitly tell np.array to try and avoid copying) + flatten = np.array(data.reshape(data.size), copy=False, + dtype=np.float32) + flatten = _maybe_np_slice(flatten, np.float32) + _check_complex(data) + return flatten + + +def _from_numpy_array(data, missing, nthread, feature_names, feature_types): + """Initialize data from a 2-D numpy matrix. + + If ``mat`` does not have ``order='C'`` (aka row-major) or is + not contiguous, a temporary copy will be made. + + If ``mat`` does not have ``dtype=numpy.float32``, a temporary copy will + be made. + + So there could be as many as two temporary data copies; be mindful of + input layout and type if memory use is a concern. + + """ + flatten = _transform_np_array(data) + handle = ctypes.c_void_p() + _check_call(_LIB.XGDMatrixCreateFromMat_omp( + flatten.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), + c_bst_ulong(data.shape[0]), + c_bst_ulong(data.shape[1]), + ctypes.c_float(missing), + ctypes.byref(handle), + ctypes.c_int(nthread))) + return handle, feature_names, feature_types + + +def _is_pandas_df(data): + try: + import pandas as pd + except ImportError: + return False + return isinstance(data, pd.DataFrame) + + +_pandas_dtype_mapper = { + 'int8': 'int', + 'int16': 'int', + 'int32': 'int', + 'int64': 'int', + 'uint8': 'int', + 'uint16': 'int', + 'uint32': 'int', + 'uint64': 'int', + 'float16': 'float', + 'float32': 'float', + 'float64': 'float', + 'bool': 'i' +} + + +def _transform_pandas_df(data, feature_names=None, feature_types=None, + meta=None, meta_type=None): + from pandas import MultiIndex, Int64Index + from pandas.api.types import is_sparse + data_dtypes = data.dtypes + if not all(dtype.name in _pandas_dtype_mapper or is_sparse(dtype) + for dtype in data_dtypes): + bad_fields = [ + str(data.columns[i]) for i, dtype in enumerate(data_dtypes) + if dtype.name not in _pandas_dtype_mapper + ] + + msg = """DataFrame.dtypes for data must be int, float or bool. + Did not expect the data types in fields """ + raise ValueError(msg + ', '.join(bad_fields)) + + if feature_names is None and meta is None: + if isinstance(data.columns, MultiIndex): + feature_names = [ + ' '.join([str(x) for x in i]) for i in data.columns + ] + elif isinstance(data.columns, Int64Index): + feature_names = list(map(str, data.columns)) + else: + feature_names = data.columns.format() + + if feature_types is None and meta is None: + feature_types = [] + for dtype in data_dtypes: + if is_sparse(dtype): + feature_types.append(_pandas_dtype_mapper[ + dtype.subtype.name]) + else: + feature_types.append(_pandas_dtype_mapper[dtype.name]) - self.meta = meta - self.meta_type = meta_type + if meta and len(data.columns) > 1: + raise ValueError( + 'DataFrame for {meta} cannot have multiple columns'.format( + meta=meta)) - def handle_meta(self, label=None, weight=None, base_margin=None, - group=None, - label_lower_bound=None, - label_upper_bound=None): - '''Handle meta data when the DMatrix type can not defer setting meta - data after construction. Example is `DeviceQuantileDMatrix` - which requires weight to be presented before digesting - data. + dtype = meta_type if meta_type else 'float' + data = data.values.astype(dtype) - ''' - raise NotImplementedError() + return data, feature_names, feature_types - def _warn_unused_missing(self, data): - if not (np.isnan(np.nan) or None): - warnings.warn( - '`missing` is not used for current input data type:' + - str(type(data))) - - def check_complex(self, data): - '''Test whether data is complex using `dtype` attribute.''' - complex_dtypes = (np.complex128, np.complex64, - np.cfloat, np.cdouble, np.clongdouble) - if hasattr(data, 'dtype') and data.dtype in complex_dtypes: - raise ValueError('Complex data not supported') - - def transform(self, data): - '''Optional method for transforming data before being accepted by - other XGBoost API.''' - return data, None, None - @abc.abstractmethod - def handle_input(self, data, feature_names, feature_types): - '''Abstract method for handling different data input.''' - - -class DMatrixDataManager: - '''The registry class for various data handler.''' - def __init__(self): - self.__data_handlers = {} - self.__data_handlers_dly = [] - - def register_handler(self, module, name, handler): - '''Register a data handler handling specfic type of data.''' - self.__data_handlers['.'.join([module, name])] = handler - - def register_handler_opaque(self, func, handler): - '''Register a data handler that handles data with opaque type. - - Parameters - ---------- - func : callable - A function with a single parameter `data`. It should return True - if the handler can handle this data, otherwise returns False. - handler : xgboost.data.DataHandler - The handler class that is a subclass of `DataHandler`. - ''' - self.__data_handlers_dly.append((func, handler)) - - def get_handler(self, data): - '''Get a handler of `data`, returns None if handler not found.''' - module, name = type(data).__module__, type(data).__name__ - if '.'.join([module, name]) in self.__data_handlers.keys(): - handler = self.__data_handlers['.'.join([module, name])] - return handler - for f, handler in self.__data_handlers_dly: - if f(data): - return handler - return None - - -__dmatrix_registry = DMatrixDataManager() # pylint: disable=invalid-name - - -def get_dmatrix_data_handler(data, missing, nthread, silent, - meta=None, meta_type=None): - '''Get a handler of `data` for DMatrix. - - .. versionadded:: 1.2.0 - - Parameters - ---------- - data : any - The input data. - missing : float - Same as `missing` for DMatrix. - nthread : int - Same as `nthread` for DMatrix. - silent : boolean - Same as `silent` for DMatrix. - meta : str - Field name of meta data, like `label`. Used only for getting handler - for meta info. - meta_type : str/np.dtype - Type of meta data. - - Returns - ------- - handler : DataHandler - ''' - handler = __dmatrix_registry.get_handler(data) - if handler is None: - return None - return handler(missing, nthread, silent, meta, meta_type) - - -def get_dmatrix_meta_handler(data, meta, meta_type): - '''Get handler for meta instead of data.''' - handler = __dmatrix_registry.get_handler(data) - if handler is None: - return None - return handler(None, 0, True, meta, meta_type) - - -class FileHandler(DataHandler): - '''Handler of path like input.''' - def handle_input(self, data, feature_names, feature_types): - self._warn_unused_missing(data) - handle = ctypes.c_void_p() - _check_call(_LIB.XGDMatrixCreateFromFile(c_str(os_fspath(data)), - ctypes.c_int(self.silent), - ctypes.byref(handle))) - return handle, feature_names, feature_types - - -__dmatrix_registry.register_handler_opaque( - lambda data: isinstance(data, (STRING_TYPES, os_PathLike)), - FileHandler) - - -class CSRHandler(DataHandler): - '''Handler of `scipy.sparse.csr.csr_matrix`.''' - def handle_input(self, data, feature_names, feature_types): - '''Initialize data from a CSR matrix.''' - if len(data.indices) != len(data.data): - raise ValueError('length mismatch: {} vs {}'.format( - len(data.indices), len(data.data))) - self._warn_unused_missing(data) - handle = ctypes.c_void_p() - _check_call(_LIB.XGDMatrixCreateFromCSREx( - c_array(ctypes.c_size_t, data.indptr), - c_array(ctypes.c_uint, data.indices), - c_array(ctypes.c_float, data.data), - ctypes.c_size_t(len(data.indptr)), - ctypes.c_size_t(len(data.data)), - ctypes.c_size_t(data.shape[1]), - ctypes.byref(handle))) - return handle, feature_names, feature_types - - -__dmatrix_registry.register_handler( - 'scipy.sparse.csr', 'csr_matrix', CSRHandler) - - -class CSCHandler(DataHandler): - '''Handler of `scipy.sparse.csc.csc_matrix`.''' - def handle_input(self, data, feature_names, feature_types): - if len(data.indices) != len(data.data): - raise ValueError('length mismatch: {} vs {}'.format( - len(data.indices), len(data.data))) - self._warn_unused_missing(data) - handle = ctypes.c_void_p() - _check_call(_LIB.XGDMatrixCreateFromCSCEx( - c_array(ctypes.c_size_t, data.indptr), - c_array(ctypes.c_uint, data.indices), - c_array(ctypes.c_float, data.data), - ctypes.c_size_t(len(data.indptr)), - ctypes.c_size_t(len(data.data)), - ctypes.c_size_t(data.shape[0]), - ctypes.byref(handle))) - return handle, feature_names, feature_types +def _from_pandas_df(data, missing, nthread, feature_names, feature_types): + data, feature_names, feature_types = _transform_pandas_df( + data, feature_names, feature_types) + return _from_numpy_array(data, missing, nthread, feature_names, + feature_types) -__dmatrix_registry.register_handler( - 'scipy.sparse.csc', 'csc_matrix', CSCHandler) +def _is_pandas_series(data): + try: + import pandas as pd + except ImportError: + return False + return isinstance(data, pd.Series) -class NumpyHandler(DataHandler): - '''Handler of `numpy.ndarray`.''' - def _maybe_np_slice(self, data, dtype): - '''Handle numpy slice. This can be removed if we use __array_interface__. - ''' - try: - if not data.flags.c_contiguous: - warnings.warn( - "Use subset (sliced data) of np.ndarray is not recommended " + - "because it will generate extra copies and increase " + - "memory consumption") - data = np.array(data, copy=True, dtype=dtype) - else: - data = np.array(data, copy=False, dtype=dtype) - except AttributeError: - data = np.array(data, copy=False, dtype=dtype) - return data - - def transform(self, data): - return self._maybe_np_slice(data, self.meta_type), None, None - - def handle_input(self, data, feature_names, feature_types): - """Initialize data from a 2-D numpy matrix. - - If ``mat`` does not have ``order='C'`` (aka row-major) or is - not contiguous, a temporary copy will be made. - - If ``mat`` does not have ``dtype=numpy.float32``, a temporary copy will - be made. - - So there could be as many as two temporary data copies; be mindful of - input layout and type if memory use is a concern. - - """ - if not isinstance(data, np.ndarray) and hasattr(data, '__array__'): - data = np.array(data, copy=False) - if len(data.shape) != 2: - raise ValueError('Expecting 2 dimensional numpy.ndarray, got: ', - data.shape) - # flatten the array by rows and ensure it is float32. we try to avoid - # data copies if possible (reshape returns a view when possible and we - # explicitly tell np.array to try and avoid copying) - flatten = np.array(data.reshape(data.size), copy=False, - dtype=np.float32) - flatten = self._maybe_np_slice(flatten, np.float32) - self.check_complex(data) - handle = ctypes.c_void_p() - _check_call(_LIB.XGDMatrixCreateFromMat_omp( - flatten.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), - c_bst_ulong(data.shape[0]), - c_bst_ulong(data.shape[1]), - ctypes.c_float(self.missing), - ctypes.byref(handle), - ctypes.c_int(self.nthread))) - return handle, feature_names, feature_types - - -__dmatrix_registry.register_handler('numpy', 'ndarray', NumpyHandler) -__dmatrix_registry.register_handler('numpy', 'matrix', NumpyHandler) -__dmatrix_registry.register_handler_opaque( - lambda x: hasattr(x, '__array__'), NumpyHandler) - - -class ListHandler(NumpyHandler): - '''Handler of builtin list and tuple''' - def handle_input(self, data, feature_names, feature_types): - assert self.meta is None, 'List input data is not supported for X' - data = np.array(data) - return super().handle_input(data, feature_names, feature_types) - - -__dmatrix_registry.register_handler('builtins', 'list', NumpyHandler) -__dmatrix_registry.register_handler('builtins', 'tuple', NumpyHandler) - - -class PandasHandler(NumpyHandler): - '''Handler of data structures defined by `pandas`.''' - pandas_dtype_mapper = { - 'int8': 'int', - 'int16': 'int', - 'int32': 'int', - 'int64': 'int', - 'uint8': 'int', - 'uint16': 'int', - 'uint32': 'int', - 'uint64': 'int', - 'float16': 'float', - 'float32': 'float', - 'float64': 'float', - 'bool': 'i' - } - - def _maybe_pandas_data(self, data, feature_names, feature_types, - meta=None, meta_type=None): - """Extract internal data from pd.DataFrame for DMatrix data""" - if lazy_isinstance(data, 'pandas.core.series', 'Series'): - dtype = meta_type if meta_type else 'float' - return data.values.astype(dtype), feature_names, feature_types - - from pandas.api.types import is_sparse - from pandas import MultiIndex, Int64Index - - data_dtypes = data.dtypes - if not all(dtype.name in self.pandas_dtype_mapper or is_sparse(dtype) - for dtype in data_dtypes): - bad_fields = [ - str(data.columns[i]) for i, dtype in enumerate(data_dtypes) - if dtype.name not in self.pandas_dtype_mapper - ] +def _from_pandas_series(data, missing, nthread, feature_types, feature_names): + return _from_numpy_array(data.values.astype('float'), missing, nthread, + feature_names, feature_types) - msg = """DataFrame.dtypes for data must be int, float or bool. - Did not expect the data types in fields """ - raise ValueError(msg + ', '.join(bad_fields)) - - if feature_names is None and meta is None: - if isinstance(data.columns, MultiIndex): - feature_names = [ - ' '.join([str(x) for x in i]) for i in data.columns - ] - elif isinstance(data.columns, Int64Index): - feature_names = list(map(str, data.columns)) - else: - feature_names = data.columns.format() - - if feature_types is None and meta is None: - feature_types = [] - for dtype in data_dtypes: - if is_sparse(dtype): - feature_types.append(self.pandas_dtype_mapper[ - dtype.subtype.name]) - else: - feature_types.append(self.pandas_dtype_mapper[dtype.name]) - - if meta and len(data.columns) > 1: - raise ValueError( - 'DataFrame for {meta} cannot have multiple columns'.format( - meta=meta)) - dtype = meta_type if meta_type else 'float' - data = data.values.astype(dtype) +def _is_dt_df(data): + return lazy_isinstance(data, 'datatable', 'Frame') or \ + lazy_isinstance(data, 'datatable', 'DataTable') - return data, feature_names, feature_types - def transform(self, data): - return self._maybe_pandas_data(data, None, None, self.meta, - self.meta_type) +_dt_type_mapper = {'bool': 'bool', 'int': 'int', 'real': 'float'} +_dt_type_mapper2 = {'bool': 'i', 'int': 'int', 'real': 'float'} - def handle_input(self, data, feature_names, feature_types): - data, feature_names, feature_types = self._maybe_pandas_data( - data, feature_names, feature_types, self.meta, self.meta_type) - return super().handle_input(data, feature_names, feature_types) +def _transform_dt_df(data, feature_names, feature_types, meta=None, + meta_type=None): + """Validate feature names and types if data table""" + if meta and data.shape[1] > 1: + raise ValueError( + 'DataTable for label or weight cannot have multiple columns') + if meta: + # below requires new dt version + # extract first column + data = data.to_numpy()[:, 0].astype(meta_type) + return data, None, None -__dmatrix_registry.register_handler( - 'pandas.core.frame', 'DataFrame', PandasHandler) -__dmatrix_registry.register_handler( - 'pandas.core.series', 'Series', PandasHandler) - + data_types_names = tuple(lt.name for lt in data.ltypes) + bad_fields = [data.names[i] + for i, type_name in enumerate(data_types_names) + if type_name not in _dt_type_mapper] + if bad_fields: + msg = """DataFrame.types for data must be int, float or bool. + Did not expect the data types in fields """ + raise ValueError(msg + ', '.join(bad_fields)) -class DTHandler(DataHandler): - '''Handler of datatable.''' - dt_type_mapper = {'bool': 'bool', 'int': 'int', 'real': 'float'} - dt_type_mapper2 = {'bool': 'i', 'int': 'int', 'real': 'float'} + if feature_names is None and meta is None: + feature_names = data.names - def _maybe_dt_data(self, data, feature_names, feature_types, - meta=None, meta_type=None): - """Validate feature names and types if data table""" - if meta and data.shape[1] > 1: + # always return stypes for dt ingestion + if feature_types is not None: raise ValueError( - 'DataTable for label or weight cannot have multiple columns') - if meta: - # below requires new dt version - # extract first column - data = data.to_numpy()[:, 0].astype(meta_type) - return data, None, None - - data_types_names = tuple(lt.name for lt in data.ltypes) - bad_fields = [data.names[i] - for i, type_name in enumerate(data_types_names) - if type_name not in self.dt_type_mapper] - if bad_fields: - msg = """DataFrame.types for data must be int, float or bool. - Did not expect the data types in fields """ - raise ValueError(msg + ', '.join(bad_fields)) - - if feature_names is None and meta is None: - feature_names = data.names - - # always return stypes for dt ingestion - if feature_types is not None: - raise ValueError( - 'DataTable has own feature types, cannot pass them in.') - feature_types = np.vectorize(self.dt_type_mapper2.get)( - data_types_names).tolist() + 'DataTable has own feature types, cannot pass them in.') + feature_types = np.vectorize(_dt_type_mapper2.get)( + data_types_names).tolist() - return data, feature_names, feature_types - - def transform(self, data): - return self._maybe_dt_data(data, None, None, self.meta, self.meta_type) + return data, feature_names, feature_types - def handle_input(self, data, feature_names, feature_types): - data, feature_names, feature_types = self._maybe_dt_data( - data, feature_names, feature_types, self.meta, self.meta_type) - ptrs = (ctypes.c_void_p * data.ncols)() - if hasattr(data, "internal") and hasattr(data.internal, "column"): - # datatable>0.8.0 - for icol in range(data.ncols): - col = data.internal.column(icol) - ptr = col.data_pointer - ptrs[icol] = ctypes.c_void_p(ptr) - else: - # datatable<=0.8.0 - from datatable.internal import \ - frame_column_data_r # pylint: disable=no-name-in-module,import-error - for icol in range(data.ncols): - ptrs[icol] = frame_column_data_r(data, icol) +def _from_dt_df(data, missing, nthread, feature_names, feature_types): + data, feature_names, feature_types = _transform_dt_df( + data, feature_names, feature_types, None, None) - # always return stypes for dt ingestion - feature_type_strings = (ctypes.c_char_p * data.ncols)() + ptrs = (ctypes.c_void_p * data.ncols)() + if hasattr(data, "internal") and hasattr(data.internal, "column"): + # datatable>0.8.0 for icol in range(data.ncols): - feature_type_strings[icol] = ctypes.c_char_p( - data.stypes[icol].name.encode('utf-8')) - - self._warn_unused_missing(data) - handle = ctypes.c_void_p() - _check_call(_LIB.XGDMatrixCreateFromDT( - ptrs, feature_type_strings, - c_bst_ulong(data.shape[0]), - c_bst_ulong(data.shape[1]), - ctypes.byref(handle), - ctypes.c_int(self.nthread))) - return handle, feature_names, feature_types - - -__dmatrix_registry.register_handler('datatable', 'Frame', DTHandler) -__dmatrix_registry.register_handler('datatable', 'DataTable', DTHandler) - - -class CudaArrayInterfaceHandler(DataHandler): - '''Handler of data with `__cuda_array_interface__` (cupy.ndarray).''' - def handle_input(self, data, feature_names, feature_types): - """Initialize DMatrix from cupy ndarray.""" - interface = data.__cuda_array_interface__ - if 'mask' in interface: - interface['mask'] = interface['mask'].__cuda_array_interface__ - interface_str = bytes(json.dumps(interface, indent=2), 'utf-8') - - handle = ctypes.c_void_p() - _check_call( - _LIB.XGDMatrixCreateFromArrayInterface( - interface_str, - ctypes.c_float(self.missing), - ctypes.c_int(self.nthread), - ctypes.byref(handle))) - return handle, feature_names, feature_types - - -__dmatrix_registry.register_handler('cupy.core.core', 'ndarray', - CudaArrayInterfaceHandler) - - -class CudaColumnarHandler(DataHandler): - '''Handler of CUDA based columnar data. (cudf.DataFrame)''' - def _maybe_cudf_dataframe(self, data, feature_names, feature_types): - """Extract internal data from cudf.DataFrame for DMatrix data.""" - if feature_names is None: - if lazy_isinstance(data, 'cudf.core.series', 'Series'): - feature_names = [data.name] - elif lazy_isinstance( - data.columns, 'cudf.core.multiindex', 'MultiIndex'): - feature_names = [ - ' '.join([str(x) for x in i]) - for i in data.columns - ] - else: - feature_names = data.columns.format() - if feature_types is None: - if lazy_isinstance(data, 'cudf.core.series', 'Series'): - dtypes = [data.dtype] - else: - dtypes = data.dtypes - feature_types = [PandasHandler.pandas_dtype_mapper[d.name] - for d in dtypes] - return data, feature_names, feature_types + col = data.internal.column(icol) + ptr = col.data_pointer + ptrs[icol] = ctypes.c_void_p(ptr) + else: + # datatable<=0.8.0 + from datatable.internal import \ + frame_column_data_r # pylint: disable=no-name-in-module + for icol in range(data.ncols): + ptrs[icol] = frame_column_data_r(data, icol) + + # always return stypes for dt ingestion + feature_type_strings = (ctypes.c_char_p * data.ncols)() + for icol in range(data.ncols): + feature_type_strings[icol] = ctypes.c_char_p( + data.stypes[icol].name.encode('utf-8')) + + _warn_unused_missing(data, missing) + handle = ctypes.c_void_p() + _check_call(_LIB.XGDMatrixCreateFromDT( + ptrs, feature_type_strings, + c_bst_ulong(data.shape[0]), + c_bst_ulong(data.shape[1]), + ctypes.byref(handle), + ctypes.c_int(nthread))) + return handle, feature_names, feature_types + + +def _is_cudf_df(data): + try: + import cudf + except ImportError: + return False + return isinstance(data, cudf.DataFrame) + + +def _cudf_array_interfaces(data): + '''Extract CuDF __cuda_array_interface__''' + interfaces = [] + if _is_cudf_ser(data): + interfaces.append(data.__cuda_array_interface__) + else: + for col in data: + interface = data[col].__cuda_array_interface__ + if 'mask' in interface: + interface['mask'] = interface['mask'].__cuda_array_interface__ + interfaces.append(interface) + interfaces_str = bytes(json.dumps(interfaces, indent=2), 'utf-8') + return interfaces_str + + +def _transform_cudf_df(data, feature_names, feature_types): + if feature_names is None: + if _is_cudf_ser(data): + feature_names = [data.name] + elif lazy_isinstance( + data.columns, 'cudf.core.multiindex', 'MultiIndex'): + feature_names = [ + ' '.join([str(x) for x in i]) + for i in data.columns + ] + else: + feature_names = data.columns.format() + if feature_types is None: + if _is_cudf_ser(data): + dtypes = [data.dtype] + else: + dtypes = data.dtypes + feature_types = [_pandas_dtype_mapper[d.name] + for d in dtypes] + return data, feature_names, feature_types + + +def _from_cudf_df(data, missing, nthread, feature_names, feature_types): + data, feature_names, feature_types = _transform_cudf_df( + data, feature_names, feature_types) + interfaces_str = _cudf_array_interfaces(data) + handle = ctypes.c_void_p() + _check_call( + _LIB.XGDMatrixCreateFromArrayInterfaceColumns( + interfaces_str, + ctypes.c_float(missing), + ctypes.c_int(nthread), + ctypes.byref(handle))) + return handle, feature_names, feature_types + + +def _is_cudf_ser(data): + try: + import cudf + except ImportError: + return False + return isinstance(data, cudf.Series) + + +def _is_cupy_array(data): + try: + import cupy + except ImportError: + return False + return isinstance(data, cupy.ndarray) + + +def _transform_cupy_array(data): + if not hasattr(data, '__cuda_array_interface__') and hasattr( + data, '__array__'): + import cupy # pylint: disable=import-error + data = cupy.array(data, copy=False) + return data + + +def _from_cupy_array(data, missing, nthread, feature_names, feature_types): + """Initialize DMatrix from cupy ndarray.""" + data = _transform_cupy_array(data) + interface = data.__cuda_array_interface__ + if 'mask' in interface: + interface['mask'] = interface['mask'].__cuda_array_interface__ + interface_str = bytes(json.dumps(interface, indent=2), 'utf-8') + + handle = ctypes.c_void_p() + _check_call( + _LIB.XGDMatrixCreateFromArrayInterface( + interface_str, + ctypes.c_float(missing), + ctypes.c_int(nthread), + ctypes.byref(handle))) + return handle, feature_names, feature_types - def transform(self, data): - return self._maybe_cudf_dataframe(data, None, None) - - def handle_input(self, data, feature_names, feature_types): - """Initialize DMatrix from columnar memory format.""" - data, feature_names, feature_types = self._maybe_cudf_dataframe( - data, feature_names, feature_types) - interfaces_str = _cudf_array_interfaces(data) - handle = ctypes.c_void_p() - _check_call( - _LIB.XGDMatrixCreateFromArrayInterfaceColumns( - interfaces_str, - ctypes.c_float(self.missing), - ctypes.c_int(self.nthread), - ctypes.byref(handle))) - return handle, feature_names, feature_types - - -__dmatrix_registry.register_handler('cudf.core.dataframe', 'DataFrame', - CudaColumnarHandler) -__dmatrix_registry.register_handler('cudf.core.series', 'Series', - CudaColumnarHandler) - - -class DLPackHandler(CudaArrayInterfaceHandler): - '''Handler of `dlpack`.''' - def _maybe_dlpack_data(self, data, feature_names, feature_types): - from cupy import fromDlpack # pylint: disable=E0401 - data = fromDlpack(data) - return data, feature_names, feature_types - def transform(self, data): - return self._maybe_dlpack_data(data, None, None) +def _is_cupy_csr(data): + try: + import cupyx + except ImportError: + return False + return isinstance(data, cupyx.scipy.sparse.csr_matrix) - def handle_input(self, data, feature_names, feature_types): - data, feature_names, feature_types = self._maybe_dlpack_data( - data, feature_names, feature_types) - return super().handle_input( - data, feature_names, feature_types) + +def _is_cupy_csc(data): + try: + import cupyx + except ImportError: + return False + return isinstance(data, cupyx.scipy.sparse.csc_matrix) -__dmatrix_registry.register_handler_opaque( - lambda x: 'PyCapsule' in str(type(x)) and "dltensor" in str(x), - DLPackHandler) +def _is_dlpack(data): + return 'PyCapsule' in str(type(data)) and "dltensor" in str(data) -class SingleBatchInternalIter(DataIter): +def _transform_dlpack(data): + from cupy import fromDlpack # pylint: disable=E0401 + assert 'used_dltensor' not in str(data) + data = fromDlpack(data) + return data + + +def _from_dlpack(data, missing, nthread, feature_names, feature_types): + data = _transform_dlpack(data) + return _from_cupy_array(data, missing, nthread, feature_names, + feature_types) + + +def _is_uri(data): + return isinstance(data, (str, os_PathLike)) + + +def _from_uri(data, missing, feature_names, feature_types): + _warn_unused_missing(data, missing) + handle = ctypes.c_void_p() + _check_call(_LIB.XGDMatrixCreateFromFile(c_str(os_fspath(data)), + ctypes.c_int(1), + ctypes.byref(handle))) + return handle, feature_names, feature_types + + +def _is_list(data): + return isinstance(data, list) + + +def _from_list(data, missing, feature_names, feature_types): + raise TypeError('List input data is not supported for data') + + +def _is_tuple(data): + return isinstance(data, tuple) + + +def _from_tuple(data, missing, feature_names, feature_types): + return _from_list(data, missing, feature_names, feature_types) + + +def _is_iter(data): + return isinstance(data, DataIter) + + +def _has_array_protocol(data): + return hasattr(data, '__array__') + + +def dispatch_data_backend(data, missing, threads, + feature_names, feature_types): + '''Dispatch data for DMatrix.''' + if _is_scipy_csr(data): + return _from_scipy_csr(data, missing, feature_names, feature_types) + if _is_scipy_csc(data): + return _from_scipy_csc(data, missing, feature_names, feature_types) + if _is_numpy_array(data): + return _from_numpy_array(data, missing, threads, feature_names, + feature_types) + if _is_uri(data): + return _from_uri(data, missing, feature_names, feature_types) + if _is_list(data): + return _from_list(data, missing, feature_names, feature_types) + if _is_tuple(data): + return _from_tuple(data, missing, feature_names, feature_types) + if _is_pandas_df(data): + return _from_pandas_df(data, missing, threads, + feature_names, feature_types) + if _is_pandas_series(data): + return _from_pandas_series(data, missing, threads, feature_names, + feature_types) + if _is_cudf_df(data): + return _from_cudf_df(data, missing, threads, feature_names, + feature_types) + if _is_cudf_ser(data): + return _from_cudf_df(data, missing, threads, feature_names, + feature_types) + if _is_cupy_array(data): + return _from_cupy_array(data, missing, threads, feature_names, + feature_types) + if _is_cupy_csr(data): + raise TypeError('cupyx CSR is not supported yet.') + if _is_cupy_csc(data): + raise TypeError('cupyx CSC is not supported yet.') + if _is_dlpack(data): + return _from_dlpack(data, missing, threads, feature_names, + feature_types) + if _is_dt_df(data): + _warn_unused_missing(data, missing) + return _from_dt_df(data, missing, threads, feature_names, + feature_types) + if _has_array_protocol(data): + pass + raise TypeError('Not supported type for data.' + str(type(data))) + + +def _meta_from_numpy(data, field, dtype, handle): + data = _maybe_np_slice(data, dtype) + if dtype == 'uint32': + c_data = c_array(ctypes.c_uint32, data) + _check_call(_LIB.XGDMatrixSetUIntInfo(handle, + c_str(field), + c_array(ctypes.c_uint, data), + c_bst_ulong(len(data)))) + elif dtype == 'float': + c_data = c_array(ctypes.c_float, data) + _check_call(_LIB.XGDMatrixSetFloatInfo(handle, + c_str(field), + c_data, + c_bst_ulong(len(data)))) + else: + raise TypeError('Unsupported type ' + str(dtype) + ' for:' + field) + + +def _meta_from_list(data, field, dtype, handle): + data = np.array(data) + _meta_from_numpy(data, field, dtype, handle) + + +def _meta_from_tuple(data, field, dtype, handle): + return _meta_from_list(data, field, dtype, handle) + + +def _meta_from_cudf_df(data, field, handle): + if len(data.columns) != 1: + raise ValueError( + 'Expecting meta-info to contain a single column') + data = data[data.columns[0]] + + interface = bytes(json.dumps([data.__cuda_array_interface__], + indent=2), 'utf-8') + _check_call(_LIB.XGDMatrixSetInfoFromInterface(handle, + c_str(field), + interface)) + + +def _meta_from_cudf_series(data, field, handle): + interface = bytes(json.dumps([data.__cuda_array_interface__], + indent=2), 'utf-8') + _check_call(_LIB.XGDMatrixSetInfoFromInterface(handle, + c_str(field), + interface)) + + +def _meta_from_cupy_array(data, field, handle): + data = _transform_cupy_array(data) + interface = bytes(json.dumps([data.__cuda_array_interface__], + indent=2), 'utf-8') + _check_call(_LIB.XGDMatrixSetInfoFromInterface(handle, + c_str(field), + interface)) + + +def _meta_from_dt(data, field, dtype, handle): + data, _, _ = _transform_dt_df(data, None, None) + _meta_from_numpy(data, field, dtype, handle) + + +def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None): + '''Dispatch for meta info.''' + handle = matrix.handle + if data is None: + return + if _is_list(data): + _meta_from_list(data, name, dtype, handle) + return + if _is_tuple(data): + _meta_from_tuple(data, name, dtype, handle) + return + if _is_numpy_array(data): + _meta_from_numpy(data, name, dtype, handle) + return + if _is_pandas_df(data): + data, _, _ = _transform_pandas_df(data, meta=name, meta_type=dtype) + _meta_from_numpy(data, name, dtype, handle) + return + if _is_pandas_series(data): + data = data.values.astype('float') + assert len(data.shape) == 1 or data.shape[1] == 0 or data.shape[1] == 1 + _meta_from_numpy(data, name, dtype, handle) + return + if _is_dlpack(data): + data = _transform_dlpack(data) + _meta_from_cupy_array(data, name, handle) + return + if _is_cupy_array(data): + _meta_from_cupy_array(data, name, handle) + return + if _is_cudf_ser(data): + _meta_from_cudf_series(data, name, handle) + return + if _is_cudf_df(data): + _meta_from_cudf_df(data, name, handle) + return + if _is_dt_df(data): + _meta_from_dt(data, name, dtype, handle) + return + if _has_array_protocol(data): + pass + raise TypeError('Unsupported type for ' + name, str(type(data))) + + +class SingleBatchInternalIter(DataIter): # pylint: disable=R0902 '''An iterator for single batch data to help creating device DMatrix. Transforming input directly to histogram with normal single batch data API can not access weight for sketching. So this iterator acts as a staging @@ -540,7 +644,8 @@ class SingleBatchInternalIter(DataIter): ''' def __init__(self, data, label, weight, base_margin, group, - label_lower_bound, label_upper_bound): + label_lower_bound, label_upper_bound, + feature_names, feature_types): self.data = data self.label = label self.weight = weight @@ -548,6 +653,8 @@ def __init__(self, data, label, weight, base_margin, group, self.group = group self.label_lower_bound = label_lower_bound self.label_upper_bound = label_upper_bound + self.feature_names = feature_names + self.feature_types = feature_types self.it = 0 # pylint: disable=invalid-name super().__init__() @@ -559,146 +666,90 @@ def next(self, input_data): weight=self.weight, base_margin=self.base_margin, group=self.group, label_lower_bound=self.label_lower_bound, - label_upper_bound=self.label_upper_bound) + label_upper_bound=self.label_upper_bound, + feature_names=self.feature_names, + feature_types=self.feature_types) return 1 def reset(self): self.it = 0 -__device_quantile_dmatrix_registry = DMatrixDataManager() # pylint: disable=invalid-name - - -class DeviceQuantileDMatrixDataHandler(DataHandler): # pylint: disable=abstract-method - '''Base class of data handler for `DeviceQuantileDMatrix`.''' - def __init__(self, max_bin, missing, nthread, silent, - meta=None, meta_type=None): - self.max_bin = max_bin - super().__init__(missing, nthread, silent, meta, meta_type) - - def handle_meta(self, label=None, weight=None, base_margin=None, - group=None, - label_lower_bound=None, - label_upper_bound=None): - self.label = label - self.weight = weight - self.base_margin = base_margin - self.group = group - self.label_lower_bound = label_lower_bound - self.label_upper_bound = label_upper_bound - - def handle_input(self, data, feature_names, feature_types): - if not isinstance(data, DataIter): - it = SingleBatchInternalIter( - data, self.label, self.weight, - self.base_margin, self.group, - self.label_lower_bound, self.label_upper_bound) - else: - it = data - reset_factory = ctypes.CFUNCTYPE(None, ctypes.c_void_p) - reset_callback = reset_factory(it.reset_wrapper) - next_factory = ctypes.CFUNCTYPE( - ctypes.c_int, - ctypes.c_void_p, - ) - next_callback = next_factory(it.next_wrapper) - handle = ctypes.c_void_p() - ret = _LIB.XGDeviceQuantileDMatrixCreateFromCallback( - None, - it.proxy.handle, - reset_callback, - next_callback, - ctypes.c_float(self.missing), - ctypes.c_int(self.nthread), - ctypes.c_int(self.max_bin), - ctypes.byref(handle) - ) - if it.exception: - raise it.exception - # delay check_call to throw intermediate exception first - _check_call(ret) - return handle, feature_names, feature_types - - -__device_quantile_dmatrix_registry.register_handler_opaque( - lambda x: isinstance(x, DataIter), - DeviceQuantileDMatrixDataHandler) - - -def get_device_quantile_dmatrix_data_handler( - data, max_bin, missing, nthread, silent): - '''Get data handler for `DeviceQuantileDMatrix`. Similar to - `get_dmatrix_data_handler`. - - .. versionadded:: 1.2.0 - - ''' - handler = __device_quantile_dmatrix_registry.get_handler( - data) - assert handler, 'Current data type ' + str(type(data)) +\ - ' is not supported for DeviceQuantileDMatrix' - return handler(max_bin, missing, nthread, silent) - - -class DeviceQuantileCudaArrayInterfaceHandler( - DeviceQuantileDMatrixDataHandler): - '''Handler of data with `__cuda_array_interface__`, for - `DeviceQuantileDMatrix`. - - ''' - def handle_input(self, data, feature_names, feature_types): - """Initialize DMatrix from cupy ndarray.""" - if not hasattr(data, '__cuda_array_interface__') and hasattr( - data, '__array__'): - import cupy # pylint: disable=import-error - data = cupy.array(data, copy=False) - return super().handle_input(data, feature_names, feature_types) - - -__device_quantile_dmatrix_registry.register_handler( - 'cupy.core.core', 'ndarray', DeviceQuantileCudaArrayInterfaceHandler) - - -class DeviceQuantileCudaColumnarHandler(DeviceQuantileDMatrixDataHandler, - CudaColumnarHandler): - '''Handler of CUDA based columnar data, for `DeviceQuantileDMatrix`.''' - def __init__(self, max_bin, missing, nthread, silent, - meta=None, meta_type=None): - super().__init__( - max_bin=max_bin, missing=missing, nthread=nthread, silent=silent, - meta=meta, meta_type=meta_type - ) - - def handle_input(self, data, feature_names, feature_types): - """Initialize Quantile Device DMatrix from columnar memory format.""" - data, feature_names, feature_types = self._maybe_cudf_dataframe( - data, feature_names, feature_types) - return super().handle_input(data, feature_names, feature_types) - - -__device_quantile_dmatrix_registry.register_handler( - 'cudf.core.dataframe', 'DataFrame', DeviceQuantileCudaColumnarHandler) -__device_quantile_dmatrix_registry.register_handler( - 'cudf.core.series', 'Series', DeviceQuantileCudaColumnarHandler) - - -class DeviceQuantileDLPackHandler(DeviceQuantileCudaArrayInterfaceHandler, - DLPackHandler): - '''Handler of `dlpack`, for `DeviceQuantileDMatrix`.''' - def __init__(self, max_bin, missing, nthread, silent, - meta=None, meta_type=None): - super().__init__( - max_bin=max_bin, missing=missing, nthread=nthread, silent=silent, - meta=meta, meta_type=meta_type - ) - - def handle_input(self, data, feature_names, feature_types): - data, feature_names, feature_types = self._maybe_dlpack_data( - data, feature_names, feature_types) - return super().handle_input( - data, feature_names, feature_types) - - -__device_quantile_dmatrix_registry.register_handler_opaque( - lambda x: 'PyCapsule' in str(type(x)) and "dltensor" in str(x), - DeviceQuantileDLPackHandler) +def init_device_quantile_dmatrix( + data, missing, max_bin, threads, feature_names, feature_types, **meta): + '''Constructor for DeviceQuantileDMatrix.''' + if not any([_is_cudf_df(data), _is_cudf_ser(data), _is_cupy_array(data), + _is_dlpack(data), _is_iter(data)]): + raise TypeError(str(type(data)) + + ' is not supported for DeviceQuantileDMatrix') + if _is_dlpack(data): + # We specialize for dlpack because cupy will take the memory from it so + # it can't be transformed twice. + data = _transform_dlpack(data) + if _is_iter(data): + it = data + else: + it = SingleBatchInternalIter( + data, **meta, feature_names=feature_names, + feature_types=feature_types) + + reset_factory = ctypes.CFUNCTYPE(None, ctypes.c_void_p) + reset_callback = reset_factory(it.reset_wrapper) + next_factory = ctypes.CFUNCTYPE( + ctypes.c_int, + ctypes.c_void_p, + ) + next_callback = next_factory(it.next_wrapper) + handle = ctypes.c_void_p() + ret = _LIB.XGDeviceQuantileDMatrixCreateFromCallback( + None, + it.proxy.handle, + reset_callback, + next_callback, + ctypes.c_float(missing), + ctypes.c_int(threads), + ctypes.c_int(max_bin), + ctypes.byref(handle) + ) + if it.exception: + raise it.exception + # delay check_call to throw intermediate exception first + _check_call(ret) + matrix = DeviceQuantileDMatrix(handle) + feature_names = matrix.feature_names + feature_types = matrix.feature_types + matrix.handle = None + return handle, feature_names, feature_types + + +def _device_quantile_transform(data, feature_names, feature_types): + if _is_cudf_df(data): + return _transform_cudf_df(data, feature_names, feature_types) + if _is_cudf_ser(data): + return _transform_cudf_df(data, feature_names, feature_types) + if _is_cupy_array(data): + data = _transform_cupy_array(data) + return data, feature_names, feature_types + if _is_dlpack(data): + return _transform_dlpack(data), feature_names, feature_types + raise TypeError('Value type is not supported for data iterator:' + + str(type(data))) + + +def dispatch_device_quantile_dmatrix_set_data(proxy, data): + '''Dispatch for DeviceQuantileDMatrix.''' + if _is_cudf_df(data): + proxy._set_data_from_cuda_columnar(data) # pylint: disable=W0212 + return + if _is_cudf_ser(data): + proxy._set_data_from_cuda_columnar(data) # pylint: disable=W0212 + return + if _is_cupy_array(data): + proxy._set_data_from_cuda_interface(data) # pylint: disable=W0212 + return + if _is_dlpack(data): + data = _transform_dlpack(data) + proxy._set_data_from_cuda_interface(data) # pylint: disable=W0212 + return + raise TypeError('Value type is not supported for data iterator:' + + str(type(data))) diff --git a/src/data/proxy_dmatrix.cu b/src/data/proxy_dmatrix.cu index 088351021675..adad6f4c4d79 100644 --- a/src/data/proxy_dmatrix.cu +++ b/src/data/proxy_dmatrix.cu @@ -12,12 +12,16 @@ void DMatrixProxy::FromCudaColumnar(std::string interface_str) { auto const& value = adapter->Value(); this->batch_ = adapter; device_ = adapter->DeviceIdx(); + this->Info().num_col_ = adapter->NumColumns(); + this->Info().num_row_ = adapter->NumRows(); } void DMatrixProxy::FromCudaArray(std::string interface_str) { std::shared_ptr adapter(new CupyAdapter(interface_str)); this->batch_ = adapter; device_ = adapter->DeviceIdx(); + this->Info().num_col_ = adapter->NumColumns(); + this->Info().num_row_ = adapter->NumRows(); } } // namespace data diff --git a/tests/python-gpu/test_device_quantile_dmatrix.py b/tests/python-gpu/test_device_quantile_dmatrix.py index 8b8bf85954f8..f0978a0afaf4 100644 --- a/tests/python-gpu/test_device_quantile_dmatrix.py +++ b/tests/python-gpu/test_device_quantile_dmatrix.py @@ -12,11 +12,12 @@ class TestDeviceQuantileDMatrix(unittest.TestCase): def test_dmatrix_numpy_init(self): data = np.random.randn(5, 5) - with pytest.raises(AssertionError, match='is not supported for DeviceQuantileDMatrix'): - dm = xgb.DeviceQuantileDMatrix(data, np.ones(5, dtype=np.float64)) + with pytest.raises(TypeError, + match='is not supported for DeviceQuantileDMatrix'): + xgb.DeviceQuantileDMatrix(data, np.ones(5, dtype=np.float64)) @pytest.mark.skipif(**tm.no_cupy()) def test_dmatrix_cupy_init(self): import cupy as cp data = cp.random.randn(5, 5) - dm = xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64)) + xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64)) diff --git a/tests/python-gpu/test_from_cudf.py b/tests/python-gpu/test_from_cudf.py index 0ba5931c9b56..661f2ed2178c 100644 --- a/tests/python-gpu/test_from_cudf.py +++ b/tests/python-gpu/test_from_cudf.py @@ -119,10 +119,10 @@ def _test_cudf_metainfo(DMatrixT): dmat.set_float_info('label', floats) dmat.set_float_info('base_margin', floats) dmat.set_uint_info('group', uints) - dmat_cudf.set_interface_info('weight', cudf_floats) - dmat_cudf.set_interface_info('label', cudf_floats) - dmat_cudf.set_interface_info('base_margin', cudf_floats) - dmat_cudf.set_interface_info('group', cudf_uints) + dmat_cudf.set_info(weight=cudf_floats) + dmat_cudf.set_info(label=cudf_floats) + dmat_cudf.set_info(base_margin=cudf_floats) + dmat_cudf.set_info(group=cudf_uints) # Test setting info with cudf DataFrame assert np.array_equal(dmat.get_float_info('weight'), dmat_cudf.get_float_info('weight')) @@ -132,10 +132,10 @@ def _test_cudf_metainfo(DMatrixT): assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cudf.get_uint_info('group_ptr')) # Test setting info with cudf Series - dmat_cudf.set_interface_info('weight', cudf_floats[cudf_floats.columns[0]]) - dmat_cudf.set_interface_info('label', cudf_floats[cudf_floats.columns[0]]) - dmat_cudf.set_interface_info('base_margin', cudf_floats[cudf_floats.columns[0]]) - dmat_cudf.set_interface_info('group', cudf_uints[cudf_uints.columns[0]]) + dmat_cudf.set_info(weight=cudf_floats[cudf_floats.columns[0]]) + dmat_cudf.set_info(label=cudf_floats[cudf_floats.columns[0]]) + dmat_cudf.set_info(base_margin=cudf_floats[cudf_floats.columns[0]]) + dmat_cudf.set_info(group=cudf_uints[cudf_uints.columns[0]]) assert np.array_equal(dmat.get_float_info('weight'), dmat_cudf.get_float_info('weight')) assert np.array_equal(dmat.get_float_info('label'), dmat_cudf.get_float_info('label')) assert np.array_equal(dmat.get_float_info('base_margin'), diff --git a/tests/python-gpu/test_from_cupy.py b/tests/python-gpu/test_from_cupy.py index 18f072f56120..cfe9cb8a2657 100644 --- a/tests/python-gpu/test_from_cupy.py +++ b/tests/python-gpu/test_from_cupy.py @@ -92,10 +92,10 @@ def _test_cupy_metainfo(DMatrixT): dmat.set_float_info('label', floats) dmat.set_float_info('base_margin', floats) dmat.set_uint_info('group', uints) - dmat_cupy.set_interface_info('weight', cupy_floats) - dmat_cupy.set_interface_info('label', cupy_floats) - dmat_cupy.set_interface_info('base_margin', cupy_floats) - dmat_cupy.set_interface_info('group', cupy_uints) + dmat_cupy.set_info(weight=cupy_floats) + dmat_cupy.set_info(label=cupy_floats) + dmat_cupy.set_info(base_margin=cupy_floats) + dmat_cupy.set_info(group=cupy_uints) # Test setting info with cupy assert np.array_equal(dmat.get_float_info('weight'), diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index e7d15cafec13..dad7ddc9db0c 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -1,17 +1,14 @@ # -*- coding: utf-8 -*- import sys from contextlib import contextmanager -try: - # python 2 - from StringIO import StringIO -except ImportError: - # python 3 - from io import StringIO +from io import StringIO import numpy as np +import os import xgboost as xgb import unittest import json from pathlib import Path +import tempfile dpath = 'demo/data/' rng = np.random.RandomState(1994) @@ -66,16 +63,19 @@ def test_basic(self): # error must be smaller than 10% assert err < 0.1 - # save dmatrix into binary buffer - dtest.save_binary('dtest.buffer') - # save model - bst.save_model('xgb.model') - # load model and data in - bst2 = xgb.Booster(model_file='xgb.model') - dtest2 = xgb.DMatrix('dtest.buffer') - preds2 = bst2.predict(dtest2) - # assert they are the same - assert np.sum(np.abs(preds2 - preds)) == 0 + with tempfile.TemporaryDirectory() as tmpdir: + dtest_path = os.path.join(tmpdir, 'dtest.dmatrix') + # save dmatrix into binary buffer + dtest.save_binary(dtest_path) + # save model + model_path = os.path.join(tmpdir, 'model.booster') + bst.save_model(model_path) + # load model and data in + bst2 = xgb.Booster(model_file=model_path) + dtest2 = xgb.DMatrix(dtest_path) + preds2 = bst2.predict(dtest2) + # assert they are the same + assert np.sum(np.abs(preds2 - preds)) == 0 def test_record_results(self): dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') diff --git a/tests/python/test_with_pandas.py b/tests/python/test_with_pandas.py index 28f3536ede13..04f8c9510629 100644 --- a/tests/python/test_with_pandas.py +++ b/tests/python/test_with_pandas.py @@ -67,8 +67,7 @@ def test_pandas(self): # 0 1 1 0 0 # 1 2 0 1 0 # 2 3 0 0 1 - pandas_handler = xgb.data.PandasHandler(np.nan, 0, False) - result, _, _ = pandas_handler._maybe_pandas_data(dummies, None, None) + result, _, _ = xgb.data._transform_pandas_df(dummies) exp = np.array([[1., 1., 0., 0.], [2., 0., 1., 0.], [3., 0., 0., 1.]]) @@ -129,18 +128,17 @@ def test_pandas_sparse(self): def test_pandas_label(self): # label must be a single column df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]}) - pandas_handler = xgb.data.PandasHandler(np.nan, 0, False) - self.assertRaises(ValueError, pandas_handler._maybe_pandas_data, df, + self.assertRaises(ValueError, xgb.data._transform_pandas_df, df, None, None, 'label', 'float') # label must be supported dtype df = pd.DataFrame({'A': np.array(['a', 'b', 'c'], dtype=object)}) - self.assertRaises(ValueError, pandas_handler._maybe_pandas_data, df, + self.assertRaises(ValueError, xgb.data._transform_pandas_df, df, None, None, 'label', 'float') df = pd.DataFrame({'A': np.array([1, 2, 3], dtype=int)}) - result, _, _ = pandas_handler._maybe_pandas_data(df, None, None, - 'label', 'float') + result, _, _ = xgb.data._transform_pandas_df(df, None, None, + 'label', 'float') np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]], dtype=float)) dm = xgb.DMatrix(np.random.randn(3, 2), label=df)