Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] Add scheduler address to dask config. #7581

Merged
merged 11 commits into from
Jan 21, 2022
26 changes: 26 additions & 0 deletions doc/tutorials/dask.rst
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,32 @@ interface, including callback functions, custom evaluation metric and objective:
)


.. _tracker-ip:

***************
Tracker Host IP
***************

.. versionadded:: 1.6.0

In some environments XGBoost might fail to resolve the IP address of the scheduler, a
symptom is user receiving ``OSError: [Errno 99] Cannot assign requested address`` error
during training. A quick workaround is to specify the address explicitly. To do that
dask config is used:

.. code-block:: python

import dask
from distributed import Client
from xgboost import dask as dxgb
# let xgboost know the scheduler address
dask.config.set({"xgboost.scheduler_address": "192.0.0.100"})

with Client(scheduler_file="sched.json") as client:
reg = dxgb.DaskXGBRegressor()

XGBoost will read configuration before training.

*****************************************************************************
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors
*****************************************************************************
Expand Down
113 changes: 95 additions & 18 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
# pylint: disable=too-many-lines, fixme
# pylint: disable=too-few-public-methods
# pylint: disable=import-error
"""Dask extensions for distributed training. See :doc:`Distributed XGBoost with Dask
</tutorials/dask>` for simple tutorial. Also xgboost/demo/dask for some examples.
"""
Dask extensions for distributed training
----------------------------------------

See :doc:`Distributed XGBoost with Dask </tutorials/dask>` for simple tutorial. Also
:doc:`/python/dask-examples/index` for some examples.

There are two sets of APIs in this module, one is the functional API including
``train`` and ``predict`` methods. Another is stateful Scikit-Learner wrapper
Expand All @@ -13,10 +17,22 @@
The implementation is heavily influenced by dask_xgboost:
https://github.com/dask/dask-xgboost

Optional dask configuration
===========================

- **xgboost.scheduler_address**: Specify the scheduler address, see :ref:`tracker-ip`.

.. versionadded:: 1.6.0

.. code-block:: python

dask.config.set({"xgboost.scheduler_address": "192.0.0.100"})

"""
import platform
import logging
import collections
import socket
from contextlib import contextmanager
from collections import defaultdict
from threading import Thread
Expand Down Expand Up @@ -136,17 +152,37 @@ async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
return MultiLock


def _start_tracker(n_workers: int) -> Dict[str, Any]:
"""Start Rabit tracker """
env: Dict[str, Union[int, str]] = {'DMLC_NUM_WORKER': n_workers}
host = get_host_ip('auto')
rabit_context = RabitTracker(hostIP=host, n_workers=n_workers, use_logger=False)
env.update(rabit_context.worker_envs())
def _try_start_tracker(
n_workers: int, addrs: List[Optional[str]]
) -> Dict[str, Union[int, str]]:
env: Dict[str, Union[int, str]] = {"DMLC_NUM_WORKER": n_workers}
try:
rabit_context = RabitTracker(
hostIP=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False
)
env.update(rabit_context.worker_envs())
rabit_context.start(n_workers)
thread = Thread(target=rabit_context.join)
thread.daemon = True
thread.start()
except socket.error as e:
if len(addrs) < 2 or e.errno != 99:
raise
LOGGER.warning(
"Failed to bind address '%s', trying to use '%s' instead.",
str(addrs[0]),
str(addrs[1]),
)
env = _try_start_tracker(n_workers, addrs[1:])

rabit_context.start(n_workers)
thread = Thread(target=rabit_context.join)
thread.daemon = True
thread.start()
return env


def _start_tracker(
n_workers: int, addr_from_dask: Optional[str], addr_from_user: Optional[str]
) -> Dict[str, Any]:
"""Start Rabit tracker, recurse to try different addresses."""
env = _try_start_tracker(n_workers, [get_host_ip(addr_from_user), addr_from_dask])
return env


Expand Down Expand Up @@ -174,6 +210,7 @@ def __init__(self, args: List[bytes]) -> None:

def __enter__(self) -> None:
rabit.init(self.args)
assert rabit.is_distributed()
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
LOGGER.debug('-------------- rabit say hello ------------------')

def __exit__(self, *args: List) -> None:
Expand Down Expand Up @@ -805,12 +842,43 @@ def _dmatrix_from_list_of_parts(
return _create_dmatrix(**kwargs)


async def _get_rabit_args(n_workers: int, client: "distributed.Client") -> List[bytes]:
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
env = await client.run_on_scheduler(_start_tracker, n_workers)
async def _get_rabit_args(
n_workers: int, dconfig: Optional[Dict[str, Any]], client: "distributed.Client"
) -> List[bytes]:
"""Get rabit context arguments from data distribution in DaskDMatrix.

"""
# There are 3 possible different addresses:
# 1. Provided by user via dask.config
# 2. Guessed by xgboost `get_host_ip` function
# 3. From dask scheduler
# We try 1 and 3 if 1 is available, otherwise 2 and 3.
valid_config = ["scheduler_address"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When extending the Dask config system we typically create a new YAML with default values in and extend the Dask config with it.

See how we do it in dask-ctl.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for sharing, let me look into that.

# See if user config is available
if dconfig is not None:
for k in dconfig:
if k not in valid_config:
raise ValueError(f"Unknown configuration: {k}")
host_ip: Optional[str] = dconfig.get("scheduler_address", None)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious why this dconfig object is being passed around instead of accessing dask.config.get("xgboost.scheduler_address") directly.

Copy link
Member Author

@trivialfis trivialfis Jan 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't access it on the worker side for some reason (key not found). I read the doc from https://docs.dask.org/en/stable/configuration.html , it's not quite clear to me how to make sure everything is synced.

Copy link
Member Author

@trivialfis trivialfis Jan 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also read this section https://docs.dask.org/en/stable/configuration.html#downstream-libraries . In XGBoost dask is loaded lazily so having dask in init might be less desirable due to the size of dependency. So I went on and read this:

However, downstream libraries may choose alternative solutions, such as isolating their configuration within their library, rather than using the global dask.config system. All functions in the dask.config module also work with parameters, and do not need to mutate global state.

So far using this parameter passing approach seems to be the easiest way to get 1 parameter in.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. Generally the Dask config should be available on the scheduler/workers but that can be deployment dependent.

else:
host_ip = None
# Try address from dask scheduler, this might not work, see
# https://github.com/dask/dask-xgboost/pull/40
try:
sched_addr = distributed.comm.get_address_host(client.scheduler.address)
sched_addr = sched_addr.strip("/:")
except Exception: # pylint: disable=broad-except
sched_addr = None

env = await client.run_on_scheduler(_start_tracker, n_workers, sched_addr, host_ip)

rabit_args = [f"{k}={v}".encode() for k, v in env.items()]
return rabit_args


def _get_dask_config() -> Optional[Dict[str, Any]]:
return dask.config.get("xgboost", default=None)

# train and predict methods are supposed to be "functional", which meets the
# dask paradigm. But as a side effect, the `evals_result` in single-node API
# is no longer supported since it mutates the input parameter, and it's not
Expand All @@ -837,6 +905,7 @@ def _get_workers_from_data(
async def _train_async(
client: "distributed.Client",
global_config: Dict[str, Any],
dconfig: Optional[Dict[str, Any]],
params: Dict[str, Any],
dtrain: DaskDMatrix,
num_boost_round: int,
Expand All @@ -850,7 +919,7 @@ async def _train_async(
custom_metric: Optional[Metric],
) -> Optional[TrainReturnT]:
workers = _get_workers_from_data(dtrain, evals)
_rabit_args = await _get_rabit_args(len(workers), client)
_rabit_args = await _get_rabit_args(len(workers), dconfig, client)

if params.get("booster", None) == "gblinear":
raise NotImplementedError(
Expand Down Expand Up @@ -948,7 +1017,7 @@ def dispatched_train(


@_deprecate_positional_args
def train( # pylint: disable=unused-argument
def train( # pylint: disable=unused-argument
client: "distributed.Client",
params: Dict[str, Any],
dtrain: DaskDMatrix,
Expand Down Expand Up @@ -995,7 +1064,12 @@ def train( # pylint: disable=unused-argument
_assert_dask_support()
client = _xgb_get_client(client)
args = locals()
return client.sync(_train_async, global_config=config.get_config(), **args)
return client.sync(
_train_async,
global_config=config.get_config(),
dconfig=_get_dask_config(),
**args,
)


def _can_output_df(is_df: bool, output_shape: Tuple) -> bool:
Expand Down Expand Up @@ -1693,6 +1767,7 @@ async def _fit_async(
asynchronous=True,
client=self.client,
global_config=config.get_config(),
dconfig=_get_dask_config(),
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
Expand Down Expand Up @@ -1796,6 +1871,7 @@ async def _fit_async(
asynchronous=True,
client=self.client,
global_config=config.get_config(),
dconfig=_get_dask_config(),
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
Expand Down Expand Up @@ -1987,6 +2063,7 @@ async def _fit_async(
asynchronous=True,
client=self.client,
global_config=config.get_config(),
dconfig=_get_dask_config(),
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
Expand Down
3 changes: 2 additions & 1 deletion python-package/xgboost/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ def __init__(
logging.info('start listen on %s:%d', hostIP, self.port)

def __del__(self) -> None:
self.sock.close()
if hasattr(self, "sock"):
self.sock.close()

@staticmethod
def get_neighbor(rank: int, n_workers: int) -> List[int]:
Expand Down
4 changes: 2 additions & 2 deletions tests/python-gpu/test_gpu_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def test_data_initialization(self, local_cuda_cluster: LocalCUDACluster) -> None
m = dxgb.DaskDMatrix(client, X, y, feature_weights=fw)

workers = _get_client_workers(client)
rabit_args = client.sync(dxgb._get_rabit_args, len(workers), client)
rabit_args = client.sync(dxgb._get_rabit_args, len(workers), None, client)

def worker_fn(worker_addr: str, data_ref: Dict) -> None:
with dxgb.RabitContext(rabit_args):
Expand Down Expand Up @@ -473,7 +473,7 @@ def runit(

with Client(local_cuda_cluster) as client:
workers = _get_client_workers(client)
rabit_args = client.sync(dxgb._get_rabit_args, workers, client)
rabit_args = client.sync(dxgb._get_rabit_args, workers, None, client)
futures = client.map(runit,
workers,
pure=False,
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def run_rabit_ops(client, n_workers):
from xgboost import rabit

workers = _get_client_workers(client)
rabit_args = client.sync(_get_rabit_args, len(workers), client)
rabit_args = client.sync(_get_rabit_args, len(workers), None, client)
assert not rabit.is_distributed()
n_workers_from_dask = len(workers)
assert n_workers == n_workers_from_dask
Expand Down
12 changes: 10 additions & 2 deletions tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
pytest.skip(msg=tm.no_dask()['reason'], allow_module_level=True)

from distributed import LocalCluster, Client
import dask
import dask.dataframe as dd
import dask.array as da
from xgboost.dask import DaskDMatrix
Expand Down Expand Up @@ -1216,6 +1217,10 @@ def after_iteration(
os.remove(before_fname)
os.remove(after_fname)

with dask.config.set({'xgboost.foo': "bar"}):
with pytest.raises(ValueError):
xgb.dask.train(client, {}, dtrain, num_boost_round=4)

def run_updater_test(
self,
client: "Client",
Expand Down Expand Up @@ -1315,7 +1320,8 @@ def runit(
with Client(cluster) as client:
workers = _get_client_workers(client)
rabit_args = client.sync(
xgb.dask._get_rabit_args, len(workers), client)
xgb.dask._get_rabit_args, len(workers), None, client
)
futures = client.map(runit,
workers,
pure=False,
Expand Down Expand Up @@ -1443,7 +1449,9 @@ def test_no_duplicated_partition(self) -> None:
n_partitions = X.npartitions
m = xgb.dask.DaskDMatrix(client, X, y)
workers = _get_client_workers(client)
rabit_args = client.sync(xgb.dask._get_rabit_args, len(workers), client)
rabit_args = client.sync(
xgb.dask._get_rabit_args, len(workers), None, client
)
n_workers = len(workers)

def worker_fn(worker_addr: str, data_ref: Dict) -> None:
Expand Down