diff --git a/examples/serve/external-lb.yaml b/examples/serve/external-lb.yaml new file mode 100644 index 00000000000..e96d5e0fda0 --- /dev/null +++ b/examples/serve/external-lb.yaml @@ -0,0 +1,31 @@ +# SkyServe YAML to run multiple Load Balancers in different region. + +name: multi-lb + +service: + readiness_probe: + path: /health + initial_delay_seconds: 20 + replicas: 2 + external_load_balancers: + - resources: + # cloud: aws + # region: us-east-2 + cloud: gcp + region: us-east1 + load_balancing_policy: round_robin + - resources: + # cloud: aws + # region: ap-northeast-1 + cloud: gcp + region: asia-northeast1 + load_balancing_policy: round_robin + +resources: + cloud: aws + ports: 8080 + cpus: 2+ + +workdir: examples/serve/http_server + +run: python3 server.py diff --git a/sky/serve/constants.py b/sky/serve/constants.py index 3974293190e..813aa0d6d0e 100644 --- a/sky/serve/constants.py +++ b/sky/serve/constants.py @@ -1,6 +1,7 @@ """Constants used for SkyServe.""" CONTROLLER_TEMPLATE = 'sky-serve-controller.yaml.j2' +EXTERNAL_LB_TEMPLATE = 'sky-serve-external-load-balancer.yaml.j2' SKYSERVE_METADATA_DIR = '~/.sky/serve' @@ -79,6 +80,7 @@ # Default port range start for controller and load balancer. Ports will be # automatically generated from this start port. CONTROLLER_PORT_START = 20001 +CONTROLLER_PORT_RANGE = '20001-20020' LOAD_BALANCER_PORT_START = 30001 LOAD_BALANCER_PORT_RANGE = '30001-30020' diff --git a/sky/serve/core.py b/sky/serve/core.py index abf9bfbc719..a006500679f 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -174,12 +174,21 @@ def up( vars_to_fill, output_path=controller_file.name) controller_task = task_lib.Task.from_yaml(controller_file.name) + # TODO(tian): Currently we exposed the controller port to the public + # network, for external load balancer to access. We should implement + # encrypted communication between controller and load balancer, and + # not expose the controller to the public network. + assert task.service is not None + ports_to_open_in_controller = (serve_constants.CONTROLLER_PORT_RANGE + if task.service.external_load_balancers + is not None else + serve_constants.LOAD_BALANCER_PORT_RANGE) # TODO(tian): Probably run another sky.launch after we get the load # balancer port from the controller? So we don't need to open so many # ports here. Or, we should have a nginx traffic control to refuse # any connection to the unregistered ports. controller_resources = { - r.copy(ports=[serve_constants.LOAD_BALANCER_PORT_RANGE]) + r.copy(ports=[ports_to_open_in_controller]) for r in controller_resources } controller_task.set_resources(controller_resources) @@ -267,12 +276,18 @@ def up( 'Failed to spin up the service. Please ' 'check the logs above for more details.') from None else: - lb_port = serve_utils.load_service_initialization_result( - lb_port_payload) - endpoint = backend_utils.get_endpoints( - controller_handle.cluster_name, lb_port, - skip_status_check=True).get(lb_port) - assert endpoint is not None, 'Did not get endpoint for controller.' + if task.service.external_load_balancers is None: + lb_port = serve_utils.load_service_initialization_result( + lb_port_payload) + endpoint = backend_utils.get_endpoints( + controller_handle.cluster_name, + lb_port, + skip_status_check=True).get(lb_port) + assert endpoint is not None, ( + 'Did not get endpoint for controller.') + else: + endpoint = ( + 'Please query with sky serve status for the endpoint.') sky_logging.print( f'{fore.CYAN}Service name: ' @@ -320,6 +335,7 @@ def update( task: sky.Task to update. service_name: Name of the service. """ + # TODO(tian): Implement update of external LBs. _validate_service_task(task) # Always apply the policy again here, even though it might have been applied # in the CLI. This is to ensure that we apply the policy to the final DAG @@ -585,6 +601,8 @@ def status( 'requested_resources_str': (str) str representation of requested resources, 'replica_info': (List[Dict[str, Any]]) replica information, + 'external_lb_info': (Dict[str, Any]) external load balancer + information, } Each entry in replica_info has the following fields: @@ -600,6 +618,17 @@ def status( 'handle': (ResourceHandle) handle of the replica cluster, } + Each entry in external_lb_info has the following fields: + + .. code-block:: python + + { + 'lb_id': (int) index of the external load balancer, + 'cluster_name': (str) cluster name of the external load balancer, + 'port': (int) port of the external load balancer, + 'endpoint': (str) endpoint of the external load balancer, + } + For possible service statuses and replica statuses, please refer to sky.cli.serve_status. @@ -695,6 +724,8 @@ def tail_logs( sky.exceptions.ClusterNotUpError: the sky serve controller is not up. ValueError: arguments not valid, or failed to tail the logs. """ + # TODO(tian): Support tail logs for external load balancer. It should be + # similar to tail replica logs. if isinstance(target, str): target = serve_utils.ServiceComponent(target) if not isinstance(target, serve_utils.ServiceComponent): diff --git a/sky/serve/serve_state.py b/sky/serve/serve_state.py index 333e0138fb4..ed8c06e5459 100644 --- a/sky/serve/serve_state.py +++ b/sky/serve/serve_state.py @@ -12,6 +12,7 @@ from sky.serve import constants from sky.utils import db_utils +from sky import exceptions if typing.TYPE_CHECKING: from sky.serve import replica_managers @@ -58,6 +59,13 @@ def create_table(cursor: 'sqlite3.Cursor', conn: 'sqlite3.Connection') -> None: service_name TEXT, spec BLOB, PRIMARY KEY (service_name, version))""") + cursor.execute("""\ + CREATE TABLE IF NOT EXISTS external_load_balancers ( + lb_id INTEGER, + service_name TEXT, + cluster_name TEXT, + port INTEGER, + PRIMARY KEY (service_name, lb_id))""") conn.commit() @@ -538,3 +546,62 @@ def delete_all_versions(service_name: str) -> None: """\ DELETE FROM version_specs WHERE service_name=(?)""", (service_name,)) + + +# === External Load Balancer functions === +# TODO(tian): Add a status column. +def add_external_load_balancer(service_name: str, lb_id: int, cluster_name: str, + port: int) -> None: + """Adds an external load balancer to the database.""" + with db_utils.safe_cursor(_DB_PATH) as cursor: + cursor.execute( + """\ + INSERT INTO external_load_balancers + (service_name, lb_id, cluster_name, port) + VALUES (?, ?, ?, ?)""", (service_name, lb_id, cluster_name, port)) + + +def _get_external_load_balancer_from_row(row) -> Dict[str, Any]: + from sky import core # pylint: disable=import-outside-toplevel + + # TODO(tian): Temporary workaround to avoid circular import. + # This should be fixed. + lb_id, cluster_name, port = row[:3] + try: + endpoint = core.endpoints(cluster_name, port)[port] + except exceptions.ClusterNotUpError: + # TODO(tian): Currently, when this cluster is not in the UP status, + # the endpoint query will raise an cluster is not up error. We should + # implement a status for external lbs as well and returns a '-' when + # it is still provisioning. + endpoint = '-' + return { + 'lb_id': lb_id, + 'cluster_name': cluster_name, + 'port': port, + 'endpoint': endpoint, + } + + +def get_external_load_balancers(service_name: str) -> List[Dict[str, Any]]: + """Gets all external load balancers of a service.""" + with db_utils.safe_cursor(_DB_PATH) as cursor: + rows = cursor.execute( + """\ + SELECT lb_id, cluster_name, port FROM external_load_balancers + WHERE service_name=(?)""", (service_name,)).fetchall() + external_load_balancers = [] + for row in rows: + external_load_balancers.append( + _get_external_load_balancer_from_row(row)) + return external_load_balancers + + +def remove_external_load_balancer(service_name: str, lb_id: int) -> None: + """Removes an external load balancer from the database.""" + with db_utils.safe_cursor(_DB_PATH) as cursor: + cursor.execute( + """\ + DELETE FROM external_load_balancers + WHERE service_name=(?) + AND lb_id=(?)""", (service_name, lb_id)) diff --git a/sky/serve/serve_utils.py b/sky/serve/serve_utils.py index 3be41cc1593..bd500432467 100644 --- a/sky/serve/serve_utils.py +++ b/sky/serve/serve_utils.py @@ -56,6 +56,8 @@ # Max number of replicas to show in `sky serve status` by default. # If user wants to see all replicas, use `sky serve status --all`. _REPLICA_TRUNC_NUM = 10 +# Similar to _REPLICA_TRUNC_NUM, but for external load balancers. +_EXTERNAL_LB_TRUNC_NUM = 10 class ServiceComponent(enum.Enum): @@ -224,6 +226,13 @@ def generate_remote_load_balancer_log_file_name(service_name: str) -> str: return os.path.join(dir_name, 'load_balancer.log') +def generate_remote_external_load_balancer_log_file_name( + service_name: str, lb_id: int) -> str: + dir_name = generate_remote_service_dir_name(service_name) + # Don't expand here since it is used for remote machine. + return os.path.join(dir_name, f'external_load_balancer_{lb_id}.log') + + def generate_replica_launch_log_file_name(service_name: str, replica_id: int) -> str: dir_name = generate_remote_service_dir_name(service_name) @@ -354,7 +363,8 @@ def terminate_replica(service_name: str, replica_id: int, purge: bool) -> str: def _get_service_status( service_name: str, - with_replica_info: bool = True) -> Optional[Dict[str, Any]]: + with_replica_info: bool = True, + with_external_lb_info: bool = True) -> Optional[Dict[str, Any]]: """Get the status dict of the service. Args: @@ -373,6 +383,9 @@ def _get_service_status( info.to_info_dict(with_handle=True) for info in serve_state.get_replica_infos(service_name) ] + if with_external_lb_info: + record['external_lb_info'] = serve_state.get_external_load_balancers( + service_name) return record @@ -457,7 +470,8 @@ def terminate_services(service_names: Optional[List[str]], purge: bool) -> str: messages = [] for service_name in service_names: service_status = _get_service_status(service_name, - with_replica_info=False) + with_replica_info=False, + with_external_lb_info=False) if (service_status is not None and service_status['status'] == serve_state.ServiceStatus.SHUTTING_DOWN): # Already scheduled to be terminated. @@ -810,10 +824,14 @@ def format_service_table(service_records: List[Dict[str, Any]], service_table = log_utils.create_table(service_columns) replica_infos = [] + external_lb_infos = [] for record in service_records: for replica in record['replica_info']: replica['service_name'] = record['name'] replica_infos.append(replica) + for external_lb in record['external_lb_info']: + external_lb['service_name'] = record['name'] + external_lb_infos.append(external_lb) service_name = record['name'] version = ','.join( @@ -824,7 +842,12 @@ def format_service_table(service_records: List[Dict[str, Any]], service_status = record['status'] status_str = service_status.colored_str() replicas = _get_replicas(record) - endpoint = get_endpoint(record) + if record['external_lb_info']: + # Don't show endpoint for services with external load balancers. + # TODO(tian): Add automatic DNS record creation and show domain here + endpoint = '-' + else: + endpoint = get_endpoint(record) policy = record['policy'] requested_resources_str = record['requested_resources_str'] @@ -841,10 +864,20 @@ def format_service_table(service_records: List[Dict[str, Any]], service_table.add_row(service_values) replica_table = _format_replica_table(replica_infos, show_all) - return (f'{service_table}\n' - f'\n{colorama.Fore.CYAN}{colorama.Style.BRIGHT}' - f'Service Replicas{colorama.Style.RESET_ALL}\n' - f'{replica_table}') + + final_table = (f'{service_table}\n' + f'\n{colorama.Fore.CYAN}{colorama.Style.BRIGHT}' + f'Service Replicas{colorama.Style.RESET_ALL}\n' + f'{replica_table}') + + if external_lb_infos: + external_lb_table = _format_external_lb_table(external_lb_infos, + show_all) + final_table += (f'\n\n{colorama.Fore.CYAN}{colorama.Style.BRIGHT}' + f'External Load Balancers{colorama.Style.RESET_ALL}\n' + f'{external_lb_table}') + + return final_table def _format_replica_table(replica_records: List[Dict[str, Any]], @@ -905,6 +938,43 @@ def _format_replica_table(replica_records: List[Dict[str, Any]], return f'{replica_table}{truncate_hint}' +def _format_external_lb_table(external_lb_records: List[Dict[str, Any]], + show_all: bool) -> str: + if not external_lb_records: + return 'No existing external load balancers.' + + external_lb_columns = ['SERVICE_NAME', 'ID', 'ENDPOINT'] + if show_all: + external_lb_columns.append('PORT') + external_lb_columns.append('CLUSTER_NAME') + external_lb_table = log_utils.create_table(external_lb_columns) + + truncate_hint = '' + if not show_all: + if len(external_lb_records) > _EXTERNAL_LB_TRUNC_NUM: + truncate_hint = ( + '\n... (use --all to show all external load balancers)') + external_lb_records = external_lb_records[:_EXTERNAL_LB_TRUNC_NUM] + + for record in external_lb_records: + service_name = record['service_name'] + external_lb_id = record['lb_id'] + endpoint = record['endpoint'] + port = record['port'] + cluster_name = record['cluster_name'] + + external_lb_values = [ + service_name, + external_lb_id, + endpoint, + ] + if show_all: + external_lb_values.extend([port, cluster_name]) + external_lb_table.add_row(external_lb_values) + + return f'{external_lb_table}{truncate_hint}' + + # =========================== CodeGen for Sky Serve =========================== diff --git a/sky/serve/service.py b/sky/serve/service.py index 0a1c7f34766..999d82cb3b8 100644 --- a/sky/serve/service.py +++ b/sky/serve/service.py @@ -7,25 +7,30 @@ import os import pathlib import shutil +import subprocess +import tempfile import time import traceback -from typing import Dict +from typing import Any, Dict import filelock from sky import authentication from sky import exceptions +from sky import resources as resources_lib from sky import sky_logging from sky import task as task_lib from sky.backends import backend_utils from sky.backends import cloud_vm_ray_backend from sky.serve import constants +from sky.skylet import constants as skylet_constants from sky.serve import controller from sky.serve import load_balancer from sky.serve import replica_managers from sky.serve import serve_state from sky.serve import serve_utils from sky.utils import common_utils +from sky.utils import controller_utils from sky.utils import subprocess_utils from sky.utils import ux_utils @@ -89,6 +94,8 @@ def _cleanup(service_name: str) -> bool: replica_infos = serve_state.get_replica_infos(service_name) info2proc: Dict[replica_managers.ReplicaInfo, multiprocessing.Process] = dict() + external_lbs = serve_state.get_external_load_balancers(service_name) + lbid2proc: Dict[int, multiprocessing.Process] = dict() for info in replica_infos: p = multiprocessing.Process(target=replica_managers.terminate_cluster, args=(info.cluster_name,)) @@ -101,6 +108,14 @@ def _cleanup(service_name: str) -> bool: replica_managers.ProcessStatus.RUNNING) serve_state.add_or_update_replica(service_name, info.replica_id, info) logger.info(f'Terminating replica {info.replica_id} ...') + for external_lb_record in external_lbs: + lb_cluster_name = external_lb_record['cluster_name'] + lb_id = external_lb_record['lb_id'] + p = multiprocessing.Process(target=replica_managers.terminate_cluster, + args=(lb_cluster_name,)) + p.start() + lbid2proc[lb_id] = p + logger.info(f'Terminating external load balancer {lb_cluster_name} ...') for info, p in info2proc.items(): p.join() if p.exitcode == 0: @@ -114,6 +129,15 @@ def _cleanup(service_name: str) -> bool: info) failed = True logger.error(f'Replica {info.replica_id} failed to terminate.') + for lb_id, p in lbid2proc.items(): + p.join() + if p.exitcode == 0: + serve_state.remove_external_load_balancer(service_name, lb_id) + logger.info( + f'External load balancer {lb_id} terminated successfully.') + else: + failed = True + logger.error(f'External load balancer {lb_id} failed to terminate.') versions = serve_state.get_service_versions(service_name) serve_state.remove_service_versions(service_name) @@ -130,6 +154,50 @@ def cleanup_version_storage(version: int) -> bool: return failed +def _get_external_lb_cluster_name(service_name: str, lb_id: int) -> str: + return f'sky-{service_name}-lb-{lb_id}' + + +def _start_external_load_balancer(service_name: str, controller_addr: str, + lb_id: int, lb_port: int, lb_policy: str, + lb_resources: Dict[str, Any]) -> None: + # TODO(tian): Hack. We should figure out the optimal resoruces. + if 'cpus' not in lb_resources: + lb_resources['cpus'] = '2+' + # Already checked in service spec validation. + assert 'ports' not in lb_resources + lb_resources['ports'] = [lb_port] + lbr = resources_lib.Resources.from_yaml_config(lb_resources) + lb_cluster_name = _get_external_lb_cluster_name(service_name, lb_id) + # TODO(tian): Set delete=False to debug. Remove this on production. + with tempfile.NamedTemporaryFile(prefix=lb_cluster_name, + mode='w', + delete=False) as f: + vars_to_fill = { + 'load_balancer_port': lb_port, + 'controller_addr': controller_addr, + 'load_balancing_policy': lb_policy, + 'sky_activate_python_env': skylet_constants.ACTIVATE_SKY_REMOTE_PYTHON_ENV, + 'lb_envs': controller_utils.sky_managed_cluster_envs(), + } + common_utils.fill_template(constants.EXTERNAL_LB_TEMPLATE, + vars_to_fill, + output_path=f.name) + lb_task = task_lib.Task.from_yaml(f.name) + lb_task.set_resources(lbr) + serve_state.add_external_load_balancer(service_name, lb_id, + lb_cluster_name, lb_port) + # TODO(tian): Temporary solution for circular import. We should move + # the import to the top of the file. + import sky # pylint: disable=import-outside-toplevel + sky.launch( + task=lb_task, + stream_logs=True, + cluster_name=lb_cluster_name, + retry_until_up=True, + ) + + def _start(service_name: str, tmp_task_yaml: str, job_id: int): """Starts the service.""" # Generate ssh key pair to avoid race condition when multiple sky.launch @@ -177,12 +245,7 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int): service_name, constants.INITIAL_VERSION) shutil.copy(tmp_task_yaml, task_yaml) - # Generate load balancer log file name. - load_balancer_log_file = os.path.expanduser( - serve_utils.generate_remote_load_balancer_log_file_name(service_name)) - controller_process = None - load_balancer_process = None try: with filelock.FileLock( os.path.expanduser(constants.PORT_SELECTION_FILE_LOCK_PATH)): @@ -202,6 +265,12 @@ def _get_host(): # ('::1', 20001, 0, 0): cannot assign requested address return '127.0.0.1' + def _get_external_host(): + assert service_spec.external_load_balancers is not None + # TODO(tian): Use a more robust way to get the host. + return subprocess.check_output( + 'curl ifconfig.me', shell=True).decode('utf-8').strip() + controller_host = _get_host() # Start the controller. @@ -215,25 +284,55 @@ def _get_host(): # TODO(tian): Support HTTPS. controller_addr = f'http://{controller_host}:{controller_port}' - - load_balancer_port = common_utils.find_free_port( - constants.LOAD_BALANCER_PORT_START) - - # Extract the load balancing policy from the service spec - policy_name = service_spec.load_balancing_policy - - # Start the load balancer. - # TODO(tian): Probably we could enable multiple ports specified in - # service spec and we could start multiple load balancers. - # After that, we will have a mapping from replica port to endpoint. - load_balancer_process = multiprocessing.Process( - target=ux_utils.RedirectOutputForProcess( - load_balancer.run_load_balancer, - load_balancer_log_file).run, - args=(controller_addr, load_balancer_port, policy_name)) - load_balancer_process.start() - serve_state.set_service_load_balancer_port(service_name, - load_balancer_port) + load_balancer_processes = [] + + if service_spec.external_load_balancers is None: + # Generate load balancer log file name. + load_balancer_log_file = os.path.expanduser( + serve_utils.generate_remote_load_balancer_log_file_name( + service_name)) + + load_balancer_port = common_utils.find_free_port( + constants.LOAD_BALANCER_PORT_START) + + # Extract the load balancing policy from the service spec + policy_name = service_spec.load_balancing_policy + + # Start the load balancer. + # TODO(tian): Probably we could enable multiple ports specified + # in service spec and we could start multiple load balancers. + # After that, we need a mapping from replica port to endpoint. + load_balancer_process = multiprocessing.Process( + target=ux_utils.RedirectOutputForProcess( + load_balancer.run_load_balancer, + load_balancer_log_file).run, + args=(controller_addr, load_balancer_port, policy_name)) + load_balancer_process.start() + load_balancer_processes.append(load_balancer_process) + serve_state.set_service_load_balancer_port( + service_name, load_balancer_port) + else: + lb_port = 8000 + for lb_id, lb_config in enumerate( + service_spec.external_load_balancers): + # Generate load balancer log file name. + load_balancer_log_file = os.path.expanduser( + serve_utils. + generate_remote_external_load_balancer_log_file_name( + service_name, lb_id)) + lb_policy = lb_config['load_balancing_policy'] + lb_resources = lb_config['resources'] + controller_external_addr = ( + f'http://{_get_external_host()}:{controller_port}') + lb_process = multiprocessing.Process( + target=ux_utils.RedirectOutputForProcess( + _start_external_load_balancer, + load_balancer_log_file).run, + args=(service_name, controller_external_addr, lb_id, + lb_port, lb_policy, lb_resources)) + lb_process.start() + load_balancer_processes.append(lb_process) + serve_state.set_service_load_balancer_port(service_name, -1) while True: _handle_signal(service_name) @@ -245,7 +344,7 @@ def _get_host(): # Kill load balancer process first since it will raise errors if failed # to connect to the controller. Then the controller process. process_to_kill = [ - proc for proc in [load_balancer_process, controller_process] + proc for proc in [*load_balancer_processes, controller_process] if proc is not None ] subprocess_utils.kill_children_processes( diff --git a/sky/serve/service_spec.py b/sky/serve/service_spec.py index 000eed139f1..8c888eb3dc2 100644 --- a/sky/serve/service_spec.py +++ b/sky/serve/service_spec.py @@ -2,11 +2,11 @@ import json import os import textwrap -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import yaml -from sky import serve +from sky import resources as resources_lib from sky.serve import constants from sky.utils import common_utils from sky.utils import schemas @@ -31,6 +31,7 @@ def __init__( upscale_delay_seconds: Optional[int] = None, downscale_delay_seconds: Optional[int] = None, load_balancing_policy: Optional[str] = None, + external_load_balancers: Optional[List[Dict[str, Any]]] = None, ) -> None: if max_replicas is not None and max_replicas < min_replicas: with ux_utils.print_exception_no_traceback(): @@ -57,13 +58,27 @@ def __init__( raise ValueError('readiness_path must start with a slash (/). ' f'Got: {readiness_path}') - # Add the check for unknown load balancing policies + # Use load_balancing_policy as fallback for external_load_balancers if (load_balancing_policy is not None and - load_balancing_policy not in serve.LB_POLICIES): - with ux_utils.print_exception_no_traceback(): - raise ValueError( - f'Unknown load balancing policy: {load_balancing_policy}. ' - f'Available policies: {list(serve.LB_POLICIES.keys())}') + external_load_balancers is not None): + for lb_config in external_load_balancers: + if lb_config.get('load_balancing_policy') is None: + lb_config['load_balancing_policy'] = load_balancing_policy + + if external_load_balancers is not None: + for lb_config in external_load_balancers: + r = lb_config.get('resources') + if r is None: + with ux_utils.print_exception_no_traceback(): + raise ValueError('`resources` must be set for ' + 'external_load_balancers.') + if 'ports' in r: + with ux_utils.print_exception_no_traceback(): + raise ValueError('`ports` must not be set for ' + 'external_load_balancers.') + # Validate resources + resources_lib.Resources.from_yaml_config(r) + self._readiness_path: str = readiness_path self._initial_delay_seconds: int = initial_delay_seconds self._readiness_timeout_seconds: int = readiness_timeout_seconds @@ -79,6 +94,8 @@ def __init__( self._upscale_delay_seconds: Optional[int] = upscale_delay_seconds self._downscale_delay_seconds: Optional[int] = downscale_delay_seconds self._load_balancing_policy: Optional[str] = load_balancing_policy + self._external_load_balancers: Optional[List[Dict[str, Any]]] = ( + external_load_balancers) self._use_ondemand_fallback: bool = ( self.dynamic_ondemand_fallback is not None and @@ -162,6 +179,8 @@ def from_yaml_config(config: Dict[str, Any]) -> 'SkyServiceSpec': service_config['load_balancing_policy'] = config.get( 'load_balancing_policy', None) + service_config['external_load_balancers'] = config.get( + 'external_load_balancers', None) return SkyServiceSpec(**service_config) @staticmethod @@ -219,6 +238,8 @@ def add_if_not_none(section, key, value, no_empty: bool = False): self.downscale_delay_seconds) add_if_not_none('load_balancing_policy', None, self._load_balancing_policy) + add_if_not_none('external_load_balancers', None, + self._external_load_balancers) return config def probe_str(self): @@ -329,3 +350,7 @@ def use_ondemand_fallback(self) -> bool: @property def load_balancing_policy(self) -> Optional[str]: return self._load_balancing_policy + + @property + def external_load_balancers(self) -> Optional[List[Dict[str, Any]]]: + return self._external_load_balancers diff --git a/sky/templates/sky-serve-external-load-balancer.yaml.j2 b/sky/templates/sky-serve-external-load-balancer.yaml.j2 new file mode 100644 index 00000000000..0c196d6ad3e --- /dev/null +++ b/sky/templates/sky-serve-external-load-balancer.yaml.j2 @@ -0,0 +1,23 @@ +# The template for the sky serve load balancers + +name: load-balancer + +setup: | + {{ sky_activate_python_env }} + # Install serve dependencies. + # TODO(tian): Gather those into serve constants. + pip list | grep uvicorn > /dev/null 2>&1 || pip install uvicorn > /dev/null 2>&1 + pip list | grep fastapi > /dev/null 2>&1 || pip install fastapi > /dev/null 2>&1 + pip list | grep httpx > /dev/null 2>&1 || pip install httpx > /dev/null 2>&1 + +run: | + {{ sky_activate_python_env }} + python -u -m sky.serve.load_balancer \ + --controller-addr {{controller_addr}} \ + --load-balancer-port {{load_balancer_port}} \ + --load-balancing-policy {{load_balancing_policy}} + +envs: +{%- for env_name, env_value in lb_envs.items() %} + {{env_name}}: {{env_value}} +{%- endfor %} diff --git a/sky/utils/controller_utils.py b/sky/utils/controller_utils.py index 0ab2fd7e117..be0db2ddaa9 100644 --- a/sky/utils/controller_utils.py +++ b/sky/utils/controller_utils.py @@ -391,6 +391,24 @@ def download_and_stream_latest_job_log( return log_file +# TODO(tian): Maybe move this to other places? +def sky_managed_cluster_envs() -> Dict[str, str]: + env_vars: Dict[str, str] = { + env.env_key: str(int(env.get())) for env in env_options.Options + } + env_vars.update({ + # Should not use $USER here, as that env var can be empty when + # running in a container. + constants.USER_ENV_VAR: getpass.getuser(), + constants.USER_ID_ENV_VAR: common_utils.get_user_hash(), + # Skip cloud identity check to avoid the overhead. + env_options.Options.SKIP_CLOUD_IDENTITY_CHECK.env_key: '1', + # Disable minimize logging to get more details on the controller. + env_options.Options.MINIMIZE_LOGGING.env_key: '0', + }) + return env_vars + + def shared_controller_vars_to_fill( controller: Controllers, remote_user_config_path: str, local_user_config: Dict[str, Any]) -> Dict[str, str]: @@ -425,19 +443,7 @@ def shared_controller_vars_to_fill( 'sky_python_cmd': constants.SKY_PYTHON_CMD, 'local_user_config_path': local_user_config_path, } - env_vars: Dict[str, str] = { - env.env_key: str(int(env.get())) for env in env_options.Options - } - env_vars.update({ - # Should not use $USER here, as that env var can be empty when - # running in a container. - constants.USER_ENV_VAR: getpass.getuser(), - constants.USER_ID_ENV_VAR: common_utils.get_user_hash(), - # Skip cloud identity check to avoid the overhead. - env_options.Options.SKIP_CLOUD_IDENTITY_CHECK.env_key: '1', - # Disable minimize logging to get more details on the controller. - env_options.Options.MINIMIZE_LOGGING.env_key: '0', - }) + env_vars = sky_managed_cluster_envs() if skypilot_config.loaded(): # Only set the SKYPILOT_CONFIG env var if the user has a config file. env_vars[ diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 4d5cc809013..1a2b535b9d1 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -311,6 +311,7 @@ def get_service_schema(): # To avoid circular imports, only import when needed. # pylint: disable=import-outside-toplevel from sky.serve import load_balancing_policies + lb_policy_choices = list(load_balancing_policies.LB_POLICIES.keys()) return { '$schema': 'https://json-schema.org/draft/2020-12/schema', 'type': 'object', @@ -385,10 +386,26 @@ def get_service_schema(): 'replicas': { 'type': 'integer', }, + 'external_load_balancers': { + 'type': 'array', + 'items': { + 'type': 'object', + 'required': ['resources', 'load_balancing_policy'], + 'additionalProperties': False, + 'properties': { + 'resources': { + 'type': 'object', + }, + 'load_balancing_policy': { + 'type': 'string', + 'case_insensitive_enum': lb_policy_choices, + }, + } + } + }, 'load_balancing_policy': { 'type': 'string', - 'case_insensitive_enum': list( - load_balancing_policies.LB_POLICIES.keys()) + 'case_insensitive_enum': lb_policy_choices, }, } }