diff --git a/test/e2e/common/aws.py b/test/e2e/common/aws.py index 6200fa68fb..fe271e525f 100644 --- a/test/e2e/common/aws.py +++ b/test/e2e/common/aws.py @@ -29,4 +29,31 @@ def duplicate_s3_contents(source_bucket: object, destination_bucket: object): destination_bucket.copy({ "Bucket": source_object.bucket_name, "Key": source_object.key, - }, source_object.key) \ No newline at end of file + }, source_object.key) + +def copy_s3_object(bucket_name: str, copy_source: object, key: str): + """ + Copy an S3 object. Check the following API documentation for input format of the arguments + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Bucket.copy + """ + region = get_aws_region() + bucket = boto3.resource("s3", region_name=region).Bucket(bucket_name) + bucket.copy(copy_source, key) + +def delete_s3_object(bucket_name: str, key: str): + """ + Delete an S3 object. Check the following API documentation for input format of the arguments + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.delete_objects + """ + region = get_aws_region() + bucket = boto3.resource("s3", region_name=region).Bucket(bucket_name) + bucket.delete_objects( + Delete={ + "Objects": [ + { + "Key": key, + }, + ], + "Quiet": False, + }, + ) diff --git a/test/e2e/common/k8s.py b/test/e2e/common/k8s.py index 1c9645e03d..0a213ec0c2 100644 --- a/test/e2e/common/k8s.py +++ b/test/e2e/common/k8s.py @@ -23,6 +23,8 @@ from kubernetes.client.api_client import ApiClient from kubernetes.client.rest import ApiException +from common.resources import load_resource_file + _k8s_api_client = None @@ -53,6 +55,78 @@ def to_short_resource_string(self): def to_long_resource_string(self): return f"{self.plural}.{self.version}.{self.group}/{self._printable_namespace}:{self.name}" +def load_resource(service_name: str, + spec_file: str, + replacements: object): + """ + Load a yaml spec to memory from root_test_path/{service}/resources and replace the values in replacement dict + + :param service_name: name of service + :param spec_file: Name of the spec file under resources directory of the service + :param replacements: A dictionary of values to be replaced + + :return: spec as json object + """ + spec = load_resource_file( + service_name, spec_file, additional_replacements=replacements + ) + logging.debug(f"loaded spec: {spec}") + return spec + +def create_reference(crd_group: str, + crd_version: str, + resource_plural: str, + resource_name: str, + namespace: str): + """ + Create an instance of CustomResourceReference based on the parameters + + :param crd_group: CRD Group + :param crd_version: CRD version + :param resource_plural: resource plural + :param resource_name: name of resource to be created in cluster + :param namespace: namespace in which resource should be created + + :return: an instance of CustomResourceReference + """ + reference = CustomResourceReference( + crd_group, crd_version, resource_plural, resource_name, namespace=namespace + ) + return reference + +def create_resource(reference: CustomResourceReference, + spec: object): + """ + Create a resource from the reference and wait to be consumed by controller + + :param reference: instance of CustomResourceReference which needs to be created + :param spec: spec of the resource corresponding to the reference + + :return: resource if it was created successfully, otherwise None + """ + resource = create_custom_resource(reference, spec) + resource = wait_resource_consumed_by_controller(reference) + return resource + +def load_and_create_resource(service_name: str, + crd_group: str, + crd_version: str, + resource_plural: str, + resource_name: str, + spec_file_name: str, + replacements: object, + namespace: str = "default"): + """ + Helper method to encapsulate the common methods used to create a resource. + Load a spec file from disk, create an instance of CustomResourceReference and resource in K8s cluster. + See respective methods for paramater definitions and return types + + :returns: an instance of CustomResourceReference, spec loaded from disk, resource created from the reference + """ + spec = load_resource(service_name, spec_file_name, replacements) + reference = create_reference(crd_group, crd_version, resource_plural, resource_name, namespace) + resource = create_resource(reference, spec) + return reference, spec, resource def _get_k8s_api_client() -> ApiClient: global _k8s_api_client @@ -167,22 +241,6 @@ def wait_resource_consumed_by_controller( f"Wait for resource {reference} to be consumed by controller timed out") return None -def _get_terminal_condition(resource: object) -> Union[None, bool]: - """Get the .status.ACK.Terminal boolean from a given resource. - - Returns: - None or bool: None if the status field doesn't exist, otherwise the - field value cast to a boolean (default False). - """ - if 'status' not in resource or 'conditions' not in resource['status']: - return None - - conditions: Dict = resource['status']['conditions'] - if 'ACK' not in conditions or 'Terminal' not in conditions['ACK']: - return None - - terminal: Dict = conditions['ACK']['Terminal'] - return bool(terminal.get('status', False)) def get_resource_arn(resource: object) -> Union[None, str]: """Get the .status.ackResourceMetadata.arn value from a given resource. @@ -253,54 +311,71 @@ def wait_on_condition(reference: CustomResourceReference, logging.error(f"Resource {reference} does not exist") return False + desired_condition = None for i in range(wait_periods): logging.debug(f"Waiting on condition {condition_name} to reach {desired_condition_status} for resource {reference} ({i})") - resource = get_resource(reference) - if 'conditions' not in resource['status']: - logging.error(f"Resource {reference} does not have a .status.conditions field.") - return False - - desired_condition = None - for condition in resource['status']['conditions']: - if condition['type'] == condition_name: - desired_condition = condition - - if not desired_condition: - logging.error(f"Resource {reference} does not have a condition of type {condition_name}.") - return False - else: - if desired_condition['status'] == desired_condition_status: - logging.info(f"Condition {condition_name} has status {desired_condition_status}, continuing...") - return True + desired_condition = get_resource_condition(reference, condition_name) + if desired_condition is not None and desired_condition['status'] == desired_condition_status: + logging.info(f"Condition {condition_name} has status {desired_condition_status}, continuing...") + return True sleep(period_length) - logging.error(f"Wait for condition {condition_name} to reach status {desired_condition_status} timed out") + if not desired_condition: + logging.error(f"Resource {reference} does not have a condition of type {condition_name}.") + else: + logging.error(f"Wait for condition {condition_name} to reach status {desired_condition_status} timed out") return False -def is_resource_in_terminal_condition( - reference: CustomResourceReference, expected_substring: str): +def get_resource_condition(reference: CustomResourceReference, condition_name: str): + """ + Returns the required condition from .status.conditions + + Precondition: + resource must exist in the cluster + + Returns: + condition json if it exists. None otherwise + """ if not get_resource_exists(reference): logging.error(f"Resource {reference} does not exist") - return False + return None resource = get_resource(reference) - terminal_status = _get_terminal_condition(resource) - # Ensure the status existed - if terminal_status is None: - logging.error(f"Expected .ACK.Terminal to exist in {reference}") - return False + if 'status' not in resource or 'conditions' not in resource['status']: + logging.error(f"Resource {reference} does not have a .status.conditions field.") + return None - if not terminal_status: - logging.error( - f"Expected terminal condition for resource {reference} to be true") - return False + for condition in resource['status']['conditions']: + if condition['type'] == condition_name: + return condition + + return None + +def assert_condition_state_message(reference: CustomResourceReference, + condition_name: str, + desired_condition_status: str, + desired_condition_message: Union[None, str]): + """ + Helper method to check the state and message of a condition on resource. + Caller can pass None for desired_condition_message if expected message is nil - terminal_message = terminal.get('message', None) - if terminal_message != expected_substring: - logging.error(f"Resource {reference} has terminal condition set True, but with a different message than expected." - f" Expected '{expected_substring}', found '{terminal_message}'") + Returns: + bool: True if condition exists and both the status and message match the desired values + """ + condition = get_resource_condition(reference, condition_name) + # Ensure the status existed + if condition is None: + logging.error(f"Resource {reference} does not have a condition of type {condition_name}") return False - return True + current_condition_status = condition.get('status', None) + current_condition_message = condition.get('message', None) + if current_condition_status == desired_condition_status and current_condition_message == desired_condition_message: + logging.info(f"Condition {condition_name} has status {desired_condition_status} and message {desired_condition_message}, continuing...") + return True + + logging.error(f"Resource {reference} has {condition_name} set {current_condition_status}, expected {desired_condition_status}; with message" + f" {current_condition_message}, expected {desired_condition_message}") + return False diff --git a/test/e2e/sagemaker/__init__.py b/test/e2e/sagemaker/__init__.py index cb265c431c..bae0c10385 100644 --- a/test/e2e/sagemaker/__init__.py +++ b/test/e2e/sagemaker/__init__.py @@ -4,7 +4,7 @@ # not use this file except in compliance with the License. A copy of the # License is located at # -# http://aws.amazon.com/apache2.0/ +# http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either @@ -12,14 +12,37 @@ # permissions and limitations under the License. import pytest +import logging +from common import k8s SERVICE_NAME = "sagemaker" CRD_GROUP = "sagemaker.services.k8s.aws" CRD_VERSION = "v1alpha1" -CONFIG_RESOURCE_PLURAL = 'endpointconfigs' -MODEL_RESOURCE_PLURAL = 'models' -ENDPOINT_RESOURCE_PLURAL = 'endpoints' +CONFIG_RESOURCE_PLURAL = "endpointconfigs" +MODEL_RESOURCE_PLURAL = "models" +ENDPOINT_RESOURCE_PLURAL = "endpoints" # PyTest marker for the current service service_marker = pytest.mark.service(arg=SERVICE_NAME) + + +def create_sagemaker_resource( + resource_plural, resource_name, spec_file, replacements, namespace="default" +): + """ + Wrapper around k8s.load_and_create_resource to create a SageMaker resource + """ + + reference, spec, resource = k8s.load_and_create_resource( + SERVICE_NAME, + CRD_GROUP, + CRD_VERSION, + resource_plural, + resource_name, + spec_file, + replacements, + namespace, + ) + + return reference, spec, resource diff --git a/test/e2e/sagemaker/resources/endpoint_config_multi_variant.yaml b/test/e2e/sagemaker/resources/endpoint_config_multi_variant.yaml new file mode 100644 index 0000000000..764fcbcffc --- /dev/null +++ b/test/e2e/sagemaker/resources/endpoint_config_multi_variant.yaml @@ -0,0 +1,18 @@ +apiVersion: sagemaker.services.k8s.aws/v1alpha1 +kind: EndpointConfig +metadata: + name: $CONFIG_NAME +spec: + endpointConfigName: $CONFIG_NAME + productionVariants: + - variantName: variant-1 + modelName: $MODEL_NAME + initialInstanceCount: 1 + # This is the smallest instance type which will support scaling + instanceType: ml.c5.large + initialVariantWeight: 1 + - variantName: variant-2 + modelName: $MODEL_NAME + initialInstanceCount: 1 + instanceType: ml.c5.large + initialVariantWeight: 1 \ No newline at end of file diff --git a/test/e2e/sagemaker/resources/endpoint_config_single_variant.yaml b/test/e2e/sagemaker/resources/endpoint_config_single_variant.yaml index 65ed4426e0..9d1540d044 100644 --- a/test/e2e/sagemaker/resources/endpoint_config_single_variant.yaml +++ b/test/e2e/sagemaker/resources/endpoint_config_single_variant.yaml @@ -7,7 +7,8 @@ spec: productionVariants: - variantName: variant-1 modelName: $MODEL_NAME - initialInstanceCount: 1 + # instanceCount is 2 to test retainAllVariantProperties + initialInstanceCount: 2 # This is the smallest instance type which will support scaling instanceType: ml.c5.large initialVariantWeight: 1 diff --git a/test/e2e/sagemaker/resources/xgboost_model.yaml b/test/e2e/sagemaker/resources/xgboost_model.yaml index 3cc58c93be..4ec4190fab 100644 --- a/test/e2e/sagemaker/resources/xgboost_model.yaml +++ b/test/e2e/sagemaker/resources/xgboost_model.yaml @@ -8,7 +8,7 @@ spec: containerHostname: xgboost modelDataURL: s3://$SAGEMAKER_DATA_BUCKET/sagemaker/model/xgboost-mnist-model.tar.gz image: $XGBOOST_IMAGE_URI - executionRoleARN: $SAGEMAKER_EXECUTION_ROLE_ARN - tags: - - key: key - value: value \ No newline at end of file + environment: + my_var: my_value + my_var2: my_value2 + executionRoleARN: $SAGEMAKER_EXECUTION_ROLE_ARN \ No newline at end of file diff --git a/test/e2e/sagemaker/resources/xgboost_model_with_model_location.yaml b/test/e2e/sagemaker/resources/xgboost_model_with_model_location.yaml new file mode 100644 index 0000000000..46eac0a5ef --- /dev/null +++ b/test/e2e/sagemaker/resources/xgboost_model_with_model_location.yaml @@ -0,0 +1,16 @@ +apiVersion: sagemaker.services.k8s.aws/v1alpha1 +kind: Model +metadata: + name: $MODEL_NAME +spec: + modelName: $MODEL_NAME + containers: + - containerHostname: xgboost + modelDataURL: $MODEL_LOCATION + image: $XGBOOST_IMAGE_URI + imageConfig: + repositoryAccessMode: Platform + environment: + my_var: my_value + my_var2: my_value2 + executionRoleARN: $SAGEMAKER_EXECUTION_ROLE_ARN \ No newline at end of file diff --git a/test/e2e/sagemaker/resources/xgboost_trainingjob.yaml b/test/e2e/sagemaker/resources/xgboost_trainingjob.yaml index f3e02d6c89..ab9eaf35c0 100644 --- a/test/e2e/sagemaker/resources/xgboost_trainingjob.yaml +++ b/test/e2e/sagemaker/resources/xgboost_trainingjob.yaml @@ -41,7 +41,4 @@ spec: s3URI: s3://$SAGEMAKER_DATA_BUCKET/sagemaker/training/validation s3DataDistributionType: FullyReplicated contentType: text/csv - compressionType: None - tags: - - key: key - value: value \ No newline at end of file + compressionType: None \ No newline at end of file diff --git a/test/e2e/sagemaker/tests/test_endpoint.py b/test/e2e/sagemaker/tests/test_endpoint.py index 2e95bbb36e..5466094969 100644 --- a/test/e2e/sagemaker/tests/test_endpoint.py +++ b/test/e2e/sagemaker/tests/test_endpoint.py @@ -20,16 +20,15 @@ from typing import Dict from sagemaker import ( - SERVICE_NAME, service_marker, - CRD_GROUP, - CRD_VERSION, CONFIG_RESOURCE_PLURAL, MODEL_RESOURCE_PLURAL, ENDPOINT_RESOURCE_PLURAL, + create_sagemaker_resource, ) from sagemaker.replacement_values import REPLACEMENT_VALUES -from common.resources import load_resource_file, random_suffix_name +from common.aws import copy_s3_object, delete_s3_object +from common.resources import random_suffix_name from common import k8s @@ -39,91 +38,149 @@ def sagemaker_client(): @pytest.fixture(scope="module") -def single_variant_xgboost_endpoint(): - endpoint_resource_name = random_suffix_name("single-variant-endpoint", 32) - config1_resource_name = endpoint_resource_name + "-config" - model_resource_name = config1_resource_name + "-model" +def name_suffix(): + return random_suffix_name("xgboost-endpoint", 32) + +@pytest.fixture(scope="module") +def single_container_model(name_suffix): + model_resource_name = name_suffix + "-model" replacements = REPLACEMENT_VALUES.copy() - replacements["ENDPOINT_NAME"] = endpoint_resource_name - replacements["CONFIG_NAME"] = config1_resource_name replacements["MODEL_NAME"] = model_resource_name - model = load_resource_file( - SERVICE_NAME, "xgboost_model", additional_replacements=replacements + model_reference, model_spec, model_resource = create_sagemaker_resource( + resource_plural=MODEL_RESOURCE_PLURAL, + resource_name=model_resource_name, + spec_file="xgboost_model", + replacements=replacements, ) - logging.debug(model) + assert model_resource is not None + + yield (model_reference, model_resource) + + k8s.delete_custom_resource(model_reference) + + +@pytest.fixture(scope="module") +def multi_variant_config(name_suffix, single_container_model): + config_resource_name = name_suffix + "-multi-variant-config" + (_, model_resource) = single_container_model + model_resource_name = model_resource["spec"].get("modelName", None) + + replacements = REPLACEMENT_VALUES.copy() + replacements["CONFIG_NAME"] = config_resource_name + replacements["MODEL_NAME"] = model_resource_name - config = load_resource_file( - SERVICE_NAME, - "endpoint_config_single_variant", - additional_replacements=replacements, + config_reference, config_spec, config_resource = create_sagemaker_resource( + resource_plural=CONFIG_RESOURCE_PLURAL, + resource_name=config_resource_name, + spec_file="endpoint_config_multi_variant", + replacements=replacements, ) - logging.debug(config) + assert config_resource is not None + + yield (config_reference, config_resource) + + k8s.delete_custom_resource(config_reference) + + +@pytest.fixture(scope="module") +def single_variant_config(name_suffix, single_container_model): + config_resource_name = name_suffix + "-single-variant-config" + (_, model_resource) = single_container_model + model_resource_name = model_resource["spec"].get("modelName", None) + + replacements = REPLACEMENT_VALUES.copy() + replacements["CONFIG_NAME"] = config_resource_name + replacements["MODEL_NAME"] = model_resource_name - endpoint_spec = load_resource_file( - SERVICE_NAME, "endpoint_base", additional_replacements=replacements + config_reference, config_spec, config_resource = create_sagemaker_resource( + resource_plural=CONFIG_RESOURCE_PLURAL, + resource_name=config_resource_name, + spec_file="endpoint_config_single_variant", + replacements=replacements, ) - logging.debug(endpoint_spec) - - # Create the k8s resources - model_reference = k8s.CustomResourceReference( - CRD_GROUP, - CRD_VERSION, - MODEL_RESOURCE_PLURAL, - model_resource_name, - namespace="default", + assert config_resource is not None + + yield (config_reference, config_resource) + + k8s.delete_custom_resource(config_reference) + + +@pytest.fixture(scope="module") +def xgboost_endpoint(name_suffix, single_variant_config): + endpoint_resource_name = name_suffix + (_, config_resource) = single_variant_config + config_resource_name = config_resource["spec"].get("endpointConfigName", None) + + replacements = REPLACEMENT_VALUES.copy() + replacements["ENDPOINT_NAME"] = endpoint_resource_name + replacements["CONFIG_NAME"] = config_resource_name + + reference, spec, resource = create_sagemaker_resource( + resource_plural=ENDPOINT_RESOURCE_PLURAL, + resource_name=endpoint_resource_name, + spec_file="endpoint_base", + replacements=replacements, ) - model_resource = k8s.create_custom_resource(model_reference, model) - model_resource = k8s.wait_resource_consumed_by_controller(model_reference) - assert model_resource is not None - config1_reference = k8s.CustomResourceReference( - CRD_GROUP, - CRD_VERSION, - CONFIG_RESOURCE_PLURAL, - config1_resource_name, - namespace="default", + assert resource is not None + + yield (reference, resource, spec) + + # Delete the k8s resource if not already deleted by tests + if k8s.get_resource_exists(reference): + k8s.delete_custom_resource(reference) + + +@pytest.fixture(scope="module") +def faulty_config(name_suffix, single_container_model): + replacements = REPLACEMENT_VALUES.copy() + + # copy model data to a temp S3 location and delete it after model is created on SageMaker + model_bucket = replacements["SAGEMAKER_DATA_BUCKET"] + copy_source = { + "Bucket": model_bucket, + "Key": "sagemaker/model/xgboost-mnist-model.tar.gz", + } + model_destination_key = "sagemaker/model/delete/xgboost-mnist-model.tar.gz" + copy_s3_object(model_bucket, copy_source, model_destination_key) + + model_resource_name = name_suffix + "faulty-model" + replacements["MODEL_NAME"] = model_resource_name + replacements["MODEL_LOCATION"] = f"s3://{model_bucket}/{model_destination_key}" + model_reference, model_spec, model_resource = create_sagemaker_resource( + resource_plural=MODEL_RESOURCE_PLURAL, + resource_name=model_resource_name, + spec_file="xgboost_model_with_model_location", + replacements=replacements, ) - config1_resource = k8s.create_custom_resource(config1_reference, config) - config1_resource = k8s.wait_resource_consumed_by_controller(config1_reference) - assert config1_resource is not None - - config2_resource_name = random_suffix_name("2-single-variant-endpoint", 32) - config["metadata"]["name"] = config["spec"][ - "endpointConfigName" - ] = config2_resource_name - logging.debug(config) - config2_reference = k8s.CustomResourceReference( - CRD_GROUP, - CRD_VERSION, - CONFIG_RESOURCE_PLURAL, - config2_resource_name, - namespace="default", + assert model_resource is not None + model_resource = k8s.get_resource(model_reference) + assert ( + "ackResourceMetadata" in model_resource["status"] + and "arn" in model_resource["status"]["ackResourceMetadata"] ) - config2_resource = k8s.create_custom_resource(config2_reference, config) - config2_resource = k8s.wait_resource_consumed_by_controller(config2_reference) - assert config2_resource is not None - - endpoint_reference = k8s.CustomResourceReference( - CRD_GROUP, - CRD_VERSION, - ENDPOINT_RESOURCE_PLURAL, - endpoint_resource_name, - namespace="default", + delete_s3_object(model_bucket, model_destination_key) + + config_resource_name = name_suffix + "-faulty-config" + (_, model_resource) = single_container_model + model_resource_name = model_resource["spec"].get("modelName", None) + + replacements["CONFIG_NAME"] = config_resource_name + + config_reference, config_spec, config_resource = create_sagemaker_resource( + resource_plural=CONFIG_RESOURCE_PLURAL, + resource_name=config_resource_name, + spec_file="endpoint_config_multi_variant", + replacements=replacements, ) - endpoint_resource = k8s.create_custom_resource(endpoint_reference, endpoint_spec) - endpoint_resource = k8s.wait_resource_consumed_by_controller(endpoint_reference) - assert endpoint_resource is not None + assert config_resource is not None - yield (endpoint_reference, endpoint_resource, endpoint_spec, config2_resource_name) + yield (config_reference, config_resource) - # Delete the k8s resource if not already deleted by tests - for cr in (model_reference, config1_reference, config2_reference, endpoint_reference): - try: - k8s.delete_custom_resource(cr) - except: - pass + k8s.delete_custom_resource(model_reference) + k8s.delete_custom_resource(config_reference) @service_marker @@ -153,7 +210,7 @@ def _wait_resource_endpoint_status( self, reference: k8s.CustomResourceReference, expected_status: str, - wait_periods: int = 18, + wait_periods: int = 30, ): resource_status = None for _ in range(wait_periods): @@ -175,7 +232,7 @@ def _wait_sagemaker_endpoint_status( sagemaker_client, endpoint_name, expected_status: str, - wait_periods: int = 18, + wait_periods: int = 60, ): actual_status = None for _ in range(wait_periods): @@ -199,20 +256,16 @@ def _assert_endpoint_status_in_sync( self._wait_sagemaker_endpoint_status( sagemaker_client, endpoint_name, expected_status ) - == self._wait_resource_endpoint_status(reference, expected_status) + == self._wait_resource_endpoint_status(reference, expected_status, 2) == expected_status ) - def test_create_endpoint(self, single_variant_xgboost_endpoint): - assert k8s.get_resource_exists(single_variant_xgboost_endpoint[0]) + def create_endpoint_test(self, sagemaker_client, xgboost_endpoint): + (reference, resource, _) = xgboost_endpoint + assert k8s.get_resource_exists(reference) - def test_endpoint_has_correct_arn_and_status( - self, sagemaker_client, single_variant_xgboost_endpoint - ): - (reference, _, _, _) = single_variant_xgboost_endpoint - resource = k8s.get_resource(reference) + # endpoint has correct arn and status endpoint_name = resource["spec"].get("endpointName", None) - assert endpoint_name is not None assert ( @@ -222,41 +275,166 @@ def test_endpoint_has_correct_arn_and_status( ] ) + # endpoint transitions Creating -> InService state self._assert_endpoint_status_in_sync( sagemaker_client, endpoint_name, reference, self.status_creating ) + assert k8s.wait_on_condition(reference, "ACK.ResourceSynced", "False") + self._assert_endpoint_status_in_sync( sagemaker_client, endpoint_name, reference, self.status_inservice ) + assert k8s.wait_on_condition(reference, "ACK.ResourceSynced", "True") - def test_update_endpoint(self, sagemaker_client, single_variant_xgboost_endpoint): - ( - reference, - resource, - endpoint_spec, - config2_resource_name, - ) = single_variant_xgboost_endpoint - endpoint_spec["spec"]["endpointConfigName"] = config2_resource_name - resource = k8s.patch_custom_resource(reference, endpoint_spec) - resource = k8s.wait_resource_consumed_by_controller(reference) - assert resource is not None + def update_endpoint_failed_test( + self, sagemaker_client, single_variant_config, faulty_config, xgboost_endpoint + ): + (endpoint_reference, _, endpoint_spec) = xgboost_endpoint + (_, faulty_config_resource) = faulty_config + faulty_config_name = faulty_config_resource["spec"].get( + "endpointConfigName", None + ) + endpoint_spec["spec"]["endpointConfigName"] = faulty_config_name + endpoint_resource = k8s.patch_custom_resource(endpoint_reference, endpoint_spec) + endpoint_resource = k8s.wait_resource_consumed_by_controller(endpoint_reference) + assert endpoint_resource is not None + # endpoint transitions Updating -> InService state self._assert_endpoint_status_in_sync( - sagemaker_client, reference.name, reference, self.status_udpating + sagemaker_client, + endpoint_reference.name, + endpoint_reference, + self.status_udpating, ) + assert k8s.wait_on_condition(endpoint_reference, "ACK.ResourceSynced", "False") + endpoint_resource = k8s.get_resource(endpoint_reference) + assert ( + endpoint_resource["status"].get("lastEndpointConfigNameForUpdate", None) + == faulty_config_name + ) + self._assert_endpoint_status_in_sync( - sagemaker_client, reference.name, reference, self.status_inservice + sagemaker_client, + endpoint_reference.name, + endpoint_reference, + self.status_inservice, ) - def test_delete_endpoint(self, sagemaker_client, single_variant_xgboost_endpoint): - (reference, _, _, _) = single_variant_xgboost_endpoint - resource = k8s.get_resource(reference) + assert k8s.wait_on_condition(endpoint_reference, "ACK.ResourceSynced", "True") + assert k8s.assert_condition_state_message( + endpoint_reference, + "ACK.Terminal", + "True", + "Unable to update Endpoint. Check FailureReason", + ) + + endpoint_resource = k8s.get_resource(endpoint_reference) + assert endpoint_resource["status"].get("failureReason", None) is not None + + # additional check: endpoint using old endpoint config + (_, old_config_resource) = single_variant_config + current_config_name = endpoint_resource["status"].get( + "latestEndpointConfigName" + ) + assert ( + current_config_name is not None + and current_config_name + == old_config_resource["spec"].get("endpointConfigName", None) + ) + + def update_endpoint_successful_test( + self, sagemaker_client, multi_variant_config, xgboost_endpoint + ): + (endpoint_reference, endpoint_resource, endpoint_spec) = xgboost_endpoint + + endpoint_name = endpoint_resource["spec"].get("endpointName", None) + production_variants = self._describe_sagemaker_endpoint( + sagemaker_client, endpoint_name + )["ProductionVariants"] + old_variant_instance_count = production_variants[0]["CurrentInstanceCount"] + old_variant_name = production_variants[0]["VariantName"] + + (_, new_config_resource) = multi_variant_config + new_config_name = new_config_resource["spec"].get("endpointConfigName", None) + endpoint_spec["spec"]["endpointConfigName"] = new_config_name + endpoint_resource = k8s.patch_custom_resource(endpoint_reference, endpoint_spec) + endpoint_resource = k8s.wait_resource_consumed_by_controller(endpoint_reference) + assert endpoint_resource is not None + + # endpoint transitions Updating -> InService state + self._assert_endpoint_status_in_sync( + sagemaker_client, + endpoint_reference.name, + endpoint_reference, + self.status_udpating, + ) + + assert k8s.wait_on_condition(endpoint_reference, "ACK.ResourceSynced", "False") + assert k8s.assert_condition_state_message( + endpoint_reference, "ACK.Terminal", "False", None + ) + endpoint_resource = k8s.get_resource(endpoint_reference) + assert ( + endpoint_resource["status"].get("lastEndpointConfigNameForUpdate", None) + == new_config_name + ) + + self._assert_endpoint_status_in_sync( + sagemaker_client, + endpoint_reference.name, + endpoint_reference, + self.status_inservice, + ) + assert k8s.wait_on_condition(endpoint_reference, "ACK.ResourceSynced", "True") + assert k8s.assert_condition_state_message( + endpoint_reference, "ACK.Terminal", "False", None + ) + endpoint_resource = k8s.get_resource(endpoint_reference) + assert endpoint_resource["status"].get("failureReason", None) is None + + # RetainAllVariantProperties - variant properties were retained + is a multi-variant endpoint + new_production_variants = self._describe_sagemaker_endpoint( + sagemaker_client, endpoint_name + )["ProductionVariants"] + assert len(new_production_variants) > 1 + new_variant_instance_count = None + for variant in new_production_variants: + if variant["VariantName"] == old_variant_name: + new_variant_instance_count = variant["CurrentInstanceCount"] + + assert new_variant_instance_count == old_variant_instance_count + + def delete_endpoint_test(self, sagemaker_client, xgboost_endpoint): + (reference, resource, _) = xgboost_endpoint endpoint_name = resource["spec"].get("endpointName", None) - # Delete the k8s resource. _, deleted = k8s.delete_custom_resource(reference) assert deleted is True + # resource is removed from management from controller side if call to deleteEndpoint succeeds. + # Sagemaker also removes a 'Deleting' endpoint pretty quickly, but there might be a lag + # If we see errors in this part of test, can add a loop in future or consider changing controller + # to wait for SageMaker + time.sleep(10) assert ( self._describe_sagemaker_endpoint(sagemaker_client, endpoint_name) is None ) + + def test_driver( + self, + sagemaker_client, + single_variant_config, + faulty_config, + multi_variant_config, + xgboost_endpoint, + ): + self.create_endpoint_test(sagemaker_client, xgboost_endpoint) + self.update_endpoint_failed_test( + sagemaker_client, single_variant_config, faulty_config, xgboost_endpoint + ) + # Note: the test has been intentionally ordered to run a successful update after a failed update + # check that controller updates the endpoint, removes the terminal condition and clears the failure reason + self.update_endpoint_successful_test( + sagemaker_client, multi_variant_config, xgboost_endpoint + ) + self.delete_endpoint_test(sagemaker_client, xgboost_endpoint) diff --git a/test/e2e/sagemaker/tests/test_endpoint_config.py b/test/e2e/sagemaker/tests/test_endpoint_config.py index 8e6e2cab2d..c1155e5340 100644 --- a/test/e2e/sagemaker/tests/test_endpoint_config.py +++ b/test/e2e/sagemaker/tests/test_endpoint_config.py @@ -19,15 +19,13 @@ from typing import Dict from sagemaker import ( - SERVICE_NAME, service_marker, - CRD_GROUP, - CRD_VERSION, CONFIG_RESOURCE_PLURAL, MODEL_RESOURCE_PLURAL, + create_sagemaker_resource, ) from sagemaker.replacement_values import REPLACEMENT_VALUES -from common.resources import load_resource_file, random_suffix_name +from common.resources import random_suffix_name from common import k8s @@ -45,49 +43,28 @@ def single_variant_config(): replacements["CONFIG_NAME"] = config_resource_name replacements["MODEL_NAME"] = model_resource_name - model = load_resource_file( - SERVICE_NAME, "xgboost_model", additional_replacements=replacements + model_reference, model_spec, model_resource = create_sagemaker_resource( + resource_plural=MODEL_RESOURCE_PLURAL, + resource_name=model_resource_name, + spec_file="xgboost_model", + replacements=replacements, ) - logging.debug(model) - - config = load_resource_file( - SERVICE_NAME, - "endpoint_config_single_variant", - additional_replacements=replacements, - ) - logging.debug(config) - - # Create the k8s resources - model_reference = k8s.CustomResourceReference( - CRD_GROUP, - CRD_VERSION, - MODEL_RESOURCE_PLURAL, - model_resource_name, - namespace="default", - ) - model_resource = k8s.create_custom_resource(model_reference, model) - model_resource = k8s.wait_resource_consumed_by_controller(model_reference) assert model_resource is not None - config_reference = k8s.CustomResourceReference( - CRD_GROUP, - CRD_VERSION, - CONFIG_RESOURCE_PLURAL, - config_resource_name, - namespace="default", + config_reference, config_spec, config_resource = create_sagemaker_resource( + resource_plural=CONFIG_RESOURCE_PLURAL, + resource_name=config_resource_name, + spec_file="endpoint_config_single_variant", + replacements=replacements, ) - config_resource = k8s.create_custom_resource(config_reference, config) - config_resource = k8s.wait_resource_consumed_by_controller(config_reference) assert config_resource is not None yield (config_reference, config_resource) + k8s.delete_custom_resource(model_reference) # Delete the k8s resource if not already deleted by tests - try: - k8s.delete_custom_resource(model_reference) + if k8s.get_resource_exists(config_reference): k8s.delete_custom_resource(config_reference) - except: - pass @service_marker diff --git a/test/e2e/sagemaker/tests/test_model.py b/test/e2e/sagemaker/tests/test_model.py index be7d79c6da..adf66abb28 100644 --- a/test/e2e/sagemaker/tests/test_model.py +++ b/test/e2e/sagemaker/tests/test_model.py @@ -16,13 +16,15 @@ import boto3 import pytest import logging -import time from typing import Dict -from sagemaker import SERVICE_NAME, service_marker, CRD_GROUP, CRD_VERSION -from sagemaker import MODEL_RESOURCE_PLURAL as RESOURCE_PLURAL +from sagemaker import ( + service_marker, + create_sagemaker_resource, + MODEL_RESOURCE_PLURAL, +) from sagemaker.replacement_values import REPLACEMENT_VALUES -from common.resources import load_resource_file, random_suffix_name +from common.resources import random_suffix_name from common import k8s @@ -38,27 +40,19 @@ def xgboost_model(): replacements = REPLACEMENT_VALUES.copy() replacements["MODEL_NAME"] = resource_name - model = load_resource_file( - SERVICE_NAME, "xgboost_model", additional_replacements=replacements + reference, spec, resource = create_sagemaker_resource( + resource_plural=MODEL_RESOURCE_PLURAL, + resource_name=resource_name, + spec_file="xgboost_model", + replacements=replacements, ) - logging.debug(model) - - # Create the k8s resource - reference = k8s.CustomResourceReference( - CRD_GROUP, CRD_VERSION, RESOURCE_PLURAL, resource_name, namespace="default" - ) - resource = k8s.create_custom_resource(reference, model) - resource = k8s.wait_resource_consumed_by_controller(reference) - assert resource is not None yield (reference, resource) # Delete the k8s resource if not already deleted by tests - try: + if k8s.get_resource_exists(reference): k8s.delete_custom_resource(reference) - except: - pass @service_marker diff --git a/test/e2e/sagemaker/tests/test_processingjob.py b/test/e2e/sagemaker/tests/test_processingjob.py index 0470be529d..daee444842 100644 --- a/test/e2e/sagemaker/tests/test_processingjob.py +++ b/test/e2e/sagemaker/tests/test_processingjob.py @@ -17,11 +17,13 @@ import pytest import logging from typing import Dict -import time -from sagemaker import SERVICE_NAME, service_marker, CRD_GROUP, CRD_VERSION +from sagemaker import ( + service_marker, + create_sagemaker_resource, +) from sagemaker.replacement_values import REPLACEMENT_VALUES -from common.resources import load_resource_file, random_suffix_name +from common.resources import random_suffix_name from common import k8s RESOURCE_PLURAL = "processingjobs" @@ -39,27 +41,19 @@ def kmeans_processing_job(): replacements = REPLACEMENT_VALUES.copy() replacements["PROCESSING_JOB_NAME"] = resource_name - processing_job = load_resource_file( - SERVICE_NAME, "kmeans_processingjob", additional_replacements=replacements + reference, spec, resource = create_sagemaker_resource( + resource_plural=RESOURCE_PLURAL, + resource_name=resource_name, + spec_file="kmeans_processingjob", + replacements=replacements, ) - logging.debug(processing_job) - - # Create the k8s resource - reference = k8s.CustomResourceReference( - CRD_GROUP, CRD_VERSION, RESOURCE_PLURAL, resource_name, namespace="default" - ) - resource = k8s.create_custom_resource(reference, processing_job) - resource = k8s.wait_resource_consumed_by_controller(reference) - assert resource is not None yield (reference, resource) # Delete the k8s resource if not already deleted by tests - try: + if k8s.get_resource_exists(reference): k8s.delete_custom_resource(reference) - except: - pass @service_marker diff --git a/test/e2e/sagemaker/tests/test_trainingjob.py b/test/e2e/sagemaker/tests/test_trainingjob.py index 5069eef653..88abf5330d 100644 --- a/test/e2e/sagemaker/tests/test_trainingjob.py +++ b/test/e2e/sagemaker/tests/test_trainingjob.py @@ -19,7 +19,10 @@ from typing import Dict import time -from sagemaker import SERVICE_NAME, service_marker, CRD_GROUP, CRD_VERSION +from sagemaker import ( + service_marker, + create_sagemaker_resource, +) from sagemaker.replacement_values import REPLACEMENT_VALUES from common.resources import load_resource_file, random_suffix_name from common import k8s @@ -39,27 +42,19 @@ def xgboost_trainingjob(): replacements = REPLACEMENT_VALUES.copy() replacements["TRAINING_JOB_NAME"] = resource_name - trainingjob = load_resource_file( - SERVICE_NAME, "xgboost_trainingjob", additional_replacements=replacements + reference, spec, resource = create_sagemaker_resource( + resource_plural=RESOURCE_PLURAL, + resource_name=resource_name, + spec_file="xgboost_trainingjob", + replacements=replacements, ) - logging.debug(trainingjob) - - # Create the k8s resource - reference = k8s.CustomResourceReference( - CRD_GROUP, CRD_VERSION, RESOURCE_PLURAL, resource_name, namespace="default" - ) - resource = k8s.create_custom_resource(reference, trainingjob) - resource = k8s.wait_resource_consumed_by_controller(reference) - assert resource is not None yield (reference, resource) # Delete the k8s resource if not already deleted by tests - try: + if k8s.get_resource_exists(reference): k8s.delete_custom_resource(reference) - except: - pass @service_marker diff --git a/test/e2e/sagemaker/tests/test_transformjob.py b/test/e2e/sagemaker/tests/test_transformjob.py index 6c73a1be57..cf087a5950 100644 --- a/test/e2e/sagemaker/tests/test_transformjob.py +++ b/test/e2e/sagemaker/tests/test_transformjob.py @@ -4,7 +4,7 @@ # not use this file except in compliance with the License. A copy of the # License is located at # -# http://aws.amazon.com/apache2.0/ +# http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either @@ -17,77 +17,69 @@ import pytest import logging from typing import Dict -import time -from sagemaker import SERVICE_NAME, service_marker, CRD_GROUP, CRD_VERSION +from sagemaker import ( + service_marker, + create_sagemaker_resource, +) from sagemaker.replacement_values import REPLACEMENT_VALUES -from common.resources import load_resource_file, random_suffix_name +from common.resources import random_suffix_name from common import k8s from common.aws import get_aws_region from sagemaker.bootstrap_resources import get_bootstrap_resources -RESOURCE_PLURAL = 'transformjobs' +RESOURCE_PLURAL = "transformjobs" @pytest.fixture(scope="module") def sagemaker_client(): - return boto3.client('sagemaker') + return boto3.client("sagemaker") @pytest.fixture(scope="module") def xgboost_transformjob(sagemaker_client): - #Create model using boto3 for TransformJob - transform_model_file = "s3://{d}/sagemaker/batch/model.tar.gz".format( - d=get_bootstrap_resources().DataBucketName) + # Create model using boto3 for TransformJob + transform_model_file = f"s3://{get_bootstrap_resources().DataBucketName}/sagemaker/batch/model.tar.gz" model_name = random_suffix_name("xgboost-model", 32) - create_response = sagemaker_client.create_model(ModelName=model_name, - PrimaryContainer={ - 'Image': REPLACEMENT_VALUES["XGBOOST_IMAGE_URI"], - 'ModelDataUrl': transform_model_file, - 'Environment': {} - }, - ExecutionRoleArn=REPLACEMENT_VALUES["SAGEMAKER_EXECUTION_ROLE_ARN"] - ) + create_response = sagemaker_client.create_model( + ModelName=model_name, + PrimaryContainer={ + "Image": REPLACEMENT_VALUES["XGBOOST_IMAGE_URI"], + "ModelDataUrl": transform_model_file, + "Environment": {}, + }, + ExecutionRoleArn=REPLACEMENT_VALUES["SAGEMAKER_EXECUTION_ROLE_ARN"], + ) logging.debug(create_response) - #Check if the model is created successfully + # Check if the model is created successfully describe_model_response = sagemaker_client.describe_model(ModelName=model_name) assert describe_model_response["ModelName"] is not None - + resource_name = random_suffix_name("xgboost-transformjob", 32) - #Use the model created above + # Use the model created above replacements = REPLACEMENT_VALUES.copy() replacements["MODEL_NAME"] = model_name replacements["TRANSFORM_JOB_NAME"] = resource_name - transformjob = load_resource_file( - SERVICE_NAME, "xgboost_transformjob", additional_replacements=replacements) - logging.debug(transformjob) - - # Create the k8s resource - reference = k8s.CustomResourceReference( - CRD_GROUP, CRD_VERSION, RESOURCE_PLURAL, resource_name, namespace="default") - resource = k8s.create_custom_resource(reference, transformjob) - resource = k8s.wait_resource_consumed_by_controller(reference) - + reference, spec, resource = create_sagemaker_resource( + resource_plural=RESOURCE_PLURAL, + resource_name=resource_name, + spec_file="xgboost_transformjob", + replacements=replacements, + ) assert resource is not None - yield (reference, resource) + yield (reference, resource) - try: - # Delete the k8s resource if not already deleted by tests - k8s.delete_custom_resource(reference) - except: - pass - - try: - # Delete the model created - sagemaker_client.delete_model(ModelName=model_name) - except: - pass + # Delete the model created + sagemaker_client.delete_model(ModelName=model_name) + # Delete the k8s resource if not already deleted by tests + if k8s.get_resource_exists(reference): + k8s.delete_custom_resource(reference) @service_marker @@ -95,31 +87,41 @@ def xgboost_transformjob(sagemaker_client): class TestTransformJob: def _get_created_transformjob_status_list(self): return ["InProgress"] - + def _get_stopped_transformjob_status_list(self): return ["Stopped", "Stopping", "Completed"] def _get_resource_transformjob_arn(self, resource: Dict): - assert 'ackResourceMetadata' in resource['status'] and \ - 'arn' in resource['status']['ackResourceMetadata'] - return resource['status']['ackResourceMetadata']['arn'] + assert ( + "ackResourceMetadata" in resource["status"] + and "arn" in resource["status"]["ackResourceMetadata"] + ) + return resource["status"]["ackResourceMetadata"]["arn"] def _get_sagemaker_transformjob_arn(self, sagemaker_client, transformjob_name: str): try: - transformjob = sagemaker_client.describe_transform_job(TransformJobName=transformjob_name) - return transformjob['TransformJobArn'] + transformjob = sagemaker_client.describe_transform_job( + TransformJobName=transformjob_name + ) + return transformjob["TransformJobArn"] except BaseException: logging.error( - f"SageMaker could not find a transformJob with the name {transformjob_name}") + f"SageMaker could not find a transformJob with the name {transformjob_name}" + ) return None - def _get_sagemaker_transformjob_status(self, sagemaker_client, transformjob_name: str): + def _get_sagemaker_transformjob_status( + self, sagemaker_client, transformjob_name: str + ): try: - transformjob = sagemaker_client.describe_transform_job(TransformJobName=transformjob_name) - return transformjob['TransformJobStatus'] + transformjob = sagemaker_client.describe_transform_job( + TransformJobName=transformjob_name + ) + return transformjob["TransformJobStatus"] except BaseException: logging.error( - f"SageMaker could not find a transformJob with the name {transformjob_name}") + f"SageMaker could not find a transformJob with the name {transformjob_name}" + ) return None def test_create_transformjob(self, xgboost_transformjob): @@ -128,27 +130,32 @@ def test_create_transformjob(self, xgboost_transformjob): def test_transformjob_has_correct_arn(self, sagemaker_client, xgboost_transformjob): (reference, resource) = xgboost_transformjob - transformjob_name = resource['spec'].get('transformJobName', None) + transformjob_name = resource["spec"].get("transformJobName", None) assert transformjob_name is not None resource_transformjob_arn = self._get_resource_transformjob_arn(resource) - assert (self._get_sagemaker_transformjob_arn( - sagemaker_client, transformjob_name)) == resource_transformjob_arn + assert ( + self._get_sagemaker_transformjob_arn(sagemaker_client, transformjob_name) + ) == resource_transformjob_arn - def test_transformjob_has_created_status(self, sagemaker_client, xgboost_transformjob): + def test_transformjob_has_created_status( + self, sagemaker_client, xgboost_transformjob + ): (reference, resource) = xgboost_transformjob - transformjob_name = resource['spec'].get('transformJobName', None) + transformjob_name = resource["spec"].get("transformJobName", None) assert transformjob_name is not None - assert (self._get_sagemaker_transformjob_status( - sagemaker_client, transformjob_name)) in self._get_created_transformjob_status_list() - + assert ( + self._get_sagemaker_transformjob_status(sagemaker_client, transformjob_name) + ) in self._get_created_transformjob_status_list() - def test_transformjob_has_stopped_status(self, sagemaker_client, xgboost_transformjob): + def test_transformjob_has_stopped_status( + self, sagemaker_client, xgboost_transformjob + ): (reference, resource) = xgboost_transformjob - transformjob_name = resource['spec'].get('transformJobName', None) + transformjob_name = resource["spec"].get("transformJobName", None) assert transformjob_name is not None @@ -156,5 +163,6 @@ def test_transformjob_has_stopped_status(self, sagemaker_client, xgboost_transfo _, deleted = k8s.delete_custom_resource(reference) assert deleted is True - assert (self._get_sagemaker_transformjob_status( - sagemaker_client, transformjob_name)) in self._get_stopped_transformjob_status_list() + assert ( + self._get_sagemaker_transformjob_status(sagemaker_client, transformjob_name) + ) in self._get_stopped_transformjob_status_list()