Skip to content

Commit

Permalink
feat: Introduce HF Transformers to ModelBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
samruds committed Jan 30, 2024
1 parent 086c946 commit 734ec60
Show file tree
Hide file tree
Showing 15 changed files with 663 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/sagemaker/serve/builder/djl_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
)
from sagemaker.serve.model_server.djl_serving.prepare import (
prepare_for_djl_serving,
_create_dir_structure,
_create_dir_structure
)
from sagemaker.serve.utils.predictors import DjlLocalModePredictor
from sagemaker.serve.utils.types import ModelServer, _DjlEngine
Expand Down
274 changes: 274 additions & 0 deletions src/sagemaker/serve/builder/hf_dlc_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""HuggingFace DLC specific model builder"""
from __future__ import absolute_import
import logging
from packaging.version import Version
from typing import Type
from abc import ABC, abstractmethod

from sagemaker.model import Model
from sagemaker import Session, image_uris
from sagemaker.serve.utils.local_hardware import (
_get_nb_instance,
_get_ram_usage_mb,
_get_gpu_info,
_get_gpu_info_fallback,
)
from sagemaker.djl_inference.model import _get_model_config_properties_from_hf
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
from sagemaker.serve.model_server.hf_dlc.prepare import (
_create_dir_structure,
)
from sagemaker.serve.utils.predictors import HfDLCLocalModePredictor
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.mode.function_pointers import Mode
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
from sagemaker.base_predictor import PredictorBase
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata

logger = logging.getLogger(__name__)
DEFAULT_TIMEOUT = 1800


class HuggingFaceDLC(ABC):
"""HuggingFace DLC build logic for ModelBuilder()"""

def __init__(self):
self.model = None
self.serve_settings = None
self.sagemaker_session = None
self.model_path = None
self.dependencies = None
self.modes = None
self.mode = None
self.model_server = None
self.image_uri = None
self._original_deploy = None
self.hf_model_config = None
self._default_data_type = None
self.pysdk_model = None
self.env_vars = None
self.nb_instance_type = None
self.ram_usage_model_load = None
self.secret_key = None
self.role_arn = None
self.py_version = None
self.tensorflow_version = None
self.pytorch_version = None

@abstractmethod
def _prepare_for_mode(self):
"""Abstract method"""

def _create_hf_dlc_model(self) -> Type[Model]:
"""Initializes the model after fetching image
1. Get the metadata for deciding framework
2. Get the supported hugging face versions
3. Create model
4. Fetch image
Returns:
pysdk_model: Corresponding model instance
"""

hf_model_md = get_huggingface_model_metadata(self.model,
self.env_vars.get("HUGGING_FACE_HUB_TOKEN"))
hf_config = image_uris.config_for_framework("huggingface").get("inference")
config = hf_config["versions"]
base_hf_version = sorted(config.keys(), key=lambda v: Version(v))[0]

if hf_model_md is None:
raise ValueError("Could not fetch HF metadata")

if 'pytorch' in hf_model_md.get("tags"):
self.pytorch_version = self._get_supported_version(hf_config, base_hf_version, "pytorch")
self.py_version = config[base_hf_version]["pytorch"+self.pytorch_version].get("py_versions")[-1]
pysdk_model = HuggingFaceModel(
env=self.env_vars,
role=self.role_arn,
sagemaker_session=self.sagemaker_session,
py_version=self.py_version,
transformers_version=base_hf_version,
pytorch_version=self.pytorch_version
)
elif 'keras' in hf_model_md.get("tags") or 'tensorflow' in hf_model_md.get("tags"):
self.tensorflow_version = self._get_supported_version(hf_config, base_hf_version, "tensorflow")
self.py_version = config[base_hf_version]["tensorflow"+self.tensorflow_version].get("py_versions")[-1]
pysdk_model = HuggingFaceModel(
env=self.env_vars,
role=self.role_arn,
sagemaker_session=self.sagemaker_session,
py_version=self.py_version,
transformers_version=base_hf_version,
tensorflow_version=self.tensorflow_version
)

if self.mode == Mode.LOCAL_CONTAINER:
self.image_uri = pysdk_model.serving_image_uri(self.sagemaker_session.boto_region_name, "local")
else:
self.image_uri = pysdk_model.serving_image_uri(self.sagemaker_session.boto_region_name, self.instance_type)

logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri)

self._original_deploy = pysdk_model.deploy
pysdk_model.deploy = self._hf_dlc_model_builder_deploy_wrapper
return pysdk_model

@_capture_telemetry("hf_dlc.deploy")
def _hf_dlc_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
"""Returns predictor depending on local or sagemaker endpoint mode
Returns:
HfDLCLocalModePredictor: During local mode deployment
"""
timeout = kwargs.get("model_data_download_timeout")
if timeout:
self.env_vars.update({"MODEL_LOADING_TIMEOUT": str(timeout)})

if "mode" in kwargs and kwargs.get("mode") != self.mode:
overwrite_mode = kwargs.get("mode")
# mode overwritten by customer during model.deploy()
logger.warning(
"Deploying in %s Mode, overriding existing configurations set for %s mode",
overwrite_mode,
self.mode,
)

if overwrite_mode == Mode.SAGEMAKER_ENDPOINT:
self.mode = self.pysdk_model.mode = Mode.SAGEMAKER_ENDPOINT
elif overwrite_mode == Mode.LOCAL_CONTAINER:
self._prepare_for_mode()
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER
else:
raise ValueError("Mode %s is not supported!" % overwrite_mode)

self._set_instance()

serializer = self.schema_builder.input_serializer
deserializer = self.schema_builder._output_deserializer
if self.mode == Mode.LOCAL_CONTAINER:
timeout = kwargs.get("model_data_download_timeout")

predictor = HfDLCLocalModePredictor(
self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer
)

ram_usage_before = _get_ram_usage_mb()
self.modes[str(Mode.LOCAL_CONTAINER)].create_server(
self.image_uri,
timeout if timeout else DEFAULT_TIMEOUT,
None,
predictor,
self.pysdk_model.env,
jumpstart=False,
)

ram_usage_after = _get_ram_usage_mb()
self.ram_usage_model_load = max(ram_usage_after - ram_usage_before, 0)

return predictor

if "mode" in kwargs:
del kwargs["mode"]
if "role" in kwargs:
self.pysdk_model.role = kwargs.get("role")
del kwargs["role"]

# set model_data to uncompressed s3 dict
self.pysdk_model.model_data, env_vars = self._prepare_for_mode()
self.env_vars.update(env_vars)
self.pysdk_model.env.update(self.env_vars)

if "endpoint_logging" not in kwargs:
kwargs["endpoint_logging"] = True

if "initial_instance_count" not in kwargs:
kwargs.update({"initial_instance_count": 1})

predictor = self._original_deploy(*args, **kwargs)

predictor.serializer = serializer
predictor.deserializer = deserializer
return predictor

def _build_for_hugging_face_dlc(self):
"""Build model for hugging face deployment using
Returns:
HfDLCLocalModePredictor: During local mode deployment
"""
self.nb_instance_type = _get_nb_instance()

_create_dir_structure(self.model_path)
if not hasattr(self, "pysdk_model"):
self.env_vars.update({"HF_MODEL_ID": self.model})

logger.info(self.env_vars)

# TODO: Move to a helper function
if hasattr(self.env_vars, "HF_API_TOKEN"):
self.hf_model_config = _get_model_config_properties_from_hf(
self.model, self.env_vars.get("HF_API_TOKEN")
)
else:
self.hf_model_config = _get_model_config_properties_from_hf(
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN"))

self.pysdk_model = self._create_hf_dlc_model()

if self.mode == Mode.LOCAL_CONTAINER:
self._prepare_for_mode()

return self.pysdk_model

def _set_instance(self, **kwargs):
"""Set the instance
Given the detected notebook type or provided instance type
"""
if self.mode == Mode.SAGEMAKER_ENDPOINT:
if self.nb_instance_type and "instance_type" not in kwargs:
kwargs.update({"instance_type": self.nb_instance_type})
elif self.instance_type and "instance_type" not in kwargs:
kwargs.update({"instance_type": self.instance_type})
else:
raise ValueError(
"Instance type must be provided when deploying to SageMaker Endpoint mode."
)
logger.info("Setting instance type to %s", self.instance_type)
return

def _get_supported_version(self, hf_config, hugging_face_version, base_fw):
"""
Uses the hugging face json config to pick supported versions
"""
version_config = hf_config.get("versions").get(hugging_face_version)
versions_to_return = list()
for key in list(version_config.keys()):
if key.startswith(base_fw):
base_fw_version = key[len(base_fw):]
if len(hugging_face_version.split(".")) == 2:
base_fw_version = ".".join(base_fw_version.split(".")[:-1])
versions_to_return.append(base_fw_version)
return sorted(versions_to_return)[0]

def _build_for_hf_dlc(self):
"""Method that triggers model build
Returns:PySDK model
"""
self.secret_key = None
self.model_server = ModelServer.HuggingFaceDLC
self.pysdk_model = self._build_for_hugging_face_dlc()
return self.pysdk_model
18 changes: 14 additions & 4 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sagemaker.serve.builder.djl_builder import DJL
from sagemaker.serve.builder.tgi_builder import TGI
from sagemaker.serve.builder.jumpstart_builder import JumpStart
from sagemaker.serve.builder.hf_dlc_builder import HuggingFaceDLC
from sagemaker.predictor import Predictor
from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import Metadata
from sagemaker.serve.spec.inference_spec import InferenceSpec
Expand All @@ -53,19 +54,21 @@
from sagemaker.serve.validations.check_image_and_hardware_type import (
validate_image_uri_and_hardware,
)
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata

logger = logging.getLogger(__name__)

supported_model_server = {
ModelServer.TORCHSERVE,
ModelServer.TRITON,
ModelServer.DJL_SERVING,
ModelServer.HuggingFaceDLC,
}


# pylint: disable=attribute-defined-outside-init
@dataclass
class ModelBuilder(Triton, DJL, JumpStart, TGI):
class ModelBuilder(Triton, DJL, JumpStart, TGI, HuggingFaceDLC):
"""Class that builds a deployable model.
Args:
Expand Down Expand Up @@ -125,7 +128,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI):
in order for model builder to build the artifacts correctly (according
to the model server). Possible values for this argument are
``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``,
``TRITON``, and ``TGI``.
``TRITON``, ``TGI``, and ``HuggingFaceDLC``.
"""

Expand Down Expand Up @@ -535,7 +538,7 @@ def wrapper(*args, **kwargs):
return wrapper

# Model Builder is a class to build the model for deployment.
# It supports three modes of deployment
# It supports two modes of deployment
# 1/ SageMaker Endpoint
# 2/ Local launch with container
def build(
Expand Down Expand Up @@ -577,12 +580,19 @@ def build(
)

self.serve_settings = self._get_serve_setting()

hf_model_md = get_huggingface_model_metadata(self.model,
self.env_vars.get("HUGGING_FACE_HUB_TOKEN"))

if isinstance(self.model, str):
if self._is_jumpstart_model_id():
return self._build_for_jumpstart()
if self._is_djl():
return self._build_for_djl()
return self._build_for_tgi()
if hf_model_md.get("pipeline_tag") == "text-generation":
return self._build_for_tgi()
else:
return self._build_for_hf_dlc()

self._build_validations()

Expand Down
12 changes: 11 additions & 1 deletion src/sagemaker/serve/mode/local_container_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sagemaker.serve.model_server.djl_serving.server import LocalDJLServing
from sagemaker.serve.model_server.triton.server import LocalTritonServer
from sagemaker.serve.model_server.tgi.server import LocalTgiServing
from sagemaker.serve.model_server.hf_dlc.server import LocalHFDLCServing
from sagemaker.session import Session

logger = logging.getLogger(__name__)
Expand All @@ -31,7 +32,7 @@
)


class LocalContainerMode(LocalTorchServe, LocalDJLServing, LocalTritonServer, LocalTgiServing):
class LocalContainerMode(LocalTorchServe, LocalDJLServing, LocalTritonServer, LocalTgiServing, LocalHFDLCServing):
"""A class that holds methods to deploy model to a container in local environment"""

def __init__(
Expand Down Expand Up @@ -128,6 +129,15 @@ def create_server(
jumpstart=jumpstart,
)
self._ping_container = self._tgi_deep_ping
elif self.model_server == ModelServer.HuggingFaceDLC:
self._start_hf_dlc_serving(
client=self.client,
image=image,
model_path=model_path if model_path else self.model_path,
secret_key=secret_key,
env_vars=env_vars if env_vars else self.env_vars,
)
self._ping_container = self._hf_dlc_deep_ping

# allow some time for container to be ready
time.sleep(10)
Expand Down
Loading

0 comments on commit 734ec60

Please sign in to comment.