Skip to content

Commit

Permalink
feat: private util for model eula key
Browse files Browse the repository at this point in the history
  • Loading branch information
evakravi committed Dec 29, 2023
1 parent 7f7fa94 commit e2daefc
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
26 changes: 26 additions & 0 deletions src/sagemaker/jumpstart/notebook_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,3 +469,29 @@ def get_model_url(
s3_client=sagemaker_session.s3_client,
)
return model_specs.url


def _get_model_eula_key(
model_id: str,
model_version: str,
region: str = JUMPSTART_DEFAULT_REGION_NAME,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> str:
"""Retrieve S3 key for EULA text for gated models, or None for non-gated models.
Args:
model_id (str): The model ID for which to retrieve the EULA S3 key.
model_version (str): The model version for which to retrieve the EULA S3 key.
region (str): Optional. The region from which to retrieve metadata.
(Default: JUMPSTART_DEFAULT_REGION_NAME)
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use
to retrieve the EULA S3 key.
"""

model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
region=region,
model_id=model_id,
version=model_version,
s3_client=sagemaker_session.s3_client,
)
return model_specs.hosting_eula_key
34 changes: 34 additions & 0 deletions tests/unit/sagemaker/jumpstart/test_notebook_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
get_header_from_base_header,
get_prototype_manifest,
get_prototype_model_spec,
get_special_model_spec,
)
from sagemaker.jumpstart.notebook_utils import (
_generate_jumpstart_model_versions,
_get_model_eula_key,
get_model_url,
list_jumpstart_frameworks,
list_jumpstart_models,
Expand Down Expand Up @@ -698,3 +700,35 @@ def test_get_model_url(
region=region,
s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client,
)


@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
def test__get_model_eula_key(
patched_get_model_specs: Mock,
):

patched_get_model_specs.side_effect = get_special_model_spec

model_id, version = "gated_llama_neuron_model", "*"
assert "fmhMetadata/eula/llamaEula.txt" == _get_model_eula_key(model_id, version)

model_id, version = "variant-model", "1.0.0"
assert None == _get_model_eula_key(model_id, version)

region = "fake-region"

patched_get_model_specs.reset_mock()
patched_get_model_specs.side_effect = lambda *largs, **kwargs: get_special_model_spec(
*largs,
region="us-west-2",
**{key: value for key, value in kwargs.items() if key != "region"},
)

_get_model_eula_key(model_id, version, region=region)

patched_get_model_specs.assert_called_once_with(
model_id=model_id,
version=version,
region=region,
s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client,
)

0 comments on commit e2daefc

Please sign in to comment.