Skip to content

Commit

Permalink
avoid stage-level scheduling on non spark standalone and localcluster…
Browse files Browse the repository at this point in the history
… mode
  • Loading branch information
wbo4958 committed Oct 11, 2023
1 parent a265236 commit c8102f1
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
19 changes: 11 additions & 8 deletions python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
MLWritable,
MLWriter,
)
from pyspark.resource import ResourceProfileBuilder, TaskResourceRequests
from pyspark.sql import Column, DataFrame
from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct
from pyspark.sql.types import (
Expand Down Expand Up @@ -89,6 +88,7 @@
_get_rabit_args,
_get_spark_session,
_is_local,
_is_standalone_or_localcluster,
deserialize_booster,
deserialize_xgb_model,
get_class_name,
Expand Down Expand Up @@ -367,7 +367,7 @@ def _validate_gpu_params(self) -> None:
" on GPU."
)

if ss.version < "3.4.0":
if not (ss.version >= "3.4.0" and _is_standalone_or_localcluster(sc)):
# We will enable stage-level scheduling in spark 3.4.0+ which doesn't
# require spark.task.resource.gpu.amount to be set explicitly
gpu_per_task = sc.getConf().get("spark.task.resource.gpu.amount")
Expand Down Expand Up @@ -914,8 +914,11 @@ def _skip_stage_level_scheduling(self) -> bool:
)
return True

if _is_local(sc):
# Local mode doesn't support stage-level scheduling
if not _is_standalone_or_localcluster(sc):
self.logger.warning(
"Stage-level scheduling in xgboost requires spark standalone or "
"local-cluster mode"
)
return True

executor_cores = sc.getConf().get("spark.executor.cores")
Expand Down Expand Up @@ -991,13 +994,13 @@ def _try_stage_level_scheduling(self, rdd: RDD) -> RDD:
# ETL gpu tasks running alongside training tasks to avoid OOM
spark_plugins = ss.conf.get("spark.plugins", " ")
assert spark_plugins is not None
spark_rapids_sql_enabled = ss.conf.get(
"spark.rapids.sql.enabled", "true"
).lower()
spark_rapids_sql_enabled = ss.conf.get("spark.rapids.sql.enabled", "true")
assert spark_rapids_sql_enabled is not None

task_cores = (
int(executor_cores)
if "com.nvidia.spark.SQLPlugin" in spark_plugins
and "true" == spark_rapids_sql_enabled
and "true" == spark_rapids_sql_enabled.lower()
else (int(executor_cores) // 2) + 1
)

Expand Down
7 changes: 7 additions & 0 deletions python-package/xgboost/spark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ def _is_local(spark_context: SparkContext) -> bool:
return spark_context._jsc.sc().isLocal()


def _is_standalone_or_localcluster(spark_context: SparkContext) -> bool:
master = spark_context.getConf().get("spark.master")
return master is not None and (
master.startswith("spark://") or master.startswith("local-cluster")
)


def _get_gpu_id(task_context: TaskContext) -> int:
"""Get the gpu id from the task resources"""
if task_context is None:
Expand Down

0 comments on commit c8102f1

Please sign in to comment.