diff --git a/src/sagemaker/experiments/run.py b/src/sagemaker/experiments/run.py index fdd8f9cdbc..8fbffe3667 100644 --- a/src/sagemaker/experiments/run.py +++ b/src/sagemaker/experiments/run.py @@ -210,7 +210,9 @@ def __init__( ) if not _TrialComponent._trial_component_is_associated_to_trial( - self._trial_component.trial_component_name, self._trial.trial_name, sagemaker_session + self._trial_component.trial_component_name, + self._trial.trial_name, + sagemaker_session, ): self._trial.add_trial_component(self._trial_component) @@ -781,6 +783,7 @@ def load_run( sagemaker_session: Optional["Session"] = None, artifact_bucket: Optional[str] = None, artifact_prefix: Optional[str] = None, + tags: Optional[List[Dict[str, str]]] = None, ) -> Run: """Load an existing run. @@ -849,6 +852,8 @@ def load_run( will be used. artifact_prefix (str): The S3 key prefix used to generate the S3 path to upload the artifact to (default: "trial-component-artifacts"). + tags (List[Dict[str, str]]): A list of tags to be used for all create calls, + e.g. to create an experiment, a run group, etc. (default: None). Returns: Run: The loaded Run object. @@ -870,6 +875,7 @@ def load_run( sagemaker_session=sagemaker_session or _utils.default_session(), artifact_bucket=artifact_bucket, artifact_prefix=artifact_prefix, + tags=tags, ) elif _RunContext.get_current_run(): run_instance = _RunContext.get_current_run() @@ -889,6 +895,7 @@ def load_run( sagemaker_session=sagemaker_session or _utils.default_session(), artifact_bucket=artifact_bucket, artifact_prefix=artifact_prefix, + tags=tags, ) else: raise RuntimeError( diff --git a/tests/unit/sagemaker/experiments/test_run.py b/tests/unit/sagemaker/experiments/test_run.py index 68326a19af..2bebbe3d9c 100644 --- a/tests/unit/sagemaker/experiments/test_run.py +++ b/tests/unit/sagemaker/experiments/test_run.py @@ -55,6 +55,7 @@ TEST_RUN_DISPLAY_NAME, TEST_ARTIFACT_BUCKET, TEST_ARTIFACT_PREFIX, + TEST_TAGS, ) @@ -155,24 +156,22 @@ def test_run_init_name_length_exceed_limit(sagemaker_session): @pytest.mark.parametrize( - ("kwargs", "expected_artifact_bucket", "expected_artifact_prefix"), + ("kwargs", "expected_artifact_bucket", "expected_artifact_prefix", "expected_tags"), [ - ({}, None, _DEFAULT_ARTIFACT_PREFIX), + ({}, None, _DEFAULT_ARTIFACT_PREFIX, None), ( { "artifact_bucket": TEST_ARTIFACT_BUCKET, "artifact_prefix": TEST_ARTIFACT_PREFIX, + "tags": TEST_TAGS, }, TEST_ARTIFACT_BUCKET, TEST_ARTIFACT_PREFIX, + TEST_TAGS, ), ], ) @patch.object(_TrialComponent, "save", MagicMock(return_value=None)) -@patch( - "sagemaker.experiments.run.Experiment._load_or_create", - MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)), -) @patch( "sagemaker.experiments.run._Trial._load_or_create", MagicMock(side_effect=mock_trial_load_or_create_func), @@ -189,6 +188,7 @@ def test_run_load_no_run_name_and_in_train_job( kwargs, expected_artifact_bucket, expected_artifact_prefix, + expected_tags, ): client = sagemaker_session.sagemaker_client job_name = "my-train-job" @@ -213,26 +213,32 @@ def test_run_load_no_run_name_and_in_train_job( { "TrialComponent": { "Parents": [ - {"ExperimentName": TEST_EXP_NAME, "TrialName": exp_config[TRIAL_NAME]} + { + "ExperimentName": TEST_EXP_NAME, + "TrialName": exp_config[TRIAL_NAME], + } ], "TrialComponentName": expected_tc_name, } } ] } - with load_run(sagemaker_session=sagemaker_session, **kwargs) as run_obj: - assert run_obj._in_load - assert not run_obj._inside_init_context - assert run_obj._inside_load_context - assert run_obj.run_name == TEST_RUN_NAME - assert run_obj._trial_component.trial_component_name == expected_tc_name - assert run_obj.run_group_name == Run._generate_trial_name(TEST_EXP_NAME) - assert run_obj._trial - assert run_obj.experiment_name == TEST_EXP_NAME - assert run_obj._experiment - assert run_obj.experiment_config == exp_config - assert run_obj._artifact_uploader.artifact_bucket == expected_artifact_bucket - assert run_obj._artifact_uploader.artifact_prefix == expected_artifact_prefix + expmock = MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME, tags=expected_tags)) + with patch("sagemaker.experiments.run.Experiment._load_or_create", expmock): + with load_run(sagemaker_session=sagemaker_session, **kwargs) as run_obj: + assert run_obj._in_load + assert not run_obj._inside_init_context + assert run_obj._inside_load_context + assert run_obj.run_name == TEST_RUN_NAME + assert run_obj._trial_component.trial_component_name == expected_tc_name + assert run_obj.run_group_name == Run._generate_trial_name(TEST_EXP_NAME) + assert run_obj._trial + assert run_obj.experiment_name == TEST_EXP_NAME + assert run_obj._experiment + assert run_obj.experiment_config == exp_config + assert run_obj._artifact_uploader.artifact_bucket == expected_artifact_bucket + assert run_obj._artifact_uploader.artifact_prefix == expected_artifact_prefix + assert run_obj._experiment.tags == expected_tags client.describe_training_job.assert_called_once_with(TrainingJobName=job_name) run_obj._trial.add_trial_component.assert_not_called() @@ -265,7 +271,9 @@ def test_run_load_no_run_name_and_not_in_train_job(run_obj, sagemaker_session): assert run_obj == run -def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(sagemaker_session): +def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context( + sagemaker_session, +): with pytest.raises(RuntimeError) as err: with load_run(sagemaker_session=sagemaker_session): pass @@ -388,7 +396,10 @@ def test_run_load_in_sm_processing_job(mock_run_env, sagemaker_session): { "TrialComponent": { "Parents": [ - {"ExperimentName": TEST_EXP_NAME, "TrialName": exp_config[TRIAL_NAME]} + { + "ExperimentName": TEST_EXP_NAME, + "TrialName": exp_config[TRIAL_NAME], + } ], "TrialComponentName": expected_tc_name, } @@ -442,7 +453,10 @@ def test_run_load_in_sm_transform_job(mock_run_env, sagemaker_session): { "TrialComponent": { "Parents": [ - {"ExperimentName": TEST_EXP_NAME, "TrialName": exp_config[TRIAL_NAME]} + { + "ExperimentName": TEST_EXP_NAME, + "TrialName": exp_config[TRIAL_NAME], + } ], "TrialComponentName": expected_tc_name, } @@ -589,7 +603,10 @@ def test_log_output_artifact_outside_run_context(run_obj): def test_log_output_artifact(run_obj): - run_obj._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value") + run_obj._artifact_uploader.upload_artifact.return_value = ( + "s3uri_value", + "etag_value", + ) with run_obj: run_obj.log_file("foo.txt", "name", "whizz/bang") run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt", extra_args=None) @@ -608,7 +625,10 @@ def test_log_input_artifact_outside_run_context(run_obj): def test_log_input_artifact(run_obj): - run_obj._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value") + run_obj._artifact_uploader.upload_artifact.return_value = ( + "s3uri_value", + "etag_value", + ) with run_obj: run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False) run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt", extra_args=None) @@ -653,7 +673,10 @@ def test_log_multiple_input_artifacts(run_obj): "etag_value" + str(index), ) run_obj.log_file( - file_path, "name" + str(index), "whizz/bang" + str(index), is_output=False + file_path, + "name" + str(index), + "whizz/bang" + str(index), + is_output=False, ) run_obj._artifact_uploader.upload_artifact.assert_called_with( file_path, extra_args=None @@ -757,7 +780,12 @@ def test_log_precision_recall_invalid_input(run_obj): with run_obj: with pytest.raises(ValueError) as error: run_obj.log_precision_recall( - y_true, y_scores, 0, title="TestPrecisionRecall", no_skill=no_skill, is_output=False + y_true, + y_scores, + 0, + title="TestPrecisionRecall", + no_skill=no_skill, + is_output=False, ) assert "Lengths mismatch between true labels and predicted probabilities" in str(error) @@ -905,7 +933,8 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses display_name="C" + str(i), source_arn="D" + str(i), status=TrialComponentStatus( - primary_status=_TrialComponentStatusType.InProgress.value, message="E" + str(i) + primary_status=_TrialComponentStatusType.InProgress.value, + message="E" + str(i), ), start_time=start_time + datetime.timedelta(hours=i), end_time=end_time + datetime.timedelta(hours=i), @@ -925,7 +954,8 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses display_name="C" + str(i), source_arn="D" + str(i), status=TrialComponentStatus( - primary_status=_TrialComponentStatusType.InProgress.value, message="E" + str(i) + primary_status=_TrialComponentStatusType.InProgress.value, + message="E" + str(i), ), start_time=start_time + datetime.timedelta(hours=i), end_time=end_time + datetime.timedelta(hours=i),