diff --git a/sky/utils/controller_utils.py b/sky/utils/controller_utils.py index 39045962a78..002de0ab77e 100644 --- a/sky/utils/controller_utils.py +++ b/sky/utils/controller_utils.py @@ -465,20 +465,17 @@ def get_controller_resources( if handle is not None: controller_resources_to_use = handle.launched_resources - if controller_resources_to_use.cloud is not None: - return {controller_resources_to_use} + # If the controller and replicas are from the same cloud (and region/zone), + # it should provide better connectivity. We will let the controller choose + # from the clouds (and regions/zones) of the resources if the user does not + # specify the cloud (and region/zone) for the controller. - # If the controller and replicas are from the same cloud, it should - # provide better connectivity. We will let the controller choose from - # the clouds of the resources if the controller does not exist. - # TODO(tian): Consider respecting the regions/zones specified for the - # resources as well. - requested_clouds: Set['clouds.Cloud'] = set() + requested_clouds_with_region_zone: Dict[str, Dict[Optional[str], + Set[Optional[str]]]] = {} for resource in task_resources: - # cloud is an object and will not be able to be distinguished by set. - # Here we manually check if the cloud is in the set. if resource.cloud is not None: - if not clouds.cloud_in_iterable(resource.cloud, requested_clouds): + cloud_name = str(resource.cloud) + if cloud_name not in requested_clouds_with_region_zone: try: resource.cloud.check_features_are_supported( resources.Resources(), @@ -486,7 +483,26 @@ def get_controller_resources( except exceptions.NotSupportedError: # Skip the cloud if it does not support hosting controllers. continue - requested_clouds.add(resource.cloud) + requested_clouds_with_region_zone[cloud_name] = {} + if resource.region is None: + # If one of the resource.region is None, this could represent + # that the user is unsure about which region the resource is + # hosted in. In this case, we allow any region for this cloud. + requested_clouds_with_region_zone[cloud_name] = {None: {None}} + elif None not in requested_clouds_with_region_zone[cloud_name]: + if resource.region not in requested_clouds_with_region_zone[ + cloud_name]: + requested_clouds_with_region_zone[cloud_name][ + resource.region] = set() + # If one of the resource.zone is None, allow any zone in the + # region. + if resource.zone is None: + requested_clouds_with_region_zone[cloud_name][ + resource.region] = {None} + elif None not in requested_clouds_with_region_zone[cloud_name][ + resource.region]: + requested_clouds_with_region_zone[cloud_name][ + resource.region].add(resource.zone) else: # if one of the resource.cloud is None, this could represent user # does not know which cloud is best for the specified resources. @@ -496,14 +512,49 @@ def get_controller_resources( # - cloud: runpod # accelerators: A40 # In this case, we allow the controller to be launched on any cloud. - requested_clouds.clear() + requested_clouds_with_region_zone.clear() break - if not requested_clouds: + + # Extract filtering criteria from the controller resources specified by the + # user. + controller_cloud = str( + controller_resources_to_use.cloud + ) if controller_resources_to_use.cloud is not None else None + controller_region = controller_resources_to_use.region + controller_zone = controller_resources_to_use.zone + + # Filter clouds if controller_resources_to_use.cloud is specified. + filtered_clouds = ({controller_cloud} if controller_cloud is not None else + requested_clouds_with_region_zone.keys()) + + # Filter regions and zones and construct the result. + result: Set[resources.Resources] = set() + for cloud_name in filtered_clouds: + regions = requested_clouds_with_region_zone.get(cloud_name, + {None: {None}}) + + # Filter regions if controller_resources_to_use.region is specified. + filtered_regions = ({controller_region} if controller_region is not None + else regions.keys()) + + for region in filtered_regions: + zones = regions.get(region, {None}) + + # Filter zones if controller_resources_to_use.zone is specified. + filtered_zones = ({controller_zone} + if controller_zone is not None else zones) + + # Create combinations of cloud, region, and zone. + for zone in filtered_zones: + resource_copy = controller_resources_to_use.copy( + cloud=clouds.CLOUD_REGISTRY.from_str(cloud_name), + region=region, + zone=zone) + result.add(resource_copy) + + if not result: return {controller_resources_to_use} - return { - controller_resources_to_use.copy(cloud=controller_cloud) - for controller_cloud in requested_clouds - } + return result def _setup_proxy_command_on_controller( diff --git a/tests/unit_tests/test_controller_utils.py b/tests/unit_tests/test_controller_utils.py index 7465f648385..f41c7413bc1 100644 --- a/tests/unit_tests/test_controller_utils.py +++ b/tests/unit_tests/test_controller_utils.py @@ -1,5 +1,5 @@ """Test the controller_utils module.""" -from typing import Any, Dict +from typing import Any, Dict, Optional, Set, Tuple import pytest @@ -65,6 +65,24 @@ def get_custom_controller_resources(keys, default): controller_resources_config, k, v) +def _check_controller_resources( + controller_resources: Set[sky.Resources], + expected_combinations: Set[Tuple[Optional[str], Optional[str], + Optional[str]]], + default_controller_resources: Dict[str, Any]) -> None: + """Helper function to check that the controller resources match the + expected combinations.""" + for r in controller_resources: + config = r.to_yaml_config() + cloud = config.pop('cloud') + region = config.pop('region', None) + zone = config.pop('zone', None) + assert (cloud, region, zone) in expected_combinations + expected_combinations.remove((cloud, region, zone)) + assert config == default_controller_resources, config + assert not expected_combinations + + @pytest.mark.parametrize(('controller_type', 'default_controller_resources'), [ ('jobs', managed_job_constants.CONTROLLER_RESOURCES), ('serve', serve_constants.CONTROLLER_RESOURCES), @@ -79,17 +97,12 @@ def test_get_controller_resources_with_task_resources( # could host controllers. Return a set, each item has # one cloud specified plus the default resources. all_clouds = {sky.AWS(), sky.GCP(), sky.Azure()} - all_cloud_names = {str(c) for c in all_clouds} + expected_combinations = {(str(c), None, None) for c in all_clouds} controller_resources = controller_utils.get_controller_resources( controller=controller_utils.Controllers.from_type(controller_type), task_resources=[sky.Resources(cloud=c) for c in all_clouds]) - for r in controller_resources: - config = r.to_yaml_config() - cloud = config.pop('cloud') - assert cloud in all_cloud_names - all_cloud_names.remove(cloud) - assert config == default_controller_resources, config - assert not all_cloud_names + _check_controller_resources(controller_resources, expected_combinations, + default_controller_resources) # 2. All resources has cloud specified. Some of them # could NOT host controllers. Return a set, only @@ -113,19 +126,14 @@ def _could_host_controllers(cloud: sky.clouds.Cloud) -> bool: return False return True - all_cloud_names_expected = { - str(c) for c in all_clouds if _could_host_controllers(c) + expected_combinations = { + (str(c), None, None) for c in all_clouds if _could_host_controllers(c) } controller_resources = controller_utils.get_controller_resources( controller=controller_utils.Controllers.from_type(controller_type), task_resources=[sky.Resources(cloud=c) for c in all_clouds]) - for r in controller_resources: - config = r.to_yaml_config() - cloud = config.pop('cloud') - assert cloud in all_cloud_names_expected - all_cloud_names_expected.remove(cloud) - assert config == default_controller_resources, config - assert not all_cloud_names_expected + _check_controller_resources(controller_resources, expected_combinations, + default_controller_resources) # 3. Some resources does not have cloud specified. # Return the default resources. @@ -138,3 +146,73 @@ def _could_host_controllers(cloud: sky.clouds.Cloud) -> bool: assert len(controller_resources) == 1 config = list(controller_resources)[0].to_yaml_config() assert config == default_controller_resources, config + + # 4. All resources have clouds, regions, and zones specified. + # Return a set of controller resources for all combinations of clouds, + # regions, and zones. Each combination should contain the default resources + # along with the cloud, region, and zone. + all_cloud_regions_zones = [ + sky.Resources(cloud=sky.AWS(), region='us-east-1', zone='us-east-1a'), + sky.Resources(cloud=sky.AWS(), region='ap-south-1', zone='ap-south-1b'), + sky.Resources(cloud=sky.GCP(), + region='us-central1', + zone='us-central1-a'), + sky.Resources(cloud=sky.GCP(), + region='europe-west1', + zone='europe-west1-b') + ] + expected_combinations = {('AWS', 'us-east-1', 'us-east-1a'), + ('AWS', 'ap-south-1', 'ap-south-1b'), + ('GCP', 'us-central1', 'us-central1-a'), + ('GCP', 'europe-west1', 'europe-west1-b')} + controller_resources = controller_utils.get_controller_resources( + controller=controller_utils.Controllers.from_type(controller_type), + task_resources=all_cloud_regions_zones) + _check_controller_resources(controller_resources, expected_combinations, + default_controller_resources) + + # 5. Clouds and regions are specified, but zones are partially specified. + # Return a set containing combinations where the zone is None when not all + # zones are specified in the input for the given region. The default + # resources should be returned along with the cloud and region, and the + # zone (if specified). + controller_resources = controller_utils.get_controller_resources( + controller=controller_utils.Controllers.from_type(controller_type), + task_resources=[ + sky.Resources(cloud=sky.AWS(), region='us-west-2'), + sky.Resources(cloud=sky.AWS(), + region='us-west-2', + zone='us-west-2b'), + sky.Resources(cloud=sky.GCP(), + region='us-central1', + zone='us-central1-a') + ]) + expected_combinations = {('AWS', 'us-west-2', None), + ('GCP', 'us-central1', 'us-central1-a')} + _check_controller_resources(controller_resources, expected_combinations, + default_controller_resources) + + # 6. Mixed case: Some resources have clouds and regions or zones, others do + # not. For clouds where regions or zones are not specified in the input, + # return None for those fields. The default resources should be returned + # along with the cloud, region (if specified), and zone (if specified). + controller_resources = controller_utils.get_controller_resources( + controller=controller_utils.Controllers.from_type(controller_type), + task_resources=[ + sky.Resources(cloud=sky.GCP(), region='europe-west1'), + sky.Resources(cloud=sky.GCP()), + sky.Resources(cloud=sky.AWS(), + region='eu-north-1', + zone='eu-north-1a'), + sky.Resources(cloud=sky.AWS(), region='eu-north-1'), + sky.Resources(cloud=sky.AWS(), region='ap-south-1'), + sky.Resources(cloud=sky.Azure()), + ]) + expected_combinations = { + ('AWS', 'eu-north-1', None), + ('AWS', 'ap-south-1', None), + ('GCP', None, None), + ('Azure', None, None), + } + _check_controller_resources(controller_resources, expected_combinations, + default_controller_resources)