Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add warning message for job-prefixed pipeline steps when no job name is provided #4371

Merged
merged 1 commit into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/sagemaker/workflow/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@
"if desired."
)

JOB_KEY_NONE_WARN_MSG_TEMPLATE = (
"Invalid input: use_custom_job_prefix flag is set but the name field [{}] has not been "
"specified. Please refer to the AWS Docs to identify which field should be set to enable the "
"custom-prefixing feature for jobs created via a pipeline execution. "
"https://docs.aws.amazon.com/sagemaker/latest/dg/"
"build-and-manage-access.html#build-and-manage-step-permissions-prefix"
)

if TYPE_CHECKING:
from sagemaker.workflow.step_collections import StepCollection

Expand Down Expand Up @@ -458,6 +466,8 @@ def trim_request_dict(request_dict, job_key, config):
request_dict.pop(job_key, None) # safely return null in case of KeyError
else:
if job_key in request_dict:
if request_dict[job_key] is None or len(request_dict[job_key]) == 0:
raise ValueError(JOB_KEY_NONE_WARN_MSG_TEMPLATE.format(job_key))
request_dict[job_key] = base_from_name(request_dict[job_key]) # trim timestamp

return request_dict
Expand Down
28 changes: 28 additions & 0 deletions tests/unit/sagemaker/workflow/test_model_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,3 +1328,31 @@ def _validate_repack_job_non_configurable_args(
assert model_archive.expr == expected_model_archive
else:
assert model_archive == expected_model_archive


def test_create_model_step_using_custom_model_name_set_to_none(pipeline_session):
# Name of the model not specified, will resolve to None.
model = Model(
image_uri="my-image",
sagemaker_session=pipeline_session,
model_data="s3://",
role=ROLE,
)
step_create_model = ModelStep(name="MyModelStep", step_args=model.create())

# 1. Toggle on custom-prefixing model name set to None.
config = PipelineDefinitionConfig(use_custom_job_prefix=True)

with pytest.raises(ValueError) as error:
pipeline = Pipeline(
name="MyPipeline",
steps=[step_create_model],
sagemaker_session=pipeline_session,
pipeline_definition_config=config,
)
pipeline.definition()

assert (
"Invalid input: use_custom_job_prefix flag is set but the name field "
"[ModelName] has not been specified." in str(error.value)
)