diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index 22a15c0570..60f28d3b0c 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -430,6 +430,7 @@ def retrieve_artifacts(self, compose_data, output_data_config, job_name): output_data_config["S3OutputPath"], job_name, self.sagemaker_session, + prefix="output", ) _delete_tree(model_artifacts) diff --git a/src/sagemaker/local/utils.py b/src/sagemaker/local/utils.py index 298c95acb6..16375de7d4 100644 --- a/src/sagemaker/local/utils.py +++ b/src/sagemaker/local/utils.py @@ -53,7 +53,7 @@ def copy_directory_structure(destination_directory, relative_path): os.makedirs(destination_directory, relative_path) -def move_to_destination(source, destination, job_name, sagemaker_session): +def move_to_destination(source, destination, job_name, sagemaker_session, prefix=""): """Move source to destination. Can handle uploading to S3. @@ -64,6 +64,8 @@ def move_to_destination(source, destination, job_name, sagemaker_session): job_name (str): SageMaker job name. sagemaker_session (sagemaker.Session): a sagemaker_session to interact with S3 if needed + prefix (str, optional): the directory on S3 used to save files, default + to the root of ``destination`` Returns: (str): destination URI @@ -75,7 +77,7 @@ def move_to_destination(source, destination, job_name, sagemaker_session): final_uri = destination elif parsed_uri.scheme == "s3": bucket = parsed_uri.netloc - path = s3.s3_path_join(parsed_uri.path, job_name) + path = s3.s3_path_join(parsed_uri.path, job_name, prefix) final_uri = s3.s3_path_join("s3://", bucket, path) sagemaker_session.upload_data(source, bucket, path) else: diff --git a/tests/unit/sagemaker/local/test_local_utils.py b/tests/unit/sagemaker/local/test_local_utils.py index 2db8c83351..39b9e2b392 100644 --- a/tests/unit/sagemaker/local/test_local_utils.py +++ b/tests/unit/sagemaker/local/test_local_utils.py @@ -66,6 +66,18 @@ def test_move_to_destination_s3(recursive_copy): sms.upload_data.assert_called_with("/tmp/data", "bucket", "job") +@patch("shutil.rmtree", Mock()) +def test_move_to_destination_s3_with_prefix(): + sms = Mock( + settings=SessionSettings(), + ) + uri = sagemaker.local.utils.move_to_destination( + "/tmp/data", "s3://bucket/path", "job", sms, "foo_prefix" + ) + sms.upload_data.assert_called_with("/tmp/data", "bucket", "path/job/foo_prefix") + assert uri == "s3://bucket/path/job/foo_prefix" + + def test_move_to_destination_illegal_destination(): with pytest.raises(ValueError): sagemaker.local.utils.move_to_destination("/tmp/data", "ftp://ftp/in/2018", "job", None)