Skip to content

Commit

Permalink
change: Upgrade smp to version 2.2 (aws#4479)
Browse files Browse the repository at this point in the history
* upgrading smp to version 2.2

* fixing linting issue

* fixing syntax error with multiline if statement

* upgrading smp to version 2.2

* fixing linting issue

* fixing syntax error with multiline if statement

* fixing formatting

---------

Co-authored-by: Andrew Tian <tinandr@amazon.com>
  • Loading branch information
2 people authored and jiapinw committed Jun 25, 2024
1 parent c63c268 commit 90aa33b
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 4 deletions.
10 changes: 9 additions & 1 deletion src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@
"2.0.1",
"2.1.0",
"2.1.2",
"2.2.0",
],
}

Expand All @@ -160,7 +161,14 @@
]


TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.2"]
TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = [
"1.13.1",
"2.0.0",
"2.0.1",
"2.1.0",
"2.1.2",
"2.2.0",
]

TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS = [
Expand Down
28 changes: 27 additions & 1 deletion src/sagemaker/image_uri_config/pytorch-smp.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
],
"version_aliases": {
"2.0": "2.0.1",
"2.1": "2.1.2"
"2.1": "2.1.2",
"2.2": "2.2.0"
},
"versions": {
"2.0.1": {
Expand Down Expand Up @@ -57,6 +58,31 @@
"us-west-2": "658645717510"
},
"repository": "smdistributed-modelparallel"
},
"2.2.0": {
"py_versions": [
"py310"
],
"registries": {
"ap-northeast-1": "658645717510",
"ap-northeast-2": "658645717510",
"ap-northeast-3": "658645717510",
"ap-south-1": "658645717510",
"ap-southeast-1": "658645717510",
"ap-southeast-2": "658645717510",
"ca-central-1": "658645717510",
"eu-central-1": "658645717510",
"eu-north-1": "658645717510",
"eu-west-1": "658645717510",
"eu-west-2": "658645717510",
"eu-west-3": "658645717510",
"sa-east-1": "658645717510",
"us-east-1": "658645717510",
"us-east-2": "658645717510",
"us-west-1": "658645717510",
"us-west-2": "658645717510"
},
"repository": "smdistributed-modelparallel"
}
}
}
Expand Down
6 changes: 5 additions & 1 deletion src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,11 @@ def get_training_image_uri(
if "modelparallel" in distribution["smdistributed"]:
if distribution["smdistributed"]["modelparallel"].get("enabled", True):
framework = "pytorch-smp"
if "p5" in instance_type or "2.1" in framework_version:
if (
"p5" in instance_type
or "2.1" in framework_version
or "2.2" in framework_version
):
container_version = "cu121"
else:
container_version = "cu118"
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/sagemaker/image_uris/test_smp_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_smp_v2(load_config):
for region in ACCOUNTS.keys():
for instance_type in CONTAINER_VERSIONS.keys():
cuda_vers = CONTAINER_VERSIONS[instance_type]
if "2.1" in version:
if "2.1" in version or "2.2" in version:
cuda_vers = "cu121"

uri = image_uris.get_training_image_uri(
Expand Down

0 comments on commit 90aa33b

Please sign in to comment.