diff --git a/eppo_client/client.py b/eppo_client/client.py index c1cde48..f266641 100644 --- a/eppo_client/client.py +++ b/eppo_client/client.py @@ -256,18 +256,20 @@ def get_bandit_action( Args: flag_key (str): The feature flag key that contains the bandit as one of the variations. subject_key (str): The key identifying the subject. - subject_context (ActionContexts | ActionAttributes): The subject context. + subject_context (Union[ContextAttributes, Attributes]): The subject context. If supplying an ActionAttributes, it gets converted to an ActionContexts instance - actions (ActionContexts | ActionAttributes): The dictionary that maps action keys + actions (Union[ActionContexts, ActionAttributes]): The dictionary that maps action keys to their context of actions with their contexts. If supplying an ActionAttributes, it gets converted to an ActionContexts instance. - default (str): The default variation to use if the subject is not part of the bandit. + default (str): The default variation to use if an error is encountered retrieving the + assigned variation. Returns: BanditResult: The result containing either the bandit action if the subject is part of the bandit, or the assignment if they are not. The BanditResult includes: - variation (str): The assignment key indicating the subject's variation. - - action (str): The key of the selected action if the subject is part of the bandit. + - action (Optional[str]): The key of the selected action if the subject was assigned one + by the bandit. Example: result = client.get_bandit_action( @@ -286,102 +288,116 @@ def get_bandit_action( }, "default" ) - if result.action is None: - do_variation(result.variation) + if result.action: + do_action(result.variation) else: - do_action(result.action) + do_status_quo() """ + variation = default + action = None try: - return self.get_bandit_action_detail( + subject_attributes = convert_context_attributes_to_attributes( + subject_context + ) + + # first, get experiment assignment + variation = self.get_string_assignment( flag_key, subject_key, - subject_context, - actions, + subject_attributes, default, ) + + if variation in self.get_bandit_keys(): + # next, if assigned a bandit, get the selected action + action = self.evaluate_bandit_action( + flag_key, + variation, # for now, we assume the variation value is always equal to the bandit key + subject_key, + subject_context, + actions, + ) except Exception as e: if self.__is_graceful_mode: logger.error("[Eppo SDK] Error getting bandit action: " + str(e)) - return BanditResult(default, None) - raise e + else: + raise e + + return BanditResult(variation, action) - def get_bandit_action_detail( + def evaluate_bandit_action( self, flag_key: str, + bandit_key: str, subject_key: str, subject_context: Union[ContextAttributes, Attributes], actions: Union[ActionContexts, ActionAttributes], - default: str, - ) -> BanditResult: - subject_attributes = convert_subject_context_to_attributes(subject_context) - action_contexts = convert_actions_to_action_contexts(actions) + ) -> Union[str, None]: + # if no actions are given--a valid use case--return the variation with no action + if len(actions) == 0: + return None - # get experiment assignment - # ignoring type because Dict[str, str] satisfies Dict[str, str | ...] but mypy does not understand - variation = self.get_string_assignment( - flag_key, - subject_key, - subject_attributes.categorical_attributes - | subject_attributes.numeric_attributes, # type: ignore - default, - ) - - # if the variation is not the bandit key, then the subject is not allocated in the bandit - if variation not in self.get_bandit_keys(): - return BanditResult(variation, None) - - # for now, assume that the variation is equal to the bandit key - bandit_data = self.__config_requestor.get_bandit_model(variation) + bandit_data = self.__config_requestor.get_bandit_model(bandit_key) if not bandit_data: logger.warning( f"[Eppo SDK] No assigned action. Bandit not found for flag: {flag_key}" ) - return BanditResult(variation, None) + return None + + subject_context_attributes = convert_attributes_to_context_attributes( + subject_context + ) + action_contexts = convert_actions_to_action_contexts(actions) evaluation = self.__bandit_evaluator.evaluate_bandit( flag_key, subject_key, - subject_attributes, + subject_context_attributes, action_contexts, bandit_data.bandit_model_data, ) # log bandit action - bandit_event = { - "flagKey": flag_key, - "banditKey": bandit_data.bandit_key, - "subject": subject_key, - "action": evaluation.action_key if evaluation else None, - "actionProbability": evaluation.action_weight if evaluation else None, - "optimalityGap": evaluation.optimality_gap if evaluation else None, - "modelVersion": bandit_data.bandit_model_version if evaluation else None, - "timestamp": datetime.datetime.utcnow().isoformat(), - "subjectNumericAttributes": ( - subject_attributes.numeric_attributes - if evaluation.subject_attributes - else {} - ), - "subjectCategoricalAttributes": ( - subject_attributes.categorical_attributes - if evaluation.subject_attributes - else {} - ), - "actionNumericAttributes": ( - evaluation.action_attributes.numeric_attributes - if evaluation.action_attributes - else {} - ), - "actionCategoricalAttributes": ( - evaluation.action_attributes.categorical_attributes - if evaluation.action_attributes - else {} - ), - "metaData": {"sdkLanguage": "python", "sdkVersion": __version__}, - } - self.__assignment_logger.log_bandit_action(bandit_event) + try: + bandit_event = { + "flagKey": flag_key, + "banditKey": bandit_data.bandit_key, + "subject": subject_key, + "action": evaluation.action_key if evaluation else None, + "actionProbability": evaluation.action_weight if evaluation else None, + "optimalityGap": evaluation.optimality_gap if evaluation else None, + "modelVersion": ( + bandit_data.bandit_model_version if evaluation else None + ), + "timestamp": datetime.datetime.utcnow().isoformat(), + "subjectNumericAttributes": ( + subject_context_attributes.numeric_attributes + if evaluation.subject_attributes + else {} + ), + "subjectCategoricalAttributes": ( + subject_context_attributes.categorical_attributes + if evaluation.subject_attributes + else {} + ), + "actionNumericAttributes": ( + evaluation.action_attributes.numeric_attributes + if evaluation.action_attributes + else {} + ), + "actionCategoricalAttributes": ( + evaluation.action_attributes.categorical_attributes + if evaluation.action_attributes + else {} + ), + "metaData": {"sdkLanguage": "python", "sdkVersion": __version__}, + } + self.__assignment_logger.log_bandit_action(bandit_event) + except Exception as e: + logger.warn("[Eppo SDK] Error logging bandit event: " + str(e)) - return BanditResult(variation, evaluation.action_key if evaluation else None) + return evaluation.action_key def get_flag_keys(self): """ @@ -406,6 +422,9 @@ def get_bandit_keys(self): """ return self.__config_requestor.get_bandit_keys() + def set_is_graceful_mode(self, is_graceful_mode: bool): + self.__is_graceful_mode = is_graceful_mode + def is_initialized(self): """ Returns True if the client has successfully initialized @@ -443,18 +462,33 @@ def check_value_type_match( return False -def convert_subject_context_to_attributes( +def convert_context_attributes_to_attributes( + subject_context: Union[ContextAttributes, Attributes] +) -> Attributes: + if isinstance(subject_context, dict): + return subject_context + + # ignoring type because Dict[str, str] satisfies Dict[str, str | ...] but mypy does not understand + return subject_context.numeric_attributes | subject_context.categorical_attributes # type: ignore + + +def convert_attributes_to_context_attributes( subject_context: Union[ContextAttributes, Attributes] ) -> ContextAttributes: if isinstance(subject_context, dict): return ContextAttributes.from_dict(subject_context) - return subject_context + + stringified_categorical_attributes = { + key: str(value) for key, value in subject_context.categorical_attributes.items() + } + + return ContextAttributes( + numeric_attributes=subject_context.numeric_attributes, + categorical_attributes=stringified_categorical_attributes, + ) def convert_actions_to_action_contexts( actions: Union[ActionContexts, ActionAttributes] ) -> ActionContexts: - return { - k: ContextAttributes.from_dict(v) if isinstance(v, dict) else v - for k, v in actions.items() - } + return {k: convert_attributes_to_context_attributes(v) for k, v in actions.items()} diff --git a/eppo_client/version.py b/eppo_client/version.py index 91b5628..99950aa 100644 --- a/eppo_client/version.py +++ b/eppo_client/version.py @@ -1,4 +1,4 @@ # Note to developers: When ready to bump to 4.0, please change # the `POLL_INTERVAL_SECONDS` constant in `eppo_client/constants.py` # to 30 seconds to match the behavior of the other server SDKs. -__version__ = "3.5.0" +__version__ = "3.5.1" diff --git a/test/client_bandit_test.py b/test/client_bandit_test.py index 64cb7c6..0674f96 100644 --- a/test/client_bandit_test.py +++ b/test/client_bandit_test.py @@ -6,7 +6,8 @@ import os from time import sleep from typing import Dict, List -from eppo_client.bandit import BanditResult, ContextAttributes +from unittest.mock import patch +from eppo_client.bandit import BanditEvaluator, BanditResult, ContextAttributes import httpretty # type: ignore import pytest @@ -81,15 +82,24 @@ def init_fixture(): httpretty.reset() +@pytest.fixture(autouse=True) +def clear_event_arrays(): + # Reset graceful mode to off + get_instance().set_is_graceful_mode(False) + # Clear captured logger events + mock_assignment_logger.assignment_events.clear() + mock_assignment_logger.bandit_events.clear() + + def test_is_initialized(): client = get_instance() assert client.is_initialized(), "Client should be initialized" -def test_get_bandit_action_bandit_does_not_exist(): +def test_get_bandit_action_flag_not_exist(): client = get_instance() result = client.get_bandit_action( - "nonexistent_bandit", + "nonexistent_flag", "subject_key", DEFAULT_SUBJECT_ATTRIBUTES, {}, @@ -98,12 +108,52 @@ def test_get_bandit_action_bandit_does_not_exist(): assert result == BanditResult("default_variation", None) -def test_get_bandit_action_flag_without_bandit(): +def test_get_bandit_action_flag_has_no_bandit(): client = get_instance() result = client.get_bandit_action( - "a_flag", "subject_key", DEFAULT_SUBJECT_ATTRIBUTES, {}, "default_variation" + "non_bandit_flag", + "subject_key", + DEFAULT_SUBJECT_ATTRIBUTES, + {}, + "default_variation", ) - assert result == BanditResult("default_variation", None) + assert result == BanditResult("control", None) + + +@patch.object( + BanditEvaluator, "evaluate_bandit", side_effect=Exception("Mocked Exception") +) +def test_get_bandit_action_bandit_error(mock_bandit_evaluator): + client = get_instance() + client.set_is_graceful_mode(True) + actions = { + "adidas": ContextAttributes( + numeric_attributes={"discount": 0.1}, + categorical_attributes={"from": "germany"}, + ), + "nike": ContextAttributes( + numeric_attributes={"discount": 0.2}, categorical_attributes={"from": "usa"} + ), + } + + result = client.get_bandit_action( + "banner_bandit_flag_uk_only", + "alice", + DEFAULT_SUBJECT_ATTRIBUTES, + actions, + "default_variation", + ) + assert result.variation == "banner_bandit" + assert result.action is None + + # testing assignment logger + assignment_log_statement = mock_assignment_logger.assignment_events[-1] + assert assignment_log_statement["featureFlag"] == "banner_bandit_flag_uk_only" + assert assignment_log_statement["variation"] == "banner_bandit" + assert assignment_log_statement["subject"] == "alice" + + # testing bandit logger + assert len(mock_assignment_logger.bandit_events) == 0 def test_get_bandit_action_with_subject_attributes(): @@ -163,6 +213,35 @@ def test_get_bandit_action_with_subject_attributes(): ) +@patch.object( + MockAssignmentLogger, "log_bandit_action", side_effect=Exception("Mocked Exception") +) +def test_get_bandit_action_bandit_logger_error(patched_mock_assignment_logger): + client = get_instance() + actions = { + "adidas": ContextAttributes( + numeric_attributes={"discount": 0.1}, + categorical_attributes={"from": "germany"}, + ), + "nike": ContextAttributes( + numeric_attributes={"discount": 0.2}, categorical_attributes={"from": "usa"} + ), + } + result = client.get_bandit_action( + "banner_bandit_flag_uk_only", + "alice", + DEFAULT_SUBJECT_ATTRIBUTES, + actions, + "default_variation", + ) + assert result.variation == "banner_bandit" + assert result.action in ["adidas", "nike"] + + # assignment should have still been logged + assert len(mock_assignment_logger.assignment_events) == 1 + assert len(mock_assignment_logger.bandit_events) == 0 + + @pytest.mark.parametrize("test_case", test_data) def test_bandit_generic_test_cases(test_case): client = get_instance()