Skip to content

Commit

Permalink
Merge branch 'master' into ci-health-checks
Browse files Browse the repository at this point in the history
  • Loading branch information
benieric authored Mar 12, 2024
2 parents 92bf4b0 + 07e1b92 commit 1c3ae1d
Show file tree
Hide file tree
Showing 11 changed files with 317 additions and 41 deletions.
48 changes: 48 additions & 0 deletions .github/workflows/codebuild-ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
name: PR Checks
on:
pull_request_target:

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.head_ref }}
cancel-in-progress: true

permissions:
id-token: write # This is required for requesting the JWT

jobs:
codestyle-doc-tests:
runs-on: ubuntu-latest
steps:
- name: Configure AWS Credentials
uses: aws-actions/configure-aws-credentials@v4
with:
role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }}
aws-region: us-west-2
role-duration-seconds: 10800
- name: Run Codestyle & Doc Tests
uses: aws-actions/aws-codebuild-run-build@v1
with:
project-name: sagemaker-python-sdk-ci-codestyle-doc-tests
source-version-override: 'pr/${{ github.event.pull_request.number }}'
unit-tests:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["py38", "py39", "py310"]
steps:
- name: Configure AWS Credentials
uses: aws-actions/configure-aws-credentials@v4
with:
role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }}
aws-region: us-west-2
role-duration-seconds: 10800
- name: Run Unit Tests
uses: aws-actions/aws-codebuild-run-build@v1
with:
project-name: sagemaker-python-sdk-ci-unit-tests
source-version-override: 'pr/${{ github.event.pull_request.number }}'
env-vars-for-codebuild: |
PY_VERSION
env:
PY_VERSION: ${{ matrix.python-version }}
3 changes: 2 additions & 1 deletion src/sagemaker/huggingface/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def get_huggingface_model_metadata(model_id: str, hf_hub_token: Optional[str] =
Returns:
dict: The model metadata retrieved with the HuggingFace API
"""

if not model_id:
raise ValueError("Model ID is empty. Please provide a valid Model ID.")
hf_model_metadata_url = f"https://huggingface.co/api/models/{model_id}"
hf_model_metadata_json = None
try:
Expand Down
8 changes: 4 additions & 4 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,8 +766,8 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:

def _script_mode_env_vars(self):
"""Returns a mapping of environment variables for script mode execution"""
script_name = None
dir_name = None
script_name = self.env.get(SCRIPT_PARAM_NAME.upper(), "")
dir_name = self.env.get(DIR_PARAM_NAME.upper(), "")
if self.uploaded_code:
script_name = self.uploaded_code.script_name
if self.repacked_model_data or self.enable_network_isolation():
Expand All @@ -783,8 +783,8 @@ def _script_mode_env_vars(self):
else "file://" + self.source_dir
)
return {
SCRIPT_PARAM_NAME.upper(): script_name or str(),
DIR_PARAM_NAME.upper(): dir_name or str(),
SCRIPT_PARAM_NAME.upper(): script_name,
DIR_PARAM_NAME.upper(): dir_name,
CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): to_string(self.container_log_level),
SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ def main(sys_args=None):
conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV")

RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env)
RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version(
client_sagemaker_pysdk_version
)

user = getpass.getuser()
if user != "root":
Expand All @@ -89,6 +86,10 @@ def main(sys_args=None):
client_python_version, conda_env, dependency_settings
)

RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version(
client_sagemaker_pysdk_version
)

exit_code = SUCCESS_EXIT_CODE
except Exception as e: # pylint: disable=broad-except
logger.exception("Error encountered while bootstrapping runtime environment: %s", e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import dataclasses
import json

import sagemaker


class _UTCFormatter(logging.Formatter):
"""Class that overrides the default local time provider in log formatter."""
Expand Down Expand Up @@ -330,6 +328,7 @@ def _current_python_version(self):

def _current_sagemaker_pysdk_version(self):
"""Returns the current sagemaker python sdk version where program is running"""
import sagemaker

return sagemaker.__version__

Expand Down Expand Up @@ -366,10 +365,10 @@ def _validate_sagemaker_pysdk_version(self, client_sagemaker_pysdk_version):
):
logger.warning(
"Inconsistent sagemaker versions found: "
"sagemaker pysdk version found in the container is "
"sagemaker python sdk version found in the container is "
"'%s' which does not match the '%s' on the local client. "
"Please make sure that the python version used in the training container "
"is the same as the local python version in case of unexpected behaviors.",
"Please make sure that the sagemaker version used in the training container "
"is the same as the local sagemaker version in case of unexpected behaviors.",
job_sagemaker_pysdk_version,
client_sagemaker_pysdk_version,
)
Expand Down
19 changes: 14 additions & 5 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
into a stream. All translations between the server and the client are handled
automatically with the specified input and output.
model (Optional[Union[object, str]): Model object (with ``predict`` method to perform
inference) or a HuggingFace/JumpStart Model ID. Either ``model`` or
``inference_spec`` is required for the model builder to build the artifact.
inference) or a HuggingFace/JumpStart Model ID. Either ``model`` or ``inference_spec``
is required for the model builder to build the artifact.
inference_spec (InferenceSpec): The inference spec file with your customized
``invoke`` and ``load`` functions.
image_uri (Optional[str]): The container image uri (which is derived from a
Expand All @@ -145,6 +145,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
to the model server). Possible values for this argument are
``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``,
``TRITON``, and``TGI``.
model_metadata (Optional[Dict[str, Any]): Dictionary used to override the HuggingFace
model metadata. Currently ``HF_TASK`` is overridable.
"""

model_path: Optional[str] = field(
Expand Down Expand Up @@ -241,6 +243,10 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
model_server: Optional[ModelServer] = field(
default=None, metadata={"help": "Define the model server to deploy to."}
)
model_metadata: Optional[Dict[str, Any]] = field(
default=None,
metadata={"help": "Define the model metadata to override, currently supports `HF_TASK`"},
)

def _build_validations(self):
"""Placeholder docstring"""
Expand Down Expand Up @@ -616,6 +622,9 @@ def build( # pylint: disable=R0911
self._is_custom_image_uri = self.image_uri is not None

if isinstance(self.model, str):
model_task = None
if self.model_metadata:
model_task = self.model_metadata.get("HF_TASK")
if self._is_jumpstart_model_id():
return self._build_for_jumpstart()
if self._is_djl(): # pylint: disable=R1705
Expand All @@ -625,10 +634,10 @@ def build( # pylint: disable=R0911
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
)

model_task = hf_model_md.get("pipeline_tag")
if self.schema_builder is None and model_task:
if model_task is None:
model_task = hf_model_md.get("pipeline_tag")
if self.schema_builder is None and model_task is not None:
self._schema_builder_init(model_task)

if model_task == "text-generation": # pylint: disable=R1705
return self._build_for_tgi()
elif self._can_fit_on_single_gpu():
Expand Down
40 changes: 20 additions & 20 deletions src/sagemaker/serve/schema/task.json
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
{
"fill-mask": {
"sample_inputs": {
"sample_inputs": {
"properties": {
"inputs": "Paris is the [MASK] of France.",
"parameters": {}
}
},
"sample_outputs": {
},
"sample_outputs": {
"properties": [
{
"sequence": "Paris is the capital of France.",
"score": 0.7
}
]
}
},
},
"question-answering": {
"sample_inputs": {
"sample_inputs": {
"properties": {
"context": "I have a German Shepherd dog, named Coco.",
"question": "What is my dog's breed?"
}
},
"sample_outputs": {
},
"sample_outputs": {
"properties": [
{
"answer": "German Shepherd",
Expand All @@ -32,36 +32,36 @@
}
]
}
},
},
"text-classification": {
"sample_inputs": {
"sample_inputs": {
"properties": {
"inputs": "Where is the capital of France?, Paris is the capital of France.",
"parameters": {}
}
},
"sample_outputs": {
},
"sample_outputs": {
"properties": [
{
"label": "entailment",
"score": 0.997
}
]
}
},
"text-generation": {
"sample_inputs": {
},
"text-generation": {
"sample_inputs": {
"properties": {
"inputs": "Hello, I'm a language model",
"parameters": {}
}
},
"sample_outputs": {
},
"sample_outputs": {
"properties": [
{
"generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my"
}
{
"generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my"
}
]
}
}
}
}
66 changes: 66 additions & 0 deletions tests/integ/sagemaker/serve/test_schema_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,69 @@ def test_model_builder_negative_path(sagemaker_session):
match="Error Message: Schema builder for text-to-image could not be found.",
):
model_builder.build(sagemaker_session=sagemaker_session)


@pytest.mark.skipif(
PYTHON_VERSION_IS_NOT_310,
reason="Testing Schema Builder Simplification feature",
)
@pytest.mark.parametrize(
"model_id, task_provided",
[
("bert-base-uncased", "fill-mask"),
("bert-large-uncased-whole-word-masking-finetuned-squad", "question-answering"),
],
)
def test_model_builder_happy_path_with_task_provided(
model_id, task_provided, sagemaker_session, gpu_instance_type
):
model_builder = ModelBuilder(model=model_id, model_metadata={"HF_TASK": task_provided})

model = model_builder.build(sagemaker_session=sagemaker_session)

assert model is not None
assert model_builder.schema_builder is not None

inputs, outputs = task.retrieve_local_schemas(task_provided)
assert model_builder.schema_builder.sample_input == inputs
assert model_builder.schema_builder.sample_output == outputs

with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
caught_ex = None
try:
iam_client = sagemaker_session.boto_session.client("iam")
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]

logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
predictor = model.deploy(
role=role_arn, instance_count=1, instance_type=gpu_instance_type
)

predicted_outputs = predictor.predict(inputs)
assert predicted_outputs is not None

except Exception as e:
caught_ex = e
finally:
cleanup_model_resources(
sagemaker_session=model_builder.sagemaker_session,
model_name=model.name,
endpoint_name=model.endpoint_name,
)
if caught_ex:
logger.exception(caught_ex)
assert (
False
), f"{caught_ex} was thrown when running transformers sagemaker endpoint test"


def test_model_builder_negative_path_with_invalid_task(sagemaker_session):
model_builder = ModelBuilder(
model="bert-base-uncased", model_metadata={"HF_TASK": "invalid-task"}
)

with pytest.raises(
TaskNotFoundException,
match="Error Message: Schema builder for invalid-task could not be found.",
):
model_builder.build(sagemaker_session=sagemaker_session)
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def test_main_failure_remote_job_with_root_user(

change_dir_permission.assert_not_called()
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION)
validate_sagemaker.assert_not_called()
run_pre_exec_script.assert_not_called()
bootstrap_runtime.assert_called()
write_failure.assert_called_with(str(runtime_err))
Expand Down Expand Up @@ -317,7 +317,7 @@ def test_main_failure_pipeline_step_with_root_user(

change_dir_permission.assert_not_called()
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION)
validate_sagemaker.assert_not_called()
run_pre_exec_script.assert_not_called()
bootstrap_runtime.assert_called()
write_failure.assert_called_with(str(runtime_err))
Expand Down
Loading

0 comments on commit 1c3ae1d

Please sign in to comment.