Skip to content

Commit

Permalink
bad rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
bencrabtree committed Mar 13, 2024
1 parent da1b642 commit 709bedc
Show file tree
Hide file tree
Showing 13 changed files with 1 addition and 38 deletions.
1 change: 0 additions & 1 deletion src/sagemaker/jumpstart/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,6 @@ def get_model_specs(
hub_arn: Optional[str] = None,
s3_client: Optional[boto3.client] = None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn: Optional[str] = None,
) -> JumpStartModelSpecs:
"""Returns model specs from JumpStart models cache.
Expand Down
1 change: 0 additions & 1 deletion src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
DescribeHubContentsResponse,
HubType,
HubContentType,
HubDataType,
)
from sagemaker.jumpstart.curated_hub import utils as hub_utils
from sagemaker.jumpstart.enums import JumpStartModelType
Expand Down
1 change: 0 additions & 1 deletion src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,6 @@ def _validate_model_id_and_get_type_hook():
model_version=model_version,
hub_arn=hub_arn,
model_type=self.model_type,
hub_arn=hub_arn,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
role=role,
Expand Down
2 changes: 0 additions & 2 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def get_init_kwargs(
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
hub_arn: Optional[str] = None,
tolerate_vulnerable_model: Optional[bool] = None,
tolerate_deprecated_model: Optional[bool] = None,
region: Optional[str] = None,
Expand Down Expand Up @@ -141,7 +140,6 @@ def get_init_kwargs(
model_version=model_version,
hub_arn=hub_arn,
model_type=model_type,
hub_arn=hub_arn,
role=role,
region=region,
instance_count=instance_count,
Expand Down
4 changes: 0 additions & 4 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,6 @@ def get_deploy_kwargs(
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
hub_arn: Optional[str] = None,
region: Optional[str] = None,
initial_instance_count: Optional[int] = None,
instance_type: Optional[str] = None,
Expand Down Expand Up @@ -585,7 +584,6 @@ def get_deploy_kwargs(
model_version=model_version,
hub_arn=hub_arn,
model_type=model_type,
hub_arn=hub_arn,
region=region,
initial_instance_count=initial_instance_count,
instance_type=instance_type,
Expand Down Expand Up @@ -726,7 +724,6 @@ def get_init_kwargs(
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
hub_arn: Optional[str] = None,
tolerate_vulnerable_model: Optional[bool] = None,
tolerate_deprecated_model: Optional[bool] = None,
instance_type: Optional[str] = None,
Expand Down Expand Up @@ -760,7 +757,6 @@ def get_init_kwargs(
model_version=model_version,
hub_arn=hub_arn,
model_type=model_type,
hub_arn=hub_arn,
instance_type=instance_type,
region=region,
image_uri=image_uri,
Expand Down
13 changes: 1 addition & 12 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from copy import deepcopy
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Union
from sagemaker.session import Session
from sagemaker.utils import get_instance_type_family, format_tags, Tags
from sagemaker.enums import EndpointType
from sagemaker.model_metrics import ModelMetrics
Expand Down Expand Up @@ -1290,7 +1291,6 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
"model_version",
"hub_arn",
"model_type",
"hub_arn",
"instance_type",
"tolerate_vulnerable_model",
"tolerate_deprecated_model",
Expand Down Expand Up @@ -1323,7 +1323,6 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
"model_version",
"hub_arn",
"model_type",
"hub_arn",
"tolerate_vulnerable_model",
"tolerate_deprecated_model",
"region",
Expand All @@ -1337,7 +1336,6 @@ def __init__(
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
hub_arn: Optional[str] = None,
region: Optional[str] = None,
instance_type: Optional[str] = None,
image_uri: Optional[Union[str, Any]] = None,
Expand Down Expand Up @@ -1369,7 +1367,6 @@ def __init__(
self.model_version = model_version
self.hub_arn = hub_arn
self.model_type = model_type
self.hub_arn = hub_arn
self.instance_type = instance_type
self.region = region
self.image_uri = image_uri
Expand Down Expand Up @@ -1404,7 +1401,6 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
"model_version",
"hub_arn",
"model_type",
"hub_arn",
"initial_instance_count",
"instance_type",
"region",
Expand Down Expand Up @@ -1436,7 +1432,6 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
SERIALIZATION_EXCLUSION_SET = {
"model_id",
"model_version",
"hub_arn",
"model_type",
"hub_arn",
"region",
Expand Down Expand Up @@ -1485,7 +1480,6 @@ def __init__(
self.model_version = model_version
self.hub_arn = hub_arn
self.model_type = model_type
self.hub_arn = hub_arn
self.initial_instance_count = initial_instance_count
self.instance_type = instance_type
self.region = region
Expand Down Expand Up @@ -1522,7 +1516,6 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
"model_version",
"hub_arn",
"model_type",
"hub_arn",
"instance_type",
"instance_count",
"region",
Expand Down Expand Up @@ -1584,7 +1577,6 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
"model_version",
"hub_arn",
"model_type",
"hub_arn",
}

def __init__(
Expand Down Expand Up @@ -1714,7 +1706,6 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs):
"model_version",
"hub_arn",
"model_type",
"hub_arn",
"region",
"inputs",
"wait",
Expand All @@ -1731,7 +1722,6 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs):
"model_version",
"hub_arn",
"model_type",
"hub_arn",
"region",
"tolerate_deprecated_model",
"tolerate_vulnerable_model",
Expand Down Expand Up @@ -1760,7 +1750,6 @@ def __init__(
self.model_version = model_version
self.hub_arn = hub_arn
self.model_type = model_type
self.hub_arn = hub_arn
self.region = region
self.inputs = inputs
self.wait = wait
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,6 @@ def add_options_to_hyperparameter(*largs, **kwargs):
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)

patched_get_model_specs.reset_mock()
Expand Down Expand Up @@ -517,7 +516,6 @@ def test_jumpstart_validate_all_hyperparameters(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)

patched_get_model_specs.reset_mock()
Expand Down
4 changes: 0 additions & 4 deletions tests/unit/sagemaker/image_uris/jumpstart/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def test_jumpstart_common_image_uri(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand All @@ -79,7 +78,6 @@ def test_jumpstart_common_image_uri(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand All @@ -102,7 +100,6 @@ def test_jumpstart_common_image_uri(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand All @@ -125,7 +122,6 @@ def test_jumpstart_common_image_uri(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand Down
1 change: 0 additions & 1 deletion tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from sagemaker.jumpstart.types import HubArnExtractedInfo
from sagemaker.jumpstart.curated_hub import utils
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
from sagemaker.jumpstart.curated_hub.types import HubArnExtractedInfo


def test_get_info_from_hub_resource_arn():
Expand Down
1 change: 0 additions & 1 deletion tests/unit/sagemaker/jumpstart/test_notebook_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,5 +751,4 @@ def test_get_model_url(
s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
4 changes: 0 additions & 4 deletions tests/unit/sagemaker/model_uris/jumpstart/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def test_jumpstart_common_model_uri(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand All @@ -74,7 +73,6 @@ def test_jumpstart_common_model_uri(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand All @@ -95,7 +93,6 @@ def test_jumpstart_common_model_uri(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand All @@ -116,7 +113,6 @@ def test_jumpstart_common_model_uri(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def test_jumpstart_resource_requirements(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_get_model_specs.reset_mock()

Expand Down
4 changes: 0 additions & 4 deletions tests/unit/sagemaker/script_uris/jumpstart/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def test_jumpstart_common_script_uri(
model_id="pytorch-ic-mobilenet-v2",
version="*",
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
Expand All @@ -74,7 +73,6 @@ def test_jumpstart_common_script_uri(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand All @@ -95,7 +93,6 @@ def test_jumpstart_common_script_uri(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand All @@ -116,7 +113,6 @@ def test_jumpstart_common_script_uri(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand Down

0 comments on commit 709bedc

Please sign in to comment.