diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 85a041379a..9df744531e 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -262,6 +262,15 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin return sorted(list(scripts)) +def _is_valid_version(version: str) -> bool: + """Checks if the version is convertable to Version class.""" + try: + Version(version) + return True + except Exception: # pylint: disable=broad-except + return False + + def list_jumpstart_models( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), region: Optional[str] = None, @@ -304,7 +313,8 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin ): if model_id not in model_id_version_dict: model_id_version_dict[model_id] = list() - model_id_version_dict[model_id].append(Version(version)) + model_version = Version(version) if _is_valid_version(version) else version + model_id_version_dict[model_id].append(model_version) if not list_versions: return sorted(list(model_id_version_dict.keys())) diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index d7c4eb4921..1ea08724b9 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -7577,6 +7577,20 @@ "spec_key": "proprietary-models/ai21-paraphrase/proprietary_specs_1.0.005.json", "search_keywords": ["Text2Text", "Generation"], }, + { + "model_id": "ai21-paraphrase", + "version": "v1.00-rc2-not-valid-version", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-paraphrase/proprietary_specs_1.0.005.json", + "search_keywords": ["Text2Text", "Generation"], + }, + { + "model_id": "nc-soft-model-1", + "version": "v3.0-not-valid-version!", + "min_version": "2.0.0", + "spec_key": "proprietary-models/nc-soft-model-1/proprietary_specs_1.0.005.json", + "search_keywords": ["Text2Text", "Generation"], + }, ] BASE_PROPRIETARY_SPEC = { diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 059cd7ccad..862d2b4174 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -25,6 +25,7 @@ list_jumpstart_models, list_jumpstart_scripts, list_jumpstart_tasks, + _is_valid_version, ) @@ -185,6 +186,13 @@ def test_list_jumpstart_frameworks( patched_get_model_specs.assert_not_called() +def test_is_valid_version(): + valid_version_strs = ["1.0", "1.0.0", "2012.4", "1!1.0", "1.dev0", "1.2.3+abc.dev1"] + invalid_version_strs = ["1.1.053_m", "invalid version", "v1-1.0-v2", "@"] + assert all(_is_valid_version(v) for v in valid_version_strs) + assert not any(_is_valid_version(v) for v in invalid_version_strs) + + class ListJumpStartModels(TestCase): @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -626,6 +634,7 @@ def test_list_jumpstart_proprietary_models( "ai21-paraphrase", "ai21-summarization", "lighton-mini-instruct40b", + "nc-soft-model-1", ] all_open_weight_model_ids = [