Skip to content

Commit

Permalink
Refactored method
Browse files Browse the repository at this point in the history
  • Loading branch information
dewan-c committed Mar 25, 2024
1 parent 3065bfa commit 3a951e4
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 13 deletions.
11 changes: 6 additions & 5 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1701,6 +1701,8 @@ class HubModelDocument(JumpStartDataHolderType):

_non_serializable_slots = ["_region"]

TASK_REGEX_IN_STUDIO_DESCRIPTION = r"\| Task: \| (.+?)\|"

def __init__(
self,
region: str,
Expand Down Expand Up @@ -1917,9 +1919,9 @@ def from_manifest(self, studio_manifest_entry: Dict[str, Any]):
self.datatype = studio_manifest_entry["dataType"]
if studio_manifest_entry.get("license"):
self.license = studio_manifest_entry["license"]

Check warning on line 1921 in src/sagemaker/jumpstart/types.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/jumpstart/types.py#L1916-L1921

Added lines #L1916 - L1921 were not covered by tests
task_pattern = r"\| Task: \| (.+?)\|"

task_value = self._extract_task_value(

Check warning on line 1923 in src/sagemaker/jumpstart/types.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/jumpstart/types.py#L1923

Added line #L1923 was not covered by tests
studio_manifest_entry.get("description"), task_pattern
studio_manifest_entry.get("description")
)
if task_value:
self.task = task_value

Check warning on line 1927 in src/sagemaker/jumpstart/types.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/jumpstart/types.py#L1926-L1927

Added lines #L1926 - L1927 were not covered by tests
Expand Down Expand Up @@ -2090,16 +2092,15 @@ def from_specs(self, model_specs: JumpStartModelSpecs, studio_specs: Dict[str, A
"disable_output_compression"
)

def _extract_task_value(self, input_string: Optional[str], pattern: str) -> Optional[str]:
def _extract_task_value(self, input_string: Optional[str]) -> Optional[str]:
"""Returns value of Task field from Studio manifest's description field.
Args:
input_string (Optional[str]): The value of description field.
pattern (str): The regex pattern to use for searching.
"""
if not input_string:
return None
match = re.search(pattern, input_string)
match = re.search(self.TASK_REGEX_IN_STUDIO_DESCRIPTION, input_string)
if match:
return match.group(1).strip()
else:
Expand Down
12 changes: 4 additions & 8 deletions tests/unit/sagemaker/jumpstart/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,9 +967,8 @@ def test_extract_task_value_with_match():
json_obj=HUB_MODEL_DOCUMENT_DICTS["huggingface-llm-gemma-2b-instruct"], region=region
)
input_string = "| | |\n|---|---|\n||\n| Task: | Text to image|\n| Fine-tunable: | No|\n| Source: | Stability AI"
pattern = r"\| Task: \| (.+?)\|"
expected_output = "Text to image"
assert gemma_model_document._extract_task_value(input_string, pattern) == expected_output
assert gemma_model_document._extract_task_value(input_string) == expected_output


def test_extract_task_value_without_match():
Expand All @@ -978,8 +977,7 @@ def test_extract_task_value_without_match():
json_obj=HUB_MODEL_DOCUMENT_DICTS["huggingface-llm-gemma-2b-instruct"], region=region
)
input_string = "| | |\n|---|---|\n||\n| Fine-tunable: | No|\n| Source: | Stability AI"
pattern = r"\| Task: \| (.+?)\|"
assert gemma_model_document._extract_task_value(input_string, pattern) is None
assert gemma_model_document._extract_task_value(input_string) is None


def test_extract_task_value_with_none_input():
Expand All @@ -988,8 +986,7 @@ def test_extract_task_value_with_none_input():
json_obj=HUB_MODEL_DOCUMENT_DICTS["huggingface-llm-gemma-2b-instruct"], region=region
)
input_string = None
pattern = r"\| Task: \| (.+?)\|"
assert gemma_model_document._extract_task_value(input_string, pattern) is None
assert gemma_model_document._extract_task_value(input_string) is None


def test_extract_task_value_with_empty_string():
Expand All @@ -998,8 +995,7 @@ def test_extract_task_value_with_empty_string():
json_obj=HUB_MODEL_DOCUMENT_DICTS["huggingface-llm-gemma-2b-instruct"], region=region
)
input_string = ""
pattern = r"\| Task: \| (.+?)\|"
assert gemma_model_document._extract_task_value(input_string, pattern) is None
assert gemma_model_document._extract_task_value(input_string) is None


def test_hub_content_document_from_json_obj():
Expand Down

0 comments on commit 3a951e4

Please sign in to comment.