diff --git a/tests/python-gpu/test_gpu_spark/discover_gpu.sh b/tests/python-gpu/test_gpu_spark/discover_gpu.sh index 42dd0551784d..fc2c71741de2 100755 --- a/tests/python-gpu/test_gpu_spark/discover_gpu.sh +++ b/tests/python-gpu/test_gpu_spark/discover_gpu.sh @@ -1,3 +1,16 @@ #!/bin/bash -echo "{\"name\":\"gpu\",\"addresses\":[\"0\",\"1\",\"2\",\"3\"]}" +# This script is only made for running XGBoost tests on official CI where we have access +# to a 4-GPU cluster, the discovery command is for running tests on a local machine where +# the driver and the GPU worker might be the same machine for the ease of development. + +if ! command -v nvidia-smi &> /dev/null +then + # default to 4 GPUs + echo "{\"name\":\"gpu\",\"addresses\":[\"0\",\"1\",\"2\",\"3\"]}" + exit +else + # https://github.com/apache/spark/blob/master/examples/src/main/scripts/getGpusResources.sh + ADDRS=`nvidia-smi --query-gpu=index --format=csv,noheader | sed -e ':a' -e 'N' -e'$!ba' -e 's/\n/","/g'` + echo {\"name\": \"gpu\", \"addresses\":[\"$ADDRS\"]} +fi diff --git a/tests/python-gpu/test_gpu_spark/test_gpu_spark.py b/tests/python-gpu/test_gpu_spark/test_gpu_spark.py index ce5b9d8c8d42..bcae96dc5dc8 100644 --- a/tests/python-gpu/test_gpu_spark/test_gpu_spark.py +++ b/tests/python-gpu/test_gpu_spark/test_gpu_spark.py @@ -1,4 +1,6 @@ +import json import logging +import subprocess import sys import pytest @@ -7,7 +9,7 @@ sys.path.append("tests/python") import testing as tm -if tm.no_dask()["condition"]: +if tm.no_spark()["condition"]: pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) if sys.platform.startswith("win"): pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True) @@ -18,8 +20,20 @@ from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor gpu_discovery_script_path = "tests/python-gpu/test_gpu_spark/discover_gpu.sh" -executor_gpu_amount = 4 -executor_cores = 4 + + +def get_devices(): + """This works only if driver is the same machine of worker.""" + completed = subprocess.run(gpu_discovery_script_path, stdout=subprocess.PIPE) + assert completed.returncode == 0, "Failed to execute discovery script." + msg = completed.stdout.decode("utf-8") + result = json.loads(msg) + addresses = result["addresses"] + return addresses + + +executor_gpu_amount = len(get_devices()) +executor_cores = executor_gpu_amount num_workers = executor_gpu_amount