diff --git a/python-package/xgboost/spark/__init__.py b/python-package/xgboost/spark/__init__.py index c0b347eefb30..7c18eeba46b5 100644 --- a/python-package/xgboost/spark/__init__.py +++ b/python-package/xgboost/spark/__init__.py @@ -1,4 +1,3 @@ -# type: ignore """PySpark XGBoost integration interface """ diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index fc6bbbc9d3c3..2fe113ad4631 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -1,11 +1,12 @@ -# type: ignore """Xgboost pyspark integration submodule for estimator API.""" # pylint: disable=too-many-ancestors +from typing import Any, Type + from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol from xgboost import XGBClassifier, XGBRanker, XGBRegressor -from .core import ( +from .core import ( # type: ignore SparkXGBClassifierModel, SparkXGBRankerModel, SparkXGBRegressorModel, @@ -95,19 +96,19 @@ class SparkXGBRegressor(_SparkXGBEstimator): """ - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__() self.setParams(**kwargs) @classmethod - def _xgb_cls(cls): + def _xgb_cls(cls) -> Type[XGBRegressor]: return XGBRegressor @classmethod - def _pyspark_model_cls(cls): + def _pyspark_model_cls(cls) -> Type[SparkXGBRegressorModel]: return SparkXGBRegressorModel - def _validate_params(self): + def _validate_params(self) -> None: super()._validate_params() if self.isDefined(self.qid_col): raise ValueError( @@ -209,7 +210,7 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction """ - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__() # The default 'objective' param value comes from sklearn `XGBClassifier` ctor, # but in pyspark we will automatically set objective param depending on @@ -219,14 +220,14 @@ def __init__(self, **kwargs): self.setParams(**kwargs) @classmethod - def _xgb_cls(cls): + def _xgb_cls(cls) -> Type[XGBClassifier]: return XGBClassifier @classmethod - def _pyspark_model_cls(cls): + def _pyspark_model_cls(cls) -> Type[SparkXGBClassifierModel]: return SparkXGBClassifierModel - def _validate_params(self): + def _validate_params(self) -> None: super()._validate_params() if self.isDefined(self.qid_col): raise ValueError( @@ -342,19 +343,19 @@ class SparkXGBRanker(_SparkXGBEstimator): >>> model.transform(df_test).show() """ - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__() self.setParams(**kwargs) @classmethod - def _xgb_cls(cls): + def _xgb_cls(cls) -> Type[XGBRanker]: return XGBRanker @classmethod - def _pyspark_model_cls(cls): + def _pyspark_model_cls(cls) -> Type[SparkXGBRankerModel]: return SparkXGBRankerModel - def _validate_params(self): + def _validate_params(self) -> None: super()._validate_params() if not self.isDefined(self.qid_col): raise ValueError( diff --git a/python-package/xgboost/spark/params.py b/python-package/xgboost/spark/params.py index 2053b43fce87..77cfcd137d99 100644 --- a/python-package/xgboost/spark/params.py +++ b/python-package/xgboost/spark/params.py @@ -1,4 +1,3 @@ -# type: ignore """Xgboost pyspark integration submodule for params.""" # pylint: disable=too-few-public-methods from pyspark.ml.param import TypeConverters @@ -12,7 +11,7 @@ class HasArbitraryParamsDict(Params): input. """ - arbitrary_params_dict = Param( + arbitrary_params_dict: Param[dict] = Param( Params._dummy(), "arbitrary_params_dict", "arbitrary_params_dict This parameter holds all of the additional parameters which are " @@ -31,6 +30,7 @@ class HasBaseMarginCol(Params): Params._dummy(), "base_margin_col", "This stores the name for the column of the base margin", + typeConverter=TypeConverters.toString, ) @@ -47,7 +47,7 @@ class HasFeaturesCols(Params): typeConverter=TypeConverters.toListString, ) - def __init__(self): + def __init__(self) -> None: super().__init__() self._setDefault(features_cols=[]) @@ -69,7 +69,7 @@ class HasEnableSparseDataOptim(Params): typeConverter=TypeConverters.toBoolean, ) - def __init__(self): + def __init__(self) -> None: super().__init__() self._setDefault(enable_sparse_data_optim=False) diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index 79c040f2701f..36705459ae4f 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -1,29 +1,30 @@ -# type: ignore """Xgboost pyspark integration submodule for helper functions.""" import inspect import json import logging import sys from threading import Thread +from typing import Any, Callable, Dict, List, Set, Type import pyspark +from pyspark import BarrierTaskContext, SparkContext from pyspark.sql.session import SparkSession from xgboost.tracker import RabitTracker from xgboost import collective -def get_class_name(cls): - """ - Return the class name. - """ +def get_class_name(cls: Type) -> str: + """Return the class name.""" return f"{cls.__module__}.{cls.__name__}" -def _get_default_params_from_func(func, unsupported_set): - """ - Returns a dictionary of parameters and their default value of function fn. - Only the parameters with a default value will be included. +def _get_default_params_from_func( + func: Callable, unsupported_set: Set[str] +) -> Dict[str, Any]: + """Returns a dictionary of parameters and their default value of function fn. Only + the parameters with a default value will be included. + """ sig = inspect.signature(func) filtered_params_dict = {} @@ -38,27 +39,26 @@ def _get_default_params_from_func(func, unsupported_set): class CommunicatorContext: - """ - A context controlling collective communicator initialization and finalization. - This isn't specificially necessary (note Part 3), but it is more understandable coding-wise. + """A context controlling collective communicator initialization and finalization. + This isn't specificially necessary (note Part 3), but it is more understandable + coding-wise. + """ - def __init__(self, context, **args): + def __init__(self, context: BarrierTaskContext, **args: Any) -> None: self.args = args self.args["DMLC_TASK_ID"] = str(context.partitionId()) - def __enter__(self): + def __enter__(self) -> None: collective.init(**self.args) - def __exit__(self, *args): + def __exit__(self, *args: Any) -> None: collective.finalize() -def _start_tracker(context, n_workers): - """ - Start Rabit tracker with n_workers - """ - env = {"DMLC_NUM_WORKER": n_workers} +def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]: + """Start Rabit tracker with n_workers""" + env: Dict[str, Any] = {"DMLC_NUM_WORKER": n_workers} host = _get_host_ip(context) rabit_context = RabitTracker(host_ip=host, n_workers=n_workers) env.update(rabit_context.worker_envs()) @@ -69,27 +69,20 @@ def _start_tracker(context, n_workers): return env -def _get_rabit_args(context, n_workers): - """ - Get rabit context arguments to send to each worker. - """ - # pylint: disable=consider-using-f-string +def _get_rabit_args(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]: + """Get rabit context arguments to send to each worker.""" env = _start_tracker(context, n_workers) return env -def _get_host_ip(context): - """ - Gets the hostIP for Spark. This essentially gets the IP of the first worker. - """ +def _get_host_ip(context: BarrierTaskContext) -> str: + """Gets the hostIP for Spark. This essentially gets the IP of the first worker.""" task_ip_list = [info.address.split(":")[0] for info in context.getTaskInfos()] return task_ip_list[0] -def _get_args_from_message_list(messages): - """ - A function to send/recieve messages in barrier context mode - """ +def _get_args_from_message_list(messages: List[str]) -> Dict[str, Any]: + """A function to send/recieve messages in barrier context mode""" output = "" for message in messages: if message != "": @@ -98,8 +91,11 @@ def _get_args_from_message_list(messages): return json.loads(output) -def _get_spark_session(): - """Get or create spark session. Note: This function can only be invoked from driver side.""" +def _get_spark_session() -> SparkSession: + """Get or create spark session. Note: This function can only be invoked from driver + side. + + """ if pyspark.TaskContext.get() is not None: # This is a safety check. raise RuntimeError( @@ -108,7 +104,7 @@ def _get_spark_session(): return SparkSession.builder.getOrCreate() -def get_logger(name, level="INFO"): +def get_logger(name: str, level: str = "INFO") -> logging.Logger: """Gets a logger by name, or creates and configures it for the first time.""" logger = logging.getLogger(name) logger.setLevel(level) @@ -119,7 +115,7 @@ def get_logger(name, level="INFO"): return logger -def _get_max_num_concurrent_tasks(spark_context): +def _get_max_num_concurrent_tasks(spark_context: SparkContext) -> int: """Gets the current max number of concurrent tasks.""" # pylint: disable=protected-access # spark 3.1 and above has a different API for fetching max concurrent tasks @@ -130,13 +126,13 @@ def _get_max_num_concurrent_tasks(spark_context): return spark_context._jsc.sc().maxNumConcurrentTasks() -def _is_local(spark_context) -> bool: +def _is_local(spark_context: SparkContext) -> bool: """Whether it is Spark local mode""" # pylint: disable=protected-access return spark_context._jsc.sc().isLocal() -def _get_gpu_id(task_context) -> int: +def _get_gpu_id(task_context: BarrierTaskContext) -> int: """Get the gpu id from the task resources""" if task_context is None: # This is a safety check.