diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 2cc5b264eb3c..2ae1a35ebd28 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -170,16 +170,37 @@ def _machines_to_worker_map(machines: str, worker_addresses: List[str]) -> Dict[ return out -def _worker_map_has_duplicates(worker_map: Dict[str, int]) -> bool: - """Check if there are any duplicate IP-port pairs in a ``worker_map``.""" +def _possibly_fix_worker_map_duplicates(worker_map: Dict[str, int], client: Client) -> Dict[str, int]: + """Fix any duplicate IP-port pairs in a ``worker_map``""" + worker_map = deepcopy(worker_map) + workers_that_need_new_ports = [] host_to_port = defaultdict(set) for worker, port in worker_map.items(): host = urlparse(worker).hostname if port in host_to_port[host]: - return True + workers_that_need_new_ports.append(worker) else: host_to_port[host].add(port) - return False + + # if any duplicates were found, search for new ports one by one + for worker in workers_that_need_new_ports: + _log_info(f"Searching for a LightGBM training port for worker '{worker}'") + host = urlparse(worker).hostname + retries_remaining = 100 + while retries_remaining > 0: + retries_remaining -= 1 + new_port = client.submit( + _find_random_open_port, + workers=[worker], + allow_other_workers=False, + pure=False + ).result() + if new_port not in host_to_port[host]: + worker_map[worker] = new_port + host_to_port[host].add(new_port) + break + + return worker_map def _train( @@ -379,21 +400,18 @@ def _train( } else: _log_info("Finding random open ports for workers") + # this approach with client.run() is faster than searching for ports + # serially, but can produce duplicates sometimes. Try the fast approach one + # time, then pass it through a function that will use a slower but more reliable + # approach if duplicates are found. worker_address_to_port = client.run( _find_random_open_port, workers=list(worker_addresses) ) - # handle the case where _find_random_open_port() produces duplicates - retries_left = 10 - while _worker_map_has_duplicates(worker_address_to_port) and retries_left > 0: - retries_left -= 1 - _log_warning( - "Searching for random ports generated duplicates. Trying again (will try %i more times after this)." % retries_left - ) - worker_address_to_port = client.run( - _find_random_open_port, - workers=list(worker_addresses) - ) + worker_address_to_port = _possibly_fix_worker_map_duplicates( + worker_map=worker_address_to_port, + client=client + ) machines = ','.join([ '%s:%d' % (urlparse(worker_address).hostname, port) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index c2cd44554f0c..bcb2061681d1 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -377,25 +377,34 @@ def test_find_random_open_port(client): client.close(timeout=CLIENT_CLOSE_TIMEOUT) -def test_worker_map_has_duplicates(): - map_with_duplicates = worker_map = { - 'tcp://127.0.0.1:8786': 123, - 'tcp://127.0.0.1:8788': 123, - 'tcp://10.1.1.2:15001': 123 - } - assert lgb.dask._worker_map_has_duplicates(map_with_duplicates) +def test_possibly_fix_worker_map(capsys, client): + client.wait_for_workers(2) + worker_addresses = list(client.scheduler_info()["workers"].keys()) + retry_msg = 'Searching for a LightGBM training port for worker' + + # should handle worker maps without any duplicates map_without_duplicates = { - 'tcp://127.0.0.1:8786': 12405, - 'tcp://10.1.1.2:15001': 12405 + worker_address: 12400 + i + for i, worker_address in enumerate(worker_addresses) } - assert lgb.dask._worker_map_has_duplicates(map_without_duplicates) is False + patched_map = lgb.dask._possibly_fix_worker_map_duplicates( + client=client, + worker_map=map_without_duplicates + ) + assert patched_map == map_without_duplicates + assert retry_msg not in capsys.readouterr().out - localcluster_map_without_duplicates = { - 'tcp://127.0.0.1:708': 12405, - 'tcp://127.0.0.1:312': 12405, + # should handle worker maps with duplicates + map_without_duplicates = { + worker_address: 12400 + for i, worker_address in enumerate(worker_addresses) } - assert lgb.dask._worker_map_has_duplicates(map_without_duplicates) is False + patched_map = lgb.dask._possibly_fix_worker_map_duplicates( + client=client, + worker_map=map_without_duplicates + ) + assert retry_msg in capsys.readouterr().out def test_training_does_not_fail_on_port_conflicts(client):