Skip to content

Commit

Permalink
feat: parallelize notebook search utils, add new operators (#4342)
Browse files Browse the repository at this point in the history
* feat: parallelize notebook search utils

* chore: raise exception in notebook utils if thread has error

* chore: improve variable name

* fix: not passing region to get jumpstart bucket

* chore: add sagemaker session to notebook utils

* chore: address PR comments

* feat: add support for includes, begins with, ends with

* fix: pylint

* feat: private util for model eula key

* fix: unit tests, use verify_model_region_and_return_specs in notebook utils

* Revert "feat: private util for model eula key"

This reverts commit e2daefc.

* chore: add search keywords to header
  • Loading branch information
evakravi authored Jan 10, 2024
1 parent ae50026 commit 80b3e08
Show file tree
Hide file tree
Showing 5 changed files with 464 additions and 264 deletions.
164 changes: 132 additions & 32 deletions src/sagemaker/jumpstart/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import absolute_import
from ast import literal_eval
from enum import Enum
from typing import Dict, List, Union, Any
from typing import Dict, List, Optional, Union, Any

from sagemaker.jumpstart.types import JumpStartDataHolderType

Expand All @@ -38,6 +38,10 @@ class FilterOperators(str, Enum):
NOT_EQUALS = "not_equals"
IN = "in"
NOT_IN = "not_in"
INCLUDES = "includes"
NOT_INCLUDES = "not_includes"
BEGINS_WITH = "begins_with"
ENDS_WITH = "ends_with"


class SpecialSupportedFilterKeys(str, Enum):
Expand All @@ -52,6 +56,10 @@ class SpecialSupportedFilterKeys(str, Enum):
FilterOperators.NOT_EQUALS: ["!==", "!=", "not equals", "is not"],
FilterOperators.IN: ["in"],
FilterOperators.NOT_IN: ["not in"],
FilterOperators.INCLUDES: ["includes", "contains"],
FilterOperators.NOT_INCLUDES: ["not includes", "not contains"],
FilterOperators.BEGINS_WITH: ["begins with", "starts with"],
FilterOperators.ENDS_WITH: ["ends with"],
}


Expand All @@ -62,7 +70,19 @@ class SpecialSupportedFilterKeys(str, Enum):
)

ACCEPTABLE_OPERATORS_IN_PARSE_ORDER = (
list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_EQUALS]))
list(
map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.BEGINS_WITH])
)
+ list(
map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.ENDS_WITH])
)
+ list(
map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_INCLUDES])
)
+ list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.INCLUDES]))
+ list(
map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_EQUALS])
)
+ list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_IN]))
+ list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.EQUALS]))
+ list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.IN]))
Expand Down Expand Up @@ -428,9 +448,96 @@ def parse_filter_string(filter_string: str) -> ModelFilter:
raise ValueError(f"Cannot parse filter string: {filter_string}")


def _negate_boolean(boolean: BooleanValues) -> BooleanValues:
"""Negates boolean expression (False -> True, True -> False)."""
if boolean == BooleanValues.TRUE:
return BooleanValues.FALSE
if boolean == BooleanValues.FALSE:
return BooleanValues.TRUE
return boolean


def _evaluate_filter_expression_equals(
model_filter: ModelFilter,
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
) -> BooleanValues:
"""Evaluates filter expressions for equals."""
if cached_model_value is None:
return BooleanValues.FALSE
model_filter_value = model_filter.value
if isinstance(cached_model_value, bool):
cached_model_value = str(cached_model_value).lower()
model_filter_value = model_filter.value.lower()
if str(model_filter_value) == str(cached_model_value):
return BooleanValues.TRUE
return BooleanValues.FALSE


def _evaluate_filter_expression_in(
model_filter: ModelFilter,
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
) -> BooleanValues:
"""Evaluates filter expressions for string/list in."""
if cached_model_value is None:
return BooleanValues.FALSE
py_obj = model_filter.value
try:
py_obj = literal_eval(py_obj)
try:
iter(py_obj)
except TypeError:
return BooleanValues.FALSE
except Exception: # pylint: disable=W0703
pass
if isinstance(cached_model_value, list):
return BooleanValues.FALSE
if cached_model_value in py_obj:
return BooleanValues.TRUE
return BooleanValues.FALSE


def _evaluate_filter_expression_includes(
model_filter: ModelFilter,
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
) -> BooleanValues:
"""Evaluates filter expressions for string includes."""
if cached_model_value is None:
return BooleanValues.FALSE
filter_value = str(model_filter.value)
if filter_value in cached_model_value:
return BooleanValues.TRUE
return BooleanValues.FALSE


def _evaluate_filter_expression_begins_with(
model_filter: ModelFilter,
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
) -> BooleanValues:
"""Evaluates filter expressions for string begins with."""
if cached_model_value is None:
return BooleanValues.FALSE
filter_value = str(model_filter.value)
if cached_model_value.startswith(filter_value):
return BooleanValues.TRUE
return BooleanValues.FALSE


def _evaluate_filter_expression_ends_with(
model_filter: ModelFilter,
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
) -> BooleanValues:
"""Evaluates filter expressions for string ends with."""
if cached_model_value is None:
return BooleanValues.FALSE
filter_value = str(model_filter.value)
if cached_model_value.endswith(filter_value):
return BooleanValues.TRUE
return BooleanValues.FALSE


def evaluate_filter_expression( # pylint: disable=too-many-return-statements
model_filter: ModelFilter,
cached_model_value: Union[str, bool, int, float, Dict[str, Any], List[Any]],
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
) -> BooleanValues:
"""Evaluates model filter with cached model spec value, returns boolean.
Expand All @@ -440,36 +547,29 @@ def evaluate_filter_expression( # pylint: disable=too-many-return-statements
evaluate the filter.
"""
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.EQUALS]:
model_filter_value = model_filter.value
if isinstance(cached_model_value, bool):
cached_model_value = str(cached_model_value).lower()
model_filter_value = model_filter.value.lower()
if str(model_filter_value) == str(cached_model_value):
return BooleanValues.TRUE
return BooleanValues.FALSE
return _evaluate_filter_expression_equals(model_filter, cached_model_value)

if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_EQUALS]:
if isinstance(cached_model_value, bool):
cached_model_value = str(cached_model_value).lower()
model_filter.value = model_filter.value.lower()
if str(model_filter.value) == str(cached_model_value):
return BooleanValues.FALSE
return BooleanValues.TRUE
return _negate_boolean(_evaluate_filter_expression_equals(model_filter, cached_model_value))

if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.IN]:
py_obj = literal_eval(model_filter.value)
try:
iter(py_obj)
except TypeError:
return BooleanValues.FALSE
if cached_model_value in py_obj:
return BooleanValues.TRUE
return BooleanValues.FALSE
return _evaluate_filter_expression_in(model_filter, cached_model_value)

if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_IN]:
py_obj = literal_eval(model_filter.value)
try:
iter(py_obj)
except TypeError:
return BooleanValues.TRUE
if cached_model_value in py_obj:
return BooleanValues.FALSE
return BooleanValues.TRUE
return _negate_boolean(_evaluate_filter_expression_in(model_filter, cached_model_value))

if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.INCLUDES]:
return _evaluate_filter_expression_includes(model_filter, cached_model_value)

if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_INCLUDES]:
return _negate_boolean(
_evaluate_filter_expression_includes(model_filter, cached_model_value)
)

if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.BEGINS_WITH]:
return _evaluate_filter_expression_begins_with(model_filter, cached_model_value)

if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.ENDS_WITH]:
return _evaluate_filter_expression_ends_with(model_filter, cached_model_value)

raise RuntimeError(f"Bad operator: {model_filter.operator}")
Loading

0 comments on commit 80b3e08

Please sign in to comment.