Skip to content
This repository has been archived by the owner on Nov 8, 2024. It is now read-only.

Commit

Permalink
[bandits] rename Attributes and ContextAttributes (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
schmit authored Jun 14, 2024
1 parent 00fb6ae commit 102bd6f
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 68 deletions.
33 changes: 17 additions & 16 deletions eppo_client/bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from eppo_client.rules import to_string
from eppo_client.sharders import Sharder
from eppo_client.types import AttributesDict
from eppo_client.types import Attributes


logger = logging.getLogger(__name__)
Expand All @@ -21,7 +21,7 @@ class BanditEvaluationError(Exception):


@dataclass
class Attributes:
class ContextAttributes:
numeric_attributes: Dict[str, float]
categorical_attributes: Dict[str, str]

Expand All @@ -31,22 +31,23 @@ def empty(cls):
Create an empty Attributes instance with no numeric or categorical attributes.
Returns:
Attributes: An instance of the Attributes class with empty dictionaries
ContextAttributes: An instance of the ContextAttributes class with empty dictionaries
for numeric and categorical attributes.
"""
return cls({}, {})

@classmethod
def from_dict(cls, attributes: AttributesDict):
def from_dict(cls, attributes: Attributes):
"""
Create an Attributes instance from a dictionary of attributes.
Create an ContextAttributes instance from a dictionary of attributes.
Args:
attributes (Dict[str, Union[float, int, bool, str]]): A dictionary where keys are attribute names
and values are attribute values which can be of type float, int, bool, or str.
and values are attribute values which can be of type float, int, bool, or str.
Returns:
Attributes: An instance of the Attributes class with numeric and categorical attributes separated.
ContextAttributes: An instance of the ContextAttributes class
with numeric and categorical attributes separated.
"""
numeric_attributes = {
key: float(value)
Expand All @@ -61,17 +62,17 @@ def from_dict(cls, attributes: AttributesDict):
return cls(numeric_attributes, categorical_attributes)


ActionContexts = Dict[str, Attributes]
ActionContextsDict = Dict[str, AttributesDict]
ActionContexts = Dict[str, ContextAttributes]
ActionAttributes = Dict[str, Attributes]


@dataclass
class BanditEvaluation:
flag_key: str
subject_key: str
subject_attributes: Attributes
subject_attributes: ContextAttributes
action_key: Optional[str]
action_attributes: Optional[Attributes]
action_attributes: Optional[ContextAttributes]
action_score: float
action_weight: float
gamma: float
Expand All @@ -88,7 +89,7 @@ def to_string(self) -> str:


def null_evaluation(
flag_key: str, subject_key: str, subject_attributes: Attributes, gamma: float
flag_key: str, subject_key: str, subject_attributes: ContextAttributes, gamma: float
):
return BanditEvaluation(
flag_key, subject_key, subject_attributes, None, None, 0.0, 0.0, gamma, 0.0
Expand All @@ -104,7 +105,7 @@ def evaluate_bandit(
self,
flag_key: str,
subject_key: str,
subject_attributes: Attributes,
subject_attributes: ContextAttributes,
actions: ActionContexts,
bandit_model: BanditModelData,
) -> BanditEvaluation:
Expand Down Expand Up @@ -138,7 +139,7 @@ def evaluate_bandit(

def score_actions(
self,
subject_attributes: Attributes,
subject_attributes: ContextAttributes,
actions: ActionContexts,
bandit_model: BanditModelData,
) -> Dict[str, float]:
Expand Down Expand Up @@ -209,8 +210,8 @@ def select_action(self, flag_key, subject_key, action_weights) -> str:


def score_action(
subject_attributes: Attributes,
action_attributes: Attributes,
subject_attributes: ContextAttributes,
action_attributes: ContextAttributes,
coefficients: BanditCoefficients,
) -> float:
score = coefficients.intercept
Expand Down
55 changes: 29 additions & 26 deletions eppo_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from typing import Any, Dict, Optional, Union
from eppo_client.assignment_logger import AssignmentLogger
from eppo_client.bandit import (
ActionContextsDict,
ActionAttributes,
BanditEvaluator,
BanditResult,
Attributes,
ContextAttributes,
ActionContexts,
)
from eppo_client.configuration_requestor import (
Expand All @@ -17,7 +17,7 @@
from eppo_client.models import VariationType
from eppo_client.poller import Poller
from eppo_client.sharders import MD5Sharder
from eppo_client.types import AttributesDict, ValueType
from eppo_client.types import Attributes, ValueType
from eppo_client.validation import validate_not_blank
from eppo_client.eval import FlagEvaluation, Evaluator, none_result
from eppo_client.version import __version__
Expand Down Expand Up @@ -49,7 +49,7 @@ def get_string_assignment(
self,
flag_key: str,
subject_key: str,
subject_attributes: AttributesDict,
subject_attributes: Attributes,
default: str,
) -> str:
return self.get_assignment_variation(
Expand All @@ -64,7 +64,7 @@ def get_integer_assignment(
self,
flag_key: str,
subject_key: str,
subject_attributes: AttributesDict,
subject_attributes: Attributes,
default: int,
) -> int:
return self.get_assignment_variation(
Expand All @@ -79,7 +79,7 @@ def get_numeric_assignment(
self,
flag_key: str,
subject_key: str,
subject_attributes: AttributesDict,
subject_attributes: Attributes,
default: float,
) -> float:
# convert to float in case we get an int
Expand All @@ -97,7 +97,7 @@ def get_boolean_assignment(
self,
flag_key: str,
subject_key: str,
subject_attributes: AttributesDict,
subject_attributes: Attributes,
default: bool,
) -> bool:
return self.get_assignment_variation(
Expand All @@ -112,7 +112,7 @@ def get_json_assignment(
self,
flag_key: str,
subject_key: str,
subject_attributes: AttributesDict,
subject_attributes: Attributes,
default: Dict[Any, Any],
) -> Dict[Any, Any]:
json_value = self.get_assignment_variation(
Expand All @@ -131,7 +131,7 @@ def get_assignment_variation(
self,
flag_key: str,
subject_key: str,
subject_attributes: AttributesDict,
subject_attributes: Attributes,
default: Optional[ValueType],
expected_variation_type: VariationType,
):
Expand All @@ -155,7 +155,7 @@ def get_assignment_detail(
self,
flag_key: str,
subject_key: str,
subject_attributes: AttributesDict,
subject_attributes: Attributes,
expected_variation_type: VariationType,
) -> FlagEvaluation:
"""Maps a subject to a variation for a given flag
Expand Down Expand Up @@ -231,8 +231,8 @@ def get_bandit_action(
self,
flag_key: str,
subject_key: str,
subject_context: Union[Attributes, AttributesDict],
actions: Union[ActionContexts, ActionContextsDict],
subject_context: Union[ContextAttributes, Attributes],
actions: Union[ActionContexts, ActionAttributes],
default: str,
) -> BanditResult:
"""
Expand All @@ -250,11 +250,11 @@ def get_bandit_action(
Args:
flag_key (str): The feature flag key that contains the bandit as one of the variations.
subject_key (str): The key identifying the subject.
subject_context (Attributes | AttributesDict): The subject context.
If supplying an AttributesDict, it gets converted to an Attributes instance
actions (ActionContexts | ActionContextsDict): The dictionary that maps action keys
subject_context (ActionContexts | ActionAttributes): The subject context.
If supplying an ActionAttributes, it gets converted to an ActionContexts instance
actions (ActionContexts | ActionAttributes): The dictionary that maps action keys
to their context of actions with their contexts.
If supplying an AttributesDict, it gets converted to an Attributes instance.
If supplying an ActionAttributes, it gets converted to an ActionContexts instance.
default (str): The default variation to use if the subject is not part of the bandit.
Returns:
Expand All @@ -267,13 +267,16 @@ def get_bandit_action(
result = client.get_bandit_action(
"flag_key",
"subject_key",
Attributes(
ContextAttributes(
numeric_attributes={"age": 25},
categorical_attributes={"country": "USA"}),
{
"action1": Attributes(numeric_attributes={"price": 10.0}, categorical_attributes={"category": "A"}),
"action1": ContextAttributes(
numeric_attributes={"price": 10.0},
categorical_attributes={"category": "A"}
),
"action2": {"price": 10.0, "category": "B"}
"action3": Attributes.empty(),
"action3": ContextAttributes.empty(),
},
"default"
)
Expand All @@ -300,8 +303,8 @@ def get_bandit_action_detail(
self,
flag_key: str,
subject_key: str,
subject_context: Union[Attributes, AttributesDict],
actions: Union[ActionContexts, ActionContextsDict],
subject_context: Union[ContextAttributes, Attributes],
actions: Union[ActionContexts, ActionAttributes],
default: str,
) -> BanditResult:
subject_attributes = convert_subject_context_to_attributes(subject_context)
Expand Down Expand Up @@ -428,17 +431,17 @@ def check_value_type_match(


def convert_subject_context_to_attributes(
subject_context: Union[Attributes, AttributesDict]
) -> Attributes:
subject_context: Union[ContextAttributes, Attributes]
) -> ContextAttributes:
if isinstance(subject_context, dict):
return Attributes.from_dict(subject_context)
return ContextAttributes.from_dict(subject_context)
return subject_context


def convert_actions_to_action_contexts(
actions: Union[ActionContexts, ActionContextsDict]
actions: Union[ActionContexts, ActionAttributes]
) -> ActionContexts:
return {
k: Attributes.from_dict(v) if isinstance(v, dict) else v
k: ContextAttributes.from_dict(v) if isinstance(v, dict) else v
for k, v in actions.items()
}
8 changes: 4 additions & 4 deletions eppo_client/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from dataclasses import dataclass
import datetime

from eppo_client.types import AttributesDict
from eppo_client.types import Attributes


@dataclass
class FlagEvaluation:
flag_key: str
variation_type: VariationType
subject_key: str
subject_attributes: AttributesDict
subject_attributes: Attributes
allocation_key: Optional[str]
variation: Optional[Variation]
extra_logging: Dict[str, str]
Expand All @@ -28,7 +28,7 @@ def evaluate_flag(
self,
flag: Flag,
subject_key: str,
subject_attributes: AttributesDict,
subject_attributes: Attributes,
) -> FlagEvaluation:
if not flag.enabled:
return none_result(
Expand Down Expand Up @@ -93,7 +93,7 @@ def none_result(
flag_key: str,
variation_type: VariationType,
subject_key: str,
subject_attributes: AttributesDict,
subject_attributes: Attributes,
) -> FlagEvaluation:
return FlagEvaluation(
flag_key=flag_key,
Expand Down
8 changes: 3 additions & 5 deletions eppo_client/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import semver

from eppo_client.models import SdkBaseModel
from eppo_client.types import AttributeType, ConditionValueType, AttributesDict
from eppo_client.types import AttributeType, ConditionValueType, Attributes


class OperatorType(Enum):
Expand All @@ -32,16 +32,14 @@ class Rule(SdkBaseModel):
conditions: List[Condition]


def matches_rule(rule: Rule, subject_attributes: AttributesDict) -> bool:
def matches_rule(rule: Rule, subject_attributes: Attributes) -> bool:
return all(
evaluate_condition(condition, subject_attributes)
for condition in rule.conditions
)


def evaluate_condition(
condition: Condition, subject_attributes: AttributesDict
) -> bool:
def evaluate_condition(condition: Condition, subject_attributes: Attributes) -> bool:
subject_value = subject_attributes.get(condition.attribute, None)
if condition.operator == OperatorType.IS_NULL:
if condition.value:
Expand Down
2 changes: 1 addition & 1 deletion eppo_client/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
ValueType = Union[str, int, float, bool]
AttributeType = Union[str, int, float, bool, None]
ConditionValueType = Union[AttributeType, List[AttributeType]]
AttributesDict = Dict[str, AttributeType]
Attributes = Dict[str, AttributeType]
Action = str
2 changes: 1 addition & 1 deletion eppo_client/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.2.0"
__version__ = "3.2.1"
6 changes: 3 additions & 3 deletions example/03_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ async def bandit(name: str, country: str, age: int):
bandit_result = client.get_bandit_action(
"shoe-bandit",
name,
eppo_client.bandit.Attributes(
eppo_client.bandit.ContextAttributes(
numeric_attributes={"age": age}, categorical_attributes={"country": country}
),
{
"nike": eppo_client.bandit.Attributes(
"nike": eppo_client.bandit.ContextAttributes(
numeric_attributes={"brand_affinity": 2.3},
categorical_attributes={"aspect_ratio": "16:9"},
),
"adidas": eppo_client.bandit.Attributes(
"adidas": eppo_client.bandit.ContextAttributes(
numeric_attributes={"brand_affinity": 0.2},
categorical_attributes={"aspect_ratio": "16:9"},
),
Expand Down
Loading

0 comments on commit 102bd6f

Please sign in to comment.