Skip to content

Commit

Permalink
fix: add warning message for job-prefixed pipeline steps when no job …
Browse files Browse the repository at this point in the history
…name is provided
  • Loading branch information
svia3 committed Jan 16, 2024
1 parent d083396 commit 5d1daa5
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
10 changes: 10 additions & 0 deletions src/sagemaker/workflow/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@
"if desired."
)

JOB_KEY_NONE_WARN_MSG_TEMPLATE = (
"Invalid input: use_custom_job_prefix flag is set but the name field [{}] has not been "
"specified. Please refer to the AWS Docs to identify which field should be set to enable the "
"custom-prefixing feature for jobs created via a pipeline execution. "
"https://docs.aws.amazon.com/sagemaker/latest/dg/"
"build-and-manage-access.html#build-and-manage-step-permissions-prefix"
)

if TYPE_CHECKING:
from sagemaker.workflow.step_collections import StepCollection

Expand Down Expand Up @@ -458,6 +466,8 @@ def trim_request_dict(request_dict, job_key, config):
request_dict.pop(job_key, None) # safely return null in case of KeyError
else:
if job_key in request_dict:
if request_dict[job_key] is None or len(request_dict[job_key]) == 0:
raise ValueError(JOB_KEY_NONE_WARN_MSG_TEMPLATE.format(job_key))
request_dict[job_key] = base_from_name(request_dict[job_key]) # trim timestamp

return request_dict
Expand Down
30 changes: 28 additions & 2 deletions tests/unit/sagemaker/workflow/test_model_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,7 +1096,6 @@ def test_pass_in_wrong_type_of_retry_policies(pipeline_session, model):


def test_register_model_step_using_custom_model_package_name(pipeline_session):

custom_model_prefix = "custom-model-package-prefix"
model = Model(
name="MyModel",
Expand Down Expand Up @@ -1139,7 +1138,6 @@ def test_register_model_step_using_custom_model_package_name(pipeline_session):


def test_create_model_step_using_custom_model_name(pipeline_session):

custom_model_prefix = "custom-model-prefix"
model = Model(
name=custom_model_prefix,
Expand Down Expand Up @@ -1171,3 +1169,31 @@ def test_create_model_step_using_custom_model_name(pipeline_session):
steps = json.loads(pipeline.definition())["Steps"]
assert len(steps) == 1
assert "ModelName" not in steps[0]["Arguments"]


def test_create_model_step_using_custom_model_name_set_to_none(pipeline_session):
# Name of the model not specified, will resolve to None.
model = Model(
image_uri="my-image",
sagemaker_session=pipeline_session,
model_data="s3://",
role=ROLE,
)
step_create_model = ModelStep(name="MyModelStep", step_args=model.create())

# 1. Toggle on custom-prefixing model package name popped
config = PipelineDefinitionConfig(use_custom_job_prefix=True)

with pytest.raises(ValueError) as error:
pipeline = Pipeline(
name="MyPipeline",
steps=[step_create_model],
sagemaker_session=pipeline_session,
pipeline_definition_config=config,
)
pipeline.definition()

assert (
"Invalid input: use_custom_job_prefix flag is set but the name field "
"[ModelName] has not been specified." in str(error.value)
)

0 comments on commit 5d1daa5

Please sign in to comment.