Skip to content

Commit

Permalink
more reliable ports check
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Mar 30, 2021
1 parent 671619c commit 05303c8
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 29 deletions.
48 changes: 33 additions & 15 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,37 @@ def _machines_to_worker_map(machines: str, worker_addresses: List[str]) -> Dict[
return out


def _worker_map_has_duplicates(worker_map: Dict[str, int]) -> bool:
"""Check if there are any duplicate IP-port pairs in a ``worker_map``."""
def _possibly_fix_worker_map_duplicates(worker_map: Dict[str, int], client: Client) -> Dict[str, int]:
"""Fix any duplicate IP-port pairs in a ``worker_map``"""
worker_map = deepcopy(worker_map)
workers_that_need_new_ports = []
host_to_port = defaultdict(set)
for worker, port in worker_map.items():
host = urlparse(worker).hostname
if port in host_to_port[host]:
return True
workers_that_need_new_ports.append(worker)
else:
host_to_port[host].add(port)
return False

# if any duplicates were found, search for new ports one by one
for worker in workers_that_need_new_ports:
_log_info(f"Searching for a LightGBM training port for worker '{worker}'")
host = urlparse(worker).hostname
retries_remaining = 100
while retries_remaining > 0:
retries_remaining -= 1
new_port = client.submit(
_find_random_open_port,
workers=[worker],
allow_other_workers=False,
pure=False
).result()
if new_port not in host_to_port[host]:
worker_map[worker] = new_port
host_to_port[host].add(new_port)
break

return worker_map


def _train(
Expand Down Expand Up @@ -379,21 +400,18 @@ def _train(
}
else:
_log_info("Finding random open ports for workers")
# this approach with client.run() is faster than searching for ports
# serially, but can produce duplicates sometimes. Try the fast approach one
# time, then pass it through a function that will use a slower but more reliable
# approach if duplicates are found.
worker_address_to_port = client.run(
_find_random_open_port,
workers=list(worker_addresses)
)
# handle the case where _find_random_open_port() produces duplicates
retries_left = 10
while _worker_map_has_duplicates(worker_address_to_port) and retries_left > 0:
retries_left -= 1
_log_warning(
"Searching for random ports generated duplicates. Trying again (will try %i more times after this)." % retries_left
)
worker_address_to_port = client.run(
_find_random_open_port,
workers=list(worker_addresses)
)
worker_address_to_port = _possibly_fix_worker_map_duplicates(
worker_map=worker_address_to_port,
client=client
)

machines = ','.join([
'%s:%d' % (urlparse(worker_address).hostname, port)
Expand Down
37 changes: 23 additions & 14 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,25 +377,34 @@ def test_find_random_open_port(client):
client.close(timeout=CLIENT_CLOSE_TIMEOUT)


def test_worker_map_has_duplicates():
map_with_duplicates = worker_map = {
'tcp://127.0.0.1:8786': 123,
'tcp://127.0.0.1:8788': 123,
'tcp://10.1.1.2:15001': 123
}
assert lgb.dask._worker_map_has_duplicates(map_with_duplicates)
def test_possibly_fix_worker_map(capsys, client):
client.wait_for_workers(2)
worker_addresses = list(client.scheduler_info()["workers"].keys())

retry_msg = 'Searching for a LightGBM training port for worker'

# should handle worker maps without any duplicates
map_without_duplicates = {
'tcp://127.0.0.1:8786': 12405,
'tcp://10.1.1.2:15001': 12405
worker_address: 12400 + i
for i, worker_address in enumerate(worker_addresses)
}
assert lgb.dask._worker_map_has_duplicates(map_without_duplicates) is False
patched_map = lgb.dask._possibly_fix_worker_map_duplicates(
client=client,
worker_map=map_without_duplicates
)
assert patched_map == map_without_duplicates
assert retry_msg not in capsys.readouterr().out

localcluster_map_without_duplicates = {
'tcp://127.0.0.1:708': 12405,
'tcp://127.0.0.1:312': 12405,
# should handle worker maps with duplicates
map_without_duplicates = {
worker_address: 12400
for i, worker_address in enumerate(worker_addresses)
}
assert lgb.dask._worker_map_has_duplicates(map_without_duplicates) is False
patched_map = lgb.dask._possibly_fix_worker_map_duplicates(
client=client,
worker_map=map_without_duplicates
)
assert retry_msg in capsys.readouterr().out


def test_training_does_not_fail_on_port_conflicts(client):
Expand Down

0 comments on commit 05303c8

Please sign in to comment.