Skip to content

Commit

Permalink
Add alternative worker commands, config options (#20)
Browse files Browse the repository at this point in the history
* add alternative worker commands, config options

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add support for Json worker args

* polls the scheduler for health check

* extra health checks for dask workers

* Handle dask not being on the path

* Revert sys.executable change

* Fix when worker args are not specified

* Clean up and add --cuda flag

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jacob Tomlinson <jtomlinson@nvidia.com>
  • Loading branch information
3 people authored Nov 9, 2023
1 parent 1ad69c9 commit 2e42701
Showing 1 changed file with 47 additions and 5 deletions.
52 changes: 47 additions & 5 deletions dask_databricks/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import os
import socket
Expand All @@ -20,7 +21,16 @@ def main():


@main.command()
def run():
@click.option('--worker-command', help='Custom worker command')
@click.option('--worker-args', help='Additional worker arguments')
@click.option(
"--cuda",
is_flag=True,
show_default=True,
default=False,
help="Use `dask cuda worker` from the dask-cuda package when starting workers.",
)
def run(worker_command, worker_args, cuda):
"""Run Dask processes on a Databricks cluster."""
log = get_logger()

Expand All @@ -38,20 +48,52 @@ def run():

if DB_IS_DRIVER == "TRUE":
log.info("This node is the Dask scheduler.")
subprocess.Popen(["dask", "scheduler", "--dashboard-address", ":8787,:8087"])
scheduler_process = subprocess.Popen(["dask", "scheduler", "--dashboard-address", ":8787,:8087"])
time.sleep(5) # give the scheduler time to start
if scheduler_process.poll() is not None:
log.error("Scheduler process has exited prematurely.")
sys.exit(1)
else:
# Specify the same port for all workers
worker_port = 8786
log.info("This node is a Dask worker.")
log.info(f"Connecting to Dask scheduler at {DB_DRIVER_IP}:8786")
log.info(f"Connecting to Dask scheduler at {DB_DRIVER_IP}:{worker_port}")
while True:
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((DB_DRIVER_IP, 8786))
sock.connect((DB_DRIVER_IP, worker_port))
sock.close()
break
except ConnectionRefusedError:
log.info("Scheduler not available yet. Waiting...")
time.sleep(1)
subprocess.Popen(["dask", "worker", f"tcp://{DB_DRIVER_IP}:8786"])

# Construct the worker command
if worker_command:
worker_command = worker_command.split()
elif cuda:
worker_command = ["dask", "cuda", "worker"]
else:
worker_command = ["dask", "worker"]

if worker_args:
try:
# Try to decode the JSON-encoded worker_args
worker_args_list = json.loads(worker_args)
if not isinstance(worker_args_list, list):
raise ValueError("The JSON-encoded worker_args must be a list.")
except json.JSONDecodeError:
# If decoding as JSON fails, split worker_args by spaces
worker_args_list = worker_args.split()

worker_command.extend(worker_args_list)
worker_command.append(f"tcp://{DB_DRIVER_IP}:{worker_port}")

worker_process = subprocess.Popen(worker_command)
time.sleep(5) # give the worker time to start
if worker_process.poll() is not None:
log.error("Worker process has exited prematurely.")
sys.exit(1)


if __name__ == "__main__":
Expand Down

0 comments on commit 2e42701

Please sign in to comment.