Skip to content

Commit

Permalink
[WIP][Serve] Enable launching multiple external LB on controller.
Browse files Browse the repository at this point in the history
  • Loading branch information
cblmemo committed Nov 14, 2024
1 parent 1c04aef commit fb83d39
Show file tree
Hide file tree
Showing 10 changed files with 434 additions and 63 deletions.
31 changes: 31 additions & 0 deletions examples/serve/external-lb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# SkyServe YAML to run multiple Load Balancers in different region.

name: multi-lb

service:
readiness_probe:
path: /health
initial_delay_seconds: 20
replicas: 2
external_load_balancers:
- resources:
# cloud: aws
# region: us-east-2
cloud: gcp
region: us-east1
load_balancing_policy: round_robin
- resources:
# cloud: aws
# region: ap-northeast-1
cloud: gcp
region: asia-northeast1
load_balancing_policy: round_robin

resources:
cloud: aws
ports: 8080
cpus: 2+

workdir: examples/serve/http_server

run: python3 server.py
2 changes: 2 additions & 0 deletions sky/serve/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Constants used for SkyServe."""

CONTROLLER_TEMPLATE = 'sky-serve-controller.yaml.j2'
EXTERNAL_LB_TEMPLATE = 'sky-serve-external-load-balancer.yaml.j2'

SKYSERVE_METADATA_DIR = '~/.sky/serve'

Expand Down Expand Up @@ -79,6 +80,7 @@
# Default port range start for controller and load balancer. Ports will be
# automatically generated from this start port.
CONTROLLER_PORT_START = 20001
CONTROLLER_PORT_RANGE = '20001-20020'
LOAD_BALANCER_PORT_START = 30001
LOAD_BALANCER_PORT_RANGE = '30001-30020'

Expand Down
45 changes: 38 additions & 7 deletions sky/serve/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,21 @@ def up(
vars_to_fill,
output_path=controller_file.name)
controller_task = task_lib.Task.from_yaml(controller_file.name)
# TODO(tian): Currently we exposed the controller port to the public
# network, for external load balancer to access. We should implement
# encrypted communication between controller and load balancer, and
# not expose the controller to the public network.
assert task.service is not None
ports_to_open_in_controller = (serve_constants.CONTROLLER_PORT_RANGE
if task.service.external_load_balancers
is not None else
serve_constants.LOAD_BALANCER_PORT_RANGE)
# TODO(tian): Probably run another sky.launch after we get the load
# balancer port from the controller? So we don't need to open so many
# ports here. Or, we should have a nginx traffic control to refuse
# any connection to the unregistered ports.
controller_resources = {
r.copy(ports=[serve_constants.LOAD_BALANCER_PORT_RANGE])
r.copy(ports=[ports_to_open_in_controller])
for r in controller_resources
}
controller_task.set_resources(controller_resources)
Expand Down Expand Up @@ -267,12 +276,18 @@ def up(
'Failed to spin up the service. Please '
'check the logs above for more details.') from None
else:
lb_port = serve_utils.load_service_initialization_result(
lb_port_payload)
endpoint = backend_utils.get_endpoints(
controller_handle.cluster_name, lb_port,
skip_status_check=True).get(lb_port)
assert endpoint is not None, 'Did not get endpoint for controller.'
if task.service.external_load_balancers is None:
lb_port = serve_utils.load_service_initialization_result(
lb_port_payload)
endpoint = backend_utils.get_endpoints(
controller_handle.cluster_name,
lb_port,
skip_status_check=True).get(lb_port)
assert endpoint is not None, (
'Did not get endpoint for controller.')
else:
endpoint = (
'Please query with sky serve status for the endpoint.')

sky_logging.print(
f'{fore.CYAN}Service name: '
Expand Down Expand Up @@ -320,6 +335,7 @@ def update(
task: sky.Task to update.
service_name: Name of the service.
"""
# TODO(tian): Implement update of external LBs.
_validate_service_task(task)
# Always apply the policy again here, even though it might have been applied
# in the CLI. This is to ensure that we apply the policy to the final DAG
Expand Down Expand Up @@ -585,6 +601,8 @@ def status(
'requested_resources_str': (str) str representation of
requested resources,
'replica_info': (List[Dict[str, Any]]) replica information,
'external_lb_info': (Dict[str, Any]) external load balancer
information,
}
Each entry in replica_info has the following fields:
Expand All @@ -600,6 +618,17 @@ def status(
'handle': (ResourceHandle) handle of the replica cluster,
}
Each entry in external_lb_info has the following fields:
.. code-block:: python
{
'lb_id': (int) index of the external load balancer,
'cluster_name': (str) cluster name of the external load balancer,
'port': (int) port of the external load balancer,
'endpoint': (str) endpoint of the external load balancer,
}
For possible service statuses and replica statuses, please refer to
sky.cli.serve_status.
Expand Down Expand Up @@ -695,6 +724,8 @@ def tail_logs(
sky.exceptions.ClusterNotUpError: the sky serve controller is not up.
ValueError: arguments not valid, or failed to tail the logs.
"""
# TODO(tian): Support tail logs for external load balancer. It should be
# similar to tail replica logs.
if isinstance(target, str):
target = serve_utils.ServiceComponent(target)
if not isinstance(target, serve_utils.ServiceComponent):
Expand Down
67 changes: 67 additions & 0 deletions sky/serve/serve_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from sky.serve import constants
from sky.utils import db_utils
from sky import exceptions

if typing.TYPE_CHECKING:
from sky.serve import replica_managers
Expand Down Expand Up @@ -58,6 +59,13 @@ def create_table(cursor: 'sqlite3.Cursor', conn: 'sqlite3.Connection') -> None:
service_name TEXT,
spec BLOB,
PRIMARY KEY (service_name, version))""")
cursor.execute("""\
CREATE TABLE IF NOT EXISTS external_load_balancers (
lb_id INTEGER,
service_name TEXT,
cluster_name TEXT,
port INTEGER,
PRIMARY KEY (service_name, lb_id))""")
conn.commit()


Expand Down Expand Up @@ -538,3 +546,62 @@ def delete_all_versions(service_name: str) -> None:
"""\
DELETE FROM version_specs
WHERE service_name=(?)""", (service_name,))


# === External Load Balancer functions ===
# TODO(tian): Add a status column.
def add_external_load_balancer(service_name: str, lb_id: int, cluster_name: str,
port: int) -> None:
"""Adds an external load balancer to the database."""
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
INSERT INTO external_load_balancers
(service_name, lb_id, cluster_name, port)
VALUES (?, ?, ?, ?)""", (service_name, lb_id, cluster_name, port))


def _get_external_load_balancer_from_row(row) -> Dict[str, Any]:
from sky import core # pylint: disable=import-outside-toplevel

# TODO(tian): Temporary workaround to avoid circular import.
# This should be fixed.
lb_id, cluster_name, port = row[:3]
try:
endpoint = core.endpoints(cluster_name, port)[port]
except exceptions.ClusterNotUpError:
# TODO(tian): Currently, when this cluster is not in the UP status,
# the endpoint query will raise an cluster is not up error. We should
# implement a status for external lbs as well and returns a '-' when
# it is still provisioning.
endpoint = '-'
return {
'lb_id': lb_id,
'cluster_name': cluster_name,
'port': port,
'endpoint': endpoint,
}


def get_external_load_balancers(service_name: str) -> List[Dict[str, Any]]:
"""Gets all external load balancers of a service."""
with db_utils.safe_cursor(_DB_PATH) as cursor:
rows = cursor.execute(
"""\
SELECT lb_id, cluster_name, port FROM external_load_balancers
WHERE service_name=(?)""", (service_name,)).fetchall()
external_load_balancers = []
for row in rows:
external_load_balancers.append(
_get_external_load_balancer_from_row(row))
return external_load_balancers


def remove_external_load_balancer(service_name: str, lb_id: int) -> None:
"""Removes an external load balancer from the database."""
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
DELETE FROM external_load_balancers
WHERE service_name=(?)
AND lb_id=(?)""", (service_name, lb_id))
84 changes: 77 additions & 7 deletions sky/serve/serve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
# Max number of replicas to show in `sky serve status` by default.
# If user wants to see all replicas, use `sky serve status --all`.
_REPLICA_TRUNC_NUM = 10
# Similar to _REPLICA_TRUNC_NUM, but for external load balancers.
_EXTERNAL_LB_TRUNC_NUM = 10


class ServiceComponent(enum.Enum):
Expand Down Expand Up @@ -224,6 +226,13 @@ def generate_remote_load_balancer_log_file_name(service_name: str) -> str:
return os.path.join(dir_name, 'load_balancer.log')


def generate_remote_external_load_balancer_log_file_name(
service_name: str, lb_id: int) -> str:
dir_name = generate_remote_service_dir_name(service_name)
# Don't expand here since it is used for remote machine.
return os.path.join(dir_name, f'external_load_balancer_{lb_id}.log')


def generate_replica_launch_log_file_name(service_name: str,
replica_id: int) -> str:
dir_name = generate_remote_service_dir_name(service_name)
Expand Down Expand Up @@ -354,7 +363,8 @@ def terminate_replica(service_name: str, replica_id: int, purge: bool) -> str:

def _get_service_status(
service_name: str,
with_replica_info: bool = True) -> Optional[Dict[str, Any]]:
with_replica_info: bool = True,
with_external_lb_info: bool = True) -> Optional[Dict[str, Any]]:
"""Get the status dict of the service.
Args:
Expand All @@ -373,6 +383,9 @@ def _get_service_status(
info.to_info_dict(with_handle=True)
for info in serve_state.get_replica_infos(service_name)
]
if with_external_lb_info:
record['external_lb_info'] = serve_state.get_external_load_balancers(
service_name)
return record


Expand Down Expand Up @@ -457,7 +470,8 @@ def terminate_services(service_names: Optional[List[str]], purge: bool) -> str:
messages = []
for service_name in service_names:
service_status = _get_service_status(service_name,
with_replica_info=False)
with_replica_info=False,
with_external_lb_info=False)
if (service_status is not None and service_status['status']
== serve_state.ServiceStatus.SHUTTING_DOWN):
# Already scheduled to be terminated.
Expand Down Expand Up @@ -810,10 +824,14 @@ def format_service_table(service_records: List[Dict[str, Any]],
service_table = log_utils.create_table(service_columns)

replica_infos = []
external_lb_infos = []
for record in service_records:
for replica in record['replica_info']:
replica['service_name'] = record['name']
replica_infos.append(replica)
for external_lb in record['external_lb_info']:
external_lb['service_name'] = record['name']
external_lb_infos.append(external_lb)

service_name = record['name']
version = ','.join(
Expand All @@ -824,7 +842,12 @@ def format_service_table(service_records: List[Dict[str, Any]],
service_status = record['status']
status_str = service_status.colored_str()
replicas = _get_replicas(record)
endpoint = get_endpoint(record)
if record['external_lb_info']:
# Don't show endpoint for services with external load balancers.
# TODO(tian): Add automatic DNS record creation and show domain here
endpoint = '-'
else:
endpoint = get_endpoint(record)
policy = record['policy']
requested_resources_str = record['requested_resources_str']

Expand All @@ -841,10 +864,20 @@ def format_service_table(service_records: List[Dict[str, Any]],
service_table.add_row(service_values)

replica_table = _format_replica_table(replica_infos, show_all)
return (f'{service_table}\n'
f'\n{colorama.Fore.CYAN}{colorama.Style.BRIGHT}'
f'Service Replicas{colorama.Style.RESET_ALL}\n'
f'{replica_table}')

final_table = (f'{service_table}\n'
f'\n{colorama.Fore.CYAN}{colorama.Style.BRIGHT}'
f'Service Replicas{colorama.Style.RESET_ALL}\n'
f'{replica_table}')

if external_lb_infos:
external_lb_table = _format_external_lb_table(external_lb_infos,
show_all)
final_table += (f'\n\n{colorama.Fore.CYAN}{colorama.Style.BRIGHT}'
f'External Load Balancers{colorama.Style.RESET_ALL}\n'
f'{external_lb_table}')

return final_table


def _format_replica_table(replica_records: List[Dict[str, Any]],
Expand Down Expand Up @@ -905,6 +938,43 @@ def _format_replica_table(replica_records: List[Dict[str, Any]],
return f'{replica_table}{truncate_hint}'


def _format_external_lb_table(external_lb_records: List[Dict[str, Any]],
show_all: bool) -> str:
if not external_lb_records:
return 'No existing external load balancers.'

external_lb_columns = ['SERVICE_NAME', 'ID', 'ENDPOINT']
if show_all:
external_lb_columns.append('PORT')
external_lb_columns.append('CLUSTER_NAME')
external_lb_table = log_utils.create_table(external_lb_columns)

truncate_hint = ''
if not show_all:
if len(external_lb_records) > _EXTERNAL_LB_TRUNC_NUM:
truncate_hint = (
'\n... (use --all to show all external load balancers)')
external_lb_records = external_lb_records[:_EXTERNAL_LB_TRUNC_NUM]

for record in external_lb_records:
service_name = record['service_name']
external_lb_id = record['lb_id']
endpoint = record['endpoint']
port = record['port']
cluster_name = record['cluster_name']

external_lb_values = [
service_name,
external_lb_id,
endpoint,
]
if show_all:
external_lb_values.extend([port, cluster_name])
external_lb_table.add_row(external_lb_values)

return f'{external_lb_table}{truncate_hint}'


# =========================== CodeGen for Sky Serve ===========================


Expand Down
Loading

0 comments on commit fb83d39

Please sign in to comment.