Skip to content

Commit

Permalink
feat: Supporting tbac in load_run (aws#4039)
Browse files Browse the repository at this point in the history
  • Loading branch information
ananth102 authored Dec 26, 2023
1 parent 0011f65 commit c797f2d
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 30 deletions.
9 changes: 8 additions & 1 deletion src/sagemaker/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ def __init__(
)

if not _TrialComponent._trial_component_is_associated_to_trial(
self._trial_component.trial_component_name, self._trial.trial_name, sagemaker_session
self._trial_component.trial_component_name,
self._trial.trial_name,
sagemaker_session,
):
self._trial.add_trial_component(self._trial_component)

Expand Down Expand Up @@ -781,6 +783,7 @@ def load_run(
sagemaker_session: Optional["Session"] = None,
artifact_bucket: Optional[str] = None,
artifact_prefix: Optional[str] = None,
tags: Optional[List[Dict[str, str]]] = None,
) -> Run:
"""Load an existing run.
Expand Down Expand Up @@ -849,6 +852,8 @@ def load_run(
will be used.
artifact_prefix (str): The S3 key prefix used to generate the S3 path
to upload the artifact to (default: "trial-component-artifacts").
tags (List[Dict[str, str]]): A list of tags to be used for all create calls,
e.g. to create an experiment, a run group, etc. (default: None).
Returns:
Run: The loaded Run object.
Expand All @@ -870,6 +875,7 @@ def load_run(
sagemaker_session=sagemaker_session or _utils.default_session(),
artifact_bucket=artifact_bucket,
artifact_prefix=artifact_prefix,
tags=tags,
)
elif _RunContext.get_current_run():
run_instance = _RunContext.get_current_run()
Expand All @@ -889,6 +895,7 @@ def load_run(
sagemaker_session=sagemaker_session or _utils.default_session(),
artifact_bucket=artifact_bucket,
artifact_prefix=artifact_prefix,
tags=tags,
)
else:
raise RuntimeError(
Expand Down
88 changes: 59 additions & 29 deletions tests/unit/sagemaker/experiments/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
TEST_RUN_DISPLAY_NAME,
TEST_ARTIFACT_BUCKET,
TEST_ARTIFACT_PREFIX,
TEST_TAGS,
)


Expand Down Expand Up @@ -155,24 +156,22 @@ def test_run_init_name_length_exceed_limit(sagemaker_session):


@pytest.mark.parametrize(
("kwargs", "expected_artifact_bucket", "expected_artifact_prefix"),
("kwargs", "expected_artifact_bucket", "expected_artifact_prefix", "expected_tags"),
[
({}, None, _DEFAULT_ARTIFACT_PREFIX),
({}, None, _DEFAULT_ARTIFACT_PREFIX, None),
(
{
"artifact_bucket": TEST_ARTIFACT_BUCKET,
"artifact_prefix": TEST_ARTIFACT_PREFIX,
"tags": TEST_TAGS,
},
TEST_ARTIFACT_BUCKET,
TEST_ARTIFACT_PREFIX,
TEST_TAGS,
),
],
)
@patch.object(_TrialComponent, "save", MagicMock(return_value=None))
@patch(
"sagemaker.experiments.run.Experiment._load_or_create",
MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)),
)
@patch(
"sagemaker.experiments.run._Trial._load_or_create",
MagicMock(side_effect=mock_trial_load_or_create_func),
Expand All @@ -189,6 +188,7 @@ def test_run_load_no_run_name_and_in_train_job(
kwargs,
expected_artifact_bucket,
expected_artifact_prefix,
expected_tags,
):
client = sagemaker_session.sagemaker_client
job_name = "my-train-job"
Expand All @@ -213,26 +213,32 @@ def test_run_load_no_run_name_and_in_train_job(
{
"TrialComponent": {
"Parents": [
{"ExperimentName": TEST_EXP_NAME, "TrialName": exp_config[TRIAL_NAME]}
{
"ExperimentName": TEST_EXP_NAME,
"TrialName": exp_config[TRIAL_NAME],
}
],
"TrialComponentName": expected_tc_name,
}
}
]
}
with load_run(sagemaker_session=sagemaker_session, **kwargs) as run_obj:
assert run_obj._in_load
assert not run_obj._inside_init_context
assert run_obj._inside_load_context
assert run_obj.run_name == TEST_RUN_NAME
assert run_obj._trial_component.trial_component_name == expected_tc_name
assert run_obj.run_group_name == Run._generate_trial_name(TEST_EXP_NAME)
assert run_obj._trial
assert run_obj.experiment_name == TEST_EXP_NAME
assert run_obj._experiment
assert run_obj.experiment_config == exp_config
assert run_obj._artifact_uploader.artifact_bucket == expected_artifact_bucket
assert run_obj._artifact_uploader.artifact_prefix == expected_artifact_prefix
expmock = MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME, tags=expected_tags))
with patch("sagemaker.experiments.run.Experiment._load_or_create", expmock):
with load_run(sagemaker_session=sagemaker_session, **kwargs) as run_obj:
assert run_obj._in_load
assert not run_obj._inside_init_context
assert run_obj._inside_load_context
assert run_obj.run_name == TEST_RUN_NAME
assert run_obj._trial_component.trial_component_name == expected_tc_name
assert run_obj.run_group_name == Run._generate_trial_name(TEST_EXP_NAME)
assert run_obj._trial
assert run_obj.experiment_name == TEST_EXP_NAME
assert run_obj._experiment
assert run_obj.experiment_config == exp_config
assert run_obj._artifact_uploader.artifact_bucket == expected_artifact_bucket
assert run_obj._artifact_uploader.artifact_prefix == expected_artifact_prefix
assert run_obj._experiment.tags == expected_tags

client.describe_training_job.assert_called_once_with(TrainingJobName=job_name)
run_obj._trial.add_trial_component.assert_not_called()
Expand Down Expand Up @@ -265,7 +271,9 @@ def test_run_load_no_run_name_and_not_in_train_job(run_obj, sagemaker_session):
assert run_obj == run


def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(sagemaker_session):
def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(
sagemaker_session,
):
with pytest.raises(RuntimeError) as err:
with load_run(sagemaker_session=sagemaker_session):
pass
Expand Down Expand Up @@ -388,7 +396,10 @@ def test_run_load_in_sm_processing_job(mock_run_env, sagemaker_session):
{
"TrialComponent": {
"Parents": [
{"ExperimentName": TEST_EXP_NAME, "TrialName": exp_config[TRIAL_NAME]}
{
"ExperimentName": TEST_EXP_NAME,
"TrialName": exp_config[TRIAL_NAME],
}
],
"TrialComponentName": expected_tc_name,
}
Expand Down Expand Up @@ -442,7 +453,10 @@ def test_run_load_in_sm_transform_job(mock_run_env, sagemaker_session):
{
"TrialComponent": {
"Parents": [
{"ExperimentName": TEST_EXP_NAME, "TrialName": exp_config[TRIAL_NAME]}
{
"ExperimentName": TEST_EXP_NAME,
"TrialName": exp_config[TRIAL_NAME],
}
],
"TrialComponentName": expected_tc_name,
}
Expand Down Expand Up @@ -589,7 +603,10 @@ def test_log_output_artifact_outside_run_context(run_obj):


def test_log_output_artifact(run_obj):
run_obj._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value")
run_obj._artifact_uploader.upload_artifact.return_value = (
"s3uri_value",
"etag_value",
)
with run_obj:
run_obj.log_file("foo.txt", "name", "whizz/bang")
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt", extra_args=None)
Expand All @@ -608,7 +625,10 @@ def test_log_input_artifact_outside_run_context(run_obj):


def test_log_input_artifact(run_obj):
run_obj._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value")
run_obj._artifact_uploader.upload_artifact.return_value = (
"s3uri_value",
"etag_value",
)
with run_obj:
run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False)
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt", extra_args=None)
Expand Down Expand Up @@ -653,7 +673,10 @@ def test_log_multiple_input_artifacts(run_obj):
"etag_value" + str(index),
)
run_obj.log_file(
file_path, "name" + str(index), "whizz/bang" + str(index), is_output=False
file_path,
"name" + str(index),
"whizz/bang" + str(index),
is_output=False,
)
run_obj._artifact_uploader.upload_artifact.assert_called_with(
file_path, extra_args=None
Expand Down Expand Up @@ -757,7 +780,12 @@ def test_log_precision_recall_invalid_input(run_obj):
with run_obj:
with pytest.raises(ValueError) as error:
run_obj.log_precision_recall(
y_true, y_scores, 0, title="TestPrecisionRecall", no_skill=no_skill, is_output=False
y_true,
y_scores,
0,
title="TestPrecisionRecall",
no_skill=no_skill,
is_output=False,
)
assert "Lengths mismatch between true labels and predicted probabilities" in str(error)

Expand Down Expand Up @@ -905,7 +933,8 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
display_name="C" + str(i),
source_arn="D" + str(i),
status=TrialComponentStatus(
primary_status=_TrialComponentStatusType.InProgress.value, message="E" + str(i)
primary_status=_TrialComponentStatusType.InProgress.value,
message="E" + str(i),
),
start_time=start_time + datetime.timedelta(hours=i),
end_time=end_time + datetime.timedelta(hours=i),
Expand All @@ -925,7 +954,8 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
display_name="C" + str(i),
source_arn="D" + str(i),
status=TrialComponentStatus(
primary_status=_TrialComponentStatusType.InProgress.value, message="E" + str(i)
primary_status=_TrialComponentStatusType.InProgress.value,
message="E" + str(i),
),
start_time=start_time + datetime.timedelta(hours=i),
end_time=end_time + datetime.timedelta(hours=i),
Expand Down

0 comments on commit c797f2d

Please sign in to comment.