From 50d701d3f906a46b7420b6c1d0187b89cc53c3da Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 20 Jan 2022 16:17:43 +0800 Subject: [PATCH] Bring back an old configuration. --- python-package/xgboost/dask.py | 45 ++++++++++++++++++++++++++----- python-package/xgboost/tracker.py | 3 ++- 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 73fe8ae07d3b..0bca8d97df56 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -32,6 +32,7 @@ import platform import logging import collections +import socket from contextlib import contextmanager from collections import defaultdict from threading import Thread @@ -151,11 +152,25 @@ async def __aexit__(self, *args: Any, **kwargs: Any) -> None: return MultiLock -def _start_tracker(n_workers: int, host_ip: Optional[str]) -> Dict[str, Any]: - """Start Rabit tracker """ - env: Dict[str, Union[int, str]] = {'DMLC_NUM_WORKER': n_workers} - host = get_host_ip(host_ip) - rabit_context = RabitTracker(hostIP=host, n_workers=n_workers, use_logger=False) +def _start_tracker( + n_workers: int, addr_from_dask: str, addr_from_user: Optional[str] +) -> Dict[str, Any]: + """Start Rabit tracker""" + env: Dict[str, Union[int, str]] = {"DMLC_NUM_WORKER": n_workers} + try: + rabit_context = RabitTracker( + hostIP=get_host_ip(addr_from_user), n_workers=n_workers, use_logger=False + ) + except socket.error as e: + if e.errno != 99: # not a bind error + raise + LOGGER.warning( + f"Failed to bind address: {get_host_ip(addr_from_user)}, try" + f" {addr_from_dask} instead." + ) + rabit_context = RabitTracker( + hostIP=addr_from_dask, n_workers=n_workers, use_logger=False + ) env.update(rabit_context.worker_envs()) rabit_context.start(n_workers) @@ -823,8 +838,16 @@ def _dmatrix_from_list_of_parts( async def _get_rabit_args( n_workers: int, dconfig: Optional[Dict[str, Any]], client: "distributed.Client" ) -> List[bytes]: - """Get rabit context arguments from data distribution in DaskDMatrix.""" + """Get rabit context arguments from data distribution in DaskDMatrix. + + """ + # There are 3 possible different addresses: + # 1. Provided by user via dask.config + # 2. Guessed by xgboost `get_host_ip` function + # 3. From dask scheduler + # We try 1 and 3 if 1 is available, otherwise 2 and 3. valid_config = ["scheduler_address"] + # See if user config is available if dconfig is not None: for k in dconfig: if k not in valid_config: @@ -832,8 +855,16 @@ async def _get_rabit_args( host_ip: Optional[str] = dconfig.get("scheduler_address", None) else: host_ip = None + # Try address from dask scheduler, this might not work, see + # https://github.com/dask/dask-xgboost/pull/40 + try: + sched_addr = distributed.comm.get_address_host(client.scheduler.address) + sched_addr = sched_addr.strip("/:") + except Exception: # pylint: disable=broad-except + sched_addr = None + + env = await client.run_on_scheduler(_start_tracker, n_workers, sched_addr, host_ip) - env = await client.run_on_scheduler(_start_tracker, n_workers, host_ip) rabit_args = [f"{k}={v}".encode() for k, v in env.items()] return rabit_args diff --git a/python-package/xgboost/tracker.py b/python-package/xgboost/tracker.py index f2fc90302c0e..9e040d05b2d3 100644 --- a/python-package/xgboost/tracker.py +++ b/python-package/xgboost/tracker.py @@ -192,7 +192,8 @@ def __init__( logging.info('start listen on %s:%d', hostIP, self.port) def __del__(self) -> None: - self.sock.close() + if hasattr(self, "sock"): + self.sock.close() @staticmethod def get_neighbor(rank: int, n_workers: int) -> List[int]: