Skip to content
This repository has been archived by the owner on Aug 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #37 from trivialfis/backport-dask-config
Browse files Browse the repository at this point in the history
[backport] [dask] Add scheduler address to dask config. (dmlc#7581)
  • Loading branch information
ajschmidt8 authored Jan 26, 2022
2 parents 28edcb4 + e731d17 commit 1190f78
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 22 deletions.
26 changes: 26 additions & 0 deletions doc/tutorials/dask.rst
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,32 @@ interface, including callback functions, custom evaluation metric and objective:
)
.. _tracker-ip:

***************
Tracker Host IP
***************

.. versionadded:: 1.6.0

In some environments XGBoost might fail to resolve the IP address of the scheduler, a
symptom is user receiving ``OSError: [Errno 99] Cannot assign requested address`` error
during training. A quick workaround is to specify the address explicitly. To do that
dask config is used:

.. code-block:: python
import dask
from distributed import Client
from xgboost import dask as dxgb
# let xgboost know the scheduler address
dask.config.set({"xgboost.scheduler_address": "192.0.0.100"})
with Client(scheduler_file="sched.json") as client:
reg = dxgb.DaskXGBRegressor()
XGBoost will read configuration before training.

*****************************************************************************
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors
*****************************************************************************
Expand Down
103 changes: 88 additions & 15 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,21 @@
The implementation is heavily influenced by dask_xgboost:
https://github.com/dask/dask-xgboost
Optional dask configuration
===========================
- **xgboost.scheduler_address**: Specify the scheduler address, see :ref:`tracker-ip`.
.. versionadded:: 1.6.0
.. code-block:: python
dask.config.set({"xgboost.scheduler_address": "192.0.0.100"})
"""
import platform
import logging
import socket
from contextlib import contextmanager
from collections import defaultdict
from collections.abc import Sequence
Expand Down Expand Up @@ -136,17 +148,37 @@ async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
return MultiLock


def _start_tracker(n_workers: int) -> Dict[str, Any]:
"""Start Rabit tracker """
env = {'DMLC_NUM_WORKER': n_workers}
host = get_host_ip('auto')
rabit_context = RabitTracker(hostIP=host, nslave=n_workers, use_logger=False)
env.update(rabit_context.slave_envs())
def _try_start_tracker(
n_workers: int, addrs: List[Optional[str]]
) -> Dict[str, Union[int, str]]:
env: Dict[str, Union[int, str]] = {"DMLC_NUM_WORKER": n_workers}
try:
rabit_context = RabitTracker(
hostIP=get_host_ip(addrs[0]), nslave=n_workers, use_logger=False
)
env.update(rabit_context.slave_envs())
rabit_context.start(n_workers)
thread = Thread(target=rabit_context.join)
thread.daemon = True
thread.start()
except socket.error as e:
if len(addrs) < 2 or e.errno != 99:
raise
LOGGER.warning(
"Failed to bind address '%s', trying to use '%s' instead.",
str(addrs[0]),
str(addrs[1]),
)
env = _try_start_tracker(n_workers, addrs[1:])

return env


rabit_context.start(n_workers)
thread = Thread(target=rabit_context.join)
thread.daemon = True
thread.start()
def _start_tracker(
n_workers: int, addr_from_dask: Optional[str], addr_from_user: Optional[str]
) -> Dict[str, Union[int, str]]:
"""Start Rabit tracker, recurse to try different addresses."""
env = _try_start_tracker(n_workers, [addr_from_user, addr_from_dask])
return env


Expand Down Expand Up @@ -174,6 +206,7 @@ def __init__(self, args: List[bytes]) -> None:

def __enter__(self) -> None:
rabit.init(self.args)
assert rabit.is_distributed()
LOGGER.debug('-------------- rabit say hello ------------------')

def __exit__(self, *args: List) -> None:
Expand Down Expand Up @@ -800,12 +833,43 @@ def _dmatrix_from_list_of_parts(
return _create_dmatrix(**kwargs)


async def _get_rabit_args(n_workers: int, client: "distributed.Client") -> List[bytes]:
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
env = await client.run_on_scheduler(_start_tracker, n_workers)
async def _get_rabit_args(
n_workers: int, dconfig: Optional[Dict[str, Any]], client: "distributed.Client"
) -> List[bytes]:
"""Get rabit context arguments from data distribution in DaskDMatrix.
"""
# There are 3 possible different addresses:
# 1. Provided by user via dask.config
# 2. Guessed by xgboost `get_host_ip` function
# 3. From dask scheduler
# We try 1 and 3 if 1 is available, otherwise 2 and 3.
valid_config = ["scheduler_address"]
# See if user config is available
if dconfig is not None:
for k in dconfig:
if k not in valid_config:
raise ValueError(f"Unknown configuration: {k}")
host_ip: Optional[str] = dconfig.get("scheduler_address", None)
else:
host_ip = None
# Try address from dask scheduler, this might not work, see
# https://github.com/dask/dask-xgboost/pull/40
try:
sched_addr = distributed.comm.get_address_host(client.scheduler.address)
sched_addr = sched_addr.strip("/:")
except Exception: # pylint: disable=broad-except
sched_addr = None

env = await client.run_on_scheduler(_start_tracker, n_workers, sched_addr, host_ip)

rabit_args = [f"{k}={v}".encode() for k, v in env.items()]
return rabit_args


def _get_dask_config() -> Optional[Dict[str, Any]]:
return dask.config.get("xgboost", default=None)

# train and predict methods are supposed to be "functional", which meets the
# dask paradigm. But as a side effect, the `evals_result` in single-node API
# is no longer supported since it mutates the input parameter, and it's not
Expand All @@ -832,6 +896,7 @@ def _get_workers_from_data(
async def _train_async(
client: "distributed.Client",
global_config: Dict[str, Any],
dconfig: Optional[Dict[str, Any]],
params: Dict[str, Any],
dtrain: DaskDMatrix,
num_boost_round: int,
Expand All @@ -844,7 +909,7 @@ async def _train_async(
callbacks: Optional[List[TrainingCallback]],
) -> Optional[TrainReturnT]:
workers = _get_workers_from_data(dtrain, evals)
_rabit_args = await _get_rabit_args(len(workers), client)
_rabit_args = await _get_rabit_args(len(workers), dconfig, client)

if params.get("booster", None) == "gblinear":
raise NotImplementedError(
Expand Down Expand Up @@ -978,7 +1043,12 @@ def train( # pylint: disable=unused-argument
_assert_dask_support()
client = _xgb_get_client(client)
args = locals()
return client.sync(_train_async, global_config=config.get_config(), **args)
return client.sync(
_train_async,
global_config=config.get_config(),
dconfig=_get_dask_config(),
**args,
)


def _can_output_df(is_df: bool, output_shape: Tuple) -> bool:
Expand Down Expand Up @@ -1676,6 +1746,7 @@ async def _fit_async(
asynchronous=True,
client=self.client,
global_config=config.get_config(),
dconfig=_get_dask_config(),
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
Expand Down Expand Up @@ -1778,6 +1849,7 @@ async def _fit_async(
asynchronous=True,
client=self.client,
global_config=config.get_config(),
dconfig=_get_dask_config(),
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
Expand Down Expand Up @@ -1963,6 +2035,7 @@ async def _fit_async(
asynchronous=True,
client=self.client,
global_config=config.get_config(),
dconfig=_get_dask_config(),
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
Expand Down
5 changes: 3 additions & 2 deletions python-package/xgboost/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,9 @@ def __init__(
self._use_logger = use_logger
logging.info('start listen on %s:%d', hostIP, self.port)

def __del__(self):
self.sock.close()
def __del__(self) -> None:
if hasattr(self, "sock"):
self.sock.close()

@staticmethod
def get_neighbor(rank, nslave):
Expand Down
4 changes: 2 additions & 2 deletions tests/python-gpu/test_gpu_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def test_data_initialization(self, local_cuda_cluster: LocalCUDACluster) -> None
m = dxgb.DaskDMatrix(client, X, y, feature_weights=fw)

workers = _get_client_workers(client)
rabit_args = client.sync(dxgb._get_rabit_args, len(workers), client)
rabit_args = client.sync(dxgb._get_rabit_args, len(workers), None, client)

def worker_fn(worker_addr: str, data_ref: Dict) -> None:
with dxgb.RabitContext(rabit_args):
Expand Down Expand Up @@ -532,7 +532,7 @@ def runit(

with Client(local_cuda_cluster) as client:
workers = _get_client_workers(client)
rabit_args = client.sync(dxgb._get_rabit_args, workers, client)
rabit_args = client.sync(dxgb._get_rabit_args, workers, None, client)
futures = client.map(runit,
workers,
pure=False,
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def run_rabit_ops(client, n_workers):
from xgboost import rabit

workers = _get_client_workers(client)
rabit_args = client.sync(_get_rabit_args, len(workers), client)
rabit_args = client.sync(_get_rabit_args, len(workers), None, client)
assert not rabit.is_distributed()
n_workers_from_dask = len(workers)
assert n_workers == n_workers_from_dask
Expand Down
12 changes: 10 additions & 2 deletions tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
pytest.skip(msg=tm.no_dask()['reason'], allow_module_level=True)

from distributed import LocalCluster, Client
import dask
import dask.dataframe as dd
import dask.array as da
from xgboost.dask import DaskDMatrix
Expand Down Expand Up @@ -1052,6 +1053,10 @@ def after_iteration(
os.remove(before_fname)
os.remove(after_fname)

with dask.config.set({'xgboost.foo': "bar"}):
with pytest.raises(ValueError):
xgb.dask.train(client, {}, dtrain, num_boost_round=4)

def run_updater_test(
self,
client: "Client",
Expand Down Expand Up @@ -1147,7 +1152,8 @@ def runit(
with Client(cluster) as client:
workers = _get_client_workers(client)
rabit_args = client.sync(
xgb.dask._get_rabit_args, len(workers), client)
xgb.dask._get_rabit_args, len(workers), None, client
)
futures = client.map(runit,
workers,
pure=False,
Expand Down Expand Up @@ -1269,7 +1275,9 @@ def test_no_duplicated_partition(self) -> None:
n_partitions = X.npartitions
m = xgb.dask.DaskDMatrix(client, X, y)
workers = _get_client_workers(client)
rabit_args = client.sync(xgb.dask._get_rabit_args, len(workers), client)
rabit_args = client.sync(
xgb.dask._get_rabit_args, len(workers), None, client
)
n_workers = len(workers)

def worker_fn(worker_addr: str, data_ref: Dict) -> None:
Expand Down

0 comments on commit 1190f78

Please sign in to comment.