Skip to content
This repository has been archived by the owner on Nov 8, 2024. It is now read-only.

Commit

Permalink
Address Aaron's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
schmit committed May 30, 2024
1 parent 741fe31 commit 53d6a5b
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 15 deletions.
27 changes: 21 additions & 6 deletions eppo_client/bandit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from dataclasses import dataclass
import logging
from typing import Dict, List, Optional, Tuple

from eppo_client.models import (
BanditCategoricalAttributeCoefficient,
BanditCoefficients,
Expand All @@ -9,6 +11,13 @@
from eppo_client.sharders import Sharder


logger = logging.getLogger(__name__)


class BanditEvaluationError(Exception):
pass


@dataclass
class Attributes:
numeric_attributes: Dict[str, float]
Expand Down Expand Up @@ -72,6 +81,9 @@ class BanditResult:
variation: str
action: Optional[str]

def to_string(self) -> str:
return coalesce(self.action, self.variation)


def null_evaluation(
flag_key: str, subject_key: str, subject_attributes: Attributes, gamma: float
Expand Down Expand Up @@ -176,16 +188,19 @@ def weigh_actions(
]

# remaining weight goes to best action
remaining_weight = 1.0 - sum(weight for _, weight in weights)
remaining_weight = max(0.0, 1.0 - sum(weight for _, weight in weights))
weights.append((best_action, remaining_weight))
return weights

def select_action(self, flag_key, subject_key, action_weights) -> Tuple[int, str]:
# deterministic ordering
sorted_action_weights = sorted(
action_weights,
key=lambda t: self.sharder.get_shard(
f"{flag_key}-{subject_key}-{t[0]}", self.total_shards
key=lambda t: (
self.sharder.get_shard(
f"{flag_key}-{subject_key}-{t[0]}", self.total_shards
),
t[0], # tie-break using action name
),
)

Expand All @@ -200,9 +215,9 @@ def select_action(self, flag_key, subject_key, action_weights) -> Tuple[int, str
return idx, action_key

# If no action is selected, return the last action (fallback)
action_index = len(sorted_action_weights) - 1
action_key = sorted_action_weights[action_index][0]
return action_index, action_key
raise BanditEvaluationError(
f"[Eppo SDK] No action selected for {flag_key} {subject_key}"
)


def score_action(
Expand Down
30 changes: 26 additions & 4 deletions eppo_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,28 @@ def get_bandit_action(
- variation (str): The assignment key indicating the subject's variation.
- action (str): The key of the selected action if the subject is part of the bandit.
"""
try:
return self.get_bandit_action_detail(
flag_key,
subject_key,
subject_attributes,
actions_with_contexts,
default,
)
except Exception as e:
if self.__is_graceful_mode:
logger.error("[Eppo SDK] Error getting bandit action: " + str(e))
return BanditResult(default, None)
raise e

def get_bandit_action_detail(
self,
flag_key: str,
subject_key: str,
subject_attributes: Attributes,
actions_with_contexts: List[ActionContext],
default: str,
) -> BanditResult:
# get experiment assignment
# ignoring type because Dict[str, str] satisfies Dict[str, str | ...] but mypy does not understand
variation = self.get_string_assignment(
Expand Down Expand Up @@ -292,22 +314,22 @@ def get_bandit_action(
"subjectNumericAttributes": (
subject_attributes.numeric_attributes
if evaluation.subject_attributes
else None
else {}
),
"subjectCategoricalAttributes": (
subject_attributes.categorical_attributes
if evaluation.subject_attributes
else None
else {}
),
"actionNumericAttributes": (
evaluation.action_attributes.numeric_attributes
if evaluation.action_attributes
else None
else {}
),
"actionCategoricalAttributes": (
evaluation.action_attributes.categorical_attributes
if evaluation.action_attributes
else None
else {}
),
"metaData": {"sdkLanguage": "python", "sdkVersion": __version__},
}
Expand Down
1 change: 0 additions & 1 deletion eppo_client/configuration_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def store_bandits(self, bandit_data) -> Dict[str, BanditData]:
config["banditKey"]: BanditData(**config)
for config in cast(dict, bandit_data.get("bandits", []))
}
print(bandit_configs)
self.__bandit_config_store.set_configurations(bandit_configs)
return bandit_configs

Expand Down
4 changes: 0 additions & 4 deletions test/client_bandit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
test_case_dict = json.load(test_case_json)
test_data.append(test_case_dict)

print(test_data)

MOCK_BASE_URL = "http://localhost:4001/api"

Expand Down Expand Up @@ -69,8 +68,6 @@ def init_fixture():
)
)
sleep(0.1) # wait for initialization
print(client.get_flag_keys())
print(client.get_bandit_keys())
yield
client._shutdown()
httpretty.disable()
Expand All @@ -91,7 +88,6 @@ def test_get_bandit_action_bandit_does_not_exist():
[],
"default_variation",
)
print(result)
assert result == BanditResult("default_variation", None)


Expand Down

0 comments on commit 53d6a5b

Please sign in to comment.