diff --git a/docs/source/getting-started/installation.rst b/docs/source/getting-started/installation.rst index 996e4f6c96b..8e1ef5e2bba 100644 --- a/docs/source/getting-started/installation.rst +++ b/docs/source/getting-started/installation.rst @@ -58,6 +58,8 @@ To get the **AWS access key** required by :code:`aws configure`, please go to th $ # Configure your AWS credentials $ aws configure +Note: If you are using AWS IAM Identity Center (AWS SSO), you will need :code:`pip install awscli>=1.27.10`. See `here `_ for instructions on how to configure AWS SSO. + **GCP** .. code-block:: console @@ -181,4 +183,4 @@ If you experience any issues after installation, you can use the :code:`--uninst $ sky --uninstall-shell-completion auto $ # sky --uninstall-shell-completion zsh $ # sky --uninstall-shell-completion bash - $ # sky --uninstall-shell-completion fish \ No newline at end of file + $ # sky --uninstall-shell-completion fish diff --git a/examples/using_file_mounts.yaml b/examples/using_file_mounts.yaml index 2906c231a0d..b6a961ab15b 100644 --- a/examples/using_file_mounts.yaml +++ b/examples/using_file_mounts.yaml @@ -70,6 +70,8 @@ file_mounts: /s3-data-test: s3://fah-public-data-covid19-cryptic-pockets/human/il6/PROJ14534/RUN999/CLONE0/results0 /s3-data-file: s3://fah-public-data-covid19-cryptic-pockets/human/il6/PROJ14534/RUN999/CLONE0/results0/frame0.xtc + # Test access to private bucket + # /my-bucket: s3://sky-detectron2-outputs # /test-my-gcs: gs://cloud-storage-test-zhwu-2 # If a source path points to a "directory", its contents will be recursively diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 0d54d9dc2f8..c9a85847fd7 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -1783,7 +1783,7 @@ def _update_cluster_status_no_lock( backend.set_autostop(handle, -1, stream_logs=False) except (Exception, SystemExit) as e: # pylint: disable=broad-except logger.debug( - f'Failed to reset autostop. Due to {common_utils.class_fullname(e.__class__)}: {e}' + f'Failed to reset autostop. Due to {common_utils.format_exception(e)}' ) global_user_state.set_cluster_autostop_value(handle.cluster_name, -1, @@ -1829,11 +1829,11 @@ def _update_cluster_status( Raises: exceptions.ClusterOwnerIdentityMismatchError: if the current user is not the - same as the user who created the cluster. + same as the user who created the cluster. exceptions.CloudUserIdentityError: if we fail to get the current user - identity. + identity. exceptions.ClusterStatusFetchingError: the cluster status cannot be - fetched from the cloud provider. + fetched from the cloud provider. """ if not acquire_per_cluster_status_lock: return _update_cluster_status_no_lock(cluster_name) diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index c99326a6eec..c5fe3338ee6 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -996,6 +996,7 @@ def _retry_region_zones(self, dryrun: bool, stream_logs: bool, cluster_name: str, + cloud_user_identity: Optional[str], cluster_exists: bool = False): """The provision retry loop.""" style = colorama.Style @@ -1062,10 +1063,8 @@ def _retry_region_zones(self, global_user_state.add_or_update_cluster(cluster_name, cluster_handle=handle, ready=False) - cloud = handle.launched_resources.cloud - cloud_user_id = cloud.get_current_user_identity() global_user_state.set_owner_identity_for_cluster( - cluster_name, cloud_user_id) + cluster_name, cloud_user_identity) tpu_name = config_dict.get('tpu_name') if tpu_name is not None: @@ -1493,24 +1492,23 @@ def provision_with_retries( style = colorama.Style # Retrying launchable resources. - provision_failed = True - while provision_failed: - provision_failed = False + while True: try: - try: - # Recheck cluster name as the 'except:' block below may - # change the cloud assignment. - backend_utils.check_cluster_name_is_valid( - cluster_name, to_provision.cloud) - except exceptions.InvalidClusterNameError as e: - # Let failover below handle this (i.e., block this cloud). - raise exceptions.ResourcesUnavailableError(str(e)) from e + # Recheck cluster name as the 'except:' block below may + # change the cloud assignment. + backend_utils.check_cluster_name_is_valid( + cluster_name, to_provision.cloud) + if dryrun: + cloud_user = None + else: + cloud_user = to_provision.cloud.get_current_user_identity() config_dict = self._retry_region_zones( to_provision, num_nodes, dryrun=dryrun, stream_logs=stream_logs, cluster_name=cluster_name, + cloud_user_identity=cloud_user, cluster_exists=cluster_exists) if dryrun: return @@ -1525,39 +1523,43 @@ def provision_with_retries( 'optimize_target=sky.OptimizeTarget.COST)') raise e - logger.warning(e) - provision_failed = True - logger.warning( - f'\n{style.BRIGHT}Provision failed for {num_nodes}x ' - f'{to_provision}. Trying other launchable resources ' - f'(if any).{style.RESET_ALL}') - if not cluster_exists: - # Add failed resources to the blocklist, only when it - # is in fallback mode. - self._blocked_launchable_resources.add(to_provision) - else: - logger.info( - 'Retrying provisioning with requested resources ' - f'{task.num_nodes}x {task.resources}') - # Retry with the current, potentially "smaller" resources: - # to_provision == the current new resources (e.g., V100:1), - # which may be "smaller" than the original (V100:8). - # num_nodes is not part of a Resources so must be updated - # separately. - num_nodes = task.num_nodes - cluster_exists = False - - # Set to None so that sky.optimize() will assign a new one - # (otherwise will skip re-optimizing this task). - # TODO: set all remaining tasks' best_resources to None. - task.best_resources = None - self._dag = sky.optimize(self._dag, - minimize=self._optimize_target, - blocked_launchable_resources=self. - _blocked_launchable_resources) - to_provision = task.best_resources - assert task in self._dag.tasks, 'Internal logic error.' - assert to_provision is not None, task + logger.warning(common_utils.format_exception(e)) + except (exceptions.CloudUserIdentityError, + exceptions.InvalidClusterNameError) as e: + # Let failover below handle this (i.e., block this cloud). + logger.warning(common_utils.format_exception(e)) + else: + # Provisioning succeeded. + break + logger.warning(f'\n{style.BRIGHT}Provision failed for {num_nodes}x ' + f'{to_provision}. Trying other launchable resources ' + f'(if any).{style.RESET_ALL}') + if not cluster_exists: + # Add failed resources to the blocklist, only when it + # is in fallback mode. + self._blocked_launchable_resources.add(to_provision) + else: + logger.info('Retrying provisioning with requested resources ' + f'{task.num_nodes}x {task.resources}') + # Retry with the current, potentially "smaller" resources: + # to_provision == the current new resources (e.g., V100:1), + # which may be "smaller" than the original (V100:8). + # num_nodes is not part of a Resources so must be updated + # separately. + num_nodes = task.num_nodes + cluster_exists = False + + # Set to None so that sky.optimize() will assign a new one + # (otherwise will skip re-optimizing this task). + # TODO: set all remaining tasks' best_resources to None. + task.best_resources = None + self._dag = sky.optimize( + self._dag, + minimize=self._optimize_target, + blocked_launchable_resources=self._blocked_launchable_resources) + to_provision = task.best_resources + assert task in self._dag.tasks, 'Internal logic error.' + assert to_provision is not None, task return config_dict diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index cc41915ccd9..afa57e4a0e8 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -19,22 +19,26 @@ # renaming to avoid shadowing variables from sky import resources as resources_lib -# Minimum set of files under ~/.aws that grant AWS access. +# This local file (under ~/.aws/) will be uploaded to remote nodes (any +# cloud), if all of the following conditions hold: +# - the current user identity is not using AWS SSO +# - this file exists +# It has the following purposes: +# - make all nodes (any cloud) able to access private S3 buckets +# - make some remote nodes able to launch new nodes on AWS (i.e., makes +# AWS head node able to launch AWS workers, or any-cloud spot controller +# able to launch spot clusters on AWS). +# +# If we detect the current user identity is AWS SSO, we will not upload this +# file to any remote nodes (any cloud). Instead, a SkyPilot IAM role is +# assigned to both AWS head and workers. +# TODO(skypilot): This also means we leave open a bug for AWS SSO users that +# use multiple clouds. The non-AWS nodes will have neither the credential +# file nor the ability to understand AWS IAM. _CREDENTIAL_FILES = [ 'credentials', ] - -def _run_output(cmd): - proc = subprocess.run(cmd, - shell=True, - check=True, - stderr=subprocess.PIPE, - stdout=subprocess.PIPE) - return proc.stdout.decode('ascii') - - -# TODO(zhwu): Move the default AMI size to the catalog instead. DEFAULT_AMI_GB = 45 @@ -45,21 +49,28 @@ class AWS(clouds.Cloud): _REPR = 'AWS' _regions: List[clouds.Region] = [] + _INDENT_PREFIX = ' ' _STATIC_CREDENTIAL_HELP_STR = ( 'Run the following commands:' - '\n $ pip install boto3' - '\n $ aws configure' - '\n For more info: ' + f'\n{_INDENT_PREFIX} $ pip install boto3' + f'\n{_INDENT_PREFIX} $ aws configure' + f'\n{_INDENT_PREFIX}For more info: ' 'https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-quickstart.html' # pylint: disable=line-too-long ) - _SSO_CREDENTIAL_HELP_STR = ( - 'Run the following commands (must use aws v2 CLI):' - '\n $ aws configure sso' - '\n $ aws sso login --profile ' - '\n For more info: ' - 'https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-sso.html' # pylint: disable=line-too-long - ) + @classmethod + def _sso_credentials_help_str(cls, expired: bool = False) -> str: + help_str = 'Run the following commands (must use aws v2 CLI):' + if not expired: + help_str += f'\n{cls._INDENT_PREFIX} $ aws configure sso' + help_str += ( + f'\n{cls._INDENT_PREFIX} $ aws sso login --profile ' + f'\n{cls._INDENT_PREFIX}For more info: ' + 'https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-sso.html' # pylint: disable=line-too-long + ) + return help_str + + _MAX_AWSCLI_MAJOR_VERSION = 1 #### Regions/Zones #### @@ -100,7 +111,7 @@ def region_zones_provision_loop( *, instance_type: Optional[str] = None, accelerators: Optional[Dict[str, int]] = None, - use_spot: bool, + use_spot: bool = False, ) -> Iterator[Tuple[clouds.Region, List[clouds.Zone]]]: # AWS provisioner can handle batched requests, so yield all zones under # each region. @@ -139,28 +150,28 @@ def get_default_ami(cls, region_name: str, instance_type: str) -> str: @classmethod def _get_image_id( cls, - image_id: Optional[Dict[str, str]], + image_id: Optional[Dict[Optional[str], str]], region_name: str, - ) -> str: + ) -> Optional[str]: if image_id is None: return None if None in image_id: - image_id = image_id[None] + image_id_str = image_id[None] else: assert region_name in image_id, image_id - image_id = image_id[region_name] - if image_id.startswith('skypilot:'): - image_id = service_catalog.get_image_id_from_tag(image_id, - region_name, - clouds='aws') - if image_id is None: + image_id_str = image_id[region_name] + if image_id_str.startswith('skypilot:'): + image_id_str = service_catalog.get_image_id_from_tag(image_id_str, + region_name, + clouds='aws') + if image_id_str is None: # Raise ResourcesUnavailableError to make sure the failover # in CloudVMRayBackend will be correctly triggered. # TODO(zhwu): This is a information leakage to the cloud # implementor, we need to find a better way to handle this. raise exceptions.ResourcesUnavailableError( f'No image found for region {region_name}') - return image_id + return image_id_str def get_image_size(self, image_id: str, region: Optional[str]) -> float: if image_id.startswith('skypilot:'): @@ -260,7 +271,7 @@ def get_vcpus_from_instance_type( def make_deploy_resources_variables( self, resources: 'resources_lib.Resources', region: Optional['clouds.Region'], - zones: Optional[List['clouds.Zone']]) -> Dict[str, str]: + zones: Optional[List['clouds.Zone']]) -> Dict[str, Optional[str]]: if region is None: assert zones is None, ( 'Set either both or neither for: region, zones.') @@ -271,7 +282,7 @@ def make_deploy_resources_variables( 'Set either both or neither for: region, zones.') region_name = region.name - zones = [zone.name for zone in zones] + zone_names = [zone.name for zone in zones] r = resources # r.accelerators is cleared but .instance_type encodes the info. @@ -290,13 +301,13 @@ def make_deploy_resources_variables( 'custom_resources': custom_resources, 'use_spot': r.use_spot, 'region': region_name, - 'zones': ','.join(zones), + 'zones': ','.join(zone_names), 'image_id': image_id, } def get_feasible_launchable_resources(self, resources: 'resources_lib.Resources'): - fuzzy_candidate_list = [] + fuzzy_candidate_list: List[str] = [] if resources.instance_type is not None: assert resources.is_launchable(), resources # Treat Resources(AWS, p3.2x, V100) as Resources(AWS, p3.2x). @@ -342,24 +353,20 @@ def check_credentials(self) -> Tuple[bool, Optional[str]]: except ImportError: raise ImportError('Fail to import dependencies for AWS.' 'Try pip install "skypilot[aws]"') from None - # This file is required because it will be synced to remote VMs for - # `aws` to access private storage buckets. - # `aws configure list` does not guarantee this file exists. - if not os.path.isfile(os.path.expanduser('~/.aws/credentials')): - return (False, '~/.aws/credentials does not exist. ' + - self._STATIC_CREDENTIAL_HELP_STR) # Checks if the AWS CLI is installed properly - try: - _run_output('aws configure list') - except subprocess.CalledProcessError: + proc = subprocess.run('aws --version', + shell=True, + check=False, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + if proc.returncode != 0: return False, ( - 'AWS CLI is not installed properly.' - # TODO(zhwu): Change the installation hint to from PyPI. - ' Run the following commands in the SkyPilot codebase:' - '\n $ pip install .[aws]' - '\n Credentials may also need to be set. ' + - self._STATIC_CREDENTIAL_HELP_STR) + 'AWS CLI is not installed properly. ' + 'Run the following commands:' + f'\n{self._INDENT_PREFIX} $ pip install skypilot[aws]' + f'{self._INDENT_PREFIX}Credentials may also need to be set. ' + f'{self._STATIC_CREDENTIAL_HELP_STR}') # Checks if AWS credentials 1) exist and 2) are valid. # https://stackoverflow.com/questions/53548737/verify-aws-credentials-with-boto3 @@ -368,9 +375,43 @@ def check_credentials(self) -> Tuple[bool, Optional[str]]: except exceptions.CloudUserIdentityError as e: return False, str(e) + static_credential_exists = os.path.isfile( + os.path.expanduser('~/.aws/credentials')) + hints = None + if self._is_current_identity_sso(): + hints = 'AWS SSO is set. ' + if static_credential_exists: + hints += ( + ' To ensure multiple clouds work correctly, please use SkyPilot ' + 'with static credentials (e.g., ~/.aws/credentials) by unsetting ' + 'the AWS_PROFILE environment variable.') + else: + hints += ( + ' It will work if you use AWS only, but will cause problems ' + 'if you want to use multiple clouds. To set up static credentials, ' + 'try: aws configure') + + else: + # This file is required because it is required by the VMs launched on + # other clouds to access private s3 buckets and resources like EC2. + # `get_current_user_identity` does not guarantee this file exists. + if not static_credential_exists: + return (False, '~/.aws/credentials does not exist. ' + + self._STATIC_CREDENTIAL_HELP_STR) + # Fetch the AWS availability zones mapping from ID to name. from sky.clouds.service_catalog import aws_catalog # pylint: disable=import-outside-toplevel,unused-import - return True, None + return True, hints + + def _is_current_identity_sso(self) -> bool: + proc = subprocess.run('aws configure list', + shell=True, + check=False, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + if proc.returncode != 0: + return False + return 'sso' in proc.stdout.decode().split() def get_current_user_identity(self) -> Optional[str]: """Returns the identity of the user on this cloud.""" @@ -397,6 +438,24 @@ def get_current_user_identity(self) -> Optional[str]: 'Failed to access AWS services with credentials. ' 'Make sure that the access and secret keys are correct.' f' {self._STATIC_CREDENTIAL_HELP_STR}') from None + except aws.botocore_exceptions().InvalidConfigError as e: + import awscli + from packaging import version + awscli_version = version.parse(awscli.__version__) + if (awscli_version < version.parse('1.27.10') and + 'configured to use SSO' in str(e)): + with ux_utils.print_exception_no_traceback(): + raise exceptions.CloudUserIdentityError( + 'awscli is too old to use SSO. Run the following command to upgrade:' + f'\n{self._INDENT_PREFIX} $ pip install awscli>=1.27.10' + f'\n{self._INDENT_PREFIX}You may need to log into SSO again after ' + f'upgrading. {self._sso_credentials_help_str()}' + ) from None + with ux_utils.print_exception_no_traceback(): + raise exceptions.CloudUserIdentityError( + f'Invalid AWS configuration.\n' + f' Reason: {common_utils.format_exception(e, use_bracket=True)}.' + ) from None except aws.botocore_exceptions().TokenRetrievalError: # This is raised when the access token is expired, which mainly # happens when the user is using temporary credentials or SSO @@ -404,19 +463,39 @@ def get_current_user_identity(self) -> Optional[str]: with ux_utils.print_exception_no_traceback(): raise exceptions.CloudUserIdentityError( 'AWS access token is expired.' - f' {self._SSO_CREDENTIAL_HELP_STR}') from None + f' {self._sso_credentials_help_str(expired=True)}' + ) from None except Exception as e: # pylint: disable=broad-except with ux_utils.print_exception_no_traceback(): raise exceptions.CloudUserIdentityError( f'Failed to get AWS user.\n' - f' Reason: [{common_utils.class_fullname(e.__class__)}] {e}.' + f' Reason: {common_utils.format_exception(e, use_bracket=True)}.' ) from None return user_id def get_credential_file_mounts(self) -> Dict[str, str]: + # TODO(skypilot): ~/.aws/credentials is required for users using multiple clouds. + # If this file does not exist, users can launch on AWS via AWS SSO and assign + # IAM role to the cluster. + # However, if users launch clusters in a non-AWS cloud, those clusters do not + # understand AWS IAM role so will not be able to access private AWS EC2 resources + # and S3 buckets. + + # The file should not be uploaded if the user is using SSO, as the credential + # file can be from a different account, and will make autopstop/autodown/spot + # controller misbehave. + + # TODO(zhwu/zongheng): We can also avoid uploading the credential file for the + # cluster launched on AWS even if the user is using static credentials. We need + # to define a mechanism to find out the cloud provider of the cluster to be + # launched in this function and make sure the cluster will not be used for + # launching clusters in other clouds, e.g. spot controller. + if self._is_current_identity_sso(): + return {} return { f'~/.aws/{filename}': f'~/.aws/{filename}' for filename in _CREDENTIAL_FILES + if os.path.exists(os.path.expanduser(f'~/.aws/{filename}')) } def instance_type_exists(self, instance_type): diff --git a/sky/clouds/azure.py b/sky/clouds/azure.py index 39d93b70b2c..1f49aee34e6 100644 --- a/sky/clouds/azure.py +++ b/sky/clouds/azure.py @@ -7,6 +7,7 @@ from sky import clouds from sky import exceptions +from sky import sky_logging from sky.adaptors import azure from sky.clouds import service_catalog from sky.utils import common_utils @@ -15,6 +16,8 @@ if typing.TYPE_CHECKING: from sky import resources +logger = sky_logging.init_logger(__name__) + # Minimum set of files under ~/.azure that grant Azure access. _CREDENTIAL_FILES = [ 'azureProfile.json', @@ -23,6 +26,8 @@ 'msal_token_cache.json', ] +_MAX_IDENTITY_FETCH_RETRY = 10 + def _run_output(cmd): proc = subprocess.run(cmd, @@ -141,7 +146,7 @@ def region_zones_provision_loop( *, instance_type: Optional[str] = None, accelerators: Optional[Dict[str, int]] = None, - use_spot: bool, + use_spot: bool = False, ) -> Iterator[Tuple[clouds.Region, List[clouds.Zone]]]: del accelerators # unused @@ -180,15 +185,13 @@ def get_zone_shell_cmd(cls) -> Optional[str]: def make_deploy_resources_variables( self, resources: 'resources.Resources', region: Optional['clouds.Region'], - zones: Optional[List['clouds.Zone']]) -> Dict[str, str]: + zones: Optional[List['clouds.Zone']]) -> Dict[str, Optional[str]]: if region is None: assert zones is None, ( 'Set either both or neither for: region, zones.') region = self._get_default_region() region_name = region.name - # Azure does not support specific zones. - zones = [] r = resources assert not r.use_spot, \ @@ -208,7 +211,8 @@ def make_deploy_resources_variables( 'custom_resources': custom_resources, 'use_spot': r.use_spot, 'region': region_name, - 'zones': zones, + # Azure does not support specific zones. + 'zones': None, **image_config } @@ -275,11 +279,8 @@ def check_credentials(self) -> Tuple[bool, Optional[str]]: except subprocess.CalledProcessError: return False, ( # TODO(zhwu): Change the installation hint to from PyPI. - 'Azure CLI returned error. Run the following commands in the SkyPilot codebase:' - '\n $ pip install skypilot[azure] # if installed from ' - 'PyPI' - '\n Or:' - '\n $ pip install .[azure] # if installed from source' + 'Azure CLI returned error. Run the following commands:' + '\n $ pip install skypilot[azure]' '\n Credentials may also need to be set.' + help_str) # If Azure is properly logged in, this will return the account email # address + subscription ID. @@ -311,21 +312,38 @@ def accelerator_in_region_or_zone(self, def get_current_user_identity(self) -> Optional[str]: """Returns the cloud user identity.""" # This returns the user's email address + [subscription_id]. + retry_cnt = 0 + while True: + retry_cnt += 1 + try: + import knack # pylint: disable=import-outside-toplevel + account_email = azure.get_current_account_user() + break + except (FileNotFoundError, knack.util.CLIError) as e: + error = exceptions.CloudUserIdentityError( + 'Failed to get activated Azure account.\n' + ' Reason: ' + f'{common_utils.format_exception(e, use_bracket=True)}') + if retry_cnt <= _MAX_IDENTITY_FETCH_RETRY: + logger.debug(f'{error}.\nRetrying...') + continue + with ux_utils.print_exception_no_traceback(): + raise error from None + except Exception as e: # pylint: disable=broad-except + with ux_utils.print_exception_no_traceback(): + raise exceptions.CloudUserIdentityError( + 'Failed to get Azure user identity with unknown ' + f'exception.\n' + ' Reason: ' + f'{common_utils.format_exception(e, use_bracket=True)}' + ) from e try: - import knack # pylint: disable=import-outside-toplevel - account_email = azure.get_current_account_user() - except (FileNotFoundError, knack.util.CLIError): - with ux_utils.print_exception_no_traceback(): - raise exceptions.CloudUserIdentityError( - 'Failed to get activated Azure account.') from None - except Exception as e: # pylint: disable=broad-except + project_id = self.get_project_id() + except (ModuleNotFoundError, RuntimeError) as e: with ux_utils.print_exception_no_traceback(): raise exceptions.CloudUserIdentityError( - 'Failed to get Azure user identity with unknown ' - f'exception.\n' - f' Reason: [{common_utils.class_fullname(e.__class__)}] ' - f'{e}') from e - return f'{account_email} [subscription_id={self.get_project_id()}]' + 'Failed to get Azure project ID.') from e + return f'{account_email} [subscription_id={project_id}]' @classmethod def get_project_id(cls, dryrun: bool = False) -> str: diff --git a/sky/clouds/cloud.py b/sky/clouds/cloud.py index b61dab7dd1f..1a0341b0992 100644 --- a/sky/clouds/cloud.py +++ b/sky/clouds/cloud.py @@ -1,7 +1,7 @@ """Interfaces: clouds, regions, and zones.""" import collections import typing -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Tuple, Type from sky.clouds import service_catalog from sky.utils import ux_utils @@ -40,7 +40,7 @@ def from_str(self, name: Optional[str]) -> Optional['Cloud']: f'{list(self.keys())}') return self.get(name.lower()) - def register(self, cloud_cls: 'Cloud') -> None: + def register(self, cloud_cls: Type['Cloud']) -> Type['Cloud']: name = cloud_cls.__name__.lower() assert name not in self, f'{name} already registered' self[name] = cloud_cls() @@ -67,7 +67,7 @@ def region_zones_provision_loop( *, instance_type: Optional[str] = None, accelerators: Optional[Dict[str, int]] = None, - use_spot: Optional[bool] = False, + use_spot: bool = False, ) -> Iterator[Tuple[Region, List[Zone]]]: """Loops over (region, zones) to retry for provisioning. @@ -127,7 +127,7 @@ def make_deploy_resources_variables( resources: 'resources.Resources', region: Optional['Region'], zones: Optional[List['Zone']], - ) -> Dict[str, str]: + ) -> Dict[str, Optional[str]]: """Converts planned sky.Resources to cloud-specific resource variables. These variables are used to fill the node type section (instance type, diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 7145d989412..fda24b42855 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -136,7 +136,7 @@ def region_zones_provision_loop( *, instance_type: Optional[str] = None, accelerators: Optional[Dict[str, int]] = None, - use_spot: Optional[bool] = False, + use_spot: bool = False, ) -> Iterator[Tuple[clouds.Region, List[clouds.Zone]]]: # GCP provisioner currently takes 1 zone per request. if accelerators is None: @@ -239,7 +239,7 @@ def _get_default_region(cls) -> clouds.Region: def make_deploy_resources_variables( self, resources: 'resources.Resources', region: Optional['clouds.Region'], - zones: Optional[List['clouds.Zone']]) -> Dict[str, str]: + zones: Optional[List['clouds.Zone']]) -> Dict[str, Optional[str]]: if region is None: assert zones is None, ( 'Set either both or neither for: region, zones.') @@ -250,7 +250,7 @@ def make_deploy_resources_variables( 'Set either both or neither for: region, zones.') region_name = region.name - zones = [zones[0].name] + zone_name = zones[0].name # gcloud compute images list \ # --project deeplearning-platform-release \ @@ -265,7 +265,7 @@ def make_deploy_resources_variables( resources_vars = { 'instance_type': r.instance_type, 'region': region_name, - 'zones': ','.join(zones), + 'zones': zone_name, 'gpu': None, 'gpu_count': None, 'tpu': None, @@ -384,7 +384,7 @@ def check_credentials(self) -> Tuple[bool, Optional[str]]: """Checks if the user has access credentials to this cloud.""" try: # pylint: disable=import-outside-toplevel,unused-import - from google import auth + from google import auth # type: ignore # Check google-api-python-client installation. import googleapiclient @@ -510,8 +510,9 @@ def get_current_user_identity(self) -> Optional[str]: raise exceptions.CloudUserIdentityError( f'Failed to get GCP user identity with unknown ' f'exception.\n' - f' Reason: [{common_utils.class_fullname(e.__class__)}] ' - f'{e}') from e + ' Reason: ' + f'{common_utils.format_exception(e, use_bracket=True)}' + ) from e if not account: with ux_utils.print_exception_no_traceback(): raise exceptions.CloudUserIdentityError( @@ -519,7 +520,16 @@ def get_current_user_identity(self) -> Optional[str]: 'auth list --filter=status:ACTIVE ' '--format="value(account)"` and ensure it correctly ' 'returns the current user.') - return f'{account} [project_id={self.get_project_id()}]' + try: + return f'{account} [project_id={self.get_project_id()}]' + except Exception as e: # pylint: disable=broad-except + with ux_utils.print_exception_no_traceback(): + raise exceptions.CloudUserIdentityError( + f'Failed to get GCP user identity with unknown ' + f'exception.\n' + ' Reason: ' + f'{common_utils.format_exception(e, use_bracket=True)}' + ) from e def instance_type_exists(self, instance_type): return service_catalog.instance_type_exists(instance_type, 'gcp') @@ -548,7 +558,8 @@ def need_cleanup_after_preemption(self, def get_project_id(cls, dryrun: bool = False) -> str: if dryrun: return 'dryrun-project-id' - from google import auth # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + from google import auth # type: ignore _, project_id = auth.default() return project_id diff --git a/sky/clouds/local.py b/sky/clouds/local.py index a49349c2985..fe91d3b97c0 100644 --- a/sky/clouds/local.py +++ b/sky/clouds/local.py @@ -48,7 +48,7 @@ def region_zones_provision_loop( *, instance_type: Optional[str] = None, accelerators: Optional[Dict[str, int]] = None, - use_spot: bool, + use_spot: bool = False, ) -> Iterator[Tuple[clouds.Region, List[clouds.Zone]]]: del instance_type del use_spot @@ -103,7 +103,7 @@ def get_accelerators_from_instance_type( def make_deploy_resources_variables( self, resources: 'resources_lib.Resources', region: Optional['clouds.Region'], - zones: Optional[List['clouds.Zone']]) -> Dict[str, str]: + zones: Optional[List['clouds.Zone']]) -> Dict[str, Optional[str]]: return {} def get_feasible_launchable_resources(self, diff --git a/sky/setup_files/setup.py b/sky/setup_files/setup.py index 8b960afef65..ea184e3ac46 100644 --- a/sky/setup_files/setup.py +++ b/sky/setup_files/setup.py @@ -101,6 +101,7 @@ def parse_readme(readme: str) -> str: # packages dependencies are changed. extras_require = { 'aws': [ + # awscli>=1.27.10 is required for SSO support. 'awscli', 'boto3', # 'Crypto' module used in authentication.py for AWS. diff --git a/sky/skylet/providers/aws/__init__.py b/sky/skylet/providers/aws/__init__.py index e71ef477e46..5e592213448 100644 --- a/sky/skylet/providers/aws/__init__.py +++ b/sky/skylet/providers/aws/__init__.py @@ -1,2 +1,2 @@ """AWS node provider""" -from sky.skylet.providers.aws.node_provider import AWSNodeProvider +from sky.skylet.providers.aws.node_provider import AWSNodeProvider, AWSNodeProviderV2 diff --git a/sky/skylet/providers/aws/config.py b/sky/skylet/providers/aws/config.py index 6daf8c951ee..876198e07a2 100644 --- a/sky/skylet/providers/aws/config.py +++ b/sky/skylet/providers/aws/config.py @@ -32,6 +32,11 @@ DEFAULT_RAY_IAM_ROLE = RAY + "-v1" SECURITY_GROUP_TEMPLATE = RAY + "-{}" +SKYPILOT = "skypilot" +DEFAULT_SKYPILOT_INSTANCE_PROFILE = SKYPILOT + "-v1" +DEFAULT_SKYPILOT_IAM_ROLE = SKYPILOT + "-v1" + + # V61.0 has CUDA 11.2 DEFAULT_AMI_NAME = "AWS Deep Learning AMI (Ubuntu 18.04) V61.0" @@ -62,6 +67,9 @@ def key_pair(i, region, key_name): If key_name is not None, key_pair will be named after key_name. Returns the ith default (aws_key_pair_name, key_pair_path). """ + # SkyPilot: we don't use this, as we explicitly set the key already. + # For backwards compatibility, we'll just return the key pair with + # the previous name. if i == 0: key_pair_name = "{}_{}".format(RAY, region) if key_name is None else key_name return ( @@ -215,7 +223,7 @@ def print_info( cli_logger.newline() -def bootstrap_aws(config): +def bootstrap_aws(config, skypilot_iam_role: bool = False): # create a copy of the input config to modify config = copy.deepcopy(config) @@ -235,7 +243,9 @@ def bootstrap_aws(config): # The head node needs to have an IAM role that allows it to create further # EC2 instances. - config = _configure_iam_role(config) + # If skypilot_iam_role is True, we use our own IAM role for both head and + # workers. + config = _configure_iam_role(config, skypilot_iam_role=skypilot_iam_role) # Configure SSH access, using an existing key pair if possible. config = _configure_key_pair(config) @@ -257,17 +267,28 @@ def bootstrap_aws(config): return config -def _configure_iam_role(config): +def _configure_iam_role(config, skypilot_iam_role: bool): + default_instance_profile = DEFAULT_RAY_INSTANCE_PROFILE + default_iam_role = DEFAULT_RAY_IAM_ROLE + if skypilot_iam_role: + default_instance_profile = DEFAULT_SKYPILOT_INSTANCE_PROFILE + default_iam_role = DEFAULT_SKYPILOT_IAM_ROLE + head_node_type = config["head_node_type"] head_node_config = config["available_node_types"][head_node_type]["node_config"] if "IamInstanceProfile" in head_node_config: _set_config_info(head_instance_profile_src="config") + if skypilot_iam_role: + # SkyPilot: let the workers use the same role as the head node, so that they + # can access private S3 buckets. + for node_type in config["available_node_types"].values(): + node_type["node_config"]["IamInstanceProfile"] = head_node_config['IamInstanceProfile'] return config _set_config_info(head_instance_profile_src="default") instance_profile_name = cwh.resolve_instance_profile_name( config["provider"], - DEFAULT_RAY_INSTANCE_PROFILE, + default_instance_profile, ) profile = _get_instance_profile(instance_profile_name, config) @@ -287,7 +308,7 @@ def _configure_iam_role(config): assert profile is not None, "Failed to create instance profile" if not profile.roles: - role_name = cwh.resolve_iam_role_name(config["provider"], DEFAULT_RAY_IAM_ROLE) + role_name = cwh.resolve_iam_role_name(config["provider"], default_iam_role) role = _get_role(role_name, config) if role is None: cli_logger.verbose( @@ -312,6 +333,7 @@ def _configure_iam_role(config): "arn:aws:iam::aws:policy/AmazonS3FullAccess", ], ) + iam.create_role( RoleName=role_name, AssumeRolePolicyDocument=json.dumps(policy_doc) @@ -326,11 +348,42 @@ def _configure_iam_role(config): for policy_arn in attach_policy_arns: role.attach_policy(PolicyArn=policy_arn) + # SkyPilot: "PassRole" is required by the head node to pass the role to + # the workers, so we can access S3 buckets on the workers. "Resource" + # is to limit the role to only able to pass itself to the workers. + skypilot_pass_role_policy_doc = { + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "iam:GetRole", + "iam:PassRole", + ], + "Resource": role.arn, + }, + { + "Effect": "Allow", + "Action": "iam:GetInstanceProfile", + "Resource": profile.arn, + } + ] + } + if skypilot_iam_role: + role.Policy("SkyPilotPassRolePolicy").put( + PolicyDocument=json.dumps(skypilot_pass_role_policy_doc) + ) + profile.add_role(RoleName=role.name) time.sleep(15) # wait for propagation - # Add IAM role to "head_node" field so that it is applied only to - # the head node -- not to workers with the same node type as the head. - config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn} + if skypilot_iam_role: + # SkyPilot: let the workers use the same role as the head node, so that they + # can access private S3 buckets. + for node_type in config["available_node_types"].values(): + node_type["node_config"]["IamInstanceProfile"] = {"Arn": profile.arn} + else: + # Add IAM role to "head_node" field so that it is applied only to + # the head node -- not to workers with the same node type as the head. + config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn} return config diff --git a/sky/skylet/providers/aws/node_provider.py b/sky/skylet/providers/aws/node_provider.py index 7203ab70f84..3397ba58cec 100644 --- a/sky/skylet/providers/aws/node_provider.py +++ b/sky/skylet/providers/aws/node_provider.py @@ -97,6 +97,10 @@ def list_ec2_instances( class AWSNodeProvider(NodeProvider): + """Deprecated for SkyPilot and kept for backward compatibility. + + The cluster launch template has been updated to use AWSNodeProviderV2. + """ max_terminate_nodes = 1000 def __init__(self, provider_config, cluster_name): @@ -662,3 +666,21 @@ def fillout_available_node_types_resources( + "." ) return cluster_config + + +class AWSNodeProviderV2(AWSNodeProvider): + """Same as V1, except head and workers use a SkyPilot IAM role. + + The new version of the AWS node provider supports AWS SSO + (see #1489), by using a new IAM role with different permissions + than the original ray-autoscaler-v1 for both the head node and + worker nodes. + + We did not overwrite the original AWSNodeProvider class to avoid + breaking existing clusters. Otherwise, the existing clusters will + have a new launch_hash and will have new node(s) launched, causing + the existing nodes to leak. + """ + @staticmethod + def bootstrap_config(cluster_config): + return bootstrap_aws(cluster_config, skypilot_iam_role=True) diff --git a/sky/spot/controller.py b/sky/spot/controller.py index 40be1d2933c..3937f768dd8 100644 --- a/sky/spot/controller.py +++ b/sky/spot/controller.py @@ -179,7 +179,7 @@ def run(self): except (Exception, SystemExit) as e: # pylint: disable=broad-except logger.error(traceback.format_exc()) logger.error('Unexpected error occurred: ' - f'{common_utils.class_fullname(e.__class__)}: {e}') + f'{common_utils.format_exception(e)}') finally: self._strategy_executor.terminate_cluster() job_status = spot_state.get_status(self._job_id) diff --git a/sky/spot/spot_utils.py b/sky/spot/spot_utils.py index cf596da49da..f6630de0272 100644 --- a/sky/spot/spot_utils.py +++ b/sky/spot/spot_utils.py @@ -543,7 +543,7 @@ def is_spot_controller_up( f'Failed to get the status of the spot controller. ' 'It is not fatal, but spot commands/calls may hang or return stale ' 'information, when the controller is not up.\n' - f' Details: [{common_utils.class_fullname(e.__class__)}]{e}') + f' Details: {common_utils.format_exception(e, use_bracket=True)}') record = global_user_state.get_cluster_from_name(SPOT_CONTROLLER_NAME) controller_status, handle = None, None if record is not None: diff --git a/sky/templates/aws-ray.yml.j2 b/sky/templates/aws-ray.yml.j2 index 7be6c4081d0..4bce0c58878 100644 --- a/sky/templates/aws-ray.yml.j2 +++ b/sky/templates/aws-ray.yml.j2 @@ -7,7 +7,7 @@ idle_timeout_minutes: 60 provider: type: external - module: sky.skylet.providers.aws.AWSNodeProvider + module: sky.skylet.providers.aws.AWSNodeProviderV2 region: {{region}} availability_zone: {{zones}} # Keep (otherwise cannot reuse when re-provisioning). diff --git a/sky/utils/common_utils.py b/sky/utils/common_utils.py index ae13aba92e5..dc35ed761cb 100644 --- a/sky/utils/common_utils.py +++ b/sky/utils/common_utils.py @@ -312,6 +312,20 @@ def class_fullname(cls): return f'{cls.__module__}.{cls.__name__}' +def format_exception(e: Exception, use_bracket: bool = False) -> str: + """Format an exception to a string. + + Args: + e: The exception to format. + + Returns: + A string that represents the exception. + """ + if use_bracket: + return f'[{class_fullname(e.__class__)}]: {e}' + return f'{class_fullname(e.__class__)}: {e}' + + def remove_color(s: str): """Remove color from a string. diff --git a/tests/backward_comaptibility_tests.sh b/tests/backward_comaptibility_tests.sh index 4e0ee9b24d2..84d8f554e85 100755 --- a/tests/backward_comaptibility_tests.sh +++ b/tests/backward_comaptibility_tests.sh @@ -1,3 +1,7 @@ +# This script is used to test backward compatibility of skypilot. +# To run this script, you need to uninstall the skypilot and ray in the base +# conda environment, and run it in the base conda environment. + #!/bin/bash set -ev @@ -19,6 +23,7 @@ conda install -c conda-forge google-cloud-sdk -y rm -r ~/.sky/wheels || true cd ../sky-master git pull origin master +pip uninstall -y skypilot pip install -e ".[all]" cd - diff --git a/tests/mypy_files.txt b/tests/mypy_files.txt index 1fb32220f49..a435c206691 100644 --- a/tests/mypy_files.txt +++ b/tests/mypy_files.txt @@ -1 +1,4 @@ sky/data/storage.py +sky/clouds +--exclude +sky/clouds/service_catalog diff --git a/tests/run_smoke_tests.sh b/tests/run_smoke_tests.sh index 1738d7e97d0..3110be313c6 100755 --- a/tests/run_smoke_tests.sh +++ b/tests/run_smoke_tests.sh @@ -7,16 +7,22 @@ # # Re-run a failed test # bash tests/run_smoke_tests.sh test_azure_start_stop # +# # Run slow tests +# bash tests/run_smoke_tests.sh --runslow +# +# # Run SSO tests +# bash tests/run_smoke_tests.sh --sso test=${1:-""} if [ -z "$test" ] then test_spec=tests/test_smoke.py +elif [[ "$test" == "--*" ]] +then + [[ "$test" == "--runslow" ]] || echo "Unknown option: $test" + test_spec="$test tests/test_smoke.py" else test_spec=tests/test_smoke.py::"${test}" fi pytest -s -n 16 -q --tb=short --disable-warnings "$test_spec" - -# To run all tests including the slow ones, add the --runslow flag: -# pytest --runslow -s -n 16 -q --tb=short --disable-warnings tests/test_smoke.py diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 4e0491839fc..68d15dc6523 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -563,7 +563,7 @@ def test_tpu_vm(): f'sky logs {name} 1', # Ensure the job finished. f'sky logs {name} 1 --status', # Ensure the job succeeded. f'sky stop -y {name}', - f'sky status --refresh | grep {name} | grep STOPPED', # Ensure the cluster is STOPPED. + f's=$(sky status --refresh); printf "$s"; echo; echo; printf "$s" | grep {name} | grep STOPPED', # Ensure the cluster is STOPPED. # Use retry: guard against transient errors observed for # just-stopped TPU VMs (#962). f'sky start --retry-until-up -y {name}', @@ -684,11 +684,11 @@ def test_autostop(): # Ensure the cluster is not stopped early. 'sleep 45', - f'sky status --refresh | grep {name} | grep UP', + f's=$(sky status --refresh); printf "$s"; echo; echo; printf "$s" | grep {name} | grep UP', # Ensure the cluster is STOPPED. - 'sleep 90', - f'sky status --refresh | grep {name} | grep STOPPED', + 'sleep 100', + f's=$(sky status --refresh); printf "$s"; echo; echo; printf "$s" | grep {name} | grep STOPPED', # Ensure the cluster is UP and the autostop setting is reset ('-'). f'sky start -y {name}', @@ -704,9 +704,9 @@ def test_autostop(): f'sky autostop -y {name} --cancel', f'sky autostop -y {name} -i 1', # Should restart the timer. 'sleep 45', - f'sky status --refresh | grep {name} | grep UP', - 'sleep 90', - f'sky status --refresh | grep {name} | grep STOPPED', + f's=$(sky status --refresh); printf "$s"; echo; echo; printf "$s" | grep {name} | grep UP', + 'sleep 100', + f's=$(sky status --refresh); printf "$s"; echo; echo; printf "$s" | grep {name} | grep STOPPED', # Test restarting the idleness timer via exec: f'sky start -y {name}', @@ -715,9 +715,9 @@ def test_autostop(): 'sleep 45', # Almost reached the threshold. f'sky exec {name} echo hi', # Should restart the timer. 'sleep 45', - f'sky status --refresh | grep {name} | grep UP', + f's=$(sky status --refresh); printf "$s"; echo; echo; printf "$s" | grep {name} | grep UP', 'sleep 90', - f'sky status --refresh | grep {name} | grep STOPPED', + f's=$(sky status --refresh); printf "$s"; echo; echo; printf "$s" | grep {name} | grep STOPPED', ], f'sky down -y {name}', timeout=20 * 60, @@ -737,7 +737,7 @@ def test_autodown(): f'sky status | grep {name} | grep "1m (down)"', # Ensure the cluster is not terminated early. 'sleep 45', - f'sky status --refresh | grep {name} | grep UP', + f's=$(sky status --refresh); printf "$s"; echo; echo; printf "$s" | grep {name} | grep UP', # Ensure the cluster is terminated. 'sleep 200', f's=$(SKYPILOT_DEBUG=0 sky status --refresh) && printf "$s" && {{ echo "$s" | grep {name} | grep "Autodowned cluster\|terminated on the cloud"; }} || {{ echo "$s" | grep {name} && exit 1 || exit 0; }}',