diff --git a/src/sagemaker/serve/builder/djl_builder.py b/src/sagemaker/serve/builder/djl_builder.py index 29bdc75b35..09197f727f 100644 --- a/src/sagemaker/serve/builder/djl_builder.py +++ b/src/sagemaker/serve/builder/djl_builder.py @@ -46,7 +46,7 @@ ) from sagemaker.serve.model_server.djl_serving.prepare import ( prepare_for_djl_serving, - _create_dir_structure, + _create_dir_structure ) from sagemaker.serve.utils.predictors import DjlLocalModePredictor from sagemaker.serve.utils.types import ModelServer, _DjlEngine diff --git a/src/sagemaker/serve/builder/hf_dlc_builder.py b/src/sagemaker/serve/builder/hf_dlc_builder.py new file mode 100644 index 0000000000..6198b9041b --- /dev/null +++ b/src/sagemaker/serve/builder/hf_dlc_builder.py @@ -0,0 +1,274 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""HuggingFace DLC specific model builder""" +from __future__ import absolute_import +import logging +from packaging.version import Version +from typing import Type +from abc import ABC, abstractmethod + +from sagemaker.model import Model +from sagemaker import Session, image_uris +from sagemaker.serve.utils.local_hardware import ( + _get_nb_instance, + _get_ram_usage_mb, + _get_gpu_info, + _get_gpu_info_fallback, +) +from sagemaker.djl_inference.model import _get_model_config_properties_from_hf +from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri +from sagemaker.serve.model_server.hf_dlc.prepare import ( + _create_dir_structure, +) +from sagemaker.serve.utils.predictors import HfDLCLocalModePredictor +from sagemaker.serve.utils.types import ModelServer +from sagemaker.serve.mode.function_pointers import Mode +from sagemaker.serve.utils.telemetry_logger import _capture_telemetry +from sagemaker.base_predictor import PredictorBase +from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata + +logger = logging.getLogger(__name__) +DEFAULT_TIMEOUT = 1800 + + +class HuggingFaceDLC(ABC): + """HuggingFace DLC build logic for ModelBuilder()""" + + def __init__(self): + self.model = None + self.serve_settings = None + self.sagemaker_session = None + self.model_path = None + self.dependencies = None + self.modes = None + self.mode = None + self.model_server = None + self.image_uri = None + self._original_deploy = None + self.hf_model_config = None + self._default_data_type = None + self.pysdk_model = None + self.env_vars = None + self.nb_instance_type = None + self.ram_usage_model_load = None + self.secret_key = None + self.role_arn = None + self.py_version = None + self.tensorflow_version = None + self.pytorch_version = None + + @abstractmethod + def _prepare_for_mode(self): + """Abstract method""" + + def _create_hf_dlc_model(self) -> Type[Model]: + """Initializes the model after fetching image + + 1. Get the metadata for deciding framework + 2. Get the supported hugging face versions + 3. Create model + 4. Fetch image + + Returns: + pysdk_model: Corresponding model instance + """ + + hf_model_md = get_huggingface_model_metadata(self.model, + self.env_vars.get("HUGGING_FACE_HUB_TOKEN")) + hf_config = image_uris.config_for_framework("huggingface").get("inference") + config = hf_config["versions"] + base_hf_version = sorted(config.keys(), key=lambda v: Version(v))[0] + + if hf_model_md is None: + raise ValueError("Could not fetch HF metadata") + + if 'pytorch' in hf_model_md.get("tags"): + self.pytorch_version = self._get_supported_version(hf_config, base_hf_version, "pytorch") + self.py_version = config[base_hf_version]["pytorch"+self.pytorch_version].get("py_versions")[-1] + pysdk_model = HuggingFaceModel( + env=self.env_vars, + role=self.role_arn, + sagemaker_session=self.sagemaker_session, + py_version=self.py_version, + transformers_version=base_hf_version, + pytorch_version=self.pytorch_version + ) + elif 'keras' in hf_model_md.get("tags") or 'tensorflow' in hf_model_md.get("tags"): + self.tensorflow_version = self._get_supported_version(hf_config, base_hf_version, "tensorflow") + self.py_version = config[base_hf_version]["tensorflow"+self.tensorflow_version].get("py_versions")[-1] + pysdk_model = HuggingFaceModel( + env=self.env_vars, + role=self.role_arn, + sagemaker_session=self.sagemaker_session, + py_version=self.py_version, + transformers_version=base_hf_version, + tensorflow_version=self.tensorflow_version + ) + + if self.mode == Mode.LOCAL_CONTAINER: + self.image_uri = pysdk_model.serving_image_uri(self.sagemaker_session.boto_region_name, "local") + else: + self.image_uri = pysdk_model.serving_image_uri(self.sagemaker_session.boto_region_name, self.instance_type) + + logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri) + + self._original_deploy = pysdk_model.deploy + pysdk_model.deploy = self._hf_dlc_model_builder_deploy_wrapper + return pysdk_model + + @_capture_telemetry("hf_dlc.deploy") + def _hf_dlc_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]: + """Returns predictor depending on local or sagemaker endpoint mode + + Returns: + HfDLCLocalModePredictor: During local mode deployment + """ + timeout = kwargs.get("model_data_download_timeout") + if timeout: + self.env_vars.update({"MODEL_LOADING_TIMEOUT": str(timeout)}) + + if "mode" in kwargs and kwargs.get("mode") != self.mode: + overwrite_mode = kwargs.get("mode") + # mode overwritten by customer during model.deploy() + logger.warning( + "Deploying in %s Mode, overriding existing configurations set for %s mode", + overwrite_mode, + self.mode, + ) + + if overwrite_mode == Mode.SAGEMAKER_ENDPOINT: + self.mode = self.pysdk_model.mode = Mode.SAGEMAKER_ENDPOINT + elif overwrite_mode == Mode.LOCAL_CONTAINER: + self._prepare_for_mode() + self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER + else: + raise ValueError("Mode %s is not supported!" % overwrite_mode) + + self._set_instance() + + serializer = self.schema_builder.input_serializer + deserializer = self.schema_builder._output_deserializer + if self.mode == Mode.LOCAL_CONTAINER: + timeout = kwargs.get("model_data_download_timeout") + + predictor = HfDLCLocalModePredictor( + self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer + ) + + ram_usage_before = _get_ram_usage_mb() + self.modes[str(Mode.LOCAL_CONTAINER)].create_server( + self.image_uri, + timeout if timeout else DEFAULT_TIMEOUT, + None, + predictor, + self.pysdk_model.env, + jumpstart=False, + ) + + ram_usage_after = _get_ram_usage_mb() + self.ram_usage_model_load = max(ram_usage_after - ram_usage_before, 0) + + return predictor + + if "mode" in kwargs: + del kwargs["mode"] + if "role" in kwargs: + self.pysdk_model.role = kwargs.get("role") + del kwargs["role"] + + # set model_data to uncompressed s3 dict + self.pysdk_model.model_data, env_vars = self._prepare_for_mode() + self.env_vars.update(env_vars) + self.pysdk_model.env.update(self.env_vars) + + if "endpoint_logging" not in kwargs: + kwargs["endpoint_logging"] = True + + if "initial_instance_count" not in kwargs: + kwargs.update({"initial_instance_count": 1}) + + predictor = self._original_deploy(*args, **kwargs) + + predictor.serializer = serializer + predictor.deserializer = deserializer + return predictor + + def _build_for_hugging_face_dlc(self): + """Build model for hugging face deployment using + + Returns: + HfDLCLocalModePredictor: During local mode deployment + """ + self.nb_instance_type = _get_nb_instance() + + _create_dir_structure(self.model_path) + if not hasattr(self, "pysdk_model"): + self.env_vars.update({"HF_MODEL_ID": self.model}) + + logger.info(self.env_vars) + + # TODO: Move to a helper function + if hasattr(self.env_vars, "HF_API_TOKEN"): + self.hf_model_config = _get_model_config_properties_from_hf( + self.model, self.env_vars.get("HF_API_TOKEN") + ) + else: + self.hf_model_config = _get_model_config_properties_from_hf( + self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")) + + self.pysdk_model = self._create_hf_dlc_model() + + if self.mode == Mode.LOCAL_CONTAINER: + self._prepare_for_mode() + + return self.pysdk_model + + def _set_instance(self, **kwargs): + """Set the instance + Given the detected notebook type or provided instance type + """ + if self.mode == Mode.SAGEMAKER_ENDPOINT: + if self.nb_instance_type and "instance_type" not in kwargs: + kwargs.update({"instance_type": self.nb_instance_type}) + elif self.instance_type and "instance_type" not in kwargs: + kwargs.update({"instance_type": self.instance_type}) + else: + raise ValueError( + "Instance type must be provided when deploying to SageMaker Endpoint mode." + ) + logger.info("Setting instance type to %s", self.instance_type) + return + + def _get_supported_version(self, hf_config, hugging_face_version, base_fw): + """ + Uses the hugging face json config to pick supported versions + """ + version_config = hf_config.get("versions").get(hugging_face_version) + versions_to_return = list() + for key in list(version_config.keys()): + if key.startswith(base_fw): + base_fw_version = key[len(base_fw):] + if len(hugging_face_version.split(".")) == 2: + base_fw_version = ".".join(base_fw_version.split(".")[:-1]) + versions_to_return.append(base_fw_version) + return sorted(versions_to_return)[0] + + def _build_for_hf_dlc(self): + """Method that triggers model build + + Returns:PySDK model + """ + self.secret_key = None + self.model_server = ModelServer.HuggingFaceDLC + self.pysdk_model = self._build_for_hugging_face_dlc() + return self.pysdk_model diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 6e3be072bb..3835b5d914 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -34,6 +34,7 @@ from sagemaker.serve.builder.djl_builder import DJL from sagemaker.serve.builder.tgi_builder import TGI from sagemaker.serve.builder.jumpstart_builder import JumpStart +from sagemaker.serve.builder.hf_dlc_builder import HuggingFaceDLC from sagemaker.predictor import Predictor from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import Metadata from sagemaker.serve.spec.inference_spec import InferenceSpec @@ -53,6 +54,7 @@ from sagemaker.serve.validations.check_image_and_hardware_type import ( validate_image_uri_and_hardware, ) +from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata logger = logging.getLogger(__name__) @@ -60,12 +62,13 @@ ModelServer.TORCHSERVE, ModelServer.TRITON, ModelServer.DJL_SERVING, + ModelServer.HuggingFaceDLC, } # pylint: disable=attribute-defined-outside-init @dataclass -class ModelBuilder(Triton, DJL, JumpStart, TGI): +class ModelBuilder(Triton, DJL, JumpStart, TGI, HuggingFaceDLC): """Class that builds a deployable model. Args: @@ -125,7 +128,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI): in order for model builder to build the artifacts correctly (according to the model server). Possible values for this argument are ``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``, - ``TRITON``, and ``TGI``. + ``TRITON``, ``TGI``, and ``HuggingFaceDLC``. """ @@ -535,7 +538,7 @@ def wrapper(*args, **kwargs): return wrapper # Model Builder is a class to build the model for deployment. - # It supports three modes of deployment + # It supports two modes of deployment # 1/ SageMaker Endpoint # 2/ Local launch with container def build( @@ -577,12 +580,19 @@ def build( ) self.serve_settings = self._get_serve_setting() + + hf_model_md = get_huggingface_model_metadata(self.model, + self.env_vars.get("HUGGING_FACE_HUB_TOKEN")) + if isinstance(self.model, str): if self._is_jumpstart_model_id(): return self._build_for_jumpstart() if self._is_djl(): return self._build_for_djl() - return self._build_for_tgi() + if hf_model_md.get("pipeline_tag") == "text-generation": + return self._build_for_tgi() + else: + return self._build_for_hf_dlc() self._build_validations() diff --git a/src/sagemaker/serve/mode/local_container_mode.py b/src/sagemaker/serve/mode/local_container_mode.py index 77f39d56bb..bce7dbc506 100644 --- a/src/sagemaker/serve/mode/local_container_mode.py +++ b/src/sagemaker/serve/mode/local_container_mode.py @@ -19,6 +19,7 @@ from sagemaker.serve.model_server.djl_serving.server import LocalDJLServing from sagemaker.serve.model_server.triton.server import LocalTritonServer from sagemaker.serve.model_server.tgi.server import LocalTgiServing +from sagemaker.serve.model_server.hf_dlc.server import LocalHFDLCServing from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ ) -class LocalContainerMode(LocalTorchServe, LocalDJLServing, LocalTritonServer, LocalTgiServing): +class LocalContainerMode(LocalTorchServe, LocalDJLServing, LocalTritonServer, LocalTgiServing, LocalHFDLCServing): """A class that holds methods to deploy model to a container in local environment""" def __init__( @@ -128,6 +129,15 @@ def create_server( jumpstart=jumpstart, ) self._ping_container = self._tgi_deep_ping + elif self.model_server == ModelServer.HuggingFaceDLC: + self._start_hf_dlc_serving( + client=self.client, + image=image, + model_path=model_path if model_path else self.model_path, + secret_key=secret_key, + env_vars=env_vars if env_vars else self.env_vars, + ) + self._ping_container = self._hf_dlc_deep_ping # allow some time for container to be ready time.sleep(10) diff --git a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py index 5db70d3d34..ed85becc9d 100644 --- a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py +++ b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py @@ -12,12 +12,13 @@ from sagemaker.serve.model_server.torchserve.server import SageMakerTorchServe from sagemaker.serve.model_server.djl_serving.server import SageMakerDjlServing from sagemaker.serve.model_server.tgi.server import SageMakerTgiServing +from sagemaker.serve.model_server.hf_dlc.server import SageMakerHFDLCServing logger = logging.getLogger(__name__) class SageMakerEndpointMode( - SageMakerTorchServe, SageMakerTritonServer, SageMakerDjlServing, SageMakerTgiServing + SageMakerTorchServe, SageMakerTritonServer, SageMakerDjlServing, SageMakerTgiServing, SageMakerHFDLCServing ): """Holds the required method to deploy a model to a SageMaker Endpoint""" @@ -93,4 +94,12 @@ def prepare( jumpstart=jumpstart, ) + if self.model_server == ModelServer.HuggingFaceDLC: + return self._upload_hf_dlc_artifacts( + model_path=model_path, + sagemaker_session=sagemaker_session, + s3_model_data_url=s3_model_data_url, + image=image, + ) + raise ValueError("%s model server is not supported" % self.model_server) diff --git a/src/sagemaker/serve/model_server/hf_dlc/__init__.py b/src/sagemaker/serve/model_server/hf_dlc/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/serve/model_server/hf_dlc/prepare.py b/src/sagemaker/serve/model_server/hf_dlc/prepare.py new file mode 100644 index 0000000000..438d3b56c5 --- /dev/null +++ b/src/sagemaker/serve/model_server/hf_dlc/prepare.py @@ -0,0 +1,38 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Shared resources for prepare step of model deployment""" + +from __future__ import absolute_import +import logging +from pathlib import Path + +from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage + +logger = logging.getLogger(__name__) + + +def _create_dir_structure(model_path: str) -> tuple: + """Create the expected model directory structure for the HF DLC server""" + model_path = Path(model_path) + if not model_path.exists(): + model_path.mkdir(parents=True) + elif not model_path.is_dir(): + raise ValueError("model_dir is not a valid directory") + + code_dir = model_path.joinpath("code") + code_dir.mkdir(exist_ok=True, parents=True) + + _check_disk_space(model_path) + _check_docker_disk_usage() + + return model_path, code_dir diff --git a/src/sagemaker/serve/model_server/hf_dlc/server.py b/src/sagemaker/serve/model_server/hf_dlc/server.py new file mode 100644 index 0000000000..6151ea580a --- /dev/null +++ b/src/sagemaker/serve/model_server/hf_dlc/server.py @@ -0,0 +1,127 @@ +"""Module for Local HF DLC Serving""" +from __future__ import absolute_import + +import requests +import logging +from pathlib import Path +from sagemaker import Session, fw_utils +from sagemaker.serve.utils.exceptions import LocalModelInvocationException +from sagemaker.base_predictor import PredictorBase +from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join +from sagemaker.s3 import S3Uploader +from sagemaker.local.utils import get_docker_host + +MODE_DIR_BINDING = "/opt/ml/model/" +_DEFAULT_ENV_VARS = {} + +logger = logging.getLogger(__name__) + + +class LocalHFDLCServing: + """Local HF DLC server instance""" + + def _start_hf_dlc_serving( + self, + client: object, + image: str, + model_path: str, + secret_key: str, + env_vars: dict, + ): + """Placeholder docstring""" + self.container = client.containers.run( + image, + "serve", + network_mode="host", + detach=True, + auto_remove=True, + volumes={ + Path(model_path).joinpath("code"): { + "bind": MODE_DIR_BINDING, + "mode": "rw", + }, + }, + environment=_update_env_vars(env_vars), + ) + + def _invoke_hf_dlc_serving(self, request: object, content_type: str, accept: str): + """Placeholder docstring""" + try: + response = requests.post( + f"http://{get_docker_host()}:8080/invocations", + data=request, + headers={"Content-Type": content_type, "Accept": accept}, + timeout=600, + ) + response.raise_for_status() + return response.content + except Exception as e: + raise Exception("Unable to send request to the local container server") from e + + def _hf_dlc_deep_ping(self, predictor: PredictorBase): + """Placeholder docstring""" + response = None + try: + response = predictor.predict(self.schema_builder.sample_input) + return True, response + # pylint: disable=broad-except + except Exception as e: + if "422 Client Error: Unprocessable Entity for url" in str(e): + raise LocalModelInvocationException(str(e)) + return False, response + + return (True, response) + + +class SageMakerHFDLCServing: + """Sagemaker endpoint HF DLC server""" + + def _upload_hf_dlc_artifacts( + self, + model_path: str, + sagemaker_session: Session, + s3_model_data_url: str = None, + image: str = None, + env_vars: dict = None, + ): + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) + + code_dir = Path(model_path).joinpath("code") + + s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") + + logger.debug("Uploading HuggingFace DLC Model Resources uncompressed to: %s", s3_location) + + model_data_url = S3Uploader.upload( + str(code_dir), + s3_location, + None, + sagemaker_session, + ) + + model_data = { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": model_data_url + "/", + } + } + return model_data, _update_env_vars(env_vars) + + +def _update_env_vars(env_vars: dict) -> dict: + """Placeholder docstring""" + updated_env_vars = {} + updated_env_vars.update(_DEFAULT_ENV_VARS) + if env_vars: + updated_env_vars.update(env_vars) + return updated_env_vars diff --git a/src/sagemaker/serve/model_server/tgi/prepare.py b/src/sagemaker/serve/model_server/tgi/prepare.py index fe1162e505..61f37319e5 100644 --- a/src/sagemaker/serve/model_server/tgi/prepare.py +++ b/src/sagemaker/serve/model_server/tgi/prepare.py @@ -60,7 +60,7 @@ def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> bo def _create_dir_structure(model_path: str) -> tuple: - """Create the expected model directory structure for the TGI server""" + """Placeholder Docstring""" model_path = Path(model_path) if not model_path.exists(): model_path.mkdir(parents=True) diff --git a/src/sagemaker/serve/utils/predictors.py b/src/sagemaker/serve/utils/predictors.py index e717b56235..a7d0553a5a 100644 --- a/src/sagemaker/serve/utils/predictors.py +++ b/src/sagemaker/serve/utils/predictors.py @@ -165,6 +165,49 @@ def delete_predictor(self): self._mode_obj.destroy_server() +class HfDLCLocalModePredictor(PredictorBase): + """Lightweight HF DLC predictor for local deployment in IN_PROCESS and LOCAL_CONTAINER modes""" + + def __init__( + self, + mode_obj: Type[LocalContainerMode], + serializer=JSONSerializer(), + deserializer=JSONDeserializer(), + ): + self._mode_obj = mode_obj + self.serializer = serializer + self.deserializer = deserializer + + def predict(self, data): + """Placeholder docstring""" + return [ + self.deserializer.deserialize( + io.BytesIO( + self._mode_obj._invoke_hf_dlc_serving( + self.serializer.serialize(data), + self.content_type, + self.deserializer.ACCEPT[0], + ) + ), + self.content_type, + ) + ] + + @property + def content_type(self): + """The MIME type of the data sent to the inference endpoint.""" + return self.serializer.CONTENT_TYPE + + @property + def accept(self): + """The content type(s) that are expected from the inference endpoint.""" + return self.deserializer.ACCEPT + + def delete_predictor(self): + """Shut down and remove the container that you created in LOCAL_CONTAINER mode""" + self._mode_obj.destroy_server() + + def _get_local_mode_predictor( model_server: ModelServer, mode_obj: Type[LocalContainerMode], diff --git a/src/sagemaker/serve/utils/telemetry_logger.py b/src/sagemaker/serve/utils/telemetry_logger.py index cb57a9f0a7..eee86b7f3a 100644 --- a/src/sagemaker/serve/utils/telemetry_logger.py +++ b/src/sagemaker/serve/utils/telemetry_logger.py @@ -45,6 +45,7 @@ str(ModelServer.DJL_SERVING): 4, str(ModelServer.TRITON): 5, str(ModelServer.TGI): 6, + str(ModelServer.HuggingFaceDLC): 7, } diff --git a/src/sagemaker/serve/utils/types.py b/src/sagemaker/serve/utils/types.py index b8657a19ca..9b6c45cd8e 100644 --- a/src/sagemaker/serve/utils/types.py +++ b/src/sagemaker/serve/utils/types.py @@ -17,6 +17,7 @@ def __str__(self): DJL_SERVING = 4 TRITON = 5 TGI = 6 + HuggingFaceDLC = 7 class _DjlEngine(Enum): diff --git a/tests/unit/sagemaker/serve/builder/test_hf_dlc_builder.py b/tests/unit/sagemaker/serve/builder/test_hf_dlc_builder.py new file mode 100644 index 0000000000..1822acf7e5 --- /dev/null +++ b/tests/unit/sagemaker/serve/builder/test_hf_dlc_builder.py @@ -0,0 +1,86 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from unittest.mock import MagicMock, patch + +import unittest +from sagemaker.serve.builder.model_builder import ModelBuilder +from sagemaker.serve.mode.function_pointers import Mode +from sagemaker.serve import ModelServer + +from sagemaker.serve.utils.predictors import HfDLCLocalModePredictor + +mock_model_id = "TheBloke/Llama-2-7b-chat-fp16" +mock_t5_model_id = "google/flan-t5-xxl" +mock_prompt = "Hello, I'm a language model," +mock_response = "Hello, I'm a language model, and I'm here to help you with your English." +mock_sample_input = {"inputs": mock_prompt, "parameters": {}} +mock_sample_output = [{"generated_text": mock_response}] +mock_expected_huggingfaceaccelerate_serving_properties = { + "engine": "Python", + "option.entryPoint": "inference.py", + "option.model_id": "TheBloke/Llama-2-7b-chat-fp16", + "option.tensor_parallel_degree": 4, + "option.dtype": "fp16", +} + +mock_schema_builder = MagicMock() +mock_schema_builder.sample_input = mock_sample_input +mock_schema_builder.sample_output = mock_sample_output + +mock_schema_builder_invalid = MagicMock() +mock_schema_builder_invalid.sample_input = {"invalid": "format"} +mock_schema_builder_invalid.sample_output = mock_sample_output + + +class TestHFDlcBuilder(unittest.TestCase): + + @patch("sagemaker.serve.builder.hf_dlc_builder._capture_telemetry", side_effect=None) + @patch("sagemaker.serve.builder.hf_dlc_builder._get_ram_usage_mb", return_value=1024) + @patch("sagemaker.serve.builder.djl_builder._get_nb_instance", return_value="ml.g5.24xlarge") + def test_build_deploy_for_hf_dlc_local_container( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_telemetry, + ): + builder = ModelBuilder( + model=mock_model_id, + schema_builder=mock_schema_builder, + mode=Mode.LOCAL_CONTAINER, + model_server=ModelServer.HF_DLC_SERVER, + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.side_effect = None + + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + + assert builder.nb_instance_type == "ml.g5.24xlarge" + + builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock() + predictor = model.deploy(model_data_download_timeout=1800) + + assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800" + assert isinstance(predictor, HfDLCLocalModePredictor) + + builder._original_deploy = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + predictor = model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") + assert "HF_MODEL_ID" in model.env + assert "HUGGING_FACE_HUB_TOKEN" in model.env + assert isinstance(predictor, HfDLCLocalModePredictor) + + with self.assertRaises(ValueError) as _: + model.deploy(mode=Mode.IN_PROCESS) diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 21a250ff7b..6e95d5641c 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -47,6 +47,7 @@ ModelServer.TORCHSERVE, ModelServer.TRITON, ModelServer.DJL_SERVING, + ModelServer.HuggingFaceDLC, } mock_session = MagicMock() diff --git a/tests/unit/sagemaker/serve/model_server/hf_dlc/test_hf_dlc_prepare.py b/tests/unit/sagemaker/serve/model_server/hf_dlc/test_hf_dlc_prepare.py new file mode 100644 index 0000000000..eebdb0b283 --- /dev/null +++ b/tests/unit/sagemaker/serve/model_server/hf_dlc/test_hf_dlc_prepare.py @@ -0,0 +1,55 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest import TestCase +from unittest.mock import Mock, PropertyMock, patch, mock_open, call + +from sagemaker.serve.model_server.hf_dlc.prepare import ( + _create_dir_structure +) + + +class HFDLCPrepareTests(TestCase): + @patch("sagemaker.serve.model_server.hf_dlc.prepare._check_disk_space") + @patch("sagemaker.serve.model_server.hf_dlc.prepare._check_docker_disk_usage") + @patch("sagemaker.serve.model_server.hf_dlc.prepare.Path") + def test_create_dir_structure_from_new(self, mock_path, mock_disk_usage, mock_disk_space): + mock_model_path = Mock() + mock_model_path.exists.return_value = False + mock_code_dir = Mock() + mock_model_path.joinpath.return_value = mock_code_dir + mock_path.return_value = mock_model_path + + ret_model_path, ret_code_dir = _create_dir_structure(mock_model_path) + + mock_model_path.mkdir.assert_called_once_with(parents=True) + mock_model_path.joinpath.assert_called_once_with("code") + mock_code_dir.mkdir.assert_called_once_with(exist_ok=True, parents=True) + mock_disk_space.assert_called_once_with(mock_model_path) + mock_disk_usage.assert_called_once() + + self.assertEquals(ret_model_path, mock_model_path) + self.assertEquals(ret_code_dir, mock_code_dir) + + @patch("sagemaker.serve.model_server.hf_dlc.prepare.Path") + def test_create_dir_structure_invalid_path(self, mock_path): + mock_model_path = Mock() + mock_model_path.exists.return_value = True + mock_model_path.is_dir.return_value = False + mock_path.return_value = mock_model_path + + with self.assertRaises(ValueError) as context: + _create_dir_structure(mock_model_path) + + self.assertEquals("model_dir is not a valid directory", str(context.exception))