Skip to content

Commit

Permalink
Cleanup some pylint errors. (#7667)
Browse files Browse the repository at this point in the history
* Cleanup some pylint errors.

* Cleanup pylint errors in rabit modules.
* Make data iter an abstract class and cleanup private access.
* Cleanup no-self-use for booster.
  • Loading branch information
trivialfis authored Feb 19, 2022
1 parent b76c5d5 commit f08c5dc
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 104 deletions.
62 changes: 31 additions & 31 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# coding: utf-8
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
# pylint: disable=too-many-lines, too-many-locals, no-self-use
# pylint: disable=too-many-lines, too-many-locals
"""Core XGBoost Library."""
# pylint: disable=no-name-in-module,import-error
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import List, Optional, Any, Union, Dict, TypeVar
# pylint: enable=no-name-in-module,import-error
from typing import Callable, Tuple, cast, Sequence
import ctypes
import os
Expand Down Expand Up @@ -123,9 +121,8 @@ def _log_callback(msg: bytes) -> None:

def _get_log_callback_func() -> Callable:
"""Wrap log_callback() method in ctypes callback type"""
# pylint: disable=invalid-name
CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_char_p)
return CALLBACK(_log_callback)
c_callback = ctypes.CFUNCTYPE(None, ctypes.c_char_p)
return c_callback(_log_callback)


def _load_lib() -> ctypes.CDLL:
Expand Down Expand Up @@ -311,7 +308,7 @@ def _prediction_output(shape, dims, predts, is_cuda):
return arr_predict


class DataIter: # pylint: disable=too-many-instance-attributes
class DataIter(ABC): # pylint: disable=too-many-instance-attributes
"""The interface for user defined data iterator.
Parameters
Expand All @@ -333,9 +330,10 @@ def __init__(self, cache_prefix: Optional[str] = None) -> None:
# Stage data in Python until reset or next is called to avoid data being free.
self._temporary_data: Optional[Tuple[Any, Any]] = None

def _get_callbacks(
def get_callbacks(
self, allow_host: bool, enable_categorical: bool
) -> Tuple[Callable, Callable]:
"""Get callback functions for iterating in C."""
assert hasattr(self, "cache_prefix"), "__init__ is not called."
self._reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(
self._reset_wrapper
Expand Down Expand Up @@ -369,7 +367,8 @@ def _handle_exception(self, fn: Callable, dft_ret: _T) -> _T:
self._exception = e.with_traceback(tb)
return dft_ret

def _reraise(self) -> None:
def reraise(self) -> None:
"""Reraise the exception thrown during iteration."""
self._temporary_data = None
if self._exception is not None:
# pylint 2.7.0 believes `self._exception` can be None even with `assert
Expand Down Expand Up @@ -424,10 +423,12 @@ def data_handle(
# pylint: disable=not-callable
return self._handle_exception(lambda: self.next(data_handle), 0)

@abstractmethod
def reset(self) -> None:
"""Reset the data iterator. Prototype for user defined function."""
raise NotImplementedError()

@abstractmethod
def next(self, input_data: Callable) -> int:
"""Set the next batch of data.
Expand Down Expand Up @@ -642,8 +643,7 @@ def _init_from_iter(self, iterator: DataIter, enable_categorical: bool) -> None:
}
args = from_pystr_to_cstr(json.dumps(args))
handle = ctypes.c_void_p()
# pylint: disable=protected-access
reset_callback, next_callback = it._get_callbacks(
reset_callback, next_callback = it.get_callbacks(
True, enable_categorical
)
ret = _LIB.XGDMatrixCreateFromCallback(
Expand All @@ -654,8 +654,7 @@ def _init_from_iter(self, iterator: DataIter, enable_categorical: bool) -> None:
args,
ctypes.byref(handle),
)
# pylint: disable=protected-access
it._reraise()
it.reraise()
# delay check_call to throw intermediate exception first
_check_call(ret)
self.handle = handle
Expand Down Expand Up @@ -1225,8 +1224,7 @@ def _init(self, data, enable_categorical: bool, **meta) -> None:
it = SingleBatchInternalIter(data=data, **meta)

handle = ctypes.c_void_p()
# pylint: disable=protected-access
reset_callback, next_callback = it._get_callbacks(False, enable_categorical)
reset_callback, next_callback = it.get_callbacks(False, enable_categorical)
if it.cache_prefix is not None:
raise ValueError(
"DeviceQuantileDMatrix doesn't cache data, remove the cache_prefix "
Expand All @@ -1242,8 +1240,7 @@ def _init(self, data, enable_categorical: bool, **meta) -> None:
ctypes.c_int(self.max_bin),
ctypes.byref(handle),
)
# pylint: disable=protected-access
it._reraise()
it.reraise()
# delay check_call to throw intermediate exception first
_check_call(ret)
self.handle = handle
Expand Down Expand Up @@ -1281,6 +1278,21 @@ def _get_booster_layer_trees(model: "Booster") -> Tuple[int, int]:
return num_parallel_tree, num_groups


def _configure_metrics(params: Union[Dict, List]) -> Union[Dict, List]:
if (
isinstance(params, dict)
and "eval_metric" in params
and isinstance(params["eval_metric"], list)
):
params = dict((k, v) for k, v in params.items())
eval_metrics = params["eval_metric"]
params.pop("eval_metric", None)
params = list(params.items())
for eval_metric in eval_metrics:
params += [("eval_metric", eval_metric)]
return params


class Booster:
# pylint: disable=too-many-public-methods
"""A Booster of XGBoost.
Expand Down Expand Up @@ -1339,7 +1351,7 @@ def __init__(
raise TypeError('Unknown type:', model_file)

params = params or {}
params = self._configure_metrics(params.copy())
params = _configure_metrics(params.copy())
params = self._configure_constraints(params)
if isinstance(params, list):
params.append(('validate_parameters', True))
Expand All @@ -1352,17 +1364,6 @@ def __init__(
else:
self.booster = 'gbtree'

def _configure_metrics(self, params: Union[Dict, List]) -> Union[Dict, List]:
if isinstance(params, dict) and 'eval_metric' in params \
and isinstance(params['eval_metric'], list):
params = dict((k, v) for k, v in params.items())
eval_metrics = params['eval_metric']
params.pop("eval_metric", None)
params = list(params.items())
for eval_metric in eval_metrics:
params += [('eval_metric', eval_metric)]
return params

def _transform_monotone_constrains(self, value: Union[Dict[str, int], str]) -> str:
if isinstance(value, str):
return value
Expand Down Expand Up @@ -1395,7 +1396,6 @@ def _transform_interaction_constraints(
)
return s + "]"
except KeyError as e:
# pylint: disable=raise-missing-from
raise ValueError(
"Constrained features are not a subset of training data feature names"
) from e
Expand Down
4 changes: 2 additions & 2 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,15 @@ def _try_start_tracker(
host_ip = addrs[0][0]
port = addrs[0][1]
rabit_context = RabitTracker(
hostIP=get_host_ip(host_ip),
host_ip=get_host_ip(host_ip),
n_workers=n_workers,
port=port,
use_logger=False,
)
else:
assert isinstance(addrs[0], str) or addrs[0] is None
rabit_context = RabitTracker(
hostIP=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False
host_ip=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False
)
env.update(rabit_context.worker_envs())
rabit_context.start(n_workers)
Expand Down
22 changes: 11 additions & 11 deletions python-package/xgboost/rabit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# coding: utf-8
# pylint: disable= invalid-name
"""Distributed XGBoost Rabit related API."""
import ctypes
from enum import IntEnum, unique
import pickle
from typing import Any, TypeVar, Callable, Optional, cast, List, Union

Expand Down Expand Up @@ -98,7 +97,7 @@ def get_processor_name() -> bytes:
return buf.value


T = TypeVar("T")
T = TypeVar("T") # pylint:disable=invalid-name


def broadcast(data: T, root: int) -> T:
Expand Down Expand Up @@ -152,34 +151,35 @@ def broadcast(data: T, root: int) -> T:
}


class Op: # pylint: disable=too-few-public-methods
@unique
class Op(IntEnum):
'''Supported operations for rabit.'''
MAX = 0
MIN = 1
SUM = 2
OR = 3


def allreduce(
data: np.ndarray, op: int, prepare_fun: Optional[Callable[[np.ndarray], None]] = None
def allreduce( # pylint:disable=invalid-name
data: np.ndarray, op: Op, prepare_fun: Optional[Callable[[np.ndarray], None]] = None
) -> np.ndarray:
"""Perform allreduce, return the result.
Parameters
----------
data: numpy array
data :
Input data.
op: int
op :
Reduction operators, can be MIN, MAX, SUM, BITOR
prepare_fun: function
prepare_fun :
Lazy preprocessing function, if it is not None, prepare_fun(data)
will be called by the function before performing allreduce, to initialize the data
If the result of Allreduce can be recovered directly,
then prepare_fun will NOT be called
Returns
-------
result : array_like
result :
The result of allreduce, have same shape as data
Notes
Expand All @@ -196,7 +196,7 @@ def allreduce(
if prepare_fun is None:
_check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[buf.dtype],
op, None, None))
int(op), None, None))
else:
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)

Expand Down
Loading

0 comments on commit f08c5dc

Please sign in to comment.