Skip to content

Commit

Permalink
change: update model path in local mode (aws#4296)
Browse files Browse the repository at this point in the history
* Update model path in local mode

* Add test
  • Loading branch information
trungleduc authored Dec 22, 2023
1 parent 1d6ba0e commit 9643e97
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/sagemaker/local/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ def retrieve_artifacts(self, compose_data, output_data_config, job_name):
output_data_config["S3OutputPath"],
job_name,
self.sagemaker_session,
prefix="output",
)

_delete_tree(model_artifacts)
Expand Down
6 changes: 4 additions & 2 deletions src/sagemaker/local/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def copy_directory_structure(destination_directory, relative_path):
os.makedirs(destination_directory, relative_path)


def move_to_destination(source, destination, job_name, sagemaker_session):
def move_to_destination(source, destination, job_name, sagemaker_session, prefix=""):
"""Move source to destination.
Can handle uploading to S3.
Expand All @@ -64,6 +64,8 @@ def move_to_destination(source, destination, job_name, sagemaker_session):
job_name (str): SageMaker job name.
sagemaker_session (sagemaker.Session): a sagemaker_session to interact
with S3 if needed
prefix (str, optional): the directory on S3 used to save files, default
to the root of ``destination``
Returns:
(str): destination URI
Expand All @@ -75,7 +77,7 @@ def move_to_destination(source, destination, job_name, sagemaker_session):
final_uri = destination
elif parsed_uri.scheme == "s3":
bucket = parsed_uri.netloc
path = s3.s3_path_join(parsed_uri.path, job_name)
path = s3.s3_path_join(parsed_uri.path, job_name, prefix)
final_uri = s3.s3_path_join("s3://", bucket, path)
sagemaker_session.upload_data(source, bucket, path)
else:
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/sagemaker/local/test_local_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,18 @@ def test_move_to_destination_s3(recursive_copy):
sms.upload_data.assert_called_with("/tmp/data", "bucket", "job")


@patch("shutil.rmtree", Mock())
def test_move_to_destination_s3_with_prefix():
sms = Mock(
settings=SessionSettings(),
)
uri = sagemaker.local.utils.move_to_destination(
"/tmp/data", "s3://bucket/path", "job", sms, "foo_prefix"
)
sms.upload_data.assert_called_with("/tmp/data", "bucket", "path/job/foo_prefix")
assert uri == "s3://bucket/path/job/foo_prefix"


def test_move_to_destination_illegal_destination():
with pytest.raises(ValueError):
sagemaker.local.utils.move_to_destination("/tmp/data", "ftp://ftp/in/2018", "job", None)
Expand Down

0 comments on commit 9643e97

Please sign in to comment.