Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pyspark] Add type hint to basic utilities. #8375

Merged
merged 4 commits into from
Oct 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python-package/xgboost/spark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# type: ignore
"""PySpark XGBoost integration interface
"""

Expand Down
29 changes: 15 additions & 14 deletions python-package/xgboost/spark/estimator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# type: ignore
"""Xgboost pyspark integration submodule for estimator API."""
# pylint: disable=too-many-ancestors
from typing import Any, Type

from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol

from xgboost import XGBClassifier, XGBRanker, XGBRegressor

from .core import (
from .core import ( # type: ignore
SparkXGBClassifierModel,
SparkXGBRankerModel,
SparkXGBRegressorModel,
Expand Down Expand Up @@ -95,19 +96,19 @@ class SparkXGBRegressor(_SparkXGBEstimator):

"""

def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
super().__init__()
self.setParams(**kwargs)

@classmethod
def _xgb_cls(cls):
def _xgb_cls(cls) -> Type[XGBRegressor]:
return XGBRegressor

@classmethod
def _pyspark_model_cls(cls):
def _pyspark_model_cls(cls) -> Type[SparkXGBRegressorModel]:
return SparkXGBRegressorModel

def _validate_params(self):
def _validate_params(self) -> None:
super()._validate_params()
if self.isDefined(self.qid_col):
raise ValueError(
Expand Down Expand Up @@ -209,7 +210,7 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction

"""

def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
super().__init__()
# The default 'objective' param value comes from sklearn `XGBClassifier` ctor,
# but in pyspark we will automatically set objective param depending on
Expand All @@ -219,14 +220,14 @@ def __init__(self, **kwargs):
self.setParams(**kwargs)

@classmethod
def _xgb_cls(cls):
def _xgb_cls(cls) -> Type[XGBClassifier]:
return XGBClassifier

@classmethod
def _pyspark_model_cls(cls):
def _pyspark_model_cls(cls) -> Type[SparkXGBClassifierModel]:
return SparkXGBClassifierModel

def _validate_params(self):
def _validate_params(self) -> None:
super()._validate_params()
if self.isDefined(self.qid_col):
raise ValueError(
Expand Down Expand Up @@ -342,19 +343,19 @@ class SparkXGBRanker(_SparkXGBEstimator):
>>> model.transform(df_test).show()
"""

def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
super().__init__()
self.setParams(**kwargs)

@classmethod
def _xgb_cls(cls):
def _xgb_cls(cls) -> Type[XGBRanker]:
return XGBRanker

@classmethod
def _pyspark_model_cls(cls):
def _pyspark_model_cls(cls) -> Type[SparkXGBRankerModel]:
return SparkXGBRankerModel

def _validate_params(self):
def _validate_params(self) -> None:
super()._validate_params()
if not self.isDefined(self.qid_col):
raise ValueError(
Expand Down
8 changes: 4 additions & 4 deletions python-package/xgboost/spark/params.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# type: ignore
"""Xgboost pyspark integration submodule for params."""
# pylint: disable=too-few-public-methods
from pyspark.ml.param import TypeConverters
Expand All @@ -12,7 +11,7 @@ class HasArbitraryParamsDict(Params):
input.
"""

arbitrary_params_dict = Param(
arbitrary_params_dict: Param[dict] = Param(
Params._dummy(),
"arbitrary_params_dict",
"arbitrary_params_dict This parameter holds all of the additional parameters which are "
Expand All @@ -31,6 +30,7 @@ class HasBaseMarginCol(Params):
Params._dummy(),
"base_margin_col",
"This stores the name for the column of the base margin",
typeConverter=TypeConverters.toString,
)


Expand All @@ -47,7 +47,7 @@ class HasFeaturesCols(Params):
typeConverter=TypeConverters.toListString,
)

def __init__(self):
def __init__(self) -> None:
super().__init__()
self._setDefault(features_cols=[])

Expand All @@ -69,7 +69,7 @@ class HasEnableSparseDataOptim(Params):
typeConverter=TypeConverters.toBoolean,
)

def __init__(self):
def __init__(self) -> None:
super().__init__()
self._setDefault(enable_sparse_data_optim=False)

Expand Down
74 changes: 35 additions & 39 deletions python-package/xgboost/spark/utils.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
# type: ignore
"""Xgboost pyspark integration submodule for helper functions."""
import inspect
import json
import logging
import sys
from threading import Thread
from typing import Any, Callable, Dict, List, Set, Type

import pyspark
from pyspark import BarrierTaskContext, SparkContext
from pyspark.sql.session import SparkSession
from xgboost.tracker import RabitTracker

from xgboost import collective


def get_class_name(cls):
"""
Return the class name.
"""
def get_class_name(cls: Type) -> str:
"""Return the class name."""
return f"{cls.__module__}.{cls.__name__}"


def _get_default_params_from_func(func, unsupported_set):
"""
Returns a dictionary of parameters and their default value of function fn.
Only the parameters with a default value will be included.
def _get_default_params_from_func(
func: Callable, unsupported_set: Set[str]
) -> Dict[str, Any]:
"""Returns a dictionary of parameters and their default value of function fn. Only
the parameters with a default value will be included.

"""
sig = inspect.signature(func)
filtered_params_dict = {}
Expand All @@ -38,27 +39,26 @@ def _get_default_params_from_func(func, unsupported_set):


class CommunicatorContext:
"""
A context controlling collective communicator initialization and finalization.
This isn't specificially necessary (note Part 3), but it is more understandable coding-wise.
"""A context controlling collective communicator initialization and finalization.
This isn't specificially necessary (note Part 3), but it is more understandable
coding-wise.

"""

def __init__(self, context, **args):
def __init__(self, context: BarrierTaskContext, **args: Any) -> None:
self.args = args
self.args["DMLC_TASK_ID"] = str(context.partitionId())

def __enter__(self):
def __enter__(self) -> None:
collective.init(**self.args)

def __exit__(self, *args):
def __exit__(self, *args: Any) -> None:
collective.finalize()


def _start_tracker(context, n_workers):
"""
Start Rabit tracker with n_workers
"""
env = {"DMLC_NUM_WORKER": n_workers}
def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:
"""Start Rabit tracker with n_workers"""
env: Dict[str, Any] = {"DMLC_NUM_WORKER": n_workers}
host = _get_host_ip(context)
rabit_context = RabitTracker(host_ip=host, n_workers=n_workers)
env.update(rabit_context.worker_envs())
Expand All @@ -69,27 +69,20 @@ def _start_tracker(context, n_workers):
return env


def _get_rabit_args(context, n_workers):
"""
Get rabit context arguments to send to each worker.
"""
# pylint: disable=consider-using-f-string
def _get_rabit_args(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:
"""Get rabit context arguments to send to each worker."""
env = _start_tracker(context, n_workers)
return env


def _get_host_ip(context):
"""
Gets the hostIP for Spark. This essentially gets the IP of the first worker.
"""
def _get_host_ip(context: BarrierTaskContext) -> str:
"""Gets the hostIP for Spark. This essentially gets the IP of the first worker."""
task_ip_list = [info.address.split(":")[0] for info in context.getTaskInfos()]
return task_ip_list[0]


def _get_args_from_message_list(messages):
"""
A function to send/recieve messages in barrier context mode
"""
def _get_args_from_message_list(messages: List[str]) -> Dict[str, Any]:
"""A function to send/recieve messages in barrier context mode"""
output = ""
for message in messages:
if message != "":
Expand All @@ -98,8 +91,11 @@ def _get_args_from_message_list(messages):
return json.loads(output)


def _get_spark_session():
"""Get or create spark session. Note: This function can only be invoked from driver side."""
def _get_spark_session() -> SparkSession:
"""Get or create spark session. Note: This function can only be invoked from driver
side.

"""
if pyspark.TaskContext.get() is not None:
# This is a safety check.
raise RuntimeError(
Expand All @@ -108,7 +104,7 @@ def _get_spark_session():
return SparkSession.builder.getOrCreate()


def get_logger(name, level="INFO"):
def get_logger(name: str, level: str = "INFO") -> logging.Logger:
"""Gets a logger by name, or creates and configures it for the first time."""
logger = logging.getLogger(name)
logger.setLevel(level)
Expand All @@ -119,7 +115,7 @@ def get_logger(name, level="INFO"):
return logger


def _get_max_num_concurrent_tasks(spark_context):
def _get_max_num_concurrent_tasks(spark_context: SparkContext) -> int:
"""Gets the current max number of concurrent tasks."""
# pylint: disable=protected-access
# spark 3.1 and above has a different API for fetching max concurrent tasks
Expand All @@ -130,13 +126,13 @@ def _get_max_num_concurrent_tasks(spark_context):
return spark_context._jsc.sc().maxNumConcurrentTasks()


def _is_local(spark_context) -> bool:
def _is_local(spark_context: SparkContext) -> bool:
"""Whether it is Spark local mode"""
# pylint: disable=protected-access
return spark_context._jsc.sc().isLocal()


def _get_gpu_id(task_context) -> int:
def _get_gpu_id(task_context: BarrierTaskContext) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change it to TaskContext instead of BarrierTaskContext?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

Copy link
Member Author

@trivialfis trivialfis Oct 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted, getTaskInfos is exclusive to the barrier task context.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, only _get_host_ip needs BarrerTaskContext. for others, TaskContext is ok.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since get host ip is called by others, unless we want to perform a down cast we need to pass the type down.

"""Get the gpu id from the task resources"""
if task_context is None:
# This is a safety check.
Expand Down