Skip to content

Commit

Permalink
Bring back an old configuration.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 20, 2022
1 parent 137bd95 commit 50d701d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 8 deletions.
45 changes: 38 additions & 7 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import platform
import logging
import collections
import socket
from contextlib import contextmanager
from collections import defaultdict
from threading import Thread
Expand Down Expand Up @@ -151,11 +152,25 @@ async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
return MultiLock


def _start_tracker(n_workers: int, host_ip: Optional[str]) -> Dict[str, Any]:
"""Start Rabit tracker """
env: Dict[str, Union[int, str]] = {'DMLC_NUM_WORKER': n_workers}
host = get_host_ip(host_ip)
rabit_context = RabitTracker(hostIP=host, n_workers=n_workers, use_logger=False)
def _start_tracker(
n_workers: int, addr_from_dask: str, addr_from_user: Optional[str]
) -> Dict[str, Any]:
"""Start Rabit tracker"""
env: Dict[str, Union[int, str]] = {"DMLC_NUM_WORKER": n_workers}
try:
rabit_context = RabitTracker(
hostIP=get_host_ip(addr_from_user), n_workers=n_workers, use_logger=False
)
except socket.error as e:
if e.errno != 99: # not a bind error
raise
LOGGER.warning(
f"Failed to bind address: {get_host_ip(addr_from_user)}, try"
f" {addr_from_dask} instead."
)
rabit_context = RabitTracker(
hostIP=addr_from_dask, n_workers=n_workers, use_logger=False
)
env.update(rabit_context.worker_envs())

rabit_context.start(n_workers)
Expand Down Expand Up @@ -823,17 +838,33 @@ def _dmatrix_from_list_of_parts(
async def _get_rabit_args(
n_workers: int, dconfig: Optional[Dict[str, Any]], client: "distributed.Client"
) -> List[bytes]:
"""Get rabit context arguments from data distribution in DaskDMatrix."""
"""Get rabit context arguments from data distribution in DaskDMatrix.
"""
# There are 3 possible different addresses:
# 1. Provided by user via dask.config
# 2. Guessed by xgboost `get_host_ip` function
# 3. From dask scheduler
# We try 1 and 3 if 1 is available, otherwise 2 and 3.
valid_config = ["scheduler_address"]
# See if user config is available
if dconfig is not None:
for k in dconfig:
if k not in valid_config:
raise ValueError(f"Unknown configuration: {k}")
host_ip: Optional[str] = dconfig.get("scheduler_address", None)
else:
host_ip = None
# Try address from dask scheduler, this might not work, see
# https://github.com/dask/dask-xgboost/pull/40
try:
sched_addr = distributed.comm.get_address_host(client.scheduler.address)
sched_addr = sched_addr.strip("/:")
except Exception: # pylint: disable=broad-except
sched_addr = None

env = await client.run_on_scheduler(_start_tracker, n_workers, sched_addr, host_ip)

env = await client.run_on_scheduler(_start_tracker, n_workers, host_ip)
rabit_args = [f"{k}={v}".encode() for k, v in env.items()]
return rabit_args

Expand Down
3 changes: 2 additions & 1 deletion python-package/xgboost/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ def __init__(
logging.info('start listen on %s:%d', hostIP, self.port)

def __del__(self) -> None:
self.sock.close()
if hasattr(self, "sock"):
self.sock.close()

@staticmethod
def get_neighbor(rank: int, n_workers: int) -> List[int]:
Expand Down

0 comments on commit 50d701d

Please sign in to comment.