From 5c7ae629c6364e2d8267d51a97c6072c387ce679 Mon Sep 17 00:00:00 2001 From: fis Date: Mon, 16 Aug 2021 10:18:07 +0800 Subject: [PATCH 01/17] Add typehint with generic. More typehint.s More types. Complete type the data handling. Extract a typing module. revert. Remove 256 Remove any. Fix dtypes. --- python-package/xgboost/core.py | 306 +++++++++++++------------ python-package/xgboost/data.py | 357 +++++++++++++++++------------- python-package/xgboost/sklearn.py | 10 +- python-package/xgboost/typing.py | 121 ++++++++++ 4 files changed, 484 insertions(+), 310 deletions(-) create mode 100644 python-package/xgboost/typing.py diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 3678f68362db..ab6053a25cfa 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -4,9 +4,8 @@ """Core XGBoost Library.""" # pylint: disable=no-name-in-module,import-error from collections.abc import Mapping -from typing import List, Optional, Any, Union, Dict, TypeVar -# pylint: enable=no-name-in-module,import-error -from typing import Callable, Tuple, cast, Sequence +from typing import List, Optional, Any, Union, Dict, TypeVar, Sequence, cast +from typing import Callable, Tuple, Type, TextIO, cast import ctypes import os import re @@ -17,16 +16,19 @@ from inspect import signature, Parameter import numpy as np + import scipy.sparse from .compat import (STRING_TYPES, DataFrame, py_str, PANDAS_INSTALLED, lazy_isinstance) from .libpath import find_lib_path +from .typing import CuArrayLike, CuDFLike, NPArrayLike, DFLike, CSRLike +from .typing import FeatureTypes, NativeInput, array_like, DTypeLike # c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h c_bst_ulong = ctypes.c_uint64 -# xgboost accepts some other possible types in practice due to historical reason, which is -# lesser tested. For now we encourage users to pass a simple list of string. +# xgboost accepts some other possible types in practice due to historical reason, which +# is lesser tested. For now we encourage users to pass a simple list of string. FeatNamesT = Optional[List[str]] @@ -43,17 +45,11 @@ def from_pystr_to_cstr(data: Union[str, List[str]]): str or list of str """ - if isinstance(data, str): - return bytes(data, "utf-8") - if isinstance(data, list): - pointers = (ctypes.c_char_p * len(data))() - data = [bytes(d, 'utf-8') for d in data] - pointers[:] = data - return pointers - raise TypeError() + assert isinstance(data, str) + return bytes(data, "utf-8") -def from_cstr_to_pystr(data, length) -> List[str]: +def from_cstr_to_pystr(data: ctypes.pointer, length: ctypes.c_uint64) -> List[str]: """Revert C pointer to Python str Parameters @@ -93,7 +89,7 @@ def _convert_ntree_limit( return iteration_range -def _expect(expectations, got): +def _expect(expectations: Sequence[Type], got: Type) -> str: """Translate input error into string. Parameters @@ -170,7 +166,7 @@ def _load_lib() -> ctypes.CDLL: Error message(s): {os_error_list} """) lib.XGBGetLastError.restype = ctypes.c_char_p - lib.callback = _get_log_callback_func() + setattr(lib, "callback", _get_log_callback_func()) if lib.XGBRegisterLogCallback(lib.callback) != 0: raise XGBoostError(lib.XGBGetLastError()) return lib @@ -211,8 +207,8 @@ def build_info() -> dict: return res -def _numpy2ctypes_type(dtype): - _NUMPY_TO_CTYPES_MAPPING = { +def _numpy2ctypes_type(dtype: DTypeLike) -> ctypes._SimpleCData: + _NUMPY_TO_CTYPES_MAPPING: Dict = { np.float32: ctypes.c_float, np.float64: ctypes.c_double, np.uint32: ctypes.c_uint, @@ -229,29 +225,30 @@ def _numpy2ctypes_type(dtype): return _NUMPY_TO_CTYPES_MAPPING[dtype] -def _cuda_array_interface(data) -> bytes: +def _cuda_array_interface(data: CuArrayLike) -> bytes: assert ( data.dtype.hasobject is False ), "Input data contains `object` dtype. Expecting numeric data." interface = data.__cuda_array_interface__ - if "mask" in interface: - interface["mask"] = interface["mask"].__cuda_array_interface__ + if "mask" in interface and interface["mask"] is not None: + mask: CuArrayLike = cast(CuArrayLike, interface["mask"]) + interface["mask"] = mask.__cuda_array_interface__ interface_str = bytes(json.dumps(interface), "utf-8") return interface_str -def ctypes2numpy(cptr, length, dtype) -> np.ndarray: +def ctypes2numpy(cptr: ctypes.pointer, length: int, dtype: DTypeLike) -> np.ndarray: """Convert a ctypes pointer array to a numpy array.""" ctype = _numpy2ctypes_type(dtype) if not isinstance(cptr, ctypes.POINTER(ctype)): raise RuntimeError(f"expected {ctype} pointer") res = np.zeros(length, dtype=dtype) - if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]): + if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]): # type:ignore raise RuntimeError("memmove failed") return res -def ctypes2cupy(cptr, length, dtype): +def ctypes2cupy(cptr: ctypes.pointer, length: int, dtype: DTypeLike) -> CuArrayLike: """Convert a ctypes pointer array to a cupy array.""" # pylint: disable=import-error import cupy @@ -277,36 +274,42 @@ def ctypes2cupy(cptr, length, dtype): return arr -def ctypes2buffer(cptr, length) -> bytearray: +def ctypes2buffer(cptr: ctypes.pointer, length: int) -> bytearray: """Convert ctypes pointer to buffer type.""" if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)): raise RuntimeError('expected char pointer') res = bytearray(length) rptr = (ctypes.c_char * length).from_buffer(res) - if not ctypes.memmove(rptr, cptr, length): + if not ctypes.memmove(rptr, cptr, length): # type:ignore raise RuntimeError('memmove failed') return res -def c_str(string): +def c_str(string: str) -> ctypes.c_char_p: """Convert a python string to cstring.""" return ctypes.c_char_p(string.encode('utf-8')) -def c_array(ctype, values): +def c_array( + ctype: Type, values: Union[NPArrayLike, List[ctypes.c_void_p], List[ctypes.c_char_p]] +) -> ctypes.pointer: """Convert a python string to c array.""" if isinstance(values, np.ndarray) and values.dtype.itemsize == ctypes.sizeof(ctype): return (ctype * len(values)).from_buffer_copy(values) return (ctype * len(values))(*values) -def _prediction_output(shape, dims, predts, is_cuda): +def _prediction_output( + shape: ctypes.pointer, dims: ctypes.c_uint64, predts: ctypes.pointer, is_cuda: bool +) -> Union[NPArrayLike, CuArrayLike]: arr_shape: np.ndarray = ctypes2numpy(shape, dims.value, np.uint64) length = int(np.prod(arr_shape)) if is_cuda: - arr_predict = ctypes2cupy(predts, length, np.float32) + arr_predict: Union[np.ndarray, CuArrayLike] = ctypes2cupy( + predts, length, np.float32 + ) else: - arr_predict: np.ndarray = ctypes2numpy(predts, length, np.float32) + arr_predict = ctypes2numpy(predts, length, np.float32) arr_predict = arr_predict.reshape(arr_shape) return arr_predict @@ -401,7 +404,7 @@ def data_handle( data: Any, *, feature_names: FeatNamesT = None, - feature_types: Optional[List[str]] = None, + feature_types: FeatureTypes = None, **kwargs: Any, ) -> None: from .data import dispatch_proxy_set_data @@ -455,7 +458,7 @@ def next(self, input_data: Callable) -> int: # Nicolas Tresegnie # Sylvain Marie # License: BSD 3 clause -def _deprecate_positional_args(f): +def _deprecate_positional_args(f: Callable) -> Callable: """Decorator for methods that issues warnings for positional arguments Using the keyword-only argument syntax in pep 3102, arguments after the @@ -479,7 +482,7 @@ def _deprecate_positional_args(f): kwonly_args.append(name) @wraps(f) - def inner_f(*args, **kwargs): + def inner_f(*args: Any, **kwargs: Any) -> Any: extra_args = len(args) - len(all_args) if extra_args > 0: # ignore first 'self' argument for instance methods @@ -512,21 +515,20 @@ class DMatrix: # pylint: disable=too-many-instance-attributes @_deprecate_positional_args def __init__( self, - data, - label=None, + data: NativeInput, + label: Optional[array_like] = None, *, - weight=None, - base_margin=None, + weight: Optional[array_like] = None, + base_margin: Optional[array_like] = None, missing: Optional[float] = None, - silent=False, feature_names: FeatNamesT = None, - feature_types: Optional[List[str]] = None, + feature_types: FeatureTypes = None, nthread: Optional[int] = None, - group=None, - qid=None, - label_lower_bound=None, - label_upper_bound=None, - feature_weights=None, + group: Optional[array_like] = None, + qid: Optional[array_like] = None, + label_lower_bound: Optional[array_like] = None, + label_upper_bound: Optional[array_like] = None, + feature_weights: Optional[array_like] = None, enable_categorical: bool = False, ) -> None: """Parameters @@ -602,10 +604,12 @@ def __init__( from .data import dispatch_data_backend, _is_iter if _is_iter(data): + assert isinstance(data, DataIter) self._init_from_iter(data, enable_categorical) assert self.handle is not None return + assert not isinstance(data, DataIter) handle, feature_names, feature_types = dispatch_data_backend( data, missing=self.missing, @@ -640,7 +644,6 @@ def _init_from_iter(self, iterator: DataIter, enable_categorical: bool) -> None: "nthread": self.nthread, "cache_prefix": it.cache_prefix if it.cache_prefix else "", } - args = from_pystr_to_cstr(json.dumps(args)) handle = ctypes.c_void_p() # pylint: disable=protected-access reset_callback, next_callback = it._get_callbacks( @@ -651,7 +654,7 @@ def _init_from_iter(self, iterator: DataIter, enable_categorical: bool) -> None: it.proxy.handle, reset_callback, next_callback, - args, + from_pystr_to_cstr(json.dumps(args)), ctypes.byref(handle), ) # pylint: disable=protected-access @@ -669,16 +672,16 @@ def __del__(self) -> None: def set_info( self, *, - label=None, - weight=None, - base_margin=None, - group=None, - qid=None, - label_lower_bound=None, - label_upper_bound=None, + label: Optional[array_like] = None, + weight: Optional[array_like] = None, + base_margin: Optional[array_like] = None, + group: Optional[array_like] = None, + qid: Optional[array_like] = None, + label_lower_bound: Optional[array_like] = None, + label_upper_bound: Optional[array_like] = None, feature_names: FeatNamesT = None, - feature_types: Optional[List[str]] = None, - feature_weights=None + feature_types: FeatureTypes = None, + feature_weights: Optional[array_like] = None ) -> None: """Set meta info for DMatrix. See doc string for :py:obj:`xgboost.DMatrix`.""" from .data import dispatch_meta_backend @@ -747,7 +750,7 @@ def get_uint_info(self, field: str) -> np.ndarray: ctypes.byref(ret))) return ctypes2numpy(ret, length.value, np.uint32) - def set_float_info(self, field: str, data) -> None: + def set_float_info(self, field: str, data: array_like) -> None: """Set float type property into the DMatrix. Parameters @@ -755,13 +758,13 @@ def set_float_info(self, field: str, data) -> None: field: str The field name of the information - data: numpy array + data: The array of data to be set """ from .data import dispatch_meta_backend dispatch_meta_backend(self, data, field, 'float') - def set_float_info_npy2d(self, field: str, data) -> None: + def set_float_info_npy2d(self, field: str, data: array_like) -> None: """Set float type property into the DMatrix for numpy 2d array input @@ -770,13 +773,13 @@ def set_float_info_npy2d(self, field: str, data) -> None: field: str The field name of the information - data: numpy array + data: The array of data to be set """ from .data import dispatch_meta_backend dispatch_meta_backend(self, data, field, 'float') - def set_uint_info(self, field: str, data) -> None: + def set_uint_info(self, field: str, data: array_like) -> None: """Set uint type property into the DMatrix. Parameters @@ -790,7 +793,7 @@ def set_uint_info(self, field: str, data) -> None: from .data import dispatch_meta_backend dispatch_meta_backend(self, data, field, 'uint32') - def save_binary(self, fname, silent=True) -> None: + def save_binary(self, fname: os.PathLike, silent: bool = True) -> None: """Save DMatrix to an XGBoost buffer. Saved binary can be later loaded by providing the path to :py:func:`xgboost.DMatrix` as input. @@ -801,12 +804,12 @@ def save_binary(self, fname, silent=True) -> None: silent : bool (optional; default: True) If set, the output is suppressed. """ - fname = os.fspath(os.path.expanduser(fname)) + fname_str = os.fspath(os.path.expanduser(fname)) _check_call(_LIB.XGDMatrixSaveBinary(self.handle, - c_str(fname), + c_str(fname_str), ctypes.c_int(silent))) - def set_label(self, label) -> None: + def set_label(self, label: array_like) -> None: """Set label of dmatrix Parameters @@ -817,7 +820,7 @@ def set_label(self, label) -> None: from .data import dispatch_meta_backend dispatch_meta_backend(self, label, 'label', 'float') - def set_weight(self, weight) -> None: + def set_weight(self, weight: array_like) -> None: """Set weight of each instance. Parameters @@ -836,7 +839,7 @@ def set_weight(self, weight) -> None: from .data import dispatch_meta_backend dispatch_meta_backend(self, weight, 'weight', 'float') - def set_base_margin(self, margin) -> None: + def set_base_margin(self, margin: array_like) -> None: """Set base margin of booster to start from. This can be used to specify a prediction value of existing model to be @@ -853,7 +856,7 @@ def set_base_margin(self, margin) -> None: from .data import dispatch_meta_backend dispatch_meta_backend(self, margin, 'base_margin', 'float') - def set_group(self, group) -> None: + def set_group(self, group: array_like) -> None: """Set group size of DMatrix (used for ranking). Parameters @@ -991,31 +994,27 @@ def feature_names(self, feature_names: FeatNamesT) -> None: """ if feature_names is not None: # validate feature name - try: - if not isinstance(feature_names, str): - feature_names = list(feature_names) - else: - feature_names = [feature_names] - except TypeError: - feature_names = [feature_names] - - if len(feature_names) != len(set(feature_names)): + if not isinstance(feature_names, str): + feature_names_list = list(feature_names) + else: + feature_names_list = [feature_names] + if len(feature_names_list) != len(set(feature_names_list)): raise ValueError('feature_names must be unique') - if len(feature_names) != self.num_col() and self.num_col() != 0: + if len(feature_names_list) != self.num_col() and self.num_col() != 0: msg = ("feature_names must have the same length as data, ", f"expected {self.num_col()}, got {len(feature_names)}") raise ValueError(msg) # prohibit to use symbols may affect to parse. e.g. []< if not all(isinstance(f, str) and not any(x in f for x in set(('[', ']', '<'))) - for f in feature_names): + for f in feature_names_list): raise ValueError('feature_names must be string, and may not contain [, ] or <') - c_feature_names = [bytes(f, encoding='utf-8') for f in feature_names] - c_feature_names = (ctypes.c_char_p * - len(c_feature_names))(*c_feature_names) + c_feature_names = [bytes(f, encoding='utf-8') for f in feature_names_list] + c_feature_n_ptrs = (ctypes.c_char_p * + len(c_feature_names))(*c_feature_names) _check_call(_LIB.XGDMatrixSetStrFeatureInfo( self.handle, c_str('feature_name'), - c_feature_names, + c_feature_n_ptrs, c_bst_ulong(len(feature_names)))) else: # reset feature_types also @@ -1074,14 +1073,13 @@ def feature_types(self, feature_types: Optional[Union[List[str], str]]) -> None: else: feature_types = [feature_types] except TypeError: - feature_types = [feature_types] - c_feature_types = [bytes(f, encoding='utf-8') - for f in feature_types] - c_feature_types = (ctypes.c_char_p * - len(c_feature_types))(*c_feature_types) + feature_types = [feature_types] # type: ignore + assert feature_types is not None + c_feature_types = [bytes(str(f), encoding='utf-8') for f in feature_types] + c_feature_t_ptrs = (ctypes.c_char_p * len(c_feature_types))(*c_feature_types) _check_call(_LIB.XGDMatrixSetStrFeatureInfo( self.handle, c_str('feature_type'), - c_feature_types, + c_feature_t_ptrs, c_bst_ulong(len(feature_types)))) if len(feature_types) != self.num_col() and self.num_col() != 0: @@ -1102,11 +1100,11 @@ class _ProxyDMatrix(DMatrix): """ - def __init__(self): # pylint: disable=super-init-not-called + def __init__(self) -> None: # pylint: disable=super-init-not-called self.handle = ctypes.c_void_p() _check_call(_LIB.XGProxyDMatrixCreate(ctypes.byref(self.handle))) - def _set_data_from_cuda_interface(self, data) -> None: + def _set_data_from_cuda_interface(self, data: CuArrayLike) -> None: """Set data from CUDA array interface.""" interface = data.__cuda_array_interface__ interface_str = bytes(json.dumps(interface, indent=2), "utf-8") @@ -1114,14 +1112,14 @@ def _set_data_from_cuda_interface(self, data) -> None: _LIB.XGProxyDMatrixSetDataCudaArrayInterface(self.handle, interface_str) ) - def _set_data_from_cuda_columnar(self, data, cat_codes: list) -> None: + def _set_data_from_cuda_columnar(self, data: CuDFLike) -> None: """Set data from CUDA columnar format.""" from .data import _cudf_array_interfaces interfaces_str = _cudf_array_interfaces(data, cat_codes) _check_call(_LIB.XGProxyDMatrixSetDataCudaColumnar(self.handle, interfaces_str)) - def _set_data_from_array(self, data: np.ndarray): + def _set_data_from_array(self, data: NPArrayLike) -> None: """Set data from numpy array.""" from .data import _array_interface @@ -1129,7 +1127,7 @@ def _set_data_from_array(self, data: np.ndarray): _LIB.XGProxyDMatrixSetDataDense(self.handle, _array_interface(data)) ) - def _set_data_from_csr(self, csr): + def _set_data_from_csr(self, csr: CSRLike) -> None: """Set data from scipy csr""" from .data import _array_interface @@ -1159,24 +1157,24 @@ class DeviceQuantileDMatrix(DMatrix): @_deprecate_positional_args def __init__( # pylint: disable=super-init-not-called self, - data, - label=None, + data: Union[CuArrayLike, CuDFLike], + label: Optional[array_like] = None, *, - weight=None, - base_margin=None, - missing=None, - silent=False, + weight: Optional[array_like] = None, + base_margin: Optional[array_like] = None, + missing: Optional[float] = None, + silent: bool = False, feature_names: FeatNamesT = None, - feature_types=None, + feature_types: FeatureTypes = None, nthread: Optional[int] = None, max_bin: int = 256, - group=None, - qid=None, - label_lower_bound=None, - label_upper_bound=None, - feature_weights=None, + group: Optional[array_like] = None, + qid: Optional[array_like] = None, + label_lower_bound: Optional[array_like] = None, + label_upper_bound: Optional[array_like] = None, + feature_weights: Optional[array_like] = None, enable_categorical: bool = False, - ): + ) -> None: self.max_bin = max_bin self.missing = missing if missing is not None else np.nan self.nthread = nthread if nthread is not None else 1 @@ -1207,7 +1205,12 @@ def __init__( # pylint: disable=super-init-not-called enable_categorical=enable_categorical, ) - def _init(self, data, enable_categorical: bool, **meta) -> None: + def _init( + self, + data: Union[CuArrayLike, CuDFLike], + enable_categorical: bool, + **meta: Union[Optional[array_like], Optional[Sequence]] + ) -> None: from .data import ( _is_dlpack, _transform_dlpack, @@ -1220,7 +1223,8 @@ def _init(self, data, enable_categorical: bool, **meta) -> None: # it can't be transformed twice. data = _transform_dlpack(data) if _is_iter(data): - it = data + assert isinstance(data, DataIter) + it: DataIter = data else: it = SingleBatchInternalIter(data=data, **meta) @@ -1311,7 +1315,7 @@ def __init__( if not isinstance(d, DMatrix): raise TypeError(f'invalid cache item: {type(d).__name__}', cache) - dmats = c_array(ctypes.c_void_p, [d.handle for d in cache]) + dmats = c_array(ctypes.c_void_p, [cast(ctypes.c_void_p, d.handle) for d in cache]) self.handle: Optional[ctypes.c_void_p] = ctypes.c_void_p() _check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(len(cache)), ctypes.byref(self.handle))) @@ -1372,6 +1376,9 @@ def _transform_monotone_constrains(self, value: Union[Dict[str, int], str]) -> s raise ValueError('Constrained features are not a subset of ' 'training data feature names') + if self.feature_names is None: + raise ValueError("feature names is not set for Booster.") + return '(' + ','.join([str(value.get(feature_name, 0)) for feature_name in self.feature_names]) + ')' @@ -1506,7 +1513,7 @@ def save_config(self) -> str: .. versionadded:: 1.0.0 ''' - json_string = ctypes.c_char_p() + c_json_string = ctypes.c_char_p() length = c_bst_ulong() _check_call(_LIB.XGBoosterSaveJsonConfig( self.handle, @@ -1529,7 +1536,7 @@ def load_config(self, config: str) -> None: def __copy__(self) -> "Booster": return self.__deepcopy__(None) - def __deepcopy__(self, _) -> "Booster": + def __deepcopy__(self, _: Optional[Dict]) -> "Booster": '''Return a copy of booster.''' return Booster(model_file=self) @@ -1578,7 +1585,7 @@ def attributes(self) -> Dict[str, str]: ctypes.byref(length), ctypes.byref(sarr))) attr_names = from_cstr_to_pystr(sarr, length) - return {n: self.attr(n) for n in attr_names} + return {n: cast(str, self.attr(n)) for n in attr_names} def set_attr(self, **kwargs: Optional[str]) -> None: """Set the attribute of the Booster. @@ -1592,9 +1599,9 @@ def set_attr(self, **kwargs: Optional[str]) -> None: if value is not None: if not isinstance(value, STRING_TYPES): raise ValueError("Set Attr only accepts string values") - value = c_str(str(value)) + value_c = c_str(str(value)) _check_call(_LIB.XGBoosterSetAttr( - self.handle, c_str(key), value)) + self.handle, c_str(key), value_c)) def _get_feature_info(self, field: str) -> Optional[List[str]]: length = c_bst_ulong() @@ -1613,10 +1620,10 @@ def _set_feature_info(self, features: Optional[List[str]], field: str) -> None: if features is not None: assert isinstance(features, list) c_feature_info = [bytes(f, encoding="utf-8") for f in features] - c_feature_info = (ctypes.c_char_p * len(c_feature_info))(*c_feature_info) + c_feature_i_ptrs = (ctypes.c_char_p * len(c_feature_info))(*c_feature_info) _check_call( _LIB.XGBoosterSetStrFeatureInfo( - self.handle, c_str(field), c_feature_info, c_bst_ulong(len(features)) + self.handle, c_str(field), c_feature_i_ptrs, c_bst_ulong(len(features)) ) ) else: @@ -1661,10 +1668,11 @@ def set_param(self, params, value=None): value of the specified parameter, when params is str key """ if isinstance(params, Mapping): - params = params.items() - elif isinstance(params, STRING_TYPES) and value is not None: - params = [(params, value)] - for key, val in params: + params_items = params.items() + elif isinstance(params, str): + assert value is not None + params_items = {params: value}.items() + for key, val in params_items: if val is not None: _check_call(_LIB.XGBoosterSetParam(self.handle, c_str(key), c_str(str(val)))) @@ -1757,7 +1765,9 @@ def eval_set( raise TypeError(f"expected string, got {type(d[1]).__name__}") self._validate_features(d[0]) - dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals]) + dmats = c_array( + ctypes.c_void_p, [cast(ctypes.c_void_p, d[0].handle) for d in evals] + ) evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals]) msg = ctypes.c_char_p() _check_call( @@ -1911,6 +1921,7 @@ def predict( if validate_features: self._validate_features(data) iteration_range = _convert_ntree_limit(self, ntree_limit, iteration_range) + assert iteration_range is not None args = { "type": 0, "training": training, @@ -1945,11 +1956,11 @@ def assign_type(t: int) -> None: ctypes.byref(preds) ) ) - return _prediction_output(shape, dims, preds, False) + return cast(np.ndarray, _prediction_output(shape, dims, preds, False)) def inplace_predict( self, - data: Any, + data: array_like, iteration_range: Tuple[int, int] = (0, 0), predict_type: str = "value", missing: float = np.nan, @@ -2096,8 +2107,8 @@ def inplace_predict( ): from .data import _transform_cupy_array - data = _transform_cupy_array(data) - interface_str = _cuda_array_interface(data) + data = _transform_cupy_array(cast(CuArrayLike, data)) + interface_str = _cuda_array_interface(cast(CuArrayLike, data)) _check_call( _LIB.XGBoosterPredictFromCudaArray( self.handle, @@ -2222,12 +2233,15 @@ def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None: else: raise TypeError('Unknown file type: ', fname) - if self.attr("best_iteration") is not None: - self.best_iteration = int(self.attr("best_iteration")) - if self.attr("best_score") is not None: - self.best_score = float(self.attr("best_score")) - if self.attr("best_ntree_limit") is not None: - self.best_ntree_limit = int(self.attr("best_ntree_limit")) + attr = self.attr("best_iteration") + if attr is not None: + self.best_iteration = int(attr) + attr = self.attr("best_score") + if attr is not None: + self.best_score = float(attr) + attr = self.attr("best_ntree_limit") + if attr is not None: + self.best_ntree_limit = int(attr) def num_boosted_rounds(self) -> int: '''Get number of boosted rounds. For gblinear this is reset to 0 after @@ -2391,7 +2405,7 @@ def get_score( ) ) features_arr = from_cstr_to_pystr(features, n_out_features) - scores_arr = _prediction_output(shape, out_dim, scores, False) + scores_arr = cast(np.ndarray, _prediction_output(shape, out_dim, scores, False)) results: Dict[str, Union[float, List[float]]] = {} if len(scores_arr.shape) > 1 and scores_arr.shape[1] > 1: @@ -2430,10 +2444,10 @@ def trees_to_dataframe(self, fmap: Union[str, os.PathLike] = '') -> DataFrame: node_ids = [] fids = [] splits = [] - categories: List[Optional[float]] = [] - y_directs = [] - n_directs = [] - missings = [] + categories: List[Optional[Union[List, float]]] = [] + y_directs: List[Union[float, str]] = [] + n_directs: List[Union[float, str]] = [] + missings: List[Union[float, str]] = [] gains = [] covers = [] @@ -2474,9 +2488,9 @@ def trees_to_dataframe(self, fmap: Union[str, os.PathLike] = '') -> DataFrame: # categorical parse = fid[0].split(":") cats = parse[1][1:-1] # strip the {} - cats = cats.split(",") + cats_l = cats.split(",") splits.append(float("NAN")) - categories.append(cats if cats else None) + categories.append(cats_l if cats_l else None) else: raise ValueError("Failed to parse model text dump.") stats = re.split('=|,', fid[1]) @@ -2513,8 +2527,13 @@ def _validate_features(self, data: DMatrix) -> None: return if self.feature_names is None: - self.feature_names = data.feature_names - self.feature_types = data.feature_types + self._set_feature_info(data.feature_names, "feature_name") + self._set_feature_info(data.feature_types, "feature_type") + + if self.feature_names is None: + # Neither has feature names + return + if data.feature_names is None and self.feature_names is not None: raise ValueError( "training data did not have the following fields: " + @@ -2538,8 +2557,7 @@ def _validate_features(self, data: DMatrix) -> None: raise ValueError(msg.format(self.feature_names, data.feature_names)) def get_split_value_histogram( - self, - feature: str, + self, feature: str, fmap: Union[os.PathLike, str] = '', bins: Optional[int] = None, as_pandas: bool = True @@ -2550,7 +2568,7 @@ def get_split_value_histogram( ---------- feature: str The name of the feature. - fmap: str or os.PathLike (optional) + fmap : The name of feature map file. bin: int, default None The maximum number of bins. diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 6fe2c56eef7d..428361cb4034 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -5,7 +5,7 @@ import json import warnings import os -from typing import Any, Tuple, Callable, Optional, List, Union +from typing import Any, Tuple, Callable, Union, Optional, List, cast import numpy as np @@ -14,8 +14,13 @@ from .core import DataIter, _ProxyDMatrix, DMatrix, FeatNamesT from .compat import lazy_isinstance, DataFrame +from .typing import array_like, FloatT, CSRLike, NPArrayLike, DFLike, NativeInput +from .typing import CuArrayLike, CuDFLike, FeatureTypes, DTypeLike + c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name +_CtorReturnT = Tuple[ctypes.c_void_p, Optional[List[str]], FeatureTypes] + CAT_T = "c" # meta info that can be a matrix instead of vector. @@ -24,14 +29,14 @@ _matrix_meta = {"base_margin", "label"} -def _warn_unused_missing(data, missing): +def _warn_unused_missing(data: NativeInput, missing: FloatT) -> None: if (missing is not None) and (not np.isnan(missing)): warnings.warn( '`missing` is not used for current input data type:' + str(type(data)), UserWarning) -def _check_complex(data): +def _check_complex(data: array_like) -> None: '''Test whether data is complex using `dtype` attribute.''' complex_dtypes = (np.complex128, np.complex64, np.cfloat, np.cdouble, np.clongdouble) @@ -39,12 +44,12 @@ def _check_complex(data): raise ValueError('Complex data not supported') -def _check_data_shape(data: Any) -> None: - if hasattr(data, "shape") and len(data.shape) != 2: +def _check_data_shape(data: NativeInput) -> None: + if hasattr(data, "shape") and len(cast(array_like, data).shape) != 2: raise ValueError("Please reshape the input data into 2-dimensional matrix.") -def _is_scipy_csr(data): +def _is_scipy_csr(data: NativeInput) -> bool: try: import scipy except ImportError: @@ -53,24 +58,24 @@ def _is_scipy_csr(data): return isinstance(data, scipy.sparse.csr_matrix) -def _array_interface(data: np.ndarray) -> bytes: +def _array_interface(data: NPArrayLike) -> bytes: assert ( data.dtype.hasobject is False ), "Input data contains `object` dtype. Expecting numeric data." interface = data.__array_interface__ if "mask" in interface: - interface["mask"] = interface["mask"].__array_interface__ + interface["mask"] = cast(NPArrayLike, interface["mask"]).__array_interface__ interface_str = bytes(json.dumps(interface), "utf-8") return interface_str def _from_scipy_csr( - data, - missing, - nthread, + data: CSRLike, + missing: FloatT, + nthread: int, feature_names: FeatNamesT, - feature_types: Optional[List[str]], -): + feature_types: FeatureTypes, +) -> _CtorReturnT: """Initialize data from a CSR matrix.""" if len(data.indices) != len(data.data): raise ValueError( @@ -95,7 +100,7 @@ def _from_scipy_csr( return handle, feature_names, feature_types -def _is_scipy_csc(data): +def _is_scipy_csc(data: NativeInput) -> bool: try: import scipy except ImportError: @@ -105,11 +110,11 @@ def _is_scipy_csc(data): def _from_scipy_csc( - data, - missing, + data: CSRLike, + missing: FloatT, feature_names: FeatNamesT, - feature_types: Optional[List[str]], -): + feature_types: FeatureTypes, +) -> _CtorReturnT: if len(data.indices) != len(data.data): raise ValueError(f"length mismatch: {len(data.indices)} vs {len(data.data)}") _warn_unused_missing(data, missing) @@ -125,7 +130,7 @@ def _from_scipy_csc( return handle, feature_names, feature_types -def _is_scipy_coo(data): +def _is_scipy_coo(data: NativeInput) -> bool: try: import scipy except ImportError: @@ -134,22 +139,26 @@ def _is_scipy_coo(data): return isinstance(data, scipy.sparse.coo_matrix) -def _is_numpy_array(data): +def _is_numpy_array(data: NativeInput) -> bool: return isinstance(data, (np.ndarray, np.matrix)) -def _ensure_np_dtype(data, dtype) -> Tuple[np.ndarray, np.dtype]: +def _ensure_np_dtype( + data: np.ndarray, dtype: DTypeLike +) -> Tuple[np.ndarray, DTypeLike]: if data.dtype.hasobject or data.dtype in [np.float16, np.bool_]: data = data.astype(np.float32, copy=False) dtype = np.float32 return data, dtype -def _maybe_np_slice(data: np.ndarray, dtype) -> np.ndarray: +def _maybe_np_slice( + data: Union[List[int], np.ndarray], dtype: DTypeLike +) -> np.ndarray: '''Handle numpy slice. This can be removed if we use __array_interface__. ''' try: - if not data.flags.c_contiguous: + if not cast(np.ndarray, data).flags.c_contiguous: data = np.array(data, copy=True, dtype=dtype) else: data = np.array(data, copy=False, dtype=dtype) @@ -160,12 +169,12 @@ def _maybe_np_slice(data: np.ndarray, dtype) -> np.ndarray: def _from_numpy_array( - data, - missing, - nthread, + data: np.ndarray, + missing: FloatT, + nthread: int, feature_names: FeatNamesT, - feature_types: Optional[List[str]], -): + feature_types: FeatureTypes, +) -> _CtorReturnT: """Initialize data from a 2-D numpy matrix. """ @@ -190,7 +199,7 @@ def _from_numpy_array( return handle, feature_names, feature_types -def _is_pandas_df(data): +def _is_pandas_df(data: NativeInput) -> bool: try: import pandas as pd except ImportError: @@ -198,7 +207,7 @@ def _is_pandas_df(data): return isinstance(data, pd.DataFrame) -def _is_modin_df(data): +def _is_modin_df(data: NativeInput) -> bool: try: import modin.pandas as pd except ImportError: @@ -242,15 +251,16 @@ def _invalid_dataframe_dtype(data: Any) -> None: def _transform_pandas_df( - data: DataFrame, + df: DFLike, enable_categorical: bool, - feature_names: FeatNamesT = None, - feature_types: Optional[List[str]] = None, - meta: Optional[str] = None, - meta_type: Optional[str] = None, -) -> Tuple[np.ndarray, FeatNamesT, Optional[List[str]]]: + feature_names: Optional[List[str]] = None, + feature_types: FeatureTypes = None, + meta: str = None, + meta_type: DTypeLike = None, +) -> Tuple[np.ndarray, Optional[List[str]], FeatureTypes]: import pandas as pd from pandas.api.types import is_sparse, is_categorical_dtype + data = cast(pd.DataFrame, df) if not all( dtype.name in _pandas_dtype_mapper @@ -271,14 +281,15 @@ def _transform_pandas_df( # handle feature types if feature_types is None and meta is None: - feature_types = [] + feature_types_lst = [] for i, dtype in enumerate(data.dtypes): if is_sparse(dtype): - feature_types.append(_pandas_dtype_mapper[dtype.subtype.name]) + feature_types_lst.append(_pandas_dtype_mapper[dtype.subtype.name]) elif is_categorical_dtype(dtype) and enable_categorical: - feature_types.append(CAT_T) + feature_types_lst.append(CAT_T) else: - feature_types.append(_pandas_dtype_mapper[dtype.name]) + feature_types_lst.append(_pandas_dtype_mapper[dtype.name]) + feature_types = feature_types_lst # handle category codes. transformed = pd.DataFrame() @@ -308,20 +319,22 @@ def _transform_pandas_df( def _from_pandas_df( - data: DataFrame, + data: DFLike, enable_categorical: bool, - missing: float, + missing: FloatT, nthread: int, feature_names: FeatNamesT, - feature_types: Optional[List[str]], -) -> Tuple[ctypes.c_void_p, FeatNamesT, Optional[List[str]]]: - data, feature_names, feature_types = _transform_pandas_df( + feature_types: FeatureTypes, +) -> _CtorReturnT: + arr, feature_names, feature_types = _transform_pandas_df( data, enable_categorical, feature_names, feature_types ) - return _from_numpy_array(data, missing, nthread, feature_names, feature_types) + return _from_numpy_array( + arr, missing, nthread, feature_names, feature_types + ) -def _is_pandas_series(data): +def _is_pandas_series(data: NativeInput) -> bool: try: import pandas as pd except ImportError: @@ -341,7 +354,7 @@ def _meta_from_pandas_series( _meta_from_numpy(data, name, dtype, handle) -def _is_modin_series(data): +def _is_modin_series(data: NativeInput) -> bool: try: import modin.pandas as pd except ImportError: @@ -355,8 +368,8 @@ def _from_pandas_series( nthread: int, enable_categorical: bool, feature_names: FeatNamesT, - feature_types: Optional[List[str]], -): + feature_types: FeatureTypes, +) -> _CtorReturnT: from pandas.api.types import is_categorical_dtype if (data.dtype.name not in _pandas_dtype_mapper) and not ( @@ -365,6 +378,7 @@ def _from_pandas_series( _invalid_dataframe_dtype(data) if enable_categorical and is_categorical_dtype(data.dtype): data = data.cat.codes + return _from_numpy_array( data.values.reshape(data.shape[0], 1).astype("float"), missing, @@ -374,7 +388,7 @@ def _from_pandas_series( ) -def _is_dt_df(data): +def _is_dt_df(data: NativeInput) -> bool: return lazy_isinstance(data, 'datatable', 'Frame') or \ lazy_isinstance(data, 'datatable', 'DataTable') @@ -384,12 +398,12 @@ def _is_dt_df(data): def _transform_dt_df( - data, - feature_names: FeatNamesT, - feature_types: Optional[List[str]], - meta=None, - meta_type=None, -): + data: Any, + feature_names: Optional[List[str]], + feature_types: FeatureTypes, + meta: str = None, + meta_type: DTypeLike = None, +) -> Tuple[np.ndarray, Optional[List[str]], FeatureTypes]: """Validate feature names and types if data table""" if meta and data.shape[1] > 1: raise ValueError('DataTable for meta info cannot have multiple columns') @@ -423,13 +437,13 @@ def _transform_dt_df( def _from_dt_df( - data, - missing, - nthread, - feature_names: FeatNamesT, - feature_types: Optional[List[str]], + data: Any, + missing: FloatT, + nthread: int, + feature_names: Optional[List[str]], + feature_types: FeatureTypes, enable_categorical: bool, -) -> Tuple[ctypes.c_void_p, FeatNamesT, Optional[List[str]]]: +) -> _CtorReturnT: if enable_categorical: raise ValueError("categorical data in datatable is not supported yet.") data, feature_names, feature_types = _transform_dt_df( @@ -466,7 +480,7 @@ def _from_dt_df( return handle, feature_names, feature_types -def _is_cudf_df(data): +def _is_cudf_df(data: NativeInput) -> bool: try: import cudf except ImportError: @@ -474,7 +488,7 @@ def _is_cudf_df(data): return hasattr(cudf, 'DataFrame') and isinstance(data, cudf.DataFrame) -def _cudf_array_interfaces(data, cat_codes: list) -> bytes: +def _cudf_array_interfaces(data: CuDFLike, cat_codes: list) -> bytes: """Extract CuDF __cuda_array_interface__. This is special as it returns a new list of data and a list of array interfaces. The data is list of categorical codes that caller can safely ignore, but have to keep their reference alive until usage of array @@ -485,6 +499,7 @@ def _cudf_array_interfaces(data, cat_codes: list) -> bytes: from cudf.api.types import is_categorical_dtype except ImportError: from cudf.utils.dtypes import is_categorical_dtype + import cudf interfaces = [] if _is_cudf_ser(data): @@ -510,11 +525,12 @@ def _cudf_array_interfaces(data, cat_codes: list) -> bytes: def _transform_cudf_df( - data, + data: CuDFLike, feature_names: FeatNamesT, - feature_types: Optional[List[str]], + feature_types: FeatureTypes, enable_categorical: bool, -): +) -> Tuple[CuDFLike, Optional[List[str]], FeatureTypes]: + import cudf try: from cudf.api.types import is_categorical_dtype except ImportError: @@ -553,7 +569,7 @@ def _transform_cudf_df( feature_types = [] for dtype in dtypes: if is_categorical_dtype(dtype) and enable_categorical: - feature_types.append(CAT_T) + feature_types_lst.append(CAT_T) else: feature_types.append(_pandas_dtype_mapper[dtype.name]) @@ -574,13 +590,13 @@ def _transform_cudf_df( def _from_cudf_df( - data, - missing, - nthread, + data: CuDFLike, + missing: FloatT, + nthread: int, feature_names: FeatNamesT, feature_types: Optional[List[str]], enable_categorical: bool, -) -> Tuple[ctypes.c_void_p, Any, Any]: +) -> _CtorReturnT: data, cat_codes, feature_names, feature_types = _transform_cudf_df( data, feature_names, feature_types, enable_categorical ) @@ -597,7 +613,7 @@ def _from_cudf_df( return handle, feature_names, feature_types -def _is_cudf_ser(data): +def _is_cudf_ser(data: NativeInput) -> bool: try: import cudf except ImportError: @@ -605,7 +621,7 @@ def _is_cudf_ser(data): return isinstance(data, cudf.Series) -def _is_cupy_array(data): +def _is_cupy_array(data: NativeInput) -> bool: try: import cupy except ImportError: @@ -613,8 +629,9 @@ def _is_cupy_array(data): return isinstance(data, cupy.ndarray) -def _transform_cupy_array(data): +def _transform_cupy_array(arr: CuArrayLike) -> CuArrayLike: import cupy # pylint: disable=import-error + data = cast(cupy.ndarray, arr) if not hasattr(data, '__cuda_array_interface__') and hasattr( data, '__array__'): data = cupy.array(data, copy=False) @@ -624,12 +641,12 @@ def _transform_cupy_array(data): def _from_cupy_array( - data, - missing, - nthread, - feature_names: FeatNamesT, - feature_types: Optional[List[str]], -): + data: CuArrayLike, + missing: FloatT, + nthread: int, + feature_names: Optional[List[str]], + feature_types: FeatureTypes, +) -> _CtorReturnT: """Initialize DMatrix from cupy ndarray.""" data = _transform_cupy_array(data) interface_str = _cuda_array_interface(data) @@ -643,7 +660,7 @@ def _from_cupy_array( return handle, feature_names, feature_types -def _is_cupy_csr(data): +def _is_cupy_csr(data: NativeInput) -> bool: try: import cupyx except ImportError: @@ -651,7 +668,7 @@ def _is_cupy_csr(data): return isinstance(data, cupyx.scipy.sparse.csr_matrix) -def _is_cupy_csc(data): +def _is_cupy_csc(data: NativeInput) -> bool: try: import cupyx except ImportError: @@ -659,11 +676,11 @@ def _is_cupy_csc(data): return isinstance(data, cupyx.scipy.sparse.csc_matrix) -def _is_dlpack(data): +def _is_dlpack(data: NativeInput) -> bool: return 'PyCapsule' in str(type(data)) and "dltensor" in str(data) -def _transform_dlpack(data): +def _transform_dlpack(data: Any) -> CuArrayLike: from cupy import fromDlpack # pylint: disable=E0401 assert 'used_dltensor' not in str(data) data = fromDlpack(data) @@ -671,27 +688,27 @@ def _transform_dlpack(data): def _from_dlpack( - data, - missing, - nthread, - feature_names: FeatNamesT, - feature_types: Optional[List[str]], -): + data: Any, + missing: FloatT, + nthread: int, + feature_names: Optional[List[str]], + feature_types: FeatureTypes, +) -> _CtorReturnT: data = _transform_dlpack(data) return _from_cupy_array(data, missing, nthread, feature_names, feature_types) -def _is_uri(data): +def _is_uri(data: NativeInput) -> bool: return isinstance(data, (str, os.PathLike)) def _from_uri( - data, - missing, + data: Union[os.PathLike, str], + missing: FloatT, feature_names: FeatNamesT, - feature_types: Optional[List[str]], -): + feature_types: FeatureTypes, +) -> _CtorReturnT: _warn_unused_missing(data, missing) handle = ctypes.c_void_p() data = os.fspath(os.path.expanduser(data)) @@ -701,45 +718,45 @@ def _from_uri( return handle, feature_names, feature_types -def _is_list(data): +def _is_list(data: NativeInput) -> bool: return isinstance(data, list) def _from_list( - data, - missing, - n_threads, + data: list, + missing: FloatT, + n_threads: int, feature_names: FeatNamesT, - feature_types: Optional[List[str]], -): + feature_types: FeatureTypes, +) -> _CtorReturnT: array = np.array(data) - _check_data_shape(data) + _check_data_shape(array) return _from_numpy_array(array, missing, n_threads, feature_names, feature_types) -def _is_tuple(data): +def _is_tuple(data: Any) -> bool: return isinstance(data, tuple) def _from_tuple( - data, - missing, - n_threads, + data: tuple, + missing: FloatT, + n_threads: int, feature_names: FeatNamesT, - feature_types: Optional[List[str]], -): - return _from_list(data, missing, n_threads, feature_names, feature_types) + feature_types: FeatureTypes, +) -> _CtorReturnT: + return _from_list(cast(list, data), missing, n_threads, feature_names, feature_types) -def _is_iter(data): +def _is_iter(data: Any) -> bool: return isinstance(data, DataIter) -def _has_array_protocol(data): +def _has_array_protocol(data: Any) -> bool: return hasattr(data, '__array__') -def _convert_unknown_data(data): +def _convert_unknown_data(data: Any) -> Optional[Any]: warnings.warn( f'Unknown data type: {type(data)}, trying to convert it to csr_matrix', UserWarning @@ -758,47 +775,59 @@ def _convert_unknown_data(data): def dispatch_data_backend( - data, - missing, - threads, + data: NativeInput, + missing: FloatT, + threads: int, feature_names: FeatNamesT, - feature_types: Optional[List[str]], + feature_types: FeatureTypes, enable_categorical: bool = False, -): +) -> _CtorReturnT: '''Dispatch data for DMatrix.''' if not _is_cudf_ser(data) and not _is_pandas_series(data): _check_data_shape(data) if _is_scipy_csr(data): - return _from_scipy_csr(data, missing, threads, feature_names, feature_types) + return _from_scipy_csr( + cast(CSRLike, data), missing, threads, feature_names, feature_types + ) if _is_scipy_csc(data): - return _from_scipy_csc(data, missing, feature_names, feature_types) + return _from_scipy_csc(cast(CSRLike, data), missing, feature_names, feature_types) if _is_scipy_coo(data): + from scipy.sparse import coo_matrix return _from_scipy_csr( - data.tocsr(), missing, threads, feature_names, feature_types + cast(coo_matrix, data).tocsr(), missing, threads, feature_names, feature_types ) if _is_numpy_array(data): - return _from_numpy_array(data, missing, threads, feature_names, - feature_types) + return _from_numpy_array( + cast(np.ndarray, data), missing, threads, feature_names, feature_types + ) if _is_uri(data): - return _from_uri(data, missing, feature_names, feature_types) + return _from_uri( + cast(Union[os.PathLike, str], data), missing, feature_names, feature_types + ) if _is_list(data): - return _from_list(data, missing, threads, feature_names, feature_types) + return _from_list( + cast(list, data), missing, threads, feature_names, feature_types + ) if _is_tuple(data): - return _from_tuple(data, missing, threads, feature_names, feature_types) + return _from_tuple( + cast(tuple, data), missing, threads, feature_names, feature_types + ) if _is_pandas_df(data): - return _from_pandas_df(data, enable_categorical, missing, threads, + return _from_pandas_df(cast(DFLike, data), enable_categorical, missing, threads, feature_names, feature_types) if _is_pandas_series(data): return _from_pandas_series( data, missing, threads, enable_categorical, feature_names, feature_types ) if _is_cudf_df(data) or _is_cudf_ser(data): + df = cast(CuDFLike, data) return _from_cudf_df( - data, missing, threads, feature_names, feature_types, enable_categorical + df, missing, threads, feature_names, feature_types, enable_categorical ) if _is_cupy_array(data): - return _from_cupy_array(data, missing, threads, feature_names, - feature_types) + return _from_cupy_array( + cast(CuArrayLike, data), missing, threads, feature_names, feature_types + ) if _is_cupy_csr(data): raise TypeError('cupyx CSR is not supported yet.') if _is_cupy_csc(data): @@ -812,8 +841,10 @@ def dispatch_data_backend( data, missing, threads, feature_names, feature_types, enable_categorical ) if _is_modin_df(data): - return _from_pandas_df(data, enable_categorical, missing, threads, - feature_names, feature_types) + return _from_pandas_df( + cast(DFLike, data), enable_categorical, missing, threads, + feature_names, feature_types + ) if _is_modin_series(data): return _from_pandas_series( data, missing, threads, enable_categorical, feature_names, feature_types @@ -829,7 +860,7 @@ def dispatch_data_backend( raise TypeError('Not supported type for data.' + str(type(data))) -def _to_data_type(dtype: str, name: str): +def _to_data_type(dtype: str, name: str) -> int: dtype_map = {'float32': 1, 'float64': 2, 'uint32': 3, 'uint64': 4} if dtype not in dtype_map: raise TypeError( @@ -838,7 +869,7 @@ def _to_data_type(dtype: str, name: str): return dtype_map[dtype] -def _validate_meta_shape(data: Any, name: str) -> None: +def _validate_meta_shape(data: array_like, name: str) -> None: if hasattr(data, "shape"): msg = f"Invalid shape: {data.shape} for {name}" if name in _matrix_meta: @@ -866,13 +897,17 @@ def _meta_from_numpy( _check_call(_LIB.XGDMatrixSetInfoFromInterface(handle, c_str(field), interface_str)) -def _meta_from_list(data, field, dtype, handle): - data = np.array(data) - _meta_from_numpy(data, field, dtype, handle) +def _meta_from_list( + data: list, field: str, dtype: DTypeLike, handle: ctypes.c_void_p +) -> None: + arr = np.asarray(data) + _meta_from_numpy(arr, field, dtype, handle) -def _meta_from_tuple(data, field, dtype, handle): - return _meta_from_list(data, field, dtype, handle) +def _meta_from_tuple( + data: tuple, field: str, dtype: DTypeLike, handle: ctypes.c_void_p +) -> None: + return _meta_from_list(cast(list, data), field, dtype, handle) def _meta_from_cudf_df(data, field: str, handle: ctypes.c_void_p) -> None: @@ -884,7 +919,7 @@ def _meta_from_cudf_df(data, field: str, handle: ctypes.c_void_p) -> None: _check_call(_LIB.XGDMatrixSetInfoFromInterface(handle, c_str(field), interface)) -def _meta_from_cudf_series(data, field, handle): +def _meta_from_cudf_series(data: CuDFLike, field: str, handle: ctypes.c_void_p) -> None: interface = bytes(json.dumps([data.__cuda_array_interface__], indent=2), 'utf-8') _check_call(_LIB.XGDMatrixSetInfoFromInterface(handle, @@ -892,7 +927,7 @@ def _meta_from_cudf_series(data, field, handle): interface)) -def _meta_from_cupy_array(data, field, handle): +def _meta_from_cupy_array(data: CuArrayLike, field: str, handle: ctypes.c_void_p) -> None: data = _transform_cupy_array(data) interface = bytes(json.dumps([data.__cuda_array_interface__], indent=2), 'utf-8') @@ -916,18 +951,19 @@ def dispatch_meta_backend( if data is None: return if _is_list(data): - _meta_from_list(data, name, dtype, handle) + _meta_from_list(cast(list, data), name, dtype, handle) return if _is_tuple(data): - _meta_from_tuple(data, name, dtype, handle) + _meta_from_tuple(cast(tuple, data), name, dtype, handle) return if _is_numpy_array(data): - _meta_from_numpy(data, name, dtype, handle) + _meta_from_numpy(cast(np.ndarray, data), name, dtype, handle) return if _is_pandas_df(data): - data, _, _ = _transform_pandas_df(data, False, meta=name, - meta_type=dtype) - _meta_from_numpy(data, name, dtype, handle) + arr, _, _ = _transform_pandas_df( + cast(DFLike, data), False, meta=name, meta_type=dtype + ) + _meta_from_numpy(arr, name, dtype, handle) return if _is_pandas_series(data): _meta_from_pandas_series(data, name, dtype, handle) @@ -937,13 +973,13 @@ def dispatch_meta_backend( _meta_from_cupy_array(data, name, handle) return if _is_cupy_array(data): - _meta_from_cupy_array(data, name, handle) + _meta_from_cupy_array(cast(CuArrayLike, data), name, handle) return if _is_cudf_ser(data): - _meta_from_cudf_series(data, name, handle) + _meta_from_cudf_series(cast(CuDFLike, data), name, handle) return if _is_cudf_df(data): - _meta_from_cudf_df(data, name, handle) + _meta_from_cudf_df(cast(CuDFLike, data), name, handle) return if _is_dt_df(data): _meta_from_dt(data, name, dtype, handle) @@ -953,9 +989,9 @@ def dispatch_meta_backend( _meta_from_numpy(data, name, dtype, handle) return if _is_modin_series(data): - data = data.values.astype('float') - assert len(data.shape) == 1 or data.shape[1] == 0 or data.shape[1] == 1 - _meta_from_numpy(data, name, dtype, handle) + arr = cast(np.ndarray, cast(DFLike, data).values).astype('float') + assert len(arr.shape) == 1 or arr.shape[1] == 0 or arr.shape[1] == 1 + _meta_from_numpy(arr, name, dtype, handle) return if _has_array_protocol(data): array = np.asarray(data) @@ -992,10 +1028,10 @@ def _proxy_transform( feature_names: FeatNamesT, feature_types: Optional[List[str]], enable_categorical: bool, -): +) -> Tuple[array_like, Optional[List[str]], FeatureTypes]: if _is_cudf_df(data) or _is_cudf_ser(data): return _transform_cudf_df( - data, feature_names, feature_types, enable_categorical + cast(CuDFLike, data), feature_names, feature_types, enable_categorical ) if _is_cupy_array(data): data = _transform_cupy_array(data) @@ -1008,7 +1044,7 @@ def _proxy_transform( return data, None, feature_names, feature_types if _is_pandas_df(data): arr, feature_names, feature_types = _transform_pandas_df( - data, enable_categorical, feature_names, feature_types + cast(DFLike, data), enable_categorical, feature_names, feature_types ) return arr, None, feature_names, feature_types raise TypeError("Value type is not supported for data iterator:" + str(type(data))) @@ -1033,7 +1069,8 @@ def dispatch_proxy_set_data( proxy._set_data_from_cuda_columnar(data, cat_codes) return if _is_cupy_array(data): - proxy._set_data_from_cuda_interface(data) # pylint: disable=W0212 + # pylint: disable=W0212 + proxy._set_data_from_cuda_interface(cast(CuArrayLike, data)) return if _is_dlpack(data): data = _transform_dlpack(data) @@ -1046,9 +1083,9 @@ def dispatch_proxy_set_data( raise err if _is_numpy_array(data): - proxy._set_data_from_array(data) # pylint: disable=W0212 + proxy._set_data_from_array(cast(NPArrayLike, data)) # pylint: disable=W0212 return if _is_scipy_csr(data): - proxy._set_data_from_csr(data) # pylint: disable=W0212 + proxy._set_data_from_csr(cast(CSRLike, data)) # pylint: disable=W0212 return raise err diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 7d12a657f3a1..6ed7edd7d1f1 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -4,8 +4,7 @@ import warnings import json import os -from typing import Union, Optional, List, Dict, Callable, Tuple, Any, TypeVar, Type, cast -from typing import Sequence +from typing import Union, Optional, List, Dict, Callable, Tuple, Any, TypeVar, Type, cast, Sequence import numpy as np from .core import Booster, DMatrix, XGBoostError @@ -14,6 +13,7 @@ from .training import train from .callback import TrainingCallback from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array +from .typing import array_like, CuDFLike # Do not use class names on scikit-learn directly. Re-define the classes on # .compat to guarantee the behavior without scikit-learn @@ -25,8 +25,6 @@ XGBoostLabelEncoder, ) -array_like = Any - class XGBRankerMixIn: # pylint: disable=too-few-public-methods """MixIn for ranking, defines the _estimator_type usually defined in scikit-learn base @@ -1049,7 +1047,7 @@ def predict( if _is_cupy_array(predts): import cupy # pylint: disable=import-error predts = cupy.asnumpy(predts) # ensure numpy array is used. - return predts + return cast(np.ndarray, predts) except TypeError: # coo, csc, dt pass @@ -1329,7 +1327,7 @@ def fit( if _is_cudf_df(y) or _is_cudf_ser(y): import cupy as cp # pylint: disable=E0401 - self.classes_ = cp.unique(y.values) + self.classes_ = cp.unique(cast(CuDFLike, y).values) self.n_classes_ = len(self.classes_) expected_classes = cp.arange(self.n_classes_) elif _is_cupy_array(y): diff --git a/python-package/xgboost/typing.py b/python-package/xgboost/typing.py new file mode 100644 index 000000000000..fa7e102e3f1f --- /dev/null +++ b/python-package/xgboost/typing.py @@ -0,0 +1,121 @@ +from abc import abstractproperty +import os +from typing import Union, Dict, Iterable, List, Tuple, Optional +from typing import SupportsIndex, Sized + +import numpy as np +try: + from numpy import typing as npt + DTypeLike = npt.DTypeLike +except ImportError: + DTypeLike = np.dtype # type: ignore + +try: + from typing import Protocol +except ImportError: + Protocol = object # type: ignore + + +class NPArrayLike(Protocol, Sized): + def __array__(self) -> np.ndarray: + ... + + @abstractproperty + def __array_interface__(self) -> Dict[str, Union[str, int, Dict, "NPArrayLike"]]: + ... + + @abstractproperty + def shape(self) -> Tuple[int, ...]: + ... + + @abstractproperty + def dtype(self) -> np.dtype: + ... + + +class CuArrayLike(Protocol): + @abstractproperty + def __cuda_array_interface__(self) -> Dict[str, Union[str, int, Dict, "CuArrayLike"]]: + ... + + @abstractproperty + def shape(self) -> Tuple[int, ...]: + ... + + @abstractproperty + def dtype(self) -> np.dtype: + ... + + def reshape(self, *s: SupportsIndex) -> "CuArrayLike": + ... + + +class CSRLike(Protocol): + @abstractproperty + def shape(self) -> Tuple[int, ...]: + ... + + @abstractproperty + def dtype(self) -> np.dtype: + ... + + indptr: NPArrayLike + indices: NPArrayLike + data: NPArrayLike + + +class DFLike(Protocol): + @abstractproperty + def values(self) -> NPArrayLike: + ... + + @abstractproperty + def shape(self) -> Tuple[int, ...]: + ... + + @abstractproperty + def dtype(self) -> np.dtype: + ... + + @abstractproperty + def dtypes(self) -> Tuple[np.dtype]: + ... + + +class CuDFLike(Protocol, Iterable): + @abstractproperty + def values(self) -> CuArrayLike: + ... + + @abstractproperty + def __cuda_array_interface__(self) -> Dict[str, Union[str, int, Dict, "CuArrayLike"]]: + ... + + @abstractproperty + def shape(self) -> Tuple[int, ...]: + ... + + @abstractproperty + def name(self) -> str: + ... + + @abstractproperty + def dtype(self) -> np.dtype: + ... + + @abstractproperty + def dtypes(self) -> Tuple[np.dtype, ...]: + ... + + def __getitem__(self, key: str) -> "CuDFLike": + ... + + +FloatT = Union[float, np.float16, np.float32, np.float64] + + +array_like = Union[NPArrayLike, DFLike, CuArrayLike, CuDFLike, CSRLike] +NativeInput = Union[NPArrayLike, DFLike, CuArrayLike, CuDFLike, CSRLike, str, os.PathLike] + + +FeatureTypes = Optional[Union[List[str], List[DTypeLike]]] From de2d310b13314de3e9c943627a3b8e6770352618 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 17 Feb 2022 04:21:16 +0800 Subject: [PATCH 02/17] some fixes. --- python-package/setup.py | 5 ++-- python-package/xgboost/core.py | 43 +++++++++++++++++++++------------- python-package/xgboost/data.py | 4 ++-- 3 files changed, 31 insertions(+), 21 deletions(-) diff --git a/python-package/setup.py b/python-package/setup.py index d6b3104798d8..00fd5013faf0 100644 --- a/python-package/setup.py +++ b/python-package/setup.py @@ -3,7 +3,6 @@ import shutil import subprocess import logging -import distutils from typing import Optional, List import sys from platform import system @@ -51,11 +50,11 @@ def lib_name() -> str: def copy_tree(src_dir: str, target_dir: str) -> None: '''Copy source tree into build directory.''' def clean_copy_tree(src: str, dst: str) -> None: - distutils.dir_util.copy_tree(src, dst) + shutil.copytree(src, dst) NEED_CLEAN_TREE.add(os.path.abspath(dst)) def clean_copy_file(src: str, dst: str) -> None: - distutils.file_util.copy_file(src, dst) + shutil.copy(src, dst) NEED_CLEAN_FILE.add(os.path.abspath(dst)) src = os.path.join(src_dir, 'src') diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index ab6053a25cfa..881ee24265e6 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -30,13 +30,14 @@ # xgboost accepts some other possible types in practice due to historical reason, which # is lesser tested. For now we encourage users to pass a simple list of string. FeatNamesT = Optional[List[str]] +Parameters = Union[List[Tuple[str, Any]], Dict[str, Any]] class XGBoostError(ValueError): """Error thrown by xgboost trainer.""" -def from_pystr_to_cstr(data: Union[str, List[str]]): +def from_pystr_to_cstr(data: Union[str, List[str]]) -> bytes: """Convert a Python str or list of Python str to C pointer Parameters @@ -320,8 +321,8 @@ class DataIter: # pylint: disable=too-many-instance-attributes Parameters ---------- cache_prefix: - Prefix to the cache files, only used in external memory. It can be either an URI - or a file path. + Prefix to the cache files, only used in external memory. It can be either an + URI or a file path. """ _T = TypeVar("_T") @@ -521,6 +522,7 @@ def __init__( weight: Optional[array_like] = None, base_margin: Optional[array_like] = None, missing: Optional[float] = None, + silent: bool = True, feature_names: FeatNamesT = None, feature_types: FeatureTypes = None, nthread: Optional[int] = None, @@ -1112,7 +1114,9 @@ def _set_data_from_cuda_interface(self, data: CuArrayLike) -> None: _LIB.XGProxyDMatrixSetDataCudaArrayInterface(self.handle, interface_str) ) - def _set_data_from_cuda_columnar(self, data: CuDFLike) -> None: + def _set_data_from_cuda_columnar( + self, data: CuDFLike, cat_codes: List[CuDFLike] + ) -> None: """Set data from CUDA columnar format.""" from .data import _cudf_array_interfaces @@ -1365,7 +1369,8 @@ def _configure_metrics(self, params: Union[Dict, List]) -> Union[Dict, List]: params = list(params.items()) for eval_metric in eval_metrics: params += [('eval_metric', eval_metric)] - return params + params_dict = {k: v for k, v in params} + return params_dict def _transform_monotone_constrains(self, value: Union[Dict[str, int], str]) -> str: if isinstance(value, str): @@ -1518,9 +1523,9 @@ def save_config(self) -> str: _check_call(_LIB.XGBoosterSaveJsonConfig( self.handle, ctypes.byref(length), - ctypes.byref(json_string))) - assert json_string.value is not None - result = json_string.value.decode() # pylint: disable=no-member + ctypes.byref(c_json_string))) + assert c_json_string.value is not None + result = c_json_string.value.decode() # pylint: disable=no-member return result def load_config(self, config: str) -> None: @@ -1657,7 +1662,7 @@ def feature_names(self) -> Optional[List[str]]: def feature_names(self, features: FeatNamesT) -> None: self._set_feature_info(features, "feature_name") - def set_param(self, params, value=None): + def set_param(self, params: Union[str, Parameters], value: Any = None) -> None: """Set parameters into the Booster. Parameters @@ -1967,7 +1972,7 @@ def inplace_predict( validate_features: bool = True, base_margin: Any = None, strict_shape: bool = False - ): + ) -> Union[np.ndarray, CuArrayLike]: """Run prediction in-place, Unlike :py:meth:`predict` method, inplace prediction does not cache the prediction result. @@ -2260,7 +2265,13 @@ def num_features(self) -> int: _check_call(_LIB.XGBoosterGetNumFeature(self.handle, ctypes.byref(features))) return features.value - def dump_model(self, fout, fmap='', with_stats=False, dump_format="text"): + def dump_model( + self, + fout: Union[os.PathLike, TextIO, str], + fmap: Union[str, os.PathLike] = '', + with_stats: bool = False, + dump_format: str = "text" + ) -> None: """Dump model into a text or JSON file. Unlike :py:meth:`save_model`, the output format is primarily used for visualization or interpretation, hence it's more human readable but cannot be loaded back to XGBoost. @@ -2595,10 +2606,10 @@ def get_split_value_histogram( bins = max(min(n_unique, bins) if bins is not None else n_unique, 1) nph = np.histogram(values, bins=bins) - nph = np.column_stack((nph[1][1:], nph[0])) - nph = nph[nph[:, 1] > 0] + nph_stacked = np.column_stack((nph[1][1:], nph[0])) + nph_stacked = nph_stacked[nph_stacked[:, 1] > 0] - if nph.size == 0: + if nph_stacked.size == 0: ft = self.feature_types fn = self.feature_names if fn is None: @@ -2616,11 +2627,11 @@ def get_split_value_histogram( ) if as_pandas and PANDAS_INSTALLED: - return DataFrame(nph, columns=['SplitValue', 'Count']) + return DataFrame(nph_stacked, columns=['SplitValue', 'Count']) if as_pandas and not PANDAS_INSTALLED: warnings.warn( "Returning histogram as ndarray" " (as_pandas == True, but pandas is not installed).", UserWarning ) - return nph + return nph_stacked diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 428361cb4034..a9bee531d80a 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -1026,9 +1026,9 @@ def reset(self) -> None: def _proxy_transform( data, feature_names: FeatNamesT, - feature_types: Optional[List[str]], + feature_types: FeatureTypes, enable_categorical: bool, -) -> Tuple[array_like, Optional[List[str]], FeatureTypes]: +) -> Tuple[array_like, Optional[list], Optional[List[str]], FeatureTypes]: if _is_cudf_df(data) or _is_cudf_ser(data): return _transform_cudf_df( cast(CuDFLike, data), feature_names, feature_types, enable_categorical From fb1ec3e52f60f877aee48bb967b92626261b296f Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 17 Feb 2022 04:49:01 +0800 Subject: [PATCH 03/17] Work on data. --- python-package/xgboost/core.py | 32 +++++++++-------- python-package/xgboost/data.py | 59 ++++++++++++++++++-------------- python-package/xgboost/typing.py | 6 +++- 3 files changed, 56 insertions(+), 41 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 881ee24265e6..3b9b78fe3eec 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -69,11 +69,14 @@ def from_cstr_to_pystr(data: ctypes.pointer, length: ctypes.c_uint64) -> List[st return res +IterRange = TypeVar("IterRange", Optional[Tuple[int, int]], Tuple[int, int]) + + def _convert_ntree_limit( booster: "Booster", ntree_limit: Optional[int], - iteration_range: Optional[Tuple[int, int]] -) -> Optional[Tuple[int, int]]: + iteration_range: IterRange +) -> IterRange: if ntree_limit is not None and ntree_limit != 0: warnings.warn( "ntree_limit is deprecated, use `iteration_range` or model " @@ -241,7 +244,7 @@ def _cuda_array_interface(data: CuArrayLike) -> bytes: def ctypes2numpy(cptr: ctypes.pointer, length: int, dtype: DTypeLike) -> np.ndarray: """Convert a ctypes pointer array to a numpy array.""" ctype = _numpy2ctypes_type(dtype) - if not isinstance(cptr, ctypes.POINTER(ctype)): + if not isinstance(cptr, ctypes.POINTER(ctype)): # type:ignore raise RuntimeError(f"expected {ctype} pointer") res = np.zeros(length, dtype=dtype) if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]): # type:ignore @@ -265,7 +268,7 @@ def ctypes2cupy(cptr: ctypes.pointer, length: int, dtype: DTypeLike) -> CuArrayL # The owner field is just used to keep the memory alive with ref count. As # unowned's life time is scoped within this function we don't need that. unownd = UnownedMemory( - addr, length * ctypes.sizeof(CUPY_TO_CTYPES_MAPPING[dtype]), owner=None + addr, length * ctypes.sizeof(CUPY_TO_CTYPES_MAPPING[dtype]), owner=None # type:ignore ) memptr = MemoryPointer(unownd, 0) # pylint: disable=unexpected-keyword-arg @@ -302,7 +305,7 @@ def c_array( def _prediction_output( shape: ctypes.pointer, dims: ctypes.c_uint64, predts: ctypes.pointer, is_cuda: bool -) -> Union[NPArrayLike, CuArrayLike]: +) -> Union[np.ndarray, Any]: arr_shape: np.ndarray = ctypes2numpy(shape, dims.value, np.uint64) length = int(np.prod(arr_shape)) if is_cuda: @@ -637,7 +640,7 @@ def __init__( if feature_names is not None: self.feature_names = feature_names if feature_types is not None: - self.feature_types = feature_types + self.feature_types = feature_types # type:ignore def _init_from_iter(self, iterator: DataIter, enable_categorical: bool) -> None: it = iterator @@ -705,7 +708,7 @@ def set_info( if feature_names is not None: self.feature_names = feature_names if feature_types is not None: - self.feature_types = feature_types + self.feature_types = feature_types # type:ignore if feature_weights is not None: dispatch_meta_backend(matrix=self, data=feature_weights, name='feature_weights') @@ -1047,7 +1050,7 @@ def feature_types(self) -> Optional[List[str]]: return res @feature_types.setter - def feature_types(self, feature_types: Optional[Union[List[str], str]]) -> None: + def feature_types(self, feature_types: FeatureTypes) -> None: """Set feature types (column types). This is for displaying the results and categorical data support. See doc string @@ -1115,7 +1118,7 @@ def _set_data_from_cuda_interface(self, data: CuArrayLike) -> None: ) def _set_data_from_cuda_columnar( - self, data: CuDFLike, cat_codes: List[CuDFLike] + self, data: CuDFLike, cat_codes: Optional[List[CuDFLike]] ) -> None: """Set data from CUDA columnar format.""" from .data import _cudf_array_interfaces @@ -1360,7 +1363,7 @@ def __init__( else: self.booster = 'gbtree' - def _configure_metrics(self, params: Union[Dict, List]) -> Union[Dict, List]: + def _configure_metrics(self, params: Union[Dict, List]) -> Dict: if isinstance(params, dict) and 'eval_metric' in params \ and isinstance(params['eval_metric'], list): params = dict((k, v) for k, v in params.items()) @@ -1412,7 +1415,7 @@ def _transform_interaction_constraints( "Constrained features are not a subset of training data feature names" ) from e - def _configure_constraints(self, params: Union[Dict, List]) -> Union[Dict, List]: + def _configure_constraints(self, params: Union[Dict, List]) -> Dict: if isinstance(params, dict): value = params.get("monotone_constraints") if value: @@ -1436,6 +1439,7 @@ def _configure_constraints(self, params: Union[Dict, List]) -> Union[Dict, List] params[idx] = (name, self._transform_monotone_constrains(value)) elif name == "interaction_constraints": params[idx] = (name, self._transform_interaction_constraints(value)) + params = {k: v for k, v in params} return params @@ -1463,7 +1467,7 @@ def __setstate__(self, state: Dict) -> None: handle = state['handle'] if handle is not None: buf = handle - dmats = c_array(ctypes.c_void_p, []) + dmats = c_array(ctypes.c_void_p, []) # type:ignore handle = ctypes.c_void_p() _check_call(_LIB.XGBoosterCreate( dmats, c_bst_ulong(0), ctypes.byref(handle))) @@ -2073,7 +2077,7 @@ def inplace_predict( else: enable_categorical = any(f == "c" for f in ft) if _is_pandas_df(data): - data, _, _ = _transform_pandas_df(data, enable_categorical) + data, _, _ = _transform_pandas_df(cast(DataFrame, data), enable_categorical) if isinstance(data, np.ndarray): from .data import _ensure_np_dtype @@ -2129,7 +2133,7 @@ def inplace_predict( if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"): from .data import _cudf_array_interfaces, _transform_cudf_df data, cat_codes, _, _ = _transform_cudf_df( - data, None, None, enable_categorical + cast(CuDFLike, data), None, None, enable_categorical ) interfaces_str = _cudf_array_interfaces(data, cat_codes) _check_call( diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index a9bee531d80a..1d00afc7ba38 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -5,7 +5,7 @@ import json import warnings import os -from typing import Any, Tuple, Callable, Union, Optional, List, cast +from typing import Any, Tuple, Callable, Union, Optional, List, cast, Sequence import numpy as np @@ -343,15 +343,17 @@ def _is_pandas_series(data: NativeInput) -> bool: def _meta_from_pandas_series( - data, name: str, dtype: Optional[str], handle: ctypes.c_void_p + data: DFLike, name: str, dtype: DTypeLike, handle: ctypes.c_void_p ) -> None: """Help transform pandas series for meta data like labels""" - data = data.values.astype('float') + import pandas as pd + df = cast(pd.Series, data) + arr: np.ndarray = df.values.astype('float') from pandas.api.types import is_sparse - if is_sparse(data): - data = data.to_dense() - assert len(data.shape) == 1 or data.shape[1] == 0 or data.shape[1] == 1 - _meta_from_numpy(data, name, dtype, handle) + if is_sparse(arr): + arr = arr.to_dense() + assert len(arr.shape) == 1 or arr.shape[1] == 0 or arr.shape[1] == 1 + _meta_from_numpy(arr, name, dtype, handle) def _is_modin_series(data: NativeInput) -> bool: @@ -363,8 +365,8 @@ def _is_modin_series(data: NativeInput) -> bool: def _from_pandas_series( - data, - missing: float, + data: DFLike, + missing: FloatT, nthread: int, enable_categorical: bool, feature_names: FeatNamesT, @@ -488,7 +490,7 @@ def _is_cudf_df(data: NativeInput) -> bool: return hasattr(cudf, 'DataFrame') and isinstance(data, cudf.DataFrame) -def _cudf_array_interfaces(data: CuDFLike, cat_codes: list) -> bytes: +def _cudf_array_interfaces(data: CuDFLike, cat_codes: Optional[list]) -> bytes: """Extract CuDF __cuda_array_interface__. This is special as it returns a new list of data and a list of array interfaces. The data is list of categorical codes that caller can safely ignore, but have to keep their reference alive until usage of array @@ -499,7 +501,6 @@ def _cudf_array_interfaces(data: CuDFLike, cat_codes: list) -> bytes: from cudf.api.types import is_categorical_dtype except ImportError: from cudf.utils.dtypes import is_categorical_dtype - import cudf interfaces = [] if _is_cudf_ser(data): @@ -529,15 +530,14 @@ def _transform_cudf_df( feature_names: FeatNamesT, feature_types: FeatureTypes, enable_categorical: bool, -) -> Tuple[CuDFLike, Optional[List[str]], FeatureTypes]: - import cudf +) -> Tuple[CuDFLike, Optional[list], Optional[List[str]], FeatureTypes]: try: from cudf.api.types import is_categorical_dtype except ImportError: from cudf.utils.dtypes import is_categorical_dtype if _is_cudf_ser(data): - dtypes = [data.dtype] + dtypes: Sequence = [data.dtype] else: dtypes = data.dtypes @@ -564,6 +564,7 @@ def _transform_cudf_df( else: feature_names = data.columns.format() + feature_types_lst = [] # handle feature types if feature_types is None: feature_types = [] @@ -886,7 +887,7 @@ def _validate_meta_shape(data: array_like, name: str) -> None: def _meta_from_numpy( data: np.ndarray, field: str, - dtype: Optional[Union[np.dtype, str]], + dtype: DTypeLike, handle: ctypes.c_void_p, ) -> None: data, dtype = _ensure_np_dtype(data, dtype) @@ -910,12 +911,14 @@ def _meta_from_tuple( return _meta_from_list(cast(list, data), field, dtype, handle) -def _meta_from_cudf_df(data, field: str, handle: ctypes.c_void_p) -> None: +def _meta_from_cudf_df(data: CuDFLike, field: str, handle: ctypes.c_void_p) -> None: + import cudf + df = cast(cudf.DataFrame, data) if field not in _matrix_meta: - _meta_from_cudf_series(data.iloc[:, 0], field, handle) + _meta_from_cudf_series(df.iloc[:, 0], field, handle) else: - data = data.values - interface = _cuda_array_interface(data) + arr = data.values + interface = _cuda_array_interface(arr) _check_call(_LIB.XGDMatrixSetInfoFromInterface(handle, c_str(field), interface)) @@ -936,14 +939,16 @@ def _meta_from_cupy_array(data: CuArrayLike, field: str, handle: ctypes.c_void_p interface)) -def _meta_from_dt(data, field: str, dtype, handle: ctypes.c_void_p): +def _meta_from_dt( + data: Any, field: str, dtype: DTypeLike, handle: ctypes.c_void_p +) -> None: data, _, _ = _transform_dt_df(data, None, None, field, dtype) _meta_from_numpy(data, field, dtype, handle) def dispatch_meta_backend( - matrix: DMatrix, data, name: str, dtype: Optional[Union[str, np.dtype]] = None -): + matrix: DMatrix, data: array_like, name: str, dtype: DTypeLike = None +) -> None: '''Dispatch for meta info.''' handle = matrix.handle assert handle is not None @@ -985,7 +990,9 @@ def dispatch_meta_backend( _meta_from_dt(data, name, dtype, handle) return if _is_modin_df(data): - data, _, _ = _transform_pandas_df(data, False, meta=name, meta_type=dtype) + data, _, _ = _transform_pandas_df( + cast(DFLike, data), False, meta=name, meta_type=dtype + ) _meta_from_numpy(data, name, dtype, handle) return if _is_modin_series(data): @@ -1024,7 +1031,7 @@ def reset(self) -> None: def _proxy_transform( - data, + data: array_like, feature_names: FeatNamesT, feature_types: FeatureTypes, enable_categorical: bool, @@ -1034,7 +1041,7 @@ def _proxy_transform( cast(CuDFLike, data), feature_names, feature_types, enable_categorical ) if _is_cupy_array(data): - data = _transform_cupy_array(data) + data = _transform_cupy_array(cast(CuArrayLike, data)) return data, None, feature_names, feature_types if _is_dlpack(data): return _transform_dlpack(data), None, feature_names, feature_types @@ -1053,7 +1060,7 @@ def _proxy_transform( def dispatch_proxy_set_data( proxy: _ProxyDMatrix, data: Any, - cat_codes: Optional[list], + cat_codes: Optional[List[CuDFLike]], allow_host: bool, ) -> None: """Dispatch for DeviceQuantileDMatrix.""" diff --git a/python-package/xgboost/typing.py b/python-package/xgboost/typing.py index fa7e102e3f1f..029d9e98fd71 100644 --- a/python-package/xgboost/typing.py +++ b/python-package/xgboost/typing.py @@ -107,6 +107,10 @@ def dtype(self) -> np.dtype: def dtypes(self) -> Tuple[np.dtype, ...]: ... + @abstractproperty + def columns(self) -> List[str]: + ... + def __getitem__(self, key: str) -> "CuDFLike": ... @@ -118,4 +122,4 @@ def __getitem__(self, key: str) -> "CuDFLike": NativeInput = Union[NPArrayLike, DFLike, CuArrayLike, CuDFLike, CSRLike, str, os.PathLike] -FeatureTypes = Optional[Union[List[str], List[DTypeLike]]] +FeatureTypes = Optional[Union[List[str], List[DTypeLike], str]] From 7f163cd9196a75f4728a7534fa2b787d4cef8b17 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 17 Feb 2022 05:05:44 +0800 Subject: [PATCH 04/17] Pass. --- python-package/xgboost/data.py | 79 +++++++++++++++++--------------- python-package/xgboost/typing.py | 5 ++ 2 files changed, 48 insertions(+), 36 deletions(-) diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 1d00afc7ba38..143e0d52242a 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -348,7 +348,7 @@ def _meta_from_pandas_series( """Help transform pandas series for meta data like labels""" import pandas as pd df = cast(pd.Series, data) - arr: np.ndarray = df.values.astype('float') + arr = df.values.astype('float') from pandas.api.types import is_sparse if is_sparse(arr): arr = arr.to_dense() @@ -372,17 +372,19 @@ def _from_pandas_series( feature_names: FeatNamesT, feature_types: FeatureTypes, ) -> _CtorReturnT: + import pandas as pd from pandas.api.types import is_categorical_dtype + series = cast(pd.Series, data) - if (data.dtype.name not in _pandas_dtype_mapper) and not ( - is_categorical_dtype(data.dtype) and enable_categorical + if (series.dtype.name not in _pandas_dtype_mapper) and not ( + is_categorical_dtype(series.dtype) and enable_categorical ): - _invalid_dataframe_dtype(data) - if enable_categorical and is_categorical_dtype(data.dtype): - data = data.cat.codes + _invalid_dataframe_dtype(series) + if enable_categorical and is_categorical_dtype(series.dtype): + series = series.cat.codes return _from_numpy_array( - data.values.reshape(data.shape[0], 1).astype("float"), + series.values.reshape(series.shape[0], 1).astype("float"), missing, nthread, feature_names, @@ -505,6 +507,7 @@ def _cudf_array_interfaces(data: CuDFLike, cat_codes: Optional[list]) -> bytes: interfaces = [] if _is_cudf_ser(data): if is_categorical_dtype(data.dtype): + assert cat_codes is not None interface = cat_codes[0].__cuda_array_interface__ else: interface = data.__cuda_array_interface__ @@ -514,6 +517,7 @@ def _cudf_array_interfaces(data: CuDFLike, cat_codes: Optional[list]) -> bytes: else: for i, col in enumerate(data): if is_categorical_dtype(data[col].dtype): + assert cat_codes is not None codes = cat_codes[i] interface = codes.__cuda_array_interface__ else: @@ -536,58 +540,61 @@ def _transform_cudf_df( except ImportError: from cudf.utils.dtypes import is_categorical_dtype - if _is_cudf_ser(data): - dtypes: Sequence = [data.dtype] + import cudf + df = cast(cudf.DataFrame, data) + + if _is_cudf_ser(df): + dtypes: Sequence = [df.dtype] else: - dtypes = data.dtypes + dtypes = df.dtypes if not all( dtype.name in _pandas_dtype_mapper or (is_categorical_dtype(dtype) and enable_categorical) for dtype in dtypes ): - _invalid_dataframe_dtype(data) + _invalid_dataframe_dtype(df) # handle feature names if feature_names is None: - if _is_cudf_ser(data): - feature_names = [data.name] - elif lazy_isinstance(data.columns, "cudf.core.multiindex", "MultiIndex"): - feature_names = [" ".join([str(x) for x in i]) for i in data.columns] + if _is_cudf_ser(df): + feature_names = [df.name] + elif lazy_isinstance(df.columns, "cudf.core.multiindex", "MultiIndex"): + feature_names = [" ".join([str(x) for x in i]) for i in df.columns] elif ( - lazy_isinstance(data.columns, "cudf.core.index", "RangeIndex") - or lazy_isinstance(data.columns, "cudf.core.index", "Int64Index") + lazy_isinstance(df.columns, "cudf.core.index", "RangeIndex") + or lazy_isinstance(df.columns, "cudf.core.index", "Int64Index") # Unique to cuDF, no equivalence in pandas 1.3.3 - or lazy_isinstance(data.columns, "cudf.core.index", "Int32Index") + or lazy_isinstance(df.columns, "cudf.core.index", "Int32Index") ): - feature_names = list(map(str, data.columns)) + feature_names = list(map(str, df.columns)) else: - feature_names = data.columns.format() + feature_names = df.columns.format() - feature_types_lst = [] # handle feature types if feature_types is None: - feature_types = [] + feature_types_lst = [] for dtype in dtypes: if is_categorical_dtype(dtype) and enable_categorical: feature_types_lst.append(CAT_T) else: - feature_types.append(_pandas_dtype_mapper[dtype.name]) + feature_types_lst.append(_pandas_dtype_mapper[dtype.name]) + feature_types = feature_types_lst # handle categorical data cat_codes = [] - if _is_cudf_ser(data): + if _is_cudf_ser(df): # unlike pandas, cuDF uses NA for missing data. - if is_categorical_dtype(data.dtype) and enable_categorical: - codes = data.cat.codes + if is_categorical_dtype(df.dtype) and enable_categorical: + codes = df.cat.codes cat_codes.append(codes) else: - for col in data: - if is_categorical_dtype(data[col].dtype) and enable_categorical: - codes = data[col].cat.codes + for col in df: + if is_categorical_dtype(df[col].dtype) and enable_categorical: + codes = df[col].cat.codes cat_codes.append(codes) - return data, cat_codes, feature_names, feature_types + return df, cat_codes, feature_names, feature_types def _from_cudf_df( @@ -595,13 +602,13 @@ def _from_cudf_df( missing: FloatT, nthread: int, feature_names: FeatNamesT, - feature_types: Optional[List[str]], + feature_types: FeatureTypes, enable_categorical: bool, ) -> _CtorReturnT: - data, cat_codes, feature_names, feature_types = _transform_cudf_df( + df, cat_codes, feature_names, feature_types = _transform_cudf_df( data, feature_names, feature_types, enable_categorical ) - interfaces_str = _cudf_array_interfaces(data, cat_codes) + interfaces_str = _cudf_array_interfaces(df, cat_codes) handle = ctypes.c_void_p() config = bytes(json.dumps({"missing": missing, "nthread": nthread}), "utf-8") _check_call( @@ -818,7 +825,7 @@ def dispatch_data_backend( feature_names, feature_types) if _is_pandas_series(data): return _from_pandas_series( - data, missing, threads, enable_categorical, feature_names, feature_types + cast(DFLike, data), missing, threads, enable_categorical, feature_names, feature_types ) if _is_cudf_df(data) or _is_cudf_ser(data): df = cast(CuDFLike, data) @@ -848,7 +855,7 @@ def dispatch_data_backend( ) if _is_modin_series(data): return _from_pandas_series( - data, missing, threads, enable_categorical, feature_names, feature_types + cast(DFLike, data), missing, threads, enable_categorical, feature_names, feature_types ) if _has_array_protocol(data): array = np.asarray(data) @@ -971,7 +978,7 @@ def dispatch_meta_backend( _meta_from_numpy(arr, name, dtype, handle) return if _is_pandas_series(data): - _meta_from_pandas_series(data, name, dtype, handle) + _meta_from_pandas_series(cast(DFLike, data), name, dtype, handle) return if _is_dlpack(data): data = _transform_dlpack(data) diff --git a/python-package/xgboost/typing.py b/python-package/xgboost/typing.py index 029d9e98fd71..c1d423c80c15 100644 --- a/python-package/xgboost/typing.py +++ b/python-package/xgboost/typing.py @@ -64,6 +64,11 @@ def dtype(self) -> np.dtype: data: NPArrayLike +class IndexLike(Protocol): + def format(self) -> List[str]: + ... + + class DFLike(Protocol): @abstractproperty def values(self) -> NPArrayLike: From c577c68ffc90bcfb0c23f87bd1e8f9ce8aabc8f7 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 17 Feb 2022 21:25:31 +0800 Subject: [PATCH 05/17] Fix. --- python-package/xgboost/core.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 3b9b78fe3eec..26a0dc2c2734 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1363,7 +1363,7 @@ def __init__( else: self.booster = 'gbtree' - def _configure_metrics(self, params: Union[Dict, List]) -> Dict: + def _configure_metrics(self, params: Union[Dict, List]) -> Union[Dict, List]: if isinstance(params, dict) and 'eval_metric' in params \ and isinstance(params['eval_metric'], list): params = dict((k, v) for k, v in params.items()) @@ -1372,8 +1372,7 @@ def _configure_metrics(self, params: Union[Dict, List]) -> Dict: params = list(params.items()) for eval_metric in eval_metrics: params += [('eval_metric', eval_metric)] - params_dict = {k: v for k, v in params} - return params_dict + return params def _transform_monotone_constrains(self, value: Union[Dict[str, int], str]) -> str: if isinstance(value, str): From 636aa296a54ff1e9e97b5722ffad57ad4475fb4b Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 17 Feb 2022 21:28:24 +0800 Subject: [PATCH 06/17] Naming. --- python-package/xgboost/core.py | 98 ++++++++++++++++---------------- python-package/xgboost/data.py | 42 +++++++------- python-package/xgboost/typing.py | 2 +- 3 files changed, 71 insertions(+), 71 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 26a0dc2c2734..e531c4cd5b5a 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -23,13 +23,13 @@ lazy_isinstance) from .libpath import find_lib_path from .typing import CuArrayLike, CuDFLike, NPArrayLike, DFLike, CSRLike -from .typing import FeatureTypes, NativeInput, array_like, DTypeLike +from .typing import FeatureTypes, NativeInput, ArrayLike, DTypeLike # c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h c_bst_ulong = ctypes.c_uint64 # xgboost accepts some other possible types in practice due to historical reason, which # is lesser tested. For now we encourage users to pass a simple list of string. -FeatNamesT = Optional[List[str]] +FeatureNames = Optional[List[str]] Parameters = Union[List[Tuple[str, Any]], Dict[str, Any]] @@ -407,7 +407,7 @@ def _next_wrapper(self, this: None) -> int: # pylint: disable=unused-argument def data_handle( data: Any, *, - feature_names: FeatNamesT = None, + feature_names: FeatureNames = None, feature_types: FeatureTypes = None, **kwargs: Any, ) -> None: @@ -520,20 +520,20 @@ class DMatrix: # pylint: disable=too-many-instance-attributes def __init__( self, data: NativeInput, - label: Optional[array_like] = None, + label: Optional[ArrayLike] = None, *, - weight: Optional[array_like] = None, - base_margin: Optional[array_like] = None, + weight: Optional[ArrayLike] = None, + base_margin: Optional[ArrayLike] = None, missing: Optional[float] = None, silent: bool = True, - feature_names: FeatNamesT = None, + feature_names: FeatureNames = None, feature_types: FeatureTypes = None, nthread: Optional[int] = None, - group: Optional[array_like] = None, - qid: Optional[array_like] = None, - label_lower_bound: Optional[array_like] = None, - label_upper_bound: Optional[array_like] = None, - feature_weights: Optional[array_like] = None, + group: Optional[ArrayLike] = None, + qid: Optional[ArrayLike] = None, + label_lower_bound: Optional[ArrayLike] = None, + label_upper_bound: Optional[ArrayLike] = None, + feature_weights: Optional[ArrayLike] = None, enable_categorical: bool = False, ) -> None: """Parameters @@ -545,9 +545,9 @@ def __init__( libsvm format txt file, csv file (by specifying uri parameter 'path_to_csv?format=csv'), or binary file that xgboost can read from. - label : array_like + label : Label of the training data. - weight : array_like + weight : ArrayLike Weight for each instance. .. note:: For ranking task, weights are per-group. @@ -557,7 +557,7 @@ def __init__( ordering of data points within each group, so it doesn't make sense to assign weights to individual data points. - base_margin: array_like + base_margin: ArrayLike Base margin used for boosting from existing model. missing : float, optional Value in the input data which needs to be present as a missing @@ -574,15 +574,15 @@ def __init__( nthread : integer, optional Number of threads to use for loading data when parallelization is applicable. If -1, uses maximum threads available on the system. - group : array_like + group : ArrayLike Group size for all ranking group. - qid : array_like + qid : ArrayLike Query ID for data samples, used for ranking. - label_lower_bound : array_like + label_lower_bound : ArrayLike Lower bound for survival training. - label_upper_bound : array_like + label_upper_bound : ArrayLike Upper bound for survival training. - feature_weights : array_like, optional + feature_weights : ArrayLike, optional Set feature weights for column sampling. enable_categorical: boolean, optional @@ -677,16 +677,16 @@ def __del__(self) -> None: def set_info( self, *, - label: Optional[array_like] = None, - weight: Optional[array_like] = None, - base_margin: Optional[array_like] = None, - group: Optional[array_like] = None, - qid: Optional[array_like] = None, - label_lower_bound: Optional[array_like] = None, - label_upper_bound: Optional[array_like] = None, - feature_names: FeatNamesT = None, + label: Optional[ArrayLike] = None, + weight: Optional[ArrayLike] = None, + base_margin: Optional[ArrayLike] = None, + group: Optional[ArrayLike] = None, + qid: Optional[ArrayLike] = None, + label_lower_bound: Optional[ArrayLike] = None, + label_upper_bound: Optional[ArrayLike] = None, + feature_names: FeatureNames = None, feature_types: FeatureTypes = None, - feature_weights: Optional[array_like] = None + feature_weights: Optional[ArrayLike] = None ) -> None: """Set meta info for DMatrix. See doc string for :py:obj:`xgboost.DMatrix`.""" from .data import dispatch_meta_backend @@ -755,7 +755,7 @@ def get_uint_info(self, field: str) -> np.ndarray: ctypes.byref(ret))) return ctypes2numpy(ret, length.value, np.uint32) - def set_float_info(self, field: str, data: array_like) -> None: + def set_float_info(self, field: str, data: ArrayLike) -> None: """Set float type property into the DMatrix. Parameters @@ -769,7 +769,7 @@ def set_float_info(self, field: str, data: array_like) -> None: from .data import dispatch_meta_backend dispatch_meta_backend(self, data, field, 'float') - def set_float_info_npy2d(self, field: str, data: array_like) -> None: + def set_float_info_npy2d(self, field: str, data: ArrayLike) -> None: """Set float type property into the DMatrix for numpy 2d array input @@ -784,7 +784,7 @@ def set_float_info_npy2d(self, field: str, data: array_like) -> None: from .data import dispatch_meta_backend dispatch_meta_backend(self, data, field, 'float') - def set_uint_info(self, field: str, data: array_like) -> None: + def set_uint_info(self, field: str, data: ArrayLike) -> None: """Set uint type property into the DMatrix. Parameters @@ -814,7 +814,7 @@ def save_binary(self, fname: os.PathLike, silent: bool = True) -> None: c_str(fname_str), ctypes.c_int(silent))) - def set_label(self, label: array_like) -> None: + def set_label(self, label: ArrayLike) -> None: """Set label of dmatrix Parameters @@ -825,7 +825,7 @@ def set_label(self, label: array_like) -> None: from .data import dispatch_meta_backend dispatch_meta_backend(self, label, 'label', 'float') - def set_weight(self, weight: array_like) -> None: + def set_weight(self, weight: ArrayLike) -> None: """Set weight of each instance. Parameters @@ -844,7 +844,7 @@ def set_weight(self, weight: array_like) -> None: from .data import dispatch_meta_backend dispatch_meta_backend(self, weight, 'weight', 'float') - def set_base_margin(self, margin: array_like) -> None: + def set_base_margin(self, margin: ArrayLike) -> None: """Set base margin of booster to start from. This can be used to specify a prediction value of existing model to be @@ -861,7 +861,7 @@ def set_base_margin(self, margin: array_like) -> None: from .data import dispatch_meta_backend dispatch_meta_backend(self, margin, 'base_margin', 'float') - def set_group(self, group: array_like) -> None: + def set_group(self, group: ArrayLike) -> None: """Set group size of DMatrix (used for ranking). Parameters @@ -989,7 +989,7 @@ def feature_names(self) -> Optional[List[str]]: return feature_names @feature_names.setter - def feature_names(self, feature_names: FeatNamesT) -> None: + def feature_names(self, feature_names: FeatureNames) -> None: """Set feature names (column labels). Parameters @@ -1165,21 +1165,21 @@ class DeviceQuantileDMatrix(DMatrix): def __init__( # pylint: disable=super-init-not-called self, data: Union[CuArrayLike, CuDFLike], - label: Optional[array_like] = None, + label: Optional[ArrayLike] = None, *, - weight: Optional[array_like] = None, - base_margin: Optional[array_like] = None, + weight: Optional[ArrayLike] = None, + base_margin: Optional[ArrayLike] = None, missing: Optional[float] = None, silent: bool = False, - feature_names: FeatNamesT = None, + feature_names: FeatureNames = None, feature_types: FeatureTypes = None, nthread: Optional[int] = None, max_bin: int = 256, - group: Optional[array_like] = None, - qid: Optional[array_like] = None, - label_lower_bound: Optional[array_like] = None, - label_upper_bound: Optional[array_like] = None, - feature_weights: Optional[array_like] = None, + group: Optional[ArrayLike] = None, + qid: Optional[ArrayLike] = None, + label_lower_bound: Optional[ArrayLike] = None, + label_upper_bound: Optional[ArrayLike] = None, + feature_weights: Optional[ArrayLike] = None, enable_categorical: bool = False, ) -> None: self.max_bin = max_bin @@ -1216,7 +1216,7 @@ def _init( self, data: Union[CuArrayLike, CuDFLike], enable_categorical: bool, - **meta: Union[Optional[array_like], Optional[Sequence]] + **meta: Union[Optional[ArrayLike], Optional[Sequence]] ) -> None: from .data import ( _is_dlpack, @@ -1662,7 +1662,7 @@ def feature_names(self) -> Optional[List[str]]: return self._get_feature_info("feature_name") @feature_names.setter - def feature_names(self, features: FeatNamesT) -> None: + def feature_names(self, features: FeatureNames) -> None: self._set_feature_info(features, "feature_name") def set_param(self, params: Union[str, Parameters], value: Any = None) -> None: @@ -1968,7 +1968,7 @@ def assign_type(t: int) -> None: def inplace_predict( self, - data: array_like, + data: ArrayLike, iteration_range: Tuple[int, int] = (0, 0), predict_type: str = "value", missing: float = np.nan, diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 143e0d52242a..b83843c9fa5d 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -11,10 +11,10 @@ from .core import c_array, _LIB, _check_call, c_str from .core import _cuda_array_interface -from .core import DataIter, _ProxyDMatrix, DMatrix, FeatNamesT -from .compat import lazy_isinstance, DataFrame +from .core import DataIter, _ProxyDMatrix, DMatrix, FeatureNames +from .compat import lazy_isinstance -from .typing import array_like, FloatT, CSRLike, NPArrayLike, DFLike, NativeInput +from .typing import ArrayLike, FloatT, CSRLike, NPArrayLike, DFLike, NativeInput from .typing import CuArrayLike, CuDFLike, FeatureTypes, DTypeLike c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name @@ -36,7 +36,7 @@ def _warn_unused_missing(data: NativeInput, missing: FloatT) -> None: str(type(data)), UserWarning) -def _check_complex(data: array_like) -> None: +def _check_complex(data: ArrayLike) -> None: '''Test whether data is complex using `dtype` attribute.''' complex_dtypes = (np.complex128, np.complex64, np.cfloat, np.cdouble, np.clongdouble) @@ -45,7 +45,7 @@ def _check_complex(data: array_like) -> None: def _check_data_shape(data: NativeInput) -> None: - if hasattr(data, "shape") and len(cast(array_like, data).shape) != 2: + if hasattr(data, "shape") and len(cast(ArrayLike, data).shape) != 2: raise ValueError("Please reshape the input data into 2-dimensional matrix.") @@ -73,7 +73,7 @@ def _from_scipy_csr( data: CSRLike, missing: FloatT, nthread: int, - feature_names: FeatNamesT, + feature_names: FeatureNames, feature_types: FeatureTypes, ) -> _CtorReturnT: """Initialize data from a CSR matrix.""" @@ -112,7 +112,7 @@ def _is_scipy_csc(data: NativeInput) -> bool: def _from_scipy_csc( data: CSRLike, missing: FloatT, - feature_names: FeatNamesT, + feature_names: FeatureNames, feature_types: FeatureTypes, ) -> _CtorReturnT: if len(data.indices) != len(data.data): @@ -172,7 +172,7 @@ def _from_numpy_array( data: np.ndarray, missing: FloatT, nthread: int, - feature_names: FeatNamesT, + feature_names: FeatureNames, feature_types: FeatureTypes, ) -> _CtorReturnT: """Initialize data from a 2-D numpy matrix. @@ -323,7 +323,7 @@ def _from_pandas_df( enable_categorical: bool, missing: FloatT, nthread: int, - feature_names: FeatNamesT, + feature_names: FeatureNames, feature_types: FeatureTypes, ) -> _CtorReturnT: arr, feature_names, feature_types = _transform_pandas_df( @@ -369,7 +369,7 @@ def _from_pandas_series( missing: FloatT, nthread: int, enable_categorical: bool, - feature_names: FeatNamesT, + feature_names: FeatureNames, feature_types: FeatureTypes, ) -> _CtorReturnT: import pandas as pd @@ -531,7 +531,7 @@ def _cudf_array_interfaces(data: CuDFLike, cat_codes: Optional[list]) -> bytes: def _transform_cudf_df( data: CuDFLike, - feature_names: FeatNamesT, + feature_names: FeatureNames, feature_types: FeatureTypes, enable_categorical: bool, ) -> Tuple[CuDFLike, Optional[list], Optional[List[str]], FeatureTypes]: @@ -601,7 +601,7 @@ def _from_cudf_df( data: CuDFLike, missing: FloatT, nthread: int, - feature_names: FeatNamesT, + feature_names: FeatureNames, feature_types: FeatureTypes, enable_categorical: bool, ) -> _CtorReturnT: @@ -714,7 +714,7 @@ def _is_uri(data: NativeInput) -> bool: def _from_uri( data: Union[os.PathLike, str], missing: FloatT, - feature_names: FeatNamesT, + feature_names: FeatureNames, feature_types: FeatureTypes, ) -> _CtorReturnT: _warn_unused_missing(data, missing) @@ -734,7 +734,7 @@ def _from_list( data: list, missing: FloatT, n_threads: int, - feature_names: FeatNamesT, + feature_names: FeatureNames, feature_types: FeatureTypes, ) -> _CtorReturnT: array = np.array(data) @@ -750,7 +750,7 @@ def _from_tuple( data: tuple, missing: FloatT, n_threads: int, - feature_names: FeatNamesT, + feature_names: FeatureNames, feature_types: FeatureTypes, ) -> _CtorReturnT: return _from_list(cast(list, data), missing, n_threads, feature_names, feature_types) @@ -786,7 +786,7 @@ def dispatch_data_backend( data: NativeInput, missing: FloatT, threads: int, - feature_names: FeatNamesT, + feature_names: FeatureNames, feature_types: FeatureTypes, enable_categorical: bool = False, ) -> _CtorReturnT: @@ -877,7 +877,7 @@ def _to_data_type(dtype: str, name: str) -> int: return dtype_map[dtype] -def _validate_meta_shape(data: array_like, name: str) -> None: +def _validate_meta_shape(data: ArrayLike, name: str) -> None: if hasattr(data, "shape"): msg = f"Invalid shape: {data.shape} for {name}" if name in _matrix_meta: @@ -954,7 +954,7 @@ def _meta_from_dt( def dispatch_meta_backend( - matrix: DMatrix, data: array_like, name: str, dtype: DTypeLike = None + matrix: DMatrix, data: ArrayLike, name: str, dtype: DTypeLike = None ) -> None: '''Dispatch for meta info.''' handle = matrix.handle @@ -1038,11 +1038,11 @@ def reset(self) -> None: def _proxy_transform( - data: array_like, - feature_names: FeatNamesT, + data: ArrayLike, + feature_names: FeatureNames, feature_types: FeatureTypes, enable_categorical: bool, -) -> Tuple[array_like, Optional[list], Optional[List[str]], FeatureTypes]: +) -> Tuple[ArrayLike, Optional[list], Optional[List[str]], FeatureTypes]: if _is_cudf_df(data) or _is_cudf_ser(data): return _transform_cudf_df( cast(CuDFLike, data), feature_names, feature_types, enable_categorical diff --git a/python-package/xgboost/typing.py b/python-package/xgboost/typing.py index c1d423c80c15..8db0c27ac0af 100644 --- a/python-package/xgboost/typing.py +++ b/python-package/xgboost/typing.py @@ -123,7 +123,7 @@ def __getitem__(self, key: str) -> "CuDFLike": FloatT = Union[float, np.float16, np.float32, np.float64] -array_like = Union[NPArrayLike, DFLike, CuArrayLike, CuDFLike, CSRLike] +ArrayLike = Union[NPArrayLike, DFLike, CuArrayLike, CuDFLike, CSRLike] NativeInput = Union[NPArrayLike, DFLike, CuArrayLike, CuDFLike, CSRLike, str, os.PathLike] From 4545bc4b08a7530c0051bbe4857e60d053ddd67b Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 17 Feb 2022 21:28:54 +0800 Subject: [PATCH 07/17] black. --- python-package/xgboost/data.py | 295 +++++++++++++++++++-------------- 1 file changed, 169 insertions(+), 126 deletions(-) diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index b83843c9fa5d..b76985b54757 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -1,6 +1,6 @@ # pylint: disable=too-many-arguments, too-many-branches, too-many-lines # pylint: disable=too-many-return-statements, import-error -'''Data dispatching for DMatrix.''' +"""Data dispatching for DMatrix.""" import ctypes import json import warnings @@ -17,7 +17,7 @@ from .typing import ArrayLike, FloatT, CSRLike, NPArrayLike, DFLike, NativeInput from .typing import CuArrayLike, CuDFLike, FeatureTypes, DTypeLike -c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name +c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name _CtorReturnT = Tuple[ctypes.c_void_p, Optional[List[str]], FeatureTypes] @@ -32,16 +32,22 @@ def _warn_unused_missing(data: NativeInput, missing: FloatT) -> None: if (missing is not None) and (not np.isnan(missing)): warnings.warn( - '`missing` is not used for current input data type:' + - str(type(data)), UserWarning) + "`missing` is not used for current input data type:" + str(type(data)), + UserWarning, + ) def _check_complex(data: ArrayLike) -> None: - '''Test whether data is complex using `dtype` attribute.''' - complex_dtypes = (np.complex128, np.complex64, - np.cfloat, np.cdouble, np.clongdouble) - if hasattr(data, 'dtype') and data.dtype in complex_dtypes: - raise ValueError('Complex data not supported') + """Test whether data is complex using `dtype` attribute.""" + complex_dtypes = ( + np.complex128, + np.complex64, + np.cfloat, + np.cdouble, + np.clongdouble, + ) + if hasattr(data, "dtype") and data.dtype in complex_dtypes: + raise ValueError("Complex data not supported") def _check_data_shape(data: NativeInput) -> None: @@ -78,9 +84,7 @@ def _from_scipy_csr( ) -> _CtorReturnT: """Initialize data from a CSR matrix.""" if len(data.indices) != len(data.data): - raise ValueError( - f"length mismatch: {len(data.indices)} vs {len(data.data)}" - ) + raise ValueError(f"length mismatch: {len(data.indices)} vs {len(data.data)}") handle = ctypes.c_void_p() args = { "missing": float(missing), @@ -119,14 +123,17 @@ def _from_scipy_csc( raise ValueError(f"length mismatch: {len(data.indices)} vs {len(data.data)}") _warn_unused_missing(data, missing) handle = ctypes.c_void_p() - _check_call(_LIB.XGDMatrixCreateFromCSCEx( - c_array(ctypes.c_size_t, data.indptr), - c_array(ctypes.c_uint, data.indices), - c_array(ctypes.c_float, data.data), - ctypes.c_size_t(len(data.indptr)), - ctypes.c_size_t(len(data.data)), - ctypes.c_size_t(data.shape[0]), - ctypes.byref(handle))) + _check_call( + _LIB.XGDMatrixCreateFromCSCEx( + c_array(ctypes.c_size_t, data.indptr), + c_array(ctypes.c_uint, data.indices), + c_array(ctypes.c_float, data.data), + ctypes.c_size_t(len(data.indptr)), + ctypes.c_size_t(len(data.data)), + ctypes.c_size_t(data.shape[0]), + ctypes.byref(handle), + ) + ) return handle, feature_names, feature_types @@ -152,11 +159,8 @@ def _ensure_np_dtype( return data, dtype -def _maybe_np_slice( - data: Union[List[int], np.ndarray], dtype: DTypeLike -) -> np.ndarray: - '''Handle numpy slice. This can be removed if we use __array_interface__. - ''' +def _maybe_np_slice(data: Union[List[int], np.ndarray], dtype: DTypeLike) -> np.ndarray: + """Handle numpy slice. This can be removed if we use __array_interface__.""" try: if not cast(np.ndarray, data).flags.c_contiguous: data = np.array(data, copy=True, dtype=dtype) @@ -175,13 +179,9 @@ def _from_numpy_array( feature_names: FeatureNames, feature_types: FeatureTypes, ) -> _CtorReturnT: - """Initialize data from a 2-D numpy matrix. - - """ + """Initialize data from a 2-D numpy matrix.""" if len(data.shape) != 2: - raise ValueError( - "Expecting 2 dimensional numpy.ndarray, got: ", data.shape - ) + raise ValueError("Expecting 2 dimensional numpy.ndarray, got: ", data.shape) data, _ = _ensure_np_dtype(data, data.dtype) handle = ctypes.c_void_p() args = { @@ -216,18 +216,18 @@ def _is_modin_df(data: NativeInput) -> bool: _pandas_dtype_mapper = { - 'int8': 'int', - 'int16': 'int', - 'int32': 'int', - 'int64': 'int', - 'uint8': 'int', - 'uint16': 'int', - 'uint32': 'int', - 'uint64': 'int', - 'float16': 'float', - 'float32': 'float', - 'float64': 'float', - 'bool': 'i', + "int8": "int", + "int16": "int", + "int32": "int", + "int64": "int", + "uint8": "int", + "uint16": "int", + "uint32": "int", + "uint64": "int", + "float16": "float", + "float32": "float", + "float64": "float", + "bool": "i", } @@ -244,9 +244,12 @@ def _invalid_dataframe_dtype(data: Any) -> None: else: err = "" - msg = """DataFrame.dtypes for data must be int, float, bool or category. When + msg = ( + """DataFrame.dtypes for data must be int, float, bool or category. When categorical type is supplied, DMatrix parameter `enable_categorical` must -be set to `True`.""" + err +be set to `True`.""" + + err + ) raise ValueError(msg) @@ -260,6 +263,7 @@ def _transform_pandas_df( ) -> Tuple[np.ndarray, Optional[List[str]], FeatureTypes]: import pandas as pd from pandas.api.types import is_sparse, is_categorical_dtype + data = cast(pd.DataFrame, df) if not all( @@ -329,9 +333,7 @@ def _from_pandas_df( arr, feature_names, feature_types = _transform_pandas_df( data, enable_categorical, feature_names, feature_types ) - return _from_numpy_array( - arr, missing, nthread, feature_names, feature_types - ) + return _from_numpy_array(arr, missing, nthread, feature_names, feature_types) def _is_pandas_series(data: NativeInput) -> bool: @@ -347,9 +349,11 @@ def _meta_from_pandas_series( ) -> None: """Help transform pandas series for meta data like labels""" import pandas as pd + df = cast(pd.Series, data) - arr = df.values.astype('float') + arr = df.values.astype("float") from pandas.api.types import is_sparse + if is_sparse(arr): arr = arr.to_dense() assert len(arr.shape) == 1 or arr.shape[1] == 0 or arr.shape[1] == 1 @@ -374,6 +378,7 @@ def _from_pandas_series( ) -> _CtorReturnT: import pandas as pd from pandas.api.types import is_categorical_dtype + series = cast(pd.Series, data) if (series.dtype.name not in _pandas_dtype_mapper) and not ( @@ -393,12 +398,13 @@ def _from_pandas_series( def _is_dt_df(data: NativeInput) -> bool: - return lazy_isinstance(data, 'datatable', 'Frame') or \ - lazy_isinstance(data, 'datatable', 'DataTable') + return lazy_isinstance(data, "datatable", "Frame") or lazy_isinstance( + data, "datatable", "DataTable" + ) -_dt_type_mapper = {'bool': 'bool', 'int': 'int', 'real': 'float'} -_dt_type_mapper2 = {'bool': 'i', 'int': 'int', 'real': 'float'} +_dt_type_mapper = {"bool": "bool", "int": "int", "real": "float"} +_dt_type_mapper2 = {"bool": "i", "int": "int", "real": "float"} def _transform_dt_df( @@ -410,7 +416,7 @@ def _transform_dt_df( ) -> Tuple[np.ndarray, Optional[List[str]], FeatureTypes]: """Validate feature names and types if data table""" if meta and data.shape[1] > 1: - raise ValueError('DataTable for meta info cannot have multiple columns') + raise ValueError("DataTable for meta info cannot have multiple columns") if meta: meta_type = "float" if meta_type is None else meta_type # below requires new dt version @@ -419,23 +425,23 @@ def _transform_dt_df( return data, None, None data_types_names = tuple(lt.name for lt in data.ltypes) - bad_fields = [data.names[i] - for i, type_name in enumerate(data_types_names) - if type_name not in _dt_type_mapper] + bad_fields = [ + data.names[i] + for i, type_name in enumerate(data_types_names) + if type_name not in _dt_type_mapper + ] if bad_fields: msg = """DataFrame.types for data must be int, float or bool. Did not expect the data types in fields """ - raise ValueError(msg + ', '.join(bad_fields)) + raise ValueError(msg + ", ".join(bad_fields)) if feature_names is None and meta is None: feature_names = data.names # always return stypes for dt ingestion if feature_types is not None: - raise ValueError( - 'DataTable has own feature types, cannot pass them in.') - feature_types = np.vectorize(_dt_type_mapper2.get)( - data_types_names).tolist() + raise ValueError("DataTable has own feature types, cannot pass them in.") + feature_types = np.vectorize(_dt_type_mapper2.get)(data_types_names).tolist() return data, feature_names, feature_types @@ -451,7 +457,8 @@ def _from_dt_df( if enable_categorical: raise ValueError("categorical data in datatable is not supported yet.") data, feature_names, feature_types = _transform_dt_df( - data, feature_names, feature_types, None, None) + data, feature_names, feature_types, None, None + ) ptrs = (ctypes.c_void_p * data.ncols)() if hasattr(data, "internal") and hasattr(data.internal, "column"): @@ -462,8 +469,10 @@ def _from_dt_df( ptrs[icol] = ctypes.c_void_p(ptr) else: # datatable<=0.8.0 - from datatable.internal import \ - frame_column_data_r # pylint: disable=no-name-in-module + from datatable.internal import ( + frame_column_data_r, + ) # pylint: disable=no-name-in-module + for icol in range(data.ncols): ptrs[icol] = frame_column_data_r(data, icol) @@ -471,16 +480,21 @@ def _from_dt_df( feature_type_strings = (ctypes.c_char_p * data.ncols)() for icol in range(data.ncols): feature_type_strings[icol] = ctypes.c_char_p( - data.stypes[icol].name.encode('utf-8')) + data.stypes[icol].name.encode("utf-8") + ) _warn_unused_missing(data, missing) handle = ctypes.c_void_p() - _check_call(_LIB.XGDMatrixCreateFromDT( - ptrs, feature_type_strings, - c_bst_ulong(data.shape[0]), - c_bst_ulong(data.shape[1]), - ctypes.byref(handle), - ctypes.c_int(nthread))) + _check_call( + _LIB.XGDMatrixCreateFromDT( + ptrs, + feature_type_strings, + c_bst_ulong(data.shape[0]), + c_bst_ulong(data.shape[1]), + ctypes.byref(handle), + ctypes.c_int(nthread), + ) + ) return handle, feature_names, feature_types @@ -489,7 +503,7 @@ def _is_cudf_df(data: NativeInput) -> bool: import cudf except ImportError: return False - return hasattr(cudf, 'DataFrame') and isinstance(data, cudf.DataFrame) + return hasattr(cudf, "DataFrame") and isinstance(data, cudf.DataFrame) def _cudf_array_interfaces(data: CuDFLike, cat_codes: Optional[list]) -> bytes: @@ -541,6 +555,7 @@ def _transform_cudf_df( from cudf.utils.dtypes import is_categorical_dtype import cudf + df = cast(cudf.DataFrame, data) if _is_cudf_ser(df): @@ -639,9 +654,9 @@ def _is_cupy_array(data: NativeInput) -> bool: def _transform_cupy_array(arr: CuArrayLike) -> CuArrayLike: import cupy # pylint: disable=import-error + data = cast(cupy.ndarray, arr) - if not hasattr(data, '__cuda_array_interface__') and hasattr( - data, '__array__'): + if not hasattr(data, "__cuda_array_interface__") and hasattr(data, "__array__"): data = cupy.array(data, copy=False) if data.dtype.hasobject or data.dtype in [cupy.float16, cupy.bool_]: data = data.astype(cupy.float32, copy=False) @@ -662,9 +677,9 @@ def _from_cupy_array( config = bytes(json.dumps({"missing": missing, "nthread": nthread}), "utf-8") _check_call( _LIB.XGDMatrixCreateFromCudaArrayInterface( - interface_str, - config, - ctypes.byref(handle))) + interface_str, config, ctypes.byref(handle) + ) + ) return handle, feature_names, feature_types @@ -685,12 +700,13 @@ def _is_cupy_csc(data: NativeInput) -> bool: def _is_dlpack(data: NativeInput) -> bool: - return 'PyCapsule' in str(type(data)) and "dltensor" in str(data) + return "PyCapsule" in str(type(data)) and "dltensor" in str(data) def _transform_dlpack(data: Any) -> CuArrayLike: from cupy import fromDlpack # pylint: disable=E0401 - assert 'used_dltensor' not in str(data) + + assert "used_dltensor" not in str(data) data = fromDlpack(data) return data @@ -703,8 +719,7 @@ def _from_dlpack( feature_types: FeatureTypes, ) -> _CtorReturnT: data = _transform_dlpack(data) - return _from_cupy_array(data, missing, nthread, feature_names, - feature_types) + return _from_cupy_array(data, missing, nthread, feature_names, feature_types) def _is_uri(data: NativeInput) -> bool: @@ -720,9 +735,9 @@ def _from_uri( _warn_unused_missing(data, missing) handle = ctypes.c_void_p() data = os.fspath(os.path.expanduser(data)) - _check_call(_LIB.XGDMatrixCreateFromFile(c_str(data), - ctypes.c_int(1), - ctypes.byref(handle))) + _check_call( + _LIB.XGDMatrixCreateFromFile(c_str(data), ctypes.c_int(1), ctypes.byref(handle)) + ) return handle, feature_names, feature_types @@ -753,7 +768,9 @@ def _from_tuple( feature_names: FeatureNames, feature_types: FeatureTypes, ) -> _CtorReturnT: - return _from_list(cast(list, data), missing, n_threads, feature_names, feature_types) + return _from_list( + cast(list, data), missing, n_threads, feature_names, feature_types + ) def _is_iter(data: Any) -> bool: @@ -761,13 +778,13 @@ def _is_iter(data: Any) -> bool: def _has_array_protocol(data: Any) -> bool: - return hasattr(data, '__array__') + return hasattr(data, "__array__") def _convert_unknown_data(data: Any) -> Optional[Any]: warnings.warn( - f'Unknown data type: {type(data)}, trying to convert it to csr_matrix', - UserWarning + f"Unknown data type: {type(data)}, trying to convert it to csr_matrix", + UserWarning, ) try: import scipy @@ -776,7 +793,7 @@ def _convert_unknown_data(data: Any) -> Optional[Any]: try: data = scipy.sparse.csr_matrix(data) - except Exception: # pylint: disable=broad-except + except Exception: # pylint: disable=broad-except return None return data @@ -790,7 +807,7 @@ def dispatch_data_backend( feature_types: FeatureTypes, enable_categorical: bool = False, ) -> _CtorReturnT: - '''Dispatch data for DMatrix.''' + """Dispatch data for DMatrix.""" if not _is_cudf_ser(data) and not _is_pandas_series(data): _check_data_shape(data) if _is_scipy_csr(data): @@ -798,11 +815,18 @@ def dispatch_data_backend( cast(CSRLike, data), missing, threads, feature_names, feature_types ) if _is_scipy_csc(data): - return _from_scipy_csc(cast(CSRLike, data), missing, feature_names, feature_types) + return _from_scipy_csc( + cast(CSRLike, data), missing, feature_names, feature_types + ) if _is_scipy_coo(data): from scipy.sparse import coo_matrix + return _from_scipy_csr( - cast(coo_matrix, data).tocsr(), missing, threads, feature_names, feature_types + cast(coo_matrix, data).tocsr(), + missing, + threads, + feature_names, + feature_types, ) if _is_numpy_array(data): return _from_numpy_array( @@ -821,11 +845,22 @@ def dispatch_data_backend( cast(tuple, data), missing, threads, feature_names, feature_types ) if _is_pandas_df(data): - return _from_pandas_df(cast(DFLike, data), enable_categorical, missing, threads, - feature_names, feature_types) + return _from_pandas_df( + cast(DFLike, data), + enable_categorical, + missing, + threads, + feature_names, + feature_types, + ) if _is_pandas_series(data): return _from_pandas_series( - cast(DFLike, data), missing, threads, enable_categorical, feature_names, feature_types + cast(DFLike, data), + missing, + threads, + enable_categorical, + feature_names, + feature_types, ) if _is_cudf_df(data) or _is_cudf_ser(data): df = cast(CuDFLike, data) @@ -837,12 +872,11 @@ def dispatch_data_backend( cast(CuArrayLike, data), missing, threads, feature_names, feature_types ) if _is_cupy_csr(data): - raise TypeError('cupyx CSR is not supported yet.') + raise TypeError("cupyx CSR is not supported yet.") if _is_cupy_csc(data): - raise TypeError('cupyx CSC is not supported yet.') + raise TypeError("cupyx CSC is not supported yet.") if _is_dlpack(data): - return _from_dlpack(data, missing, threads, feature_names, - feature_types) + return _from_dlpack(data, missing, threads, feature_names, feature_types) if _is_dt_df(data): _warn_unused_missing(data, missing) return _from_dt_df( @@ -850,12 +884,21 @@ def dispatch_data_backend( ) if _is_modin_df(data): return _from_pandas_df( - cast(DFLike, data), enable_categorical, missing, threads, - feature_names, feature_types + cast(DFLike, data), + enable_categorical, + missing, + threads, + feature_names, + feature_types, ) if _is_modin_series(data): return _from_pandas_series( - cast(DFLike, data), missing, threads, enable_categorical, feature_names, feature_types + cast(DFLike, data), + missing, + threads, + enable_categorical, + feature_names, + feature_types, ) if _has_array_protocol(data): array = np.asarray(data) @@ -863,17 +906,19 @@ def dispatch_data_backend( converted = _convert_unknown_data(data) if converted is not None: - return _from_scipy_csr(converted, missing, threads, feature_names, feature_types) + return _from_scipy_csr( + converted, missing, threads, feature_names, feature_types + ) - raise TypeError('Not supported type for data.' + str(type(data))) + raise TypeError("Not supported type for data." + str(type(data))) def _to_data_type(dtype: str, name: str) -> int: - dtype_map = {'float32': 1, 'float64': 2, 'uint32': 3, 'uint64': 4} + dtype_map = {"float32": 1, "float64": 2, "uint32": 3, "uint64": 4} if dtype not in dtype_map: raise TypeError( - f'Expecting float32, float64, uint32, uint64, got {dtype} ' + - f'for {name}.') + f"Expecting float32, float64, uint32, uint64, got {dtype} " + f"for {name}." + ) return dtype_map[dtype] @@ -920,6 +965,7 @@ def _meta_from_tuple( def _meta_from_cudf_df(data: CuDFLike, field: str, handle: ctypes.c_void_p) -> None: import cudf + df = cast(cudf.DataFrame, data) if field not in _matrix_meta: _meta_from_cudf_series(df.iloc[:, 0], field, handle) @@ -930,20 +976,16 @@ def _meta_from_cudf_df(data: CuDFLike, field: str, handle: ctypes.c_void_p) -> N def _meta_from_cudf_series(data: CuDFLike, field: str, handle: ctypes.c_void_p) -> None: - interface = bytes(json.dumps([data.__cuda_array_interface__], - indent=2), 'utf-8') - _check_call(_LIB.XGDMatrixSetInfoFromInterface(handle, - c_str(field), - interface)) + interface = bytes(json.dumps([data.__cuda_array_interface__], indent=2), "utf-8") + _check_call(_LIB.XGDMatrixSetInfoFromInterface(handle, c_str(field), interface)) -def _meta_from_cupy_array(data: CuArrayLike, field: str, handle: ctypes.c_void_p) -> None: +def _meta_from_cupy_array( + data: CuArrayLike, field: str, handle: ctypes.c_void_p +) -> None: data = _transform_cupy_array(data) - interface = bytes(json.dumps([data.__cuda_array_interface__], - indent=2), 'utf-8') - _check_call(_LIB.XGDMatrixSetInfoFromInterface(handle, - c_str(field), - interface)) + interface = bytes(json.dumps([data.__cuda_array_interface__], indent=2), "utf-8") + _check_call(_LIB.XGDMatrixSetInfoFromInterface(handle, c_str(field), interface)) def _meta_from_dt( @@ -956,7 +998,7 @@ def _meta_from_dt( def dispatch_meta_backend( matrix: DMatrix, data: ArrayLike, name: str, dtype: DTypeLike = None ) -> None: - '''Dispatch for meta info.''' + """Dispatch for meta info.""" handle = matrix.handle assert handle is not None _validate_meta_shape(data, name) @@ -1003,7 +1045,7 @@ def dispatch_meta_backend( _meta_from_numpy(data, name, dtype, handle) return if _is_modin_series(data): - arr = cast(np.ndarray, cast(DFLike, data).values).astype('float') + arr = cast(np.ndarray, cast(DFLike, data).values).astype("float") assert len(arr.shape) == 1 or arr.shape[1] == 0 or arr.shape[1] == 1 _meta_from_numpy(arr, name, dtype, handle) return @@ -1011,19 +1053,20 @@ def dispatch_meta_backend( array = np.asarray(data) _meta_from_numpy(array, name, dtype, handle) return - raise TypeError('Unsupported type for ' + name, str(type(data))) + raise TypeError("Unsupported type for " + name, str(type(data))) class SingleBatchInternalIter(DataIter): # pylint: disable=R0902 - '''An iterator for single batch data to help creating device DMatrix. + """An iterator for single batch data to help creating device DMatrix. Transforming input directly to histogram with normal single batch data API can not access weight for sketching. So this iterator acts as a staging area for meta info. - ''' + """ + def __init__(self, **kwargs: Any): self.kwargs = kwargs - self.it = 0 # pylint: disable=invalid-name + self.it = 0 # pylint: disable=invalid-name super().__init__() def next(self, input_data: Callable) -> int: @@ -1097,7 +1140,7 @@ def dispatch_proxy_set_data( raise err if _is_numpy_array(data): - proxy._set_data_from_array(cast(NPArrayLike, data)) # pylint: disable=W0212 + proxy._set_data_from_array(cast(NPArrayLike, data)) # pylint: disable=W0212 return if _is_scipy_csr(data): proxy._set_data_from_csr(cast(CSRLike, data)) # pylint: disable=W0212 From 1bea924dc0505c4cd40240f743d3c68a6bcea8ce Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 17 Feb 2022 21:29:04 +0800 Subject: [PATCH 08/17] comment. --- python-package/xgboost/data.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index b76985b54757..8b10d777fa10 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -19,13 +19,12 @@ c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name +# return of DMatrix constructor. _CtorReturnT = Tuple[ctypes.c_void_p, Optional[List[str]], FeatureTypes] CAT_T = "c" # meta info that can be a matrix instead of vector. -# For now it's base_margin for multi-class, but it can be extended to label once we have -# multi-output. _matrix_meta = {"base_margin", "label"} From 9ec2447d2da13c40247e5ae734dd4b56a6a5d597 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 17 Feb 2022 21:38:19 +0800 Subject: [PATCH 09/17] Cleanup. --- python-package/xgboost/core.py | 44 ++++++------- python-package/xgboost/dask.py | 22 +++---- python-package/xgboost/data.py | 15 +++-- python-package/xgboost/sklearn.py | 104 +++++++++++++++--------------- python-package/xgboost/typing.py | 4 +- 5 files changed, 96 insertions(+), 93 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index e531c4cd5b5a..d5d6218b5a6e 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -22,14 +22,16 @@ from .compat import (STRING_TYPES, DataFrame, py_str, PANDAS_INSTALLED, lazy_isinstance) from .libpath import find_lib_path -from .typing import CuArrayLike, CuDFLike, NPArrayLike, DFLike, CSRLike -from .typing import FeatureTypes, NativeInput, ArrayLike, DTypeLike +from .typing import ( + CuArrayLike, CuDFLike, NPArrayLike, DFLike, CSRLike, + PathLike, NativeInput, ArrayLike, DTypeLike +) +from .typing import FeatureNames, FeatureTypes # c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h c_bst_ulong = ctypes.c_uint64 # xgboost accepts some other possible types in practice due to historical reason, which # is lesser tested. For now we encourage users to pass a simple list of string. -FeatureNames = Optional[List[str]] Parameters = Union[List[Tuple[str, Any]], Dict[str, Any]] @@ -798,15 +800,15 @@ def set_uint_info(self, field: str, data: ArrayLike) -> None: from .data import dispatch_meta_backend dispatch_meta_backend(self, data, field, 'uint32') - def save_binary(self, fname: os.PathLike, silent: bool = True) -> None: + def save_binary(self, fname: PathLike, silent: bool = True) -> None: """Save DMatrix to an XGBoost buffer. Saved binary can be later loaded by providing the path to :py:func:`xgboost.DMatrix` as input. Parameters ---------- - fname : string or os.PathLike + fname : Name of the output buffer file. - silent : bool (optional; default: True) + silent : If set, the output is suppressed. """ fname_str = os.fspath(os.path.expanduser(fname)) @@ -1314,7 +1316,7 @@ def __init__( Parameters for boosters. cache : list List of cache items. - model_file : string/os.PathLike/Booster/bytearray + model_file : Path to the model file if it's string or PathLike. """ cache = cache if cache is not None else [] @@ -2152,7 +2154,7 @@ def inplace_predict( "Data type:" + str(type(data)) + " not supported by inplace prediction." ) - def save_model(self, fname: Union[str, os.PathLike]) -> None: + def save_model(self, fname: PathLike) -> None: """Save the model to a file. The model is saved in an XGBoost internal format which is universal among the @@ -2169,7 +2171,7 @@ def save_model(self, fname: Union[str, os.PathLike]) -> None: Parameters ---------- - fname : string or os.PathLike + fname : Output file name """ @@ -2271,7 +2273,7 @@ def num_features(self) -> int: def dump_model( self, fout: Union[os.PathLike, TextIO, str], - fmap: Union[str, os.PathLike] = '', + fmap: PathLike = '', with_stats: bool = False, dump_format: str = "text" ) -> None: @@ -2281,13 +2283,13 @@ def dump_model( Parameters ---------- - fout : string or os.PathLike + fout : Output file name. - fmap : string or os.PathLike, optional + fmap : Name of the file containing feature map names. - with_stats : bool, optional + with_stats : Controls whether the split statistics are output. - dump_format : string, optional + dump_format : Format of model dump file. Can be 'text' or 'json'. """ if isinstance(fout, (STRING_TYPES, os.PathLike)): @@ -2314,7 +2316,7 @@ def dump_model( def get_dump( self, - fmap: Union[str, os.PathLike] = "", + fmap: PathLike = "", with_stats: bool = False, dump_format: str = "text" ) -> List[str]: @@ -2344,9 +2346,7 @@ def get_dump( res = from_cstr_to_pystr(sarr, length) return res - def get_fscore( - self, fmap: Union[str, os.PathLike] = "" - ) -> Dict[str, Union[float, List[float]]]: + def get_fscore(self, fmap: PathLike = "") -> Dict[str, Union[float, List[float]]]: """Get feature importance of each feature. .. note:: Zero-importance features will not be included @@ -2363,7 +2363,7 @@ def get_fscore( return self.get_score(fmap, importance_type='weight') def get_score( - self, fmap: Union[str, os.PathLike] = '', importance_type: str = 'weight' + self, fmap: PathLike = '', importance_type: str = 'weight' ) -> Dict[str, Union[float, List[float]]]: """Get feature importance of each feature. For tree model Importance type can be defined as: @@ -2431,7 +2431,7 @@ def get_score( return results # pylint: disable=too-many-statements - def trees_to_dataframe(self, fmap: Union[str, os.PathLike] = '') -> DataFrame: + def trees_to_dataframe(self, fmap: PathLike = '') -> DataFrame: """Parse a boosted tree model text dump into a pandas DataFrame structure. This feature is only defined when the decision tree model is chosen as base @@ -2440,7 +2440,7 @@ def trees_to_dataframe(self, fmap: Union[str, os.PathLike] = '') -> DataFrame: Parameters ---------- - fmap: str or os.PathLike (optional) + fmap : The name of feature map file. """ # pylint: disable=too-many-locals @@ -2572,7 +2572,7 @@ def _validate_features(self, data: DMatrix) -> None: def get_split_value_histogram( self, feature: str, - fmap: Union[os.PathLike, str] = '', + fmap: PathLike = '', bins: Optional[int] = None, as_pandas: bool = True ) -> Union[np.ndarray, DataFrame]: diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 67265c3e41bc..4d5fb5aa39d3 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -57,7 +57,7 @@ from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter from .core import Objective, Metric from .core import _deprecate_positional_args -from .data import FeatNamesT +from .typing import FeatureNames, FeatureTypes from .training import train as worker_train from .tracker import RabitTracker, get_host_ip from .sklearn import XGBModel, XGBClassifier, XGBRegressorBase, XGBClassifierBase @@ -326,7 +326,7 @@ def __init__( base_margin: Optional[_DaskCollection] = None, missing: float = None, silent: bool = False, # pylint: disable=unused-argument - feature_names: FeatNamesT = None, + feature_names: FeatureNames = None, feature_types: Optional[List[str]] = None, group: Optional[_DaskCollection] = None, qid: Optional[_DaskCollection] = None, @@ -602,8 +602,8 @@ def __init__( qid: Optional[List[Any]] = None, label_lower_bound: Optional[List[Any]] = None, label_upper_bound: Optional[List[Any]] = None, - feature_names: FeatNamesT = None, - feature_types: Optional[Union[Any, List[Any]]] = None, + feature_names: FeatureNames = None, + feature_types: FeatureTypes = None, ) -> None: self._data = data self._label = label @@ -645,7 +645,7 @@ def next(self, input_data: Callable) -> int: if self._iter == len(self._data): # Return 0 when there's no more batch. return 0 - feature_names: FeatNamesT = None + feature_names: FeatureNames = None if self._feature_names: feature_names = self._feature_names else: @@ -696,8 +696,8 @@ def __init__( base_margin: Optional[_DaskCollection] = None, missing: float = None, silent: bool = False, # disable=unused-argument - feature_names: FeatNamesT = None, - feature_types: Optional[Union[Any, List[Any]]] = None, + feature_names: FeatureNames = None, + feature_types: FeatureTypes = None, max_bin: int = 256, group: Optional[_DaskCollection] = None, qid: Optional[_DaskCollection] = None, @@ -733,8 +733,8 @@ def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]: def _create_device_quantile_dmatrix( - feature_names: FeatNamesT, - feature_types: Optional[Union[Any, List[Any]]], + feature_names: FeatureNames, + feature_types: FeatureTypes, feature_weights: Optional[Any], missing: float, nthread: int, @@ -774,8 +774,8 @@ def _create_device_quantile_dmatrix( def _create_dmatrix( - feature_names: FeatNamesT, - feature_types: Optional[Union[Any, List[Any]]], + feature_names: FeatureNames, + feature_types: FeatureTypes, feature_weights: Optional[Any], missing: float, nthread: int, diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 8b10d777fa10..607ebf7eae93 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -16,6 +16,7 @@ from .typing import ArrayLike, FloatT, CSRLike, NPArrayLike, DFLike, NativeInput from .typing import CuArrayLike, CuDFLike, FeatureTypes, DTypeLike +from .typing import PathLike c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name @@ -255,7 +256,7 @@ def _invalid_dataframe_dtype(data: Any) -> None: def _transform_pandas_df( df: DFLike, enable_categorical: bool, - feature_names: Optional[List[str]] = None, + feature_names: FeatureNames = None, feature_types: FeatureTypes = None, meta: str = None, meta_type: DTypeLike = None, @@ -408,7 +409,7 @@ def _is_dt_df(data: NativeInput) -> bool: def _transform_dt_df( data: Any, - feature_names: Optional[List[str]], + feature_names: FeatureNames, feature_types: FeatureTypes, meta: str = None, meta_type: DTypeLike = None, @@ -449,7 +450,7 @@ def _from_dt_df( data: Any, missing: FloatT, nthread: int, - feature_names: Optional[List[str]], + feature_names: FeatureNames, feature_types: FeatureTypes, enable_categorical: bool, ) -> _CtorReturnT: @@ -666,7 +667,7 @@ def _from_cupy_array( data: CuArrayLike, missing: FloatT, nthread: int, - feature_names: Optional[List[str]], + feature_names: FeatureNames, feature_types: FeatureTypes, ) -> _CtorReturnT: """Initialize DMatrix from cupy ndarray.""" @@ -714,7 +715,7 @@ def _from_dlpack( data: Any, missing: FloatT, nthread: int, - feature_names: Optional[List[str]], + feature_names: FeatureNames, feature_types: FeatureTypes, ) -> _CtorReturnT: data = _transform_dlpack(data) @@ -726,7 +727,7 @@ def _is_uri(data: NativeInput) -> bool: def _from_uri( - data: Union[os.PathLike, str], + data: PathLike, missing: FloatT, feature_names: FeatureNames, feature_types: FeatureTypes, @@ -833,7 +834,7 @@ def dispatch_data_backend( ) if _is_uri(data): return _from_uri( - cast(Union[os.PathLike, str], data), missing, feature_names, feature_types + cast(PathLike, data), missing, feature_names, feature_types ) if _is_list(data): return _from_list( diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 6ed7edd7d1f1..3eee2439f218 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -13,7 +13,7 @@ from .training import train from .callback import TrainingCallback from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array -from .typing import array_like, CuDFLike +from .typing import ArrayLike, CuDFLike # Do not use class names on scikit-learn directly. Re-define the classes on # .compat to guarantee the behavior without scikit-learn @@ -851,19 +851,19 @@ def _set_evaluation_result(self, evals_result: TrainingCallback.EvalsLog) -> Non @_deprecate_positional_args def fit( self, - X: array_like, - y: array_like, + X: ArrayLike, + y: ArrayLike, *, - sample_weight: Optional[array_like] = None, - base_margin: Optional[array_like] = None, - eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None, + sample_weight: Optional[ArrayLike] = None, + base_margin: Optional[ArrayLike] = None, + eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None, eval_metric: Optional[Union[str, Sequence[str], Metric]] = None, early_stopping_rounds: Optional[int] = None, verbose: Optional[bool] = True, xgb_model: Optional[Union[Booster, str, "XGBModel"]] = None, - sample_weight_eval_set: Optional[Sequence[array_like]] = None, - base_margin_eval_set: Optional[Sequence[array_like]] = None, - feature_weights: Optional[array_like] = None, + sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None, + base_margin_eval_set: Optional[Sequence[ArrayLike]] = None, + feature_weights: Optional[ArrayLike] = None, callbacks: Optional[Sequence[TrainingCallback]] = None ) -> "XGBModel": # pylint: disable=invalid-name,attribute-defined-outside-init @@ -990,11 +990,11 @@ def _get_iteration_range( def predict( self, - X: array_like, + X: ArrayLike, output_margin: bool = False, ntree_limit: Optional[int] = None, validate_features: bool = True, - base_margin: Optional[array_like] = None, + base_margin: Optional[ArrayLike] = None, iteration_range: Optional[Tuple[int, int]] = None, ) -> np.ndarray: """Predict with `X`. If the model is trained with early stopping, then `best_iteration` @@ -1066,7 +1066,7 @@ def predict( ) def apply( - self, X: array_like, + self, X: ArrayLike, ntree_limit: int = 0, iteration_range: Optional[Tuple[int, int]] = None ) -> np.ndarray: @@ -1306,19 +1306,19 @@ def __init__( @_deprecate_positional_args def fit( self, - X: array_like, - y: array_like, + X: ArrayLike, + y: ArrayLike, *, - sample_weight: Optional[array_like] = None, - base_margin: Optional[array_like] = None, - eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None, + sample_weight: Optional[ArrayLike] = None, + base_margin: Optional[ArrayLike] = None, + eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None, eval_metric: Optional[Union[str, Sequence[str], Metric]] = None, early_stopping_rounds: Optional[int] = None, verbose: Optional[bool] = True, xgb_model: Optional[Union[Booster, str, XGBModel]] = None, - sample_weight_eval_set: Optional[Sequence[array_like]] = None, - base_margin_eval_set: Optional[Sequence[array_like]] = None, - feature_weights: Optional[array_like] = None, + sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None, + base_margin_eval_set: Optional[Sequence[ArrayLike]] = None, + feature_weights: Optional[ArrayLike] = None, callbacks: Optional[Sequence[TrainingCallback]] = None ) -> "XGBClassifier": # pylint: disable = attribute-defined-outside-init,too-many-statements @@ -1414,11 +1414,11 @@ def fit( def predict( self, - X: array_like, + X: ArrayLike, output_margin: bool = False, ntree_limit: Optional[int] = None, validate_features: bool = True, - base_margin: Optional[array_like] = None, + base_margin: Optional[ArrayLike] = None, iteration_range: Optional[Tuple[int, int]] = None, ) -> np.ndarray: class_probs = super().predict( @@ -1453,10 +1453,10 @@ def predict( def predict_proba( self, - X: array_like, + X: ArrayLike, ntree_limit: Optional[int] = None, validate_features: bool = True, - base_margin: Optional[array_like] = None, + base_margin: Optional[ArrayLike] = None, iteration_range: Optional[Tuple[int, int]] = None, ) -> np.ndarray: """ Predict the probability of each `X` example being of a given class. @@ -1547,19 +1547,19 @@ def get_num_boosting_rounds(self) -> int: @_deprecate_positional_args def fit( self, - X: array_like, - y: array_like, + X: ArrayLike, + y: ArrayLike, *, - sample_weight: Optional[array_like] = None, - base_margin: Optional[array_like] = None, - eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None, + sample_weight: Optional[ArrayLike] = None, + base_margin: Optional[ArrayLike] = None, + eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None, eval_metric: Optional[Union[str, Sequence[str], Metric]] = None, early_stopping_rounds: Optional[int] = None, verbose: Optional[bool] = True, xgb_model: Optional[Union[Booster, str, XGBModel]] = None, - sample_weight_eval_set: Optional[Sequence[array_like]] = None, - base_margin_eval_set: Optional[Sequence[array_like]] = None, - feature_weights: Optional[array_like] = None, + sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None, + base_margin_eval_set: Optional[Sequence[ArrayLike]] = None, + feature_weights: Optional[ArrayLike] = None, callbacks: Optional[Sequence[TrainingCallback]] = None ) -> "XGBRFClassifier": args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} @@ -1619,19 +1619,19 @@ def get_num_boosting_rounds(self) -> int: @_deprecate_positional_args def fit( self, - X: array_like, - y: array_like, + X: ArrayLike, + y: ArrayLike, *, - sample_weight: Optional[array_like] = None, - base_margin: Optional[array_like] = None, - eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None, + sample_weight: Optional[ArrayLike] = None, + base_margin: Optional[ArrayLike] = None, + eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None, eval_metric: Optional[Union[str, Sequence[str], Metric]] = None, early_stopping_rounds: Optional[int] = None, verbose: Optional[bool] = True, xgb_model: Optional[Union[Booster, str, XGBModel]] = None, - sample_weight_eval_set: Optional[Sequence[array_like]] = None, - base_margin_eval_set: Optional[Sequence[array_like]] = None, - feature_weights: Optional[array_like] = None, + sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None, + base_margin_eval_set: Optional[Sequence[ArrayLike]] = None, + feature_weights: Optional[ArrayLike] = None, callbacks: Optional[Sequence[TrainingCallback]] = None ) -> "XGBRFRegressor": args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} @@ -1694,23 +1694,23 @@ def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any): @_deprecate_positional_args def fit( self, - X: array_like, - y: array_like, + X: ArrayLike, + y: ArrayLike, *, - group: Optional[array_like] = None, - qid: Optional[array_like] = None, - sample_weight: Optional[array_like] = None, - base_margin: Optional[array_like] = None, - eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None, - eval_group: Optional[Sequence[array_like]] = None, - eval_qid: Optional[Sequence[array_like]] = None, + group: Optional[ArrayLike] = None, + qid: Optional[ArrayLike] = None, + sample_weight: Optional[ArrayLike] = None, + base_margin: Optional[ArrayLike] = None, + eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None, + eval_group: Optional[Sequence[ArrayLike]] = None, + eval_qid: Optional[Sequence[ArrayLike]] = None, eval_metric: Optional[Union[str, Sequence[str], Metric]] = None, early_stopping_rounds: Optional[int] = None, verbose: Optional[bool] = False, xgb_model: Optional[Union[Booster, str, XGBModel]] = None, - sample_weight_eval_set: Optional[Sequence[array_like]] = None, - base_margin_eval_set: Optional[Sequence[array_like]] = None, - feature_weights: Optional[array_like] = None, + sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None, + base_margin_eval_set: Optional[Sequence[ArrayLike]] = None, + feature_weights: Optional[ArrayLike] = None, callbacks: Optional[Sequence[TrainingCallback]] = None ) -> "XGBRanker": # pylint: disable = attribute-defined-outside-init,arguments-differ diff --git a/python-package/xgboost/typing.py b/python-package/xgboost/typing.py index 8db0c27ac0af..941ac8b7f97a 100644 --- a/python-package/xgboost/typing.py +++ b/python-package/xgboost/typing.py @@ -126,5 +126,7 @@ def __getitem__(self, key: str) -> "CuDFLike": ArrayLike = Union[NPArrayLike, DFLike, CuArrayLike, CuDFLike, CSRLike] NativeInput = Union[NPArrayLike, DFLike, CuArrayLike, CuDFLike, CSRLike, str, os.PathLike] - +FeatureNames = Optional[List[str]] FeatureTypes = Optional[Union[List[str], List[DTypeLike], str]] + +PathLike = Union[str, os.PathLike] From 94e95b3df47536a738a30aca699cd531ef3b27bf Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 17 Feb 2022 21:39:17 +0800 Subject: [PATCH 10/17] Cleanup. --- python-package/xgboost/data.py | 4 +--- python-package/xgboost/typing.py | 18 +++++++++++++----- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 607ebf7eae93..9d34bbb42dc2 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -833,9 +833,7 @@ def dispatch_data_backend( cast(np.ndarray, data), missing, threads, feature_names, feature_types ) if _is_uri(data): - return _from_uri( - cast(PathLike, data), missing, feature_names, feature_types - ) + return _from_uri(cast(PathLike, data), missing, feature_names, feature_types) if _is_list(data): return _from_list( cast(list, data), missing, threads, feature_names, feature_types diff --git a/python-package/xgboost/typing.py b/python-package/xgboost/typing.py index 941ac8b7f97a..1c63431d5b20 100644 --- a/python-package/xgboost/typing.py +++ b/python-package/xgboost/typing.py @@ -4,16 +4,18 @@ from typing import SupportsIndex, Sized import numpy as np + try: from numpy import typing as npt + DTypeLike = npt.DTypeLike except ImportError: - DTypeLike = np.dtype # type: ignore + DTypeLike = Union[np.dtype, str] # type: ignore try: from typing import Protocol except ImportError: - Protocol = object # type: ignore + Protocol = object # type: ignore class NPArrayLike(Protocol, Sized): @@ -35,7 +37,9 @@ def dtype(self) -> np.dtype: class CuArrayLike(Protocol): @abstractproperty - def __cuda_array_interface__(self) -> Dict[str, Union[str, int, Dict, "CuArrayLike"]]: + def __cuda_array_interface__( + self, + ) -> Dict[str, Union[str, int, Dict, "CuArrayLike"]]: ... @abstractproperty @@ -93,7 +97,9 @@ def values(self) -> CuArrayLike: ... @abstractproperty - def __cuda_array_interface__(self) -> Dict[str, Union[str, int, Dict, "CuArrayLike"]]: + def __cuda_array_interface__( + self, + ) -> Dict[str, Union[str, int, Dict, "CuArrayLike"]]: ... @abstractproperty @@ -124,7 +130,9 @@ def __getitem__(self, key: str) -> "CuDFLike": ArrayLike = Union[NPArrayLike, DFLike, CuArrayLike, CuDFLike, CSRLike] -NativeInput = Union[NPArrayLike, DFLike, CuArrayLike, CuDFLike, CSRLike, str, os.PathLike] +NativeInput = Union[ + NPArrayLike, DFLike, CuArrayLike, CuDFLike, CSRLike, str, os.PathLike +] FeatureNames = Optional[List[str]] FeatureTypes = Optional[Union[List[str], List[DTypeLike], str]] From 3bb4e14f6d413d81c3585963048f6ff6c7bef0b4 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 17 Feb 2022 21:57:54 +0800 Subject: [PATCH 11/17] Compat and config. --- .github/workflows/main.yml | 2 +- python-package/xgboost/compat.py | 38 ++++++++++++++++------------ python-package/xgboost/config.py | 24 ++++++++++++------ python-package/xgboost/core.py | 1 + python-package/xgboost/plotting.py | 40 ++++++++++++++++++++++-------- 5 files changed, 70 insertions(+), 35 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d7ec12c78cd0..ebbb629c8b2d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -159,7 +159,7 @@ jobs: architecture: 'x64' - name: Install Python packages run: | - python -m pip install wheel setuptools mypy pandas dask[complete] distributed + python -m pip install wheel setuptools mypy pandas dask[complete] distributed types-setuptools - name: Run mypy run: | make mypy diff --git a/python-package/xgboost/compat.py b/python-package/xgboost/compat.py index 256a77adf463..d92d659995d7 100644 --- a/python-package/xgboost/compat.py +++ b/python-package/xgboost/compat.py @@ -1,7 +1,7 @@ # coding: utf-8 # pylint: disable= invalid-name, unused-import """For compatibility and optional dependencies.""" -from typing import Any +from typing import Any, List, Optional import sys import types import importlib.util @@ -14,20 +14,20 @@ STRING_TYPES = (str,) -def py_str(x): +def py_str(x: bytes) -> str: """convert c string back to python string""" return x.decode('utf-8') -def lazy_isinstance(instance, module, name): +def lazy_isinstance(instance: Any, module: str, name: str) -> bool: """Use string representation to identify a type.""" # Notice, we use .__class__ as opposed to type() in order # to support object proxies such as weakref.proxy cls = instance.__class__ - module = cls.__module__ == module - name = cls.__name__ == name - return module and name + is_module = cls.__module__ == module + is_name = cls.__name__ == name + return is_module and is_name # pandas @@ -40,7 +40,7 @@ def lazy_isinstance(instance, module, name): except ImportError: MultiIndex = object - DataFrame: Any = object + DataFrame: Any = object # type: ignore Series = object pandas_concat = None PANDAS_INSTALLED = False @@ -67,7 +67,7 @@ def lazy_isinstance(instance, module, name): class XGBoostLabelEncoder(LabelEncoder): '''Label encoder with JSON serialization methods.''' - def to_json(self): + def to_json(self) -> dict: '''Returns a JSON compatible dictionary''' meta = {} for k, v in self.__dict__.items(): @@ -77,7 +77,7 @@ def to_json(self): meta[k] = v return meta - def from_json(self, doc): + def from_json(self, doc: dict) -> None: # pylint: disable=attribute-defined-outside-init '''Load the encoder back from a JSON compatible dict.''' meta = {} @@ -97,7 +97,7 @@ def from_json(self, doc): XGBKFold = None XGBStratifiedKFold = None - XGBoostLabelEncoder = None + XGBoostLabelEncoder = None # type: ignore # dask @@ -116,7 +116,7 @@ def from_json(self, doc): SCIPY_INSTALLED = True except ImportError: scipy_sparse = False - scipy_csr: Any = object + scipy_csr: Any = object # type: ignore SCIPY_INSTALLED = False @@ -139,15 +139,21 @@ class LazyLoader(types.ModuleType): """Lazily import a module, mainly to avoid pulling in large dependencies. """ - def __init__(self, local_name, parent_module_globals, name, warning=None): + def __init__( + self, + local_name: str, + parent_module_globals: dict, + name: str, + warning: bool = None + ) -> None: self._local_name = local_name self._parent_module_globals = parent_module_globals self._warning = warning - self.module = None + self.module: Optional[types.ModuleType] = None super().__init__(name) - def _load(self): + def _load(self) -> types.ModuleType: """Load the module and insert it into the parent's globals.""" # Import the target module and insert it into the parent's namespace module = importlib.import_module(self.__name__) @@ -166,12 +172,12 @@ def _load(self): return module - def __getattr__(self, item): + def __getattr__(self, item: Any) -> Any: if not self.module: self.module = self._load() return getattr(self.module, item) - def __dir__(self): + def __dir__(self) -> List[str]: if not self.module: self.module = self._load() return dir(self.module) diff --git a/python-package/xgboost/config.py b/python-package/xgboost/config.py index 427ea4ea3915..2f9f7849164d 100644 --- a/python-package/xgboost/config.py +++ b/python-package/xgboost/config.py @@ -4,12 +4,19 @@ import json from contextlib import contextmanager from functools import wraps +from typing import Optional, Callable, Any, Dict, Generator from .core import _LIB, _check_call, c_str, py_str -def config_doc(*, header=None, extra_note=None, parameters=None, returns=None, - see_also=None): +def config_doc( + *, + header: Optional[str] = None, + extra_note: Optional[str] = None, + parameters: Optional[str] = None, + returns: Optional[str] = None, + see_also: Optional[str] = None +) -> Callable: """Decorator to format docstring for config functions. Parameters @@ -64,17 +71,17 @@ def config_doc(*, header=None, extra_note=None, parameters=None, returns=None, assert xgb.get_config()['verbosity'] == 2 # old value restored """ - def none_to_str(value): + def none_to_str(value: Optional[str]) -> str: return '' if value is None else value - def config_doc_decorator(func): + def config_doc_decorator(func: Callable) -> Callable: func.__doc__ = (doc_template.format(header=none_to_str(header), extra_note=none_to_str(extra_note)) + none_to_str(parameters) + none_to_str(returns) + none_to_str(common_example) + none_to_str(see_also)) @wraps(func) - def wrap(*args, **kwargs): + def wrap(*args: Any, **kwargs: Any) -> Any: return func(*args, **kwargs) return wrap return config_doc_decorator @@ -89,7 +96,7 @@ def wrap(*args, **kwargs): new_config: Dict[str, Any] Keyword arguments representing the parameters and their values """) -def set_config(**new_config): +def set_config(**new_config: Any) -> None: config = json.dumps(new_config) _check_call(_LIB.XGBSetGlobalConfig(c_str(config))) @@ -103,9 +110,10 @@ def set_config(**new_config): args: Dict[str, Any] The list of global parameters and their values """) -def get_config(): +def get_config() -> Dict[str, Any]: config_str = ctypes.c_char_p() _check_call(_LIB.XGBGetGlobalConfig(ctypes.byref(config_str))) + assert config_str.value is not None config = json.loads(py_str(config_str.value)) return config @@ -132,7 +140,7 @@ def get_config(): set_config: Set global XGBoost configuration get_config: Get current values of the global configuration """) -def config_context(**new_config): +def config_context(**new_config: Any) -> Generator: old_config = get_config().copy() set_config(**new_config) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index d5d6218b5a6e..d4d1ada8302d 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1578,6 +1578,7 @@ def attr(self, key: str) -> Optional[str]: _check_call(_LIB.XGBoosterGetAttr( self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success))) if success.value != 0: + assert ret.value is not None return py_str(ret.value) return None diff --git a/python-package/xgboost/plotting.py b/python-package/xgboost/plotting.py index 75159d10434e..9ac98f4053db 100644 --- a/python-package/xgboost/plotting.py +++ b/python-package/xgboost/plotting.py @@ -1,19 +1,24 @@ # pylint: disable=too-many-locals, too-many-arguments, invalid-name, # pylint: disable=too-many-branches -# coding: utf-8 """Plotting Library.""" from io import BytesIO import json +from typing import Union + import numpy as np + from .core import Booster from .sklearn import XGBModel +from .typing import PathLike -def plot_importance(booster, ax=None, height=0.2, - xlim=None, ylim=None, title='Feature importance', - xlabel='F score', ylabel='Features', fmap='', - importance_type='weight', max_num_features=None, - grid=True, show_values=True, **kwargs): +def plot_importance( + booster: Union[Booster, XGBModel], ax=None, height=0.2, + xlim=None, ylim=None, title='Feature importance', + xlabel='F score', ylabel='Features', fmap='', + importance_type='weight', max_num_features=None, + grid=True, show_values=True, **kwargs +): """Plot importance based on fitted trees. Parameters @@ -120,9 +125,17 @@ def plot_importance(booster, ax=None, height=0.2, return ax -def to_graphviz(booster, fmap='', num_trees=0, rankdir=None, - yes_color=None, no_color=None, - condition_node_params=None, leaf_node_params=None, **kwargs): +def to_graphviz( + booster: Union[Booster, XGBModel], + fmap: PathLike = "", + num_trees: int = 0, + rankdir=None, + yes_color=None, + no_color=None, + condition_node_params=None, + leaf_node_params=None, + **kwargs +): """Convert specified tree to graphviz instance. IPython can automatically plot the returned graphiz instance. Otherwise, you should call .render() method of the returned graphiz instance. @@ -212,7 +225,14 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir=None, return g -def plot_tree(booster, fmap='', num_trees=0, rankdir=None, ax=None, **kwargs): +def plot_tree( + booster: Union[Booster, XGBModel], + fmap: PathLike = "", + num_trees=0, + rankdir=None, + ax=None, + **kwargs +): """Plot specified tree. Parameters From a54a70d2a893bb7b3629ccba8fe1b1da9d88750a Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 17 Feb 2022 22:02:31 +0800 Subject: [PATCH 12/17] Add to makefile. --- Makefile | 3 +++ python-package/xgboost/sklearn.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 1d86c2ed0b9c..acb1a0c74a62 100644 --- a/Makefile +++ b/Makefile @@ -92,6 +92,9 @@ endif mypy: cd python-package; \ mypy ./xgboost/dask.py && \ + mypy ./xgboost/config.py && \ + mypy ./xgboost/compat.py && \ + mypy ./xgboost/data.py && \ mypy ./xgboost/rabit.py && \ mypy ./xgboost/tracker.py && \ mypy ./xgboost/sklearn.py && \ diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 3eee2439f218..242088ce2471 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -695,7 +695,7 @@ def _get_type(self) -> str: return self._estimator_type # pylint: disable=no-member def save_model(self, fname: Union[str, os.PathLike]) -> None: - meta = {} + meta: Dict[str, Any] = {} for k, v in self.__dict__.items(): if k == '_le': meta['_le'] = self._le.to_json() From 917c074d9ad649952824bda8cfc88316e9b832d9 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 17 Feb 2022 23:20:37 +0800 Subject: [PATCH 13/17] doc string. --- python-package/xgboost/core.py | 33 ++++++++++++++++-------------- python-package/xgboost/sklearn.py | 34 +++++++++++++++---------------- 2 files changed, 35 insertions(+), 32 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index d4d1ada8302d..32d216a24bd9 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -540,16 +540,19 @@ def __init__( ) -> None: """Parameters ---------- - data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/ - dt.Frame/cudf.DataFrame/cupy.array/dlpack - Data source of DMatrix. + data : + Data source of DMatrix. It can be one of the : + + os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame/ + cudf.DataFrame/cupy.array/dlpack + When data is string or os.PathLike type, it represents the path libsvm format txt file, csv file (by specifying uri parameter 'path_to_csv?format=csv'), or binary file that xgboost can read from. label : Label of the training data. - weight : ArrayLike + weight : Weight for each instance. .. note:: For ranking task, weights are per-group. @@ -559,34 +562,34 @@ def __init__( ordering of data points within each group, so it doesn't make sense to assign weights to individual data points. - base_margin: ArrayLike + base_margin : Base margin used for boosting from existing model. - missing : float, optional + missing : Value in the input data which needs to be present as a missing value. If None, defaults to np.nan. - silent : boolean, optional + silent : Whether print messages during construction - feature_names : list, optional + feature_names : Set names for features. feature_types : Set types for features. When `enable_categorical` is set to `True`, string "c" represents categorical data type. - nthread : integer, optional + nthread : Number of threads to use for loading data when parallelization is applicable. If -1, uses maximum threads available on the system. - group : ArrayLike + group : Group size for all ranking group. - qid : ArrayLike + qid : Query ID for data samples, used for ranking. - label_lower_bound : ArrayLike + label_lower_bound : Lower bound for survival training. - label_upper_bound : ArrayLike + label_upper_bound : Upper bound for survival training. - feature_weights : ArrayLike, optional + feature_weights : Set feature weights for column sampling. - enable_categorical: boolean, optional + enable_categorical : .. versionadded:: 1.3.0 diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 242088ce2471..9ace4a674555 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -186,7 +186,7 @@ def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]: allowed to interact with each other. See :doc:`tutorial ` for more information importance_type: Optional[str] - The feature importance type for the feature_importances\\_ property: + The feature importance type for the :py:attr:`feature_importances_` property: * For tree model, it's either "gain", "weight", "cover", "total_gain" or "total_cover". @@ -228,27 +228,27 @@ def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]: Metric used for monitoring the training result and early stopping. It can be a string or list of strings as names of predefined metric in XGBoost (See - doc/parameter.rst), one of the metrics in :py:mod:`sklearn.metrics`, or any other - user defined metric that looks like `sklearn.metrics`. + :doc:`Parameter `), one of the metrics in :py:mod:`sklearn.metrics`, + or any other user defined metric that looks like :py:mod:`sklearn.metrics`. - If custom objective is also provided, then custom metric should implement the - corresponding reverse link function. + - If custom objective is also provided, then custom metric should implement the + corresponding reverse link function. - Unlike the `scoring` parameter commonly used in scikit-learn, when a callable - object is provided, it's assumed to be a cost function and by default XGBoost will - minimize the result during early stopping. + - Unlike the `scoring` parameter commonly used in scikit-learn, when a callable + object is provided, it's assumed to be a cost function and by default XGBoost + will minimize the result during early stopping. - For advanced usage on Early stopping like directly choosing to maximize instead of - minimize, see :py:obj:`xgboost.callback.EarlyStopping`. + - For advanced usage on Early stopping like directly choosing to maximize + instead of minimize, see :py:obj:`xgboost.callback.EarlyStopping`. - See :doc:`Custom Objective and Evaluation Metric ` - for more. + - See :doc:`Custom Objective and Evaluation Metric + ` for more. .. note:: - This parameter replaces `eval_metric` in :py:meth:`fit` method. The old one - receives un-transformed prediction regardless of whether custom objective is - being used. + This parameter replaces `eval_metric` in :py:meth:`fit` method. The old + one receives un-transformed prediction regardless of whether custom + objective is being used. .. code-block:: python @@ -1075,8 +1075,8 @@ def apply( Parameters ---------- - X : array_like, shape=[n_samples, n_features] - Input features matrix. + X : + Input features matrix, shape=[n_samples, n_features] iteration_range : See :py:meth:`predict`. From 30f90570095f9fb41a477baea1f8453a95380201 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 17 Feb 2022 23:58:19 +0800 Subject: [PATCH 14/17] Doc. --- doc/conf.py | 6 +++++- doc/python/python_api.rst | 12 ++++++++++++ python-package/xgboost/sklearn.py | 15 ++++++++------- python-package/xgboost/typing.py | 1 + 4 files changed, 26 insertions(+), 8 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 53b2ba503915..e05374ce9c34 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # documentation build configuration file, created by # sphinx-quickstart on Thu Jul 23 19:40:08 2015. @@ -11,6 +10,7 @@ # # All configuration values have a default; values that are commented out # serve to show the default. +from __future__ import annotations from subprocess import call from sh.contrib import git import urllib.request @@ -98,6 +98,10 @@ } autodoc_typehints = "description" +autodoc_typehints_format = "short" +autodoc_preserve_defaults = True + +autodoc_type_aliases = {"ArrayLike": "xgboost.typing.ArrayLike"} graphviz_output_format = 'png' plot_formats = [('svg', 300), ('png', 100), ('hires.png', 300)] diff --git a/doc/python/python_api.rst b/doc/python/python_api.rst index 9f077edbc0df..1da7b7d3984c 100644 --- a/doc/python/python_api.rst +++ b/doc/python/python_api.rst @@ -147,3 +147,15 @@ Dask API :members: :inherited-members: :show-inheritance: + +Types +----- +.. automodule:: xgboost.typing + +.. autoclass:: xgboost.typing.ArrayLike + +.. autoclass:: NativeInput + +.. autoclass:: FeatureNames + +.. autoclass:: FeatureTypes diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 9ace4a674555..90b105a4fdd1 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -1,5 +1,6 @@ # pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, too-many-lines """Scikit-Learn Wrapper interface for XGBoost.""" +from __future__ import annotations import copy import warnings import json @@ -876,7 +877,7 @@ def fit( Parameters ---------- X : - Feature matrix + | Feature matrix y : Labels sample_weight : @@ -888,11 +889,11 @@ def fit( metrics will be computed. Validation metrics will help us track the performance of the model. - eval_metric : str, list of str, or callable, optional + eval_metric : .. deprecated:: 1.6.0 Use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead. - early_stopping_rounds : int + early_stopping_rounds : .. deprecated:: 1.6.0 Use `early_stopping_rounds` in :py:meth:`__init__` or :py:meth:`set_params` instead. @@ -1465,14 +1466,14 @@ def predict_proba( Parameters ---------- - X : array_like + X : Feature matrix. - ntree_limit : int + ntree_limit : Deprecated, use `iteration_range` instead. - validate_features : bool + validate_features : When this is True, validate that the Booster's and data's feature_names are identical. Otherwise, it is assumed that the feature_names are the same. - base_margin : array_like + base_margin : Margin added to prediction. iteration_range : Specifies which layer of trees are used in prediction. For example, if a diff --git a/python-package/xgboost/typing.py b/python-package/xgboost/typing.py index 1c63431d5b20..129c63ad51af 100644 --- a/python-package/xgboost/typing.py +++ b/python-package/xgboost/typing.py @@ -1,3 +1,4 @@ +from __future__ import annotations from abc import abstractproperty import os from typing import Union, Dict, Iterable, List, Tuple, Optional From 4fc2fe347e699e6dac9b0e59f519b8741c97f654 Mon Sep 17 00:00:00 2001 From: fis Date: Sun, 20 Feb 2022 02:44:55 +0800 Subject: [PATCH 15/17] callbacks. --- python-package/xgboost/callback.py | 210 +++++++++++++++++------------ python-package/xgboost/core.py | 49 ++++++- python-package/xgboost/training.py | 47 +------ 3 files changed, 174 insertions(+), 132 deletions(-) diff --git a/python-package/xgboost/callback.py b/python-package/xgboost/callback.py index 901724a67d00..a5b330693c6f 100644 --- a/python-package/xgboost/callback.py +++ b/python-package/xgboost/callback.py @@ -6,58 +6,63 @@ """ -from abc import ABC import collections import os import pickle -from typing import Callable, List, Optional, Union, Dict, Tuple, TypeVar, cast -from typing import Sequence +from abc import ABC +from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union, cast + import numpy from . import rabit -from .core import Booster, DMatrix, XGBoostError, _get_booster_layer_trees -from .compat import STRING_TYPES - +from .core import ( + Booster, + DMatrix, + XGBoostError, + _get_booster_layer_trees, + _PackedBooster, +) _Score = Union[float, Tuple[float, float]] _ScoreList = Union[List[float], List[Tuple[float, float]]] +_Model = TypeVar("_Model", Booster, _PackedBooster) + + # pylint: disable=unused-argument class TrainingCallback(ABC): - '''Interface for training callback. + """Interface for training callback. .. versionadded:: 1.3.0 - ''' + """ EvalsLog = Dict[str, Dict[str, _ScoreList]] def __init__(self) -> None: pass - def before_training(self, model): - '''Run before training starts.''' + def before_training(self, model: _Model) -> _Model: + """Run before training starts.""" return model - def after_training(self, model): - '''Run after training is finished.''' + def after_training(self, model: _Model) -> _Model: + """Run after training is finished.""" return model - def before_iteration(self, model, epoch: int, evals_log: EvalsLog) -> bool: - '''Run before each iteration. Return True when training should stop.''' + def before_iteration(self, model: _Model, epoch: int, evals_log: EvalsLog) -> bool: + """Run before each iteration. Return True when training should stop.""" return False - def after_iteration(self, model, epoch: int, evals_log: EvalsLog) -> bool: - '''Run after each iteration. Return True when training should stop.''' + def after_iteration(self, model: _Model, epoch: int, evals_log: EvalsLog) -> bool: + """Run after each iteration. Return True when training should stop.""" return False def _aggcv(rlist: List[str]) -> List[Tuple[str, float, float]]: # pylint: disable=invalid-name, too-many-locals - """Aggregate cross-validation results. - - """ + """Aggregate cross-validation results.""" cvmap: Dict[Tuple[int, str], List[float]] = {} idx = rlist[0].split()[0] for line in rlist: @@ -66,7 +71,7 @@ def _aggcv(rlist: List[str]) -> List[Tuple[str, float, float]]: for metric_idx, it in enumerate(arr[1:]): if not isinstance(it, str): it = it.decode() - k, v = it.split(':') + k, v = it.split(":") if (metric_idx, k) not in cvmap: cvmap[(metric_idx, k)] = [] cvmap[(metric_idx, k)].append(float(v)) @@ -74,7 +79,7 @@ def _aggcv(rlist: List[str]) -> List[Tuple[str, float, float]]: results = [] for (_, name), s in sorted(cvmap.items(), key=lambda x: x[0][0]): as_arr = numpy.array(s) - if not isinstance(msg, STRING_TYPES): + if not isinstance(msg, str): msg = msg.decode() mean, std = numpy.mean(as_arr), numpy.std(as_arr) results.extend([(name, mean, std)]) @@ -86,29 +91,30 @@ def _aggcv(rlist: List[str]) -> List[Tuple[str, float, float]]: def _allreduce_metric(score: _ART) -> _ART: - '''Helper function for computing customized metric in distributed + """Helper function for computing customized metric in distributed environment. Not strictly correct as many functions don't use mean value as final result. - ''' + """ world = rabit.get_world_size() assert world != 0 if world == 1: return score if isinstance(score, tuple): # has mean and stdv raise ValueError( - 'xgboost.cv function should not be used in distributed environment.') + "xgboost.cv function should not be used in distributed environment." + ) arr = numpy.array([score]) arr = rabit.allreduce(arr, rabit.Op.SUM) / world return arr[0] class CallbackContainer: - '''A special internal callback for invoking a list of other callbacks. + """A special internal callback for invoking a list of other callbacks. .. versionadded:: 1.3.0 - ''' + """ EvalsLog = TrainingCallback.EvalsLog @@ -117,13 +123,15 @@ def __init__( callbacks: Sequence[TrainingCallback], metric: Callable = None, output_margin: bool = True, - is_cv: bool = False + is_cv: bool = False, ) -> None: self.callbacks = set(callbacks) if metric is not None: - msg = 'metric must be callable object for monitoring. For ' + \ - 'builtin metrics, passing them in training parameter' + \ - ' will invoke monitor automatically.' + msg = ( + "metric must be callable object for monitoring. For " + + "builtin metrics, passing them in training parameter" + + " will invoke monitor automatically." + ) assert callable(metric), msg self.metric = metric self.history: TrainingCallback.EvalsLog = collections.OrderedDict() @@ -133,32 +141,35 @@ def __init__( if self.is_cv: self.aggregated_cv = None - def before_training(self, model): - '''Function called before training.''' + def before_training(self, model: _Model) -> _Model: + """Function called before training.""" for c in self.callbacks: model = c.before_training(model=model) - msg = 'before_training should return the model' + msg = "before_training should return the model" if self.is_cv: + assert isinstance(model, _PackedBooster) assert isinstance(model.cvfolds, list), msg else: assert isinstance(model, Booster), msg return model - def after_training(self, model): - '''Function called after training.''' + def after_training(self, model: _Model) -> _Model: + """Function called after training.""" for c in self.callbacks: model = c.after_training(model=model) - msg = 'after_training should return the model' + msg = "after_training should return the model" if self.is_cv: + assert isinstance(model, _PackedBooster) assert isinstance(model.cvfolds, list), msg else: assert isinstance(model, Booster), msg if not self.is_cv: + assert not isinstance(model, _PackedBooster) num_parallel_tree, _ = _get_booster_layer_trees(model) - if model.attr('best_score') is not None: - model.best_score = float(cast(str, model.attr('best_score'))) - model.best_iteration = int(cast(str, model.attr('best_iteration'))) + if model.attr("best_score") is not None: + model.best_score = float(cast(str, model.attr("best_score"))) + model.best_iteration = int(cast(str, model.attr("best_iteration"))) # num_class is handled internally model.set_attr( best_ntree_limit=str((model.best_iteration + 1) * num_parallel_tree) @@ -175,16 +186,21 @@ def after_training(self, model): return model def before_iteration( - self, model, epoch: int, dtrain: DMatrix, evals: List[Tuple[DMatrix, str]] + self, + model: _Model, + epoch: int, + dtrain: DMatrix, + evals: List[Tuple[DMatrix, str]], ) -> bool: - '''Function called before training iteration.''' - return any(c.before_iteration(model, epoch, self.history) - for c in self.callbacks) + """Function called before training iteration.""" + return any( + c.before_iteration(model, epoch, self.history) for c in self.callbacks + ) def _update_history( self, score: Union[List[Tuple[str, float]], List[Tuple[str, float, float]]], - epoch: int + epoch: int, ) -> None: for d in score: name: str = d[0] @@ -194,9 +210,9 @@ def _update_history( x: _Score = (s, std) else: x = s - splited_names = name.split('-') + splited_names = name.split("-") data_name = splited_names[0] - metric_name = '-'.join(splited_names[1:]) + metric_name = "-".join(splited_names[1:]) x = _allreduce_metric(x) if data_name not in self.history: self.history[data_name] = collections.OrderedDict() @@ -213,30 +229,31 @@ def _update_history( def after_iteration( self, - model, + model: _Model, epoch: int, dtrain: DMatrix, evals: Optional[List[Tuple[DMatrix, str]]], ) -> bool: - '''Function called after training iteration.''' + """Function called after training iteration.""" if self.is_cv: + assert isinstance(model, _PackedBooster) scores = model.eval(epoch, self.metric, self._output_margin) scores = _aggcv(scores) self.aggregated_cv = scores self._update_history(scores, epoch) else: + assert not isinstance(model, _PackedBooster) evals = [] if evals is None else evals for _, name in evals: - assert name.find('-') == -1, 'Dataset name should not contain `-`' + assert name.find("-") == -1, "Dataset name should not contain `-`" score: str = model.eval_set(evals, epoch, self.metric, self._output_margin) splited = score.split()[1:] # into datasets # split up `test-error:0.1234` - metric_score_str = [tuple(s.split(':')) for s in splited] + metric_score_str = [tuple(s.split(":")) for s in splited] # convert to float metric_score = [(n, float(s)) for n, s in metric_score_str] self._update_history(metric_score, epoch) - ret = any(c.after_iteration(model, epoch, self.history) - for c in self.callbacks) + ret = any(c.after_iteration(model, epoch, self.history) for c in self.callbacks) return ret @@ -269,7 +286,7 @@ def __init__( super().__init__() def after_iteration( - self, model, epoch: int, evals_log: TrainingCallback.EvalsLog + self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog ) -> bool: model.set_param("learning_rate", self.learning_rates(epoch)) return False @@ -313,6 +330,7 @@ class EarlyStopping(TrainingCallback): X, y = load_digits(return_X_y=True) clf.fit(X, y, eval_set=[(X, y)], callbacks=[es]) """ + def __init__( self, rounds: int, @@ -320,7 +338,7 @@ def __init__( data_name: Optional[str] = None, maximize: Optional[bool] = None, save_best: Optional[bool] = False, - min_delta: float = 0.0 + min_delta: float = 0.0, ) -> None: self.data = data_name self.metric_name = metric_name @@ -337,12 +355,12 @@ def __init__( self.starting_round: int = 0 super().__init__() - def before_training(self, model): + def before_training(self, model: _Model) -> _Model: self.starting_round = model.num_boosted_rounds() return model def _update_rounds( - self, score: _Score, name: str, metric: str, model, epoch: int + self, score: _Score, name: str, metric: str, model: _Model, epoch: int ) -> bool: def get_s(x: _Score) -> float: """get score if it's cross validation history.""" @@ -359,9 +377,17 @@ def minimize(new: _Score, best: _Score) -> bool: if self.maximize is None: # Just to be compatibility with old behavior before 1.3. We should let # user to decide. - maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg', 'auc@', - 'aucpr@', 'map@', 'ndcg@') - if metric != 'mape' and any(metric.startswith(x) for x in maximize_metrics): + maximize_metrics = ( + "auc", + "aucpr", + "map", + "ndcg", + "auc@", + "aucpr@", + "map@", + "ndcg@", + ) + if metric != "mape" and any(metric.startswith(x) for x in maximize_metrics): self.maximize = True else: self.maximize = False @@ -396,18 +422,19 @@ def minimize(new: _Score, best: _Score) -> bool: return True return False - def after_iteration(self, model, epoch: int, - evals_log: TrainingCallback.EvalsLog) -> bool: + def after_iteration( + self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog + ) -> bool: epoch += self.starting_round # training continuation - msg = 'Must have at least 1 validation dataset for early stopping.' + msg = "Must have at least 1 validation dataset for early stopping." assert len(evals_log.keys()) >= 1, msg - data_name = '' + data_name = "" if self.data: for d, _ in evals_log.items(): if d == self.data: data_name = d if not data_name: - raise ValueError('No dataset named:', self.data) + raise ValueError("No dataset named:", self.data) else: # Use the last one as default. data_name = list(evals_log.keys())[-1] @@ -424,10 +451,12 @@ def after_iteration(self, model, epoch: int, score = data_log[metric_name][-1] return self._update_rounds(score, data_name, metric_name, model, epoch) - def after_training(self, model): + def after_training(self, model: _Model) -> _Model: try: - if self.save_best: - model = model[: int(model.attr("best_iteration")) + 1] + if self.save_best and isinstance(model, Booster): + best_iteration = model.attr("best_iteration") + assert best_iteration is not None + model = model[: int(best_iteration) + 1] except XGBoostError as e: raise XGBoostError( "`save_best` is not applicable to current booster" @@ -436,7 +465,7 @@ def after_training(self, model): class EvaluationMonitor(TrainingCallback): - '''Print the evaluation result at each iteration. + """Print the evaluation result at each iteration. .. versionadded:: 1.3.0 @@ -451,7 +480,8 @@ class EvaluationMonitor(TrainingCallback): How many epoches between printing. show_stdv : Used in cv to show standard deviation. Users should not specify it. - ''' + """ + def __init__(self, rank: int = 0, period: int = 1, show_stdv: bool = False) -> None: self.printer_rank = rank self.show_stdv = show_stdv @@ -470,12 +500,13 @@ def _fmt_metric( msg = f"\t{data + '-' + metric}:{score:.5f}" return msg - def after_iteration(self, model, epoch: int, - evals_log: TrainingCallback.EvalsLog) -> bool: + def after_iteration( + self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog + ) -> bool: if not evals_log: return False - msg: str = f'[{epoch}]' + msg: str = f"[{epoch}]" if rabit.get_rank() == self.printer_rank: for data, metric in evals_log.items(): for metric_name, log in metric.items(): @@ -486,7 +517,7 @@ def after_iteration(self, model, epoch: int, else: score = log[-1] msg += self._fmt_metric(data, metric_name, score, stdv) - msg += '\n' + msg += "\n" if (epoch % self.period) == 0 or self.period == 1: rabit.tracker_print(msg) @@ -496,14 +527,14 @@ def after_iteration(self, model, epoch: int, self._latest = msg return False - def after_training(self, model): + def after_training(self, model: _Model) -> _Model: if rabit.get_rank() == self.printer_rank and self._latest is not None: rabit.tracker_print(self._latest) return model class TrainingCheckPoint(TrainingCallback): - '''Checkpointing operation. + """Checkpointing operation. .. versionadded:: 1.3.0 @@ -516,19 +547,20 @@ class TrainingCheckPoint(TrainingCallback): pattern of output model file. Models will be saved as name_0.json, name_1.json, name_2.json .... as_pickle : - When set to True, all training parameters will be saved in pickle format, instead - of saving only the model. + When set to True, all training parameters will be saved in pickle format, + instead of saving only the model. iterations : Interval of checkpointing. Checkpointing is slow so setting a larger number can reduce performance hit. - ''' + """ + def __init__( self, directory: Union[str, os.PathLike], - name: str = 'model', + name: str = "model", as_pickle: bool = False, - iterations: int = 100 + iterations: int = 100, ) -> None: self._path = os.fspath(directory) self._name = name @@ -537,15 +569,23 @@ def __init__( self._epoch = 0 super().__init__() - def after_iteration(self, model, epoch: int, - evals_log: TrainingCallback.EvalsLog) -> bool: + def after_iteration( + self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog + ) -> bool: + if not isinstance(model, Booster): + raise ValueError("Checkpointing is not supported for cv.") if self._epoch == self._iterations: - path = os.path.join(self._path, self._name + '_' + str(epoch) + - ('.pkl' if self._as_pickle else '.json')) + path = os.path.join( + self._path, + self._name + + "_" + + str(epoch) + + (".pkl" if self._as_pickle else ".json"), + ) self._epoch = 0 if rabit.get_rank() == 0: if self._as_pickle: - with open(path, 'wb') as fd: + with open(path, "wb") as fd: pickle.dump(model, fd) else: model.save_model(path) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 32d216a24bd9..2f88c845946c 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -5,7 +5,7 @@ # pylint: disable=no-name-in-module,import-error from collections.abc import Mapping from typing import List, Optional, Any, Union, Dict, TypeVar, Sequence, cast -from typing import Callable, Tuple, Type, TextIO, cast +from typing import Callable, Tuple, Type, TextIO import ctypes import os import re @@ -23,8 +23,8 @@ lazy_isinstance) from .libpath import find_lib_path from .typing import ( - CuArrayLike, CuDFLike, NPArrayLike, DFLike, CSRLike, - PathLike, NativeInput, ArrayLike, DTypeLike + CuArrayLike, CuDFLike, NPArrayLike, CSRLike, PathLike, NativeInput, ArrayLike, + DTypeLike ) from .typing import FeatureNames, FeatureTypes @@ -2642,3 +2642,46 @@ def get_split_value_histogram( UserWarning ) return nph_stacked + + +class _PackedBooster: + def __init__(self, cvfolds) -> None: + self.cvfolds = cvfolds + + def update(self, iteration, obj): + '''Iterate through folds for update''' + for fold in self.cvfolds: + fold.update(iteration, obj) + + def eval(self, iteration, feval, output_margin): + '''Iterate through folds for eval''' + result = [f.eval(iteration, feval, output_margin) for f in self.cvfolds] + return result + + def set_attr(self, **kwargs): + '''Iterate through folds for setting attributes''' + for f in self.cvfolds: + f.bst.set_attr(**kwargs) + + def attr(self, key): + '''Redirect to booster attr.''' + return self.cvfolds[0].bst.attr(key) + + def set_param(self, params, value=None): + """Iterate through folds for set_param""" + for f in self.cvfolds: + f.bst.set_param(params, value) + + def num_boosted_rounds(self): + '''Number of boosted rounds.''' + return self.cvfolds[0].num_boosted_rounds() + + @property + def best_iteration(self): + '''Get best_iteration''' + return int(self.cvfolds[0].bst.attr("best_iteration")) + + @property + def best_score(self): + """Get best_score.""" + return float(self.cvfolds[0].bst.attr("best_score")) diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index 9066124a4cd6..cf57a22040c6 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -8,7 +8,9 @@ from typing import Optional, Dict, Any, Union, Tuple, Sequence import numpy as np -from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args +from .core import ( + Booster, DMatrix, XGBoostError, _deprecate_positional_args, _PackedBooster +) from .core import Metric, Objective from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold) from . import callback @@ -207,49 +209,6 @@ def eval(self, iteration, feval, output_margin): return self.bst.eval_set(self.watchlist, iteration, feval, output_margin) -class _PackedBooster: - def __init__(self, cvfolds) -> None: - self.cvfolds = cvfolds - - def update(self, iteration, obj): - '''Iterate through folds for update''' - for fold in self.cvfolds: - fold.update(iteration, obj) - - def eval(self, iteration, feval, output_margin): - '''Iterate through folds for eval''' - result = [f.eval(iteration, feval, output_margin) for f in self.cvfolds] - return result - - def set_attr(self, **kwargs): - '''Iterate through folds for setting attributes''' - for f in self.cvfolds: - f.bst.set_attr(**kwargs) - - def attr(self, key): - '''Redirect to booster attr.''' - return self.cvfolds[0].bst.attr(key) - - def set_param(self, params, value=None): - """Iterate through folds for set_param""" - for f in self.cvfolds: - f.bst.set_param(params, value) - - def num_boosted_rounds(self): - '''Number of boosted rounds.''' - return self.cvfolds[0].num_boosted_rounds() - - @property - def best_iteration(self): - '''Get best_iteration''' - return int(self.cvfolds[0].bst.attr("best_iteration")) - - @property - def best_score(self): - """Get best_score.""" - return float(self.cvfolds[0].bst.attr("best_score")) - - def groups_to_rows(groups, boundaries): """ Given group row boundaries, convert ground indexes to row indexes From a595f982dc53b6da68a5189c8d768e02c7a38f14 Mon Sep 17 00:00:00 2001 From: fis Date: Sun, 20 Feb 2022 03:24:39 +0800 Subject: [PATCH 16/17] Complete. --- .github/workflows/main.yml | 5 +- Makefile | 20 ----- python-package/xgboost/callback.py | 10 +-- python-package/xgboost/core.py | 53 ++++++++++--- python-package/xgboost/plotting.py | 47 +++++++----- python-package/xgboost/training.py | 119 +++++++++++++++++------------ 6 files changed, 146 insertions(+), 108 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index ebbb629c8b2d..b28e771c5b8d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -155,14 +155,15 @@ jobs: submodules: 'true' - uses: actions/setup-python@v2 with: - python-version: '3.7' + python-version: '3.8' architecture: 'x64' - name: Install Python packages run: | python -m pip install wheel setuptools mypy pandas dask[complete] distributed types-setuptools - name: Run mypy run: | - make mypy + cd python-package + mypy . doxygen: runs-on: ubuntu-latest diff --git a/Makefile b/Makefile index acb1a0c74a62..8bc136283e09 100644 --- a/Makefile +++ b/Makefile @@ -86,26 +86,6 @@ cover: check ) endif - -# dask is required to pass, others are not -# If any of the dask tests failed, contributor won't see the other error. -mypy: - cd python-package; \ - mypy ./xgboost/dask.py && \ - mypy ./xgboost/config.py && \ - mypy ./xgboost/compat.py && \ - mypy ./xgboost/data.py && \ - mypy ./xgboost/rabit.py && \ - mypy ./xgboost/tracker.py && \ - mypy ./xgboost/sklearn.py && \ - mypy ../demo/guide-python/external_memory.py && \ - mypy ../demo/guide-python/categorical.py && \ - mypy ../demo/guide-python/cat_in_the_dat.py && \ - mypy ../tests/python-gpu/test_gpu_with_dask.py && \ - mypy ../tests/python/test_data_iterator.py && \ - mypy ../tests/python-gpu/test_gpu_data_iterator.py || exit 1; \ - mypy . || true ; - clean: $(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o #xgboost $(RM) -rf build_tests *.gcov tests/cpp/xgboost_test diff --git a/python-package/xgboost/callback.py b/python-package/xgboost/callback.py index a5b330693c6f..0f576a7906e2 100644 --- a/python-package/xgboost/callback.py +++ b/python-package/xgboost/callback.py @@ -139,7 +139,7 @@ def __init__( self.is_cv = is_cv if self.is_cv: - self.aggregated_cv = None + self.aggregated_cv: Optional[List[Tuple[str, float, float]]] = None def before_training(self, model: _Model) -> _Model: """Function called before training.""" @@ -190,7 +190,7 @@ def before_iteration( model: _Model, epoch: int, dtrain: DMatrix, - evals: List[Tuple[DMatrix, str]], + evals: Optional[List[Tuple[DMatrix, str]]], ) -> bool: """Function called before training iteration.""" return any( @@ -238,9 +238,9 @@ def after_iteration( if self.is_cv: assert isinstance(model, _PackedBooster) scores = model.eval(epoch, self.metric, self._output_margin) - scores = _aggcv(scores) - self.aggregated_cv = scores - self._update_history(scores, epoch) + scores_cv = _aggcv(scores) + self.aggregated_cv = scores_cv + self._update_history(scores_cv, epoch) else: assert not isinstance(model, _PackedBooster) evals = [] if evals is None else evals diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 2f88c845946c..6de98ccea08d 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1307,7 +1307,7 @@ class Booster: def __init__( self, - params: Optional[Dict] = None, + params: Parameters = None, cache: Optional[Sequence[DMatrix]] = None, model_file: Optional[Union["Booster", bytearray, os.PathLike, str]] = None ) -> None: @@ -2644,44 +2644,73 @@ def get_split_value_histogram( return nph_stacked +class CVPack: + """"Auxiliary datastruct to hold one fold of CV.""" + def __init__(self, dtrain: DMatrix, dtest: DMatrix, param: Parameters) -> None: + """"Initialize the CVPack""" + self.dtrain = dtrain + self.dtest = dtest + self.watchlist = [(dtrain, 'train'), (dtest, 'test')] + self.bst = Booster(param, [dtrain, dtest]) + + def __getattr__(self, name: str) -> Any: + def _inner(*args: Any, **kwargs: Any) -> Any: + return getattr(self.bst, name)(*args, **kwargs) + return _inner + + def update(self, iteration: int, fobj: Optional[Objective]) -> None: + """"Update the boosters for one iteration""" + self.bst.update(self.dtrain, iteration, fobj) + + def eval(self, iteration: int, feval: Optional[Metric], output_margin: bool) -> str: + """"Evaluate the CVPack for one iteration.""" + return self.bst.eval_set(self.watchlist, iteration, feval, output_margin) + + class _PackedBooster: - def __init__(self, cvfolds) -> None: + def __init__(self, cvfolds: List[CVPack]) -> None: self.cvfolds = cvfolds - def update(self, iteration, obj): + def update(self, iteration: int, obj: Optional[Objective]) -> None: '''Iterate through folds for update''' for fold in self.cvfolds: fold.update(iteration, obj) - def eval(self, iteration, feval, output_margin): + def eval( + self, iteration: int, feval: Optional[Metric], output_margin: bool + ) -> List[str]: '''Iterate through folds for eval''' result = [f.eval(iteration, feval, output_margin) for f in self.cvfolds] return result - def set_attr(self, **kwargs): + def set_attr(self, **kwargs: Any) -> None: '''Iterate through folds for setting attributes''' for f in self.cvfolds: f.bst.set_attr(**kwargs) - def attr(self, key): + def attr(self, key: str) -> Any: '''Redirect to booster attr.''' return self.cvfolds[0].bst.attr(key) - def set_param(self, params, value=None): + def set_param(self, params: Union[str, Parameters], value: Any = None) -> None: """Iterate through folds for set_param""" for f in self.cvfolds: f.bst.set_param(params, value) - def num_boosted_rounds(self): + def num_boosted_rounds(self) -> int: '''Number of boosted rounds.''' return self.cvfolds[0].num_boosted_rounds() @property - def best_iteration(self): + def best_iteration(self) -> int: '''Get best_iteration''' - return int(self.cvfolds[0].bst.attr("best_iteration")) + best_iteration = self.cvfolds[0].bst.attr("best_iteration") + assert best_iteration is not None + return int(best_iteration) @property - def best_score(self): + def best_score(self) -> float: """Get best_score.""" - return float(self.cvfolds[0].bst.attr("best_score")) + best_score = self.cvfolds[0].bst.attr("best_score") + assert best_score is not None + return float(best_score) diff --git a/python-package/xgboost/plotting.py b/python-package/xgboost/plotting.py index 9ac98f4053db..3cc35456cf54 100644 --- a/python-package/xgboost/plotting.py +++ b/python-package/xgboost/plotting.py @@ -3,7 +3,7 @@ """Plotting Library.""" from io import BytesIO import json -from typing import Union +from typing import Union, Optional, Any, Dict import numpy as np @@ -13,12 +13,21 @@ def plot_importance( - booster: Union[Booster, XGBModel], ax=None, height=0.2, - xlim=None, ylim=None, title='Feature importance', - xlabel='F score', ylabel='Features', fmap='', - importance_type='weight', max_num_features=None, - grid=True, show_values=True, **kwargs -): + booster: Union[Booster, XGBModel], + ax: Any = None, + height: float = 0.2, + xlim: tuple = None, + ylim: tuple = None, + title: str = 'Feature importance', + xlabel: str = 'F score', + ylabel: str = 'Features', + fmap: PathLike = '', + importance_type: str = 'weight', + max_num_features: Optional[int] = None, + grid: bool = True, + show_values: bool = True, + **kwargs: Any +) -> Any: """Plot importance based on fitted trees. Parameters @@ -129,13 +138,13 @@ def to_graphviz( booster: Union[Booster, XGBModel], fmap: PathLike = "", num_trees: int = 0, - rankdir=None, - yes_color=None, - no_color=None, - condition_node_params=None, - leaf_node_params=None, - **kwargs -): + rankdir: Optional[str] = None, + yes_color: Optional[str] = None, + no_color: Optional[str] = None, + condition_node_params: Dict[str, str] = None, + leaf_node_params: Dict[str, str] = None, + **kwargs: Any +) -> Any: """Convert specified tree to graphviz instance. IPython can automatically plot the returned graphiz instance. Otherwise, you should call .render() method of the returned graphiz instance. @@ -228,11 +237,11 @@ def to_graphviz( def plot_tree( booster: Union[Booster, XGBModel], fmap: PathLike = "", - num_trees=0, - rankdir=None, - ax=None, - **kwargs -): + num_trees: int = 0, + rankdir: Optional[str] = None, + ax: Any = None, + **kwargs: Any +) -> Any: """Plot specified tree. Parameters diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index cf57a22040c6..b8b6dadf3a75 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -5,12 +5,13 @@ import copy import os import warnings -from typing import Optional, Dict, Any, Union, Tuple, Sequence +from typing import Optional, Dict, Any, Union, Tuple, Sequence, List, Callable import numpy as np from .core import ( - Booster, DMatrix, XGBoostError, _deprecate_positional_args, _PackedBooster + Booster, DMatrix, XGBoostError, _deprecate_positional_args, _PackedBooster, CVPack ) +from .core import Parameters from .core import Metric, Objective from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold) from . import callback @@ -49,7 +50,7 @@ def _configure_custom_metric( @_deprecate_positional_args def train( - params: Dict[str, Any], + params: Parameters, dtrain: DMatrix, num_boost_round: int = 10, *, @@ -186,30 +187,7 @@ def train( return bst.copy() -class CVPack: - """"Auxiliary datastruct to hold one fold of CV.""" - def __init__(self, dtrain, dtest, param): - """"Initialize the CVPack""" - self.dtrain = dtrain - self.dtest = dtest - self.watchlist = [(dtrain, 'train'), (dtest, 'test')] - self.bst = Booster(param, [dtrain, dtest]) - - def __getattr__(self, name): - def _inner(*args, **kwargs): - return getattr(self.bst, name)(*args, **kwargs) - return _inner - - def update(self, iteration, fobj): - """"Update the boosters for one iteration""" - self.bst.update(self.dtrain, iteration, fobj) - - def eval(self, iteration, feval, output_margin): - """"Evaluate the CVPack for one iteration.""" - return self.bst.eval_set(self.watchlist, iteration, feval, output_margin) - - -def groups_to_rows(groups, boundaries): +def groups_to_rows(groups: np.ndarray, boundaries: np.ndarray) -> np.ndarray: """ Given group row boundaries, convert ground indexes to row indexes :param groups: list of groups for testing @@ -219,7 +197,14 @@ def groups_to_rows(groups, boundaries): return np.concatenate([np.arange(boundaries[g], boundaries[g+1]) for g in groups]) -def mkgroupfold(dall, nfold, param, evals=(), fpreproc=None, shuffle=True): +def mkgroupfold( + dall: DMatrix, + nfold: int, + param: Parameters, + evals: Union[str, List[str]], + fpreproc: Optional[Callable] = None, + shuffle: bool = True +) -> List[CVPack]: """ Make n folds for cross-validation maintaining groups :return: cross-validation folds @@ -235,11 +220,16 @@ def mkgroupfold(dall, nfold, param, evals=(), fpreproc=None, shuffle=True): # list by fold of test group indexes out_group_idset = np.array_split(idx, nfold) # list by fold of train group indexes - in_group_idset = [np.concatenate([out_group_idset[i] for i in range(nfold) if k != i]) - for k in range(nfold)] + in_group_idset: List[np.ndarray] = [ + np.concatenate([out_group_idset[i] for i in range(nfold) if k != i]) for k in range(nfold) + ] # from the group indexes, convert them to row indexes - in_idset = [groups_to_rows(in_groups, group_boundaries) for in_groups in in_group_idset] - out_idset = [groups_to_rows(out_groups, group_boundaries) for out_groups in out_group_idset] + in_idset = [ + groups_to_rows(in_groups, group_boundaries) for in_groups in in_group_idset + ] + out_idset = [ + groups_to_rows(out_groups, group_boundaries) for out_groups in out_group_idset + ] # build the folds by taking the appropriate slices ret = [] @@ -259,18 +249,29 @@ def mkgroupfold(dall, nfold, param, evals=(), fpreproc=None, shuffle=True): return ret -def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False, - folds=None, shuffle=True): +def mknfold( + dall: DMatrix, + nfold: int, + param: Parameters, + seed: int, + evals: Optional[Union[str, List[str]]] = None, + fpreproc: Optional[Callable] = None, + stratified: bool = False, + folds: Any = None, + shuffle: bool = True +) -> List[CVPack]: """ Make an n-fold list of CVPack from random indices. """ - evals = list(evals) + evals = list(evals if evals is not None else []) np.random.seed(seed) if stratified is False and folds is None: # Do standard k-fold cross validation. Automatically determine the folds. if len(dall.get_uint_info('group_ptr')) > 1: - return mkgroupfold(dall, nfold, param, evals=evals, fpreproc=fpreproc, shuffle=shuffle) + return mkgroupfold( + dall, nfold, param, evals=evals, fpreproc=fpreproc, shuffle=shuffle + ) if shuffle is True: idx = np.random.permutation(dall.num_row()) @@ -313,11 +314,27 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False, return ret -def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None, - metrics=(), obj: Optional[Objective] = None, - feval=None, maximize=None, early_stopping_rounds=None, - fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True, - seed=0, callbacks=None, shuffle=True, custom_metric: Optional[Metric] = None): +def cv( + params: Parameters, + dtrain: DMatrix, + num_boost_round: int = 10, + nfold: int = 3, + stratified: bool = False, + folds: Any = None, + metrics: Union[str, List[str]] = None, + obj: Optional[Objective] = None, + feval: Optional[Metric] = None, + maximize: Optional[bool] = None, + early_stopping_rounds: Optional[int] = None, + fpreproc: Optional[Callable] = None, + as_pandas: bool = True, + verbose_eval: Optional[Union[bool, int]] = None, + show_stdv: bool = True, + seed: int = 0, + callbacks: Optional[List[callback.TrainingCallback]] = None, + shuffle: bool = True, + custom_metric: Optional[Metric] = None +) -> Dict[str, List[float]]: # pylint: disable = invalid-name """Cross-validation with given parameters. @@ -419,9 +436,10 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None params.pop("eval_metric", None) - results = {} - cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc, - stratified, folds, shuffle) + results: Dict[str, List[float]] = {} + cvfolds = mknfold( + dtrain, nfold, params, seed, metrics, fpreproc, stratified, folds, shuffle + ) metric_fn = _configure_custom_metric(feval, custom_metric) @@ -438,7 +456,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None callbacks.append( callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize) ) - callbacks = callback.CallbackContainer( + cb_container = callback.CallbackContainer( callbacks, metric=metric_fn, is_cv=True, @@ -446,15 +464,16 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None ) booster = _PackedBooster(cvfolds) - callbacks.before_training(booster) + cb_container.before_training(booster) for i in range(num_boost_round): - if callbacks.before_iteration(booster, i, dtrain, None): + if cb_container.before_iteration(booster, i, dtrain, None): break booster.update(i, obj) - should_break = callbacks.after_iteration(booster, i, dtrain, None) - res = callbacks.aggregated_cv + should_break = cb_container.after_iteration(booster, i, dtrain, None) + res = cb_container.aggregated_cv + assert res is not None for key, mean, std in res: if key + '-mean' not in results: results[key + '-mean'] = [] @@ -474,6 +493,6 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None except ImportError: pass - callbacks.after_training(booster) + cb_container.after_training(booster) return results From 9fdc69397c999fff2130b53c776da92673cda23e Mon Sep 17 00:00:00 2001 From: fis Date: Sun, 20 Feb 2022 03:30:27 +0800 Subject: [PATCH 17/17] Any. --- python-package/xgboost/sklearn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 90b105a4fdd1..cf9746f65254 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -997,7 +997,7 @@ def predict( validate_features: bool = True, base_margin: Optional[ArrayLike] = None, iteration_range: Optional[Tuple[int, int]] = None, - ) -> np.ndarray: + ) -> Any: """Predict with `X`. If the model is trained with early stopping, then `best_iteration` is used automatically. For tree models, when data is on GPU, like cupy array or cuDF dataframe and `predictor` is not specified, the prediction is run on GPU @@ -1070,7 +1070,7 @@ def apply( self, X: ArrayLike, ntree_limit: int = 0, iteration_range: Optional[Tuple[int, int]] = None - ) -> np.ndarray: + ) -> Any: """Return the predicted leaf every tree for each sample. If the model is trained with early stopping, then `best_iteration` is used automatically. @@ -1421,7 +1421,7 @@ def predict( validate_features: bool = True, base_margin: Optional[ArrayLike] = None, iteration_range: Optional[Tuple[int, int]] = None, - ) -> np.ndarray: + ) -> Any: class_probs = super().predict( X=X, output_margin=output_margin,