From abfec666408373fb4770b41ebe3530515d869c7f Mon Sep 17 00:00:00 2001 From: Sven Schmit Date: Wed, 15 May 2024 16:48:33 -0700 Subject: [PATCH 01/10] wip --- eppo_client/bandit.py | 161 +++++++++++++++++++++++++++++++++ eppo_client/models.py | 55 +++++++++++- eppo_client/types.py | 1 + test/bandit_test.py | 201 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 417 insertions(+), 1 deletion(-) create mode 100644 eppo_client/bandit.py create mode 100644 test/bandit_test.py diff --git a/eppo_client/bandit.py b/eppo_client/bandit.py new file mode 100644 index 0000000..27affd9 --- /dev/null +++ b/eppo_client/bandit.py @@ -0,0 +1,161 @@ +from dataclasses import dataclass +from typing import Dict, List, Tuple +from eppo_client.models import ( + ActionContext, + Attributes, + BanditCategoricalAttributeCoefficient, + BanditCoefficients, + BanditModelData, + BanditNumericAttributeCoefficient, +) +from eppo_client.sharders import Sharder + + +@dataclass +class BanditEvaluator: + sharder: Sharder + total_shards: int = 10_000 + + def evaluate_bandit( + self, + flag_key: str, + subject_key: str, + subject_attributes: Attributes, + action_attributes: Attributes, + actions_with_contexts: List[ActionContext], + bandit_model: BanditModelData, + ): + action_scores = self.score_actions( + subject_attributes, action_attributes, actions_with_contexts, bandit_model + ) + + action_weights = self.weight_actions( + action_scores, + bandit_model.gamma, + bandit_model.action_probability_floor, + ) + + selected_action = self.select_action(flag_key, subject_key, action_weights) + return selected_action + + def score_actions( + self, + subject_attributes: Attributes, + action_attributes: Attributes, + actions_with_contexts: List[ActionContext], + bandit_model: BanditModelData, + ): + return [ + ( + action_context.action_key, + ( + score_action( + subject_attributes, + action_attributes, + bandit_model.coefficients[action_context.action_key], + ) + if action_context.action_key in bandit_model.coefficients + else bandit_model.default_action_score + ), + ) + for action_context in actions_with_contexts + ] + + def weight_actions(self, action_scores, gamma, probability_floor): + number_of_actions = len(action_scores) + best_action, best_score = max(action_scores, key=lambda t: t[1]) + + # adjust probability floor for number of actions to control the sum + min_probability = probability_floor / number_of_actions + + # weight all but the best action + weights = [ + ( + action_key, + max( + min_probability, + 1.0 / (number_of_actions + gamma * (best_score - score)), + ), + ) + for action_key, score in action_scores + if action_key != best_action + ] + + # remaining weight goes to best action + remaining_weight = 1.0 - sum(weight for _, weight in weights) + weights.append((best_action, remaining_weight)) + return weights + + def select_action(self, flag_key, subject_key, action_weights) -> Tuple[str, float]: + # deterministic ordering + sorted_action_weights = sorted( + action_weights, + key=lambda t: self.sharder.get_shard( + f"{flag_key}-{subject_key}-{t[0]}", self.total_shards + ), + ) + + # select action based on weights + shard = self.sharder.get_shard(f"{flag_key}-{subject_key}", self.total_shards) + cumulative_weight = 0.0 + shard_value = shard / self.total_shards + + for action_key, weight in sorted_action_weights: + cumulative_weight += weight + if cumulative_weight > shard_value: + return action_key, weight + + # If no action is selected, return the last action (fallback) + return sorted_action_weights[-1] + + +def score_action( + subject_attributes: Attributes, + action_attributes: Attributes, + coefficients: BanditCoefficients, +): + score = coefficients.intercept + score += score_numeric_attributes( + subject_attributes.numeric_attributes, coefficients.subject_numeric_coefficients + ) + score += score_categorical_attributes( + subject_attributes.categorical_attributes, + coefficients.subject_categorical_coefficients, + ) + score += score_numeric_attributes( + action_attributes.numeric_attributes, coefficients.action_numeric_coefficients + ) + score += score_numeric_attributes( + action_attributes.numeric_attributes, coefficients.action_numeric_coefficients + ) + return score + + +def score_numeric_attributes( + coefficients: List[BanditNumericAttributeCoefficient], + attributes: Dict[str, float], +) -> float: + score = 0.0 + for coefficient in coefficients: + if coefficient.attribute_key in attributes: + score += coefficient.coefficient * attributes[coefficient.attribute_key] + else: + score += coefficient.missing_value_coefficient + + return score + + +def score_categorical_attributes( + coefficients: List[BanditCategoricalAttributeCoefficient], + attributes: Dict[str, str], +) -> float: + score = 0.0 + for coefficient in coefficients: + if coefficient.attribute_key in attributes: + score += coefficient.value_coefficients.get( + attributes[coefficient.attribute_key], + coefficient.missing_value_coefficient, + ) + else: + score += coefficient.missing_value_coefficient + return score diff --git a/eppo_client/models.py b/eppo_client/models.py index 601bc86..e119753 100644 --- a/eppo_client/models.py +++ b/eppo_client/models.py @@ -4,7 +4,7 @@ from eppo_client.base_model import SdkBaseModel from eppo_client.rules import Rule -from eppo_client.types import ValueType +from eppo_client.types import Action, ValueType class VariationType(Enum): @@ -52,3 +52,56 @@ class Flag(SdkBaseModel): variations: Dict[str, Variation] allocations: List[Allocation] total_shards: int = 10_000 + + +class Attributes(SdkBaseModel): + numeric_attributes: Dict[str, float] + categorical_attributes: Dict[str, str] + + +class ActionContext(SdkBaseModel): + action_key: str + attributes: Attributes + + +class BanditVariation(SdkBaseModel): + key: str + flag_key: str + variation_key: str + variation_value: str + + +class BanditNumericAttributeCoefficient(SdkBaseModel): + attribute_key: str + coefficient: float + missing_value_coefficient: float + + +class BanditCategoricalAttributeCoefficient(SdkBaseModel): + attribute_key: str + missing_value_coefficient: float + value_coefficients: Dict[str, float] + + +class BanditCoefficients(SdkBaseModel): + action_key: str + intercept: float + subject_numeric_coefficients: List[BanditNumericAttributeCoefficient] + subject_categorical_coefficients: List[BanditCategoricalAttributeCoefficient] + action_numeric_coefficients: List[BanditNumericAttributeCoefficient] + action_categorical_coefficients: List[BanditCategoricalAttributeCoefficient] + + +class BanditModelData(SdkBaseModel): + gamma: float + default_action_score: float + action_probability_floor: float + coefficients: Dict[Action, BanditCoefficients] + + +class BanditData(SdkBaseModel): + bandit_key: str + model_name: str + updated_at: datetime + model_version: str + model_data: BanditModelData diff --git a/eppo_client/types.py b/eppo_client/types.py index dcc4a76..07f126a 100644 --- a/eppo_client/types.py +++ b/eppo_client/types.py @@ -4,3 +4,4 @@ AttributeType = Union[str, int, float, bool] ConditionValueType = Union[AttributeType, List[AttributeType]] SubjectAttributes = Dict[str, AttributeType] +Action = str diff --git a/test/bandit_test.py b/test/bandit_test.py new file mode 100644 index 0000000..3fb56cb --- /dev/null +++ b/test/bandit_test.py @@ -0,0 +1,201 @@ +import pytest + +from eppo_client.bandit import ( + score_numeric_attributes, + score_categorical_attributes, + weight_actions, +) +from eppo_client.models import ( + BanditNumericAttributeCoefficient, + BanditCategoricalAttributeCoefficient, +) + + +def test_score_numeric_attributes_all_present(): + coefficients = [ + BanditNumericAttributeCoefficient( + attribute_key="age", coefficient=2.0, missing_value_coefficient=0.5 + ), + BanditNumericAttributeCoefficient( + attribute_key="height", coefficient=1.5, missing_value_coefficient=0.3 + ), + ] + attributes = {"age": 30, "height": 170} + expected_score = 30 * 2.0 + 170 * 1.5 + assert score_numeric_attributes(coefficients, attributes) == expected_score + + +def test_score_numeric_attributes_some_missing(): + coefficients = [ + BanditNumericAttributeCoefficient( + attribute_key="age", coefficient=2.0, missing_value_coefficient=0.5 + ), + BanditNumericAttributeCoefficient( + attribute_key="height", coefficient=1.5, missing_value_coefficient=0.3 + ), + ] + attributes = {"age": 30} + expected_score = 30 * 2.0 + 0.3 # height is missing, use missing_value_coefficient + assert score_numeric_attributes(coefficients, attributes) == expected_score + + +def test_score_numeric_attributes_all_missing(): + coefficients = [ + BanditNumericAttributeCoefficient( + attribute_key="age", coefficient=2.0, missing_value_coefficient=0.5 + ), + BanditNumericAttributeCoefficient( + attribute_key="height", coefficient=1.5, missing_value_coefficient=0.3 + ), + ] + attributes = {} + expected_score = 0.5 + 0.3 # both are missing, use missing_value_coefficients + assert score_numeric_attributes(coefficients, attributes) == expected_score + + +def test_score_numeric_attributes_empty_coefficients(): + coefficients = [] + attributes = {"age": 30, "height": 170} + expected_score = 0.0 # no coefficients to apply + assert score_numeric_attributes(coefficients, attributes) == expected_score + + +def test_score_numeric_attributes_negative_coefficients(): + coefficients = [ + BanditNumericAttributeCoefficient( + attribute_key="age", coefficient=-2.0, missing_value_coefficient=0.5 + ), + BanditNumericAttributeCoefficient( + attribute_key="height", coefficient=-1.5, missing_value_coefficient=0.3 + ), + ] + attributes = {"age": 30, "height": 170} + expected_score = 30 * -2.0 + 170 * -1.5 + assert score_numeric_attributes(coefficients, attributes) == expected_score + + +def test_score_categorical_attributes_some_missing(): + coefficients = [ + BanditCategoricalAttributeCoefficient( + attribute_key="color", + missing_value_coefficient=0.2, + value_coefficients={"red": 1.0, "blue": 0.5}, + ), + BanditCategoricalAttributeCoefficient( + attribute_key="size", + missing_value_coefficient=0.3, + value_coefficients={"large": 2.0, "small": 1.0}, + ), + ] + attributes = {"color": "red"} + expected_score = 1.0 + 0.3 # size is missing, use missing_value_coefficient + assert score_categorical_attributes(coefficients, attributes) == expected_score + + +def test_score_categorical_attributes_all_missing(): + coefficients = [ + BanditCategoricalAttributeCoefficient( + attribute_key="color", + missing_value_coefficient=0.2, + value_coefficients={"red": 1.0, "blue": 0.5}, + ), + BanditCategoricalAttributeCoefficient( + attribute_key="size", + missing_value_coefficient=0.3, + value_coefficients={"large": 2.0, "small": 1.0}, + ), + ] + attributes = {} + expected_score = 0.2 + 0.3 # both are missing, use missing_value_coefficients + assert score_categorical_attributes(coefficients, attributes) == expected_score + + +def test_score_categorical_attributes_empty_coefficients(): + coefficients = [] + attributes = {"color": "red", "size": "large"} + expected_score = 0.0 # no coefficients to apply + assert score_categorical_attributes(coefficients, attributes) == expected_score + + +def test_score_categorical_attributes_negative_coefficients(): + coefficients = [ + BanditCategoricalAttributeCoefficient( + attribute_key="color", + missing_value_coefficient=0.2, + value_coefficients={"red": -1.0, "blue": -0.5}, + ), + BanditCategoricalAttributeCoefficient( + attribute_key="size", + missing_value_coefficient=0.3, + value_coefficients={"large": -2.0, "small": -1.0}, + ), + ] + attributes = {"color": "red", "size": "large"} + expected_score = -1.0 + -2.0 + assert score_categorical_attributes(coefficients, attributes) == expected_score + + +def test_score_categorical_attributes_mixed_coefficients(): + coefficients = [ + BanditCategoricalAttributeCoefficient( + attribute_key="color", + missing_value_coefficient=0.2, + value_coefficients={"red": 1.0, "blue": -0.5}, + ), + BanditCategoricalAttributeCoefficient( + attribute_key="size", + missing_value_coefficient=0.3, + value_coefficients={"large": -2.0, "small": 1.0}, + ), + ] + attributes = {"color": "blue", "size": "small"} + expected_score = -0.5 + 1.0 + assert score_categorical_attributes(coefficients, attributes) == expected_score + + +def test_weight_actions_single_action(): + action_scores = [("action1", 1.0)] + gamma = 0.1 + probability_floor = 0.1 + expected_weights = [("action1", 1.0)] + assert weight_actions(action_scores, gamma, probability_floor) == expected_weights + + +def test_weight_actions_multiple_actions(): + action_scores = [("action1", 1.0), ("action2", 0.5)] + gamma = 0.1 + probability_floor = 0.1 + weights = weight_actions(action_scores, gamma, probability_floor) + assert len(weights) == 2 + assert any(action == "action1" and weight > 0.5 for action, weight in weights) + assert any(action == "action2" and weight <= 0.5 for action, weight in weights) + + +def test_weight_actions_probability_floor(): + action_scores = [("action1", 1.0), ("action2", 0.5), ("action3", 0.2)] + gamma = 0.1 + probability_floor = 0.3 + weights = weight_actions(action_scores, gamma, probability_floor) + assert len(weights) == 3 + for action, weight in weights: + assert weight >= 0.1 + + +def test_weight_actions_gamma_effect(): + action_scores = [("action1", 1.0), ("action2", 0.5)] + gamma = 1.0 + probability_floor = 0.1 + weights = weight_actions(action_scores, gamma, probability_floor) + assert len(weights) == 2 + assert any(action == "action1" and weight > 0.5 for action, weight in weights) + assert any(action == "action2" and weight <= 0.5 for action, weight in weights) + + +def test_weight_actions_all_equal_scores(): + action_scores = [("action1", 1.0), ("action2", 1.0), ("action3", 1.0)] + gamma = 0.1 + probability_floor = 0.1 + weights = weight_actions(action_scores, gamma, probability_floor) + assert len(weights) == 3 + for _, weight in weights: + assert weight == pytest.approx(1.0 / 3, rel=1e-2) From edaa51395e9909cf2ea4ae61988d867264ffdf71 Mon Sep 17 00:00:00 2001 From: Sven Schmit Date: Wed, 22 May 2024 17:07:16 -0700 Subject: [PATCH 02/10] wip --- eppo_client/assignment_logger.py | 3 ++ eppo_client/bandit.py | 47 ++++++++++++++----- eppo_client/client.py | 63 +++++++++++++++++++++++++- eppo_client/configuration_requestor.py | 18 ++++++-- 4 files changed, 113 insertions(+), 18 deletions(-) diff --git a/eppo_client/assignment_logger.py b/eppo_client/assignment_logger.py index 309b26a..cc17e76 100644 --- a/eppo_client/assignment_logger.py +++ b/eppo_client/assignment_logger.py @@ -8,3 +8,6 @@ class AssignmentLogger(BaseModel): def log_assignment(self, assignment_event: Dict): pass + + def log_bandit_action(self, bandit_event: Dict): + pass diff --git a/eppo_client/bandit.py b/eppo_client/bandit.py index 27affd9..afbec37 100644 --- a/eppo_client/bandit.py +++ b/eppo_client/bandit.py @@ -11,6 +11,18 @@ from eppo_client.sharders import Sharder +@dataclass +class BanditEvaluation: + flag_key: str + subject_key: str + subject_attributes: Attributes + action_key: str + action_attributes: Attributes + action_score: float + action_weight: float + gamma: float + + @dataclass class BanditEvaluator: sharder: Sharder @@ -21,12 +33,11 @@ def evaluate_bandit( flag_key: str, subject_key: str, subject_attributes: Attributes, - action_attributes: Attributes, actions_with_contexts: List[ActionContext], bandit_model: BanditModelData, ): action_scores = self.score_actions( - subject_attributes, action_attributes, actions_with_contexts, bandit_model + subject_attributes, actions_with_contexts, bandit_model ) action_weights = self.weight_actions( @@ -35,23 +46,33 @@ def evaluate_bandit( bandit_model.action_probability_floor, ) - selected_action = self.select_action(flag_key, subject_key, action_weights) - return selected_action + selected_idx, selected_action = self.select_action( + flag_key, subject_key, action_weights + ) + return BanditEvaluation( + flag_key, + subject_key, + subject_attributes, + selected_action, + actions_with_contexts[selected_idx].attributes, + action_scores[selected_idx][1], + action_weights[selected_idx][1], + bandit_model.gamma, + ) def score_actions( self, subject_attributes: Attributes, - action_attributes: Attributes, actions_with_contexts: List[ActionContext], bandit_model: BanditModelData, - ): + ) -> List[Tuple[str, float]]: return [ ( action_context.action_key, ( score_action( subject_attributes, - action_attributes, + action_context.attributes, bandit_model.coefficients[action_context.action_key], ) if action_context.action_key in bandit_model.coefficients @@ -61,7 +82,9 @@ def score_actions( for action_context in actions_with_contexts ] - def weight_actions(self, action_scores, gamma, probability_floor): + def weight_actions( + self, action_scores, gamma, probability_floor + ) -> List[Tuple[str, float]]: number_of_actions = len(action_scores) best_action, best_score = max(action_scores, key=lambda t: t[1]) @@ -100,20 +123,20 @@ def select_action(self, flag_key, subject_key, action_weights) -> Tuple[str, flo cumulative_weight = 0.0 shard_value = shard / self.total_shards - for action_key, weight in sorted_action_weights: + for idx, (action_key, weight) in enumerate(sorted_action_weights): cumulative_weight += weight if cumulative_weight > shard_value: - return action_key, weight + return idx, action_key # If no action is selected, return the last action (fallback) - return sorted_action_weights[-1] + return (len(sorted_action_weights) - 1, sorted_action_weights[-1][0]) def score_action( subject_attributes: Attributes, action_attributes: Attributes, coefficients: BanditCoefficients, -): +) -> float: score = coefficients.intercept score += score_numeric_attributes( subject_attributes.numeric_attributes, coefficients.subject_numeric_coefficients diff --git a/eppo_client/client.py b/eppo_client/client.py index 2200bc6..557c36f 100644 --- a/eppo_client/client.py +++ b/eppo_client/client.py @@ -1,13 +1,14 @@ import datetime import logging import json -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from eppo_client.assignment_logger import AssignmentLogger +from eppo_client.bandit import BanditEvaluator from eppo_client.configuration_requestor import ( ExperimentConfigurationRequestor, ) from eppo_client.constants import POLL_INTERVAL_MILLIS, POLL_JITTER_MILLIS -from eppo_client.models import VariationType +from eppo_client.models import ActionContext, Attributes, VariationType from eppo_client.poller import Poller from eppo_client.sharders import MD5Sharder from eppo_client.types import SubjectAttributes, ValueType @@ -36,6 +37,7 @@ def __init__( ) self.__poller.start() self.__evaluator = Evaluator(sharder=MD5Sharder()) + self.__bandit_evaluator = BanditEvaluator(sharder=MD5Sharder()) def get_string_assignment( self, @@ -219,6 +221,63 @@ def get_assignment_detail( logger.error("[Eppo SDK] Error logging assignment event: " + str(e)) return result + def get_bandit_action( + self, + bandit_key: str, + subject_key: str, + subject_attributes: Attributes, + actions_with_contexts: List[ActionContext], + ) -> Optional[str]: + + # get experiment assignment + assignment = self.get_string_assignment( + bandit_key, subject_key, subject_attributes.categorical_attributes, None + ) + + # if the assignment is not the bandit key, then the subject is not allocated in the bandit + if assignment != bandit_key: + return None + + bandit_data = self.__config_requestor.get_bandit_model(bandit_key) + + if not bandit_data: + logger.warning( + "[Eppo SDK] No assigned action. Bandit not found: " + bandit_key + ) + return None + + action = self.__bandit_evaluator.get_bandit_action( + bandit_key, + subject_key, + subject_attributes, + actions_with_contexts, + bandit_data.model_data, + ) + + # log bandit action + bandit_event = { + "banditKey": bandit_key, + "subject": subject_key, + "action": action.action_key if action else None, + "actionProbability": action.weight if action else None, + "modelVersion": bandit_data.model_version if action else None, + "timestamp": datetime.datetime.utcnow().isoformat(), + "subjectNumericAttributes": ( + subject_attributes.numeric_attributes if action else None + ), + "subjectCategoricalAttributes": ( + subject_attributes.categorical_attributes if action else None + ), + "actionNumericAttributes": (action.numeric_attributes if action else None), + "actionCategoricalAttributes": ( + action.categorical_attributes if action else None + ), + "metaData": {"sdkLanguage": "python", "sdkVersion": __version__}, + } + self.__assignment_logger.log_bandit_action(bandit_event) + + return action.action_key if action else None + def get_flag_keys(self): """ Returns a list of all flag keys that have been initialized. diff --git a/eppo_client/configuration_requestor.py b/eppo_client/configuration_requestor.py index 9b1d5dd..ba4b2d4 100644 --- a/eppo_client/configuration_requestor.py +++ b/eppo_client/configuration_requestor.py @@ -2,7 +2,7 @@ from typing import Dict, Optional, cast from eppo_client.configuration_store import ConfigurationStore from eppo_client.http_client import HttpClient -from eppo_client.models import Flag +from eppo_client.models import BanditData, Flag logger = logging.getLogger(__name__) @@ -14,10 +14,12 @@ class ExperimentConfigurationRequestor: def __init__( self, http_client: HttpClient, - config_store: ConfigurationStore[Flag], + flag_config_store: ConfigurationStore[Flag], + bandit_config_store: ConfigurationStore[BanditData], ): self.__http_client = http_client - self.__config_store = config_store + self.__flag_config_store = flag_config_store + self.__bandit_config_store = bandit_config_store self.__is_initialized = False def get_configuration(self, flag_key: str) -> Optional[Flag]: @@ -25,8 +27,16 @@ def get_configuration(self, flag_key: str) -> Optional[Flag]: raise ValueError("Unauthorized: please check your API key") return self.__config_store.get_configuration(flag_key) + def get_bandit_model(self, bandit_key: str) -> Optional[BanditData]: + if self.__http_client.is_unauthorized(): + raise ValueError("Unauthorized: please check your API key") + return self.__bandit_config_store.get_configuration(bandit_key) + def get_flag_keys(self): - return self.__config_store.get_keys() + return self.__flag_config_store.get_keys() + + def get_bandit_keys(self): + return self.__bandit_config_store.get_keys() def fetch_and_store_configurations(self) -> Dict[str, Flag]: try: From a5837ebeb619eeaa788c51ff4f8005e82de54aa2 Mon Sep 17 00:00:00 2001 From: Sven Schmit Date: Sat, 25 May 2024 16:41:53 -0700 Subject: [PATCH 03/10] wip --- eppo_client/__init__.py | 9 +- eppo_client/bandit.py | 86 +++++++++++++--- eppo_client/client.py | 83 +++++++++++----- eppo_client/configuration_requestor.py | 44 +++++++-- eppo_client/configuration_store.py | 2 +- eppo_client/models.py | 15 +-- test/bandit_test.py | 131 +++++++++++++++++++++++-- test/client_bandit_test.py | 115 ++++++++++++++++++++++ test/configuration_store_test.py | 8 ++ 9 files changed, 430 insertions(+), 63 deletions(-) create mode 100644 test/client_bandit_test.py diff --git a/eppo_client/__init__.py b/eppo_client/__init__.py index 884ae50..91ee517 100644 --- a/eppo_client/__init__.py +++ b/eppo_client/__init__.py @@ -6,7 +6,7 @@ ) from eppo_client.configuration_store import ConfigurationStore from eppo_client.http_client import HttpClient, SdkParams -from eppo_client.models import Flag +from eppo_client.models import BanditData, Flag from eppo_client.read_write_lock import ReadWriteLock from eppo_client.version import __version__ @@ -30,9 +30,12 @@ def init(config: Config) -> EppoClient: apiKey=config.api_key, sdkName="python", sdkVersion=__version__ ) http_client = HttpClient(base_url=config.base_url, sdk_params=sdk_params) - config_store: ConfigurationStore[Flag] = ConfigurationStore() + flag_config_store: ConfigurationStore[Flag] = ConfigurationStore() + bandit_config_store: ConfigurationStore[BanditData] = ConfigurationStore() config_requestor = ExperimentConfigurationRequestor( - http_client=http_client, config_store=config_store + http_client=http_client, + flag_config_store=flag_config_store, + bandit_config_store=bandit_config_store, ) assignment_logger = config.assignment_logger is_graceful_mode = config.is_graceful_mode diff --git a/eppo_client/bandit.py b/eppo_client/bandit.py index afbec37..34623ff 100644 --- a/eppo_client/bandit.py +++ b/eppo_client/bandit.py @@ -1,8 +1,6 @@ from dataclasses import dataclass -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple from eppo_client.models import ( - ActionContext, - Attributes, BanditCategoricalAttributeCoefficient, BanditCoefficients, BanditModelData, @@ -11,18 +9,74 @@ from eppo_client.sharders import Sharder +@dataclass +class Attributes: + numeric_attributes: Dict[str, float] + categorical_attributes: Dict[str, str] + + +@dataclass +class ActionContext: + action_key: str + attributes: Attributes + + @classmethod + def create( + cls, + action_key: str, + numeric_attributes: Dict[str, float], + categorical_attributes: Dict[str, str], + ): + return cls( + action_key, + Attributes( + numeric_attributes=numeric_attributes, + categorical_attributes=categorical_attributes, + ), + ) + + @property + def numeric_attributes(self): + return self.attributes.numeric_attributes + + @property + def categorical_attributes(self): + return self.attributes.categorical_attributes + + @dataclass class BanditEvaluation: flag_key: str subject_key: str subject_attributes: Attributes - action_key: str - action_attributes: Attributes + action_key: Optional[str] + action_attributes: Optional[Attributes] action_score: float action_weight: float gamma: float +@dataclass +class BanditResult: + action: str + assignment: str + + +def null_evaluation( + flag_key: str, subject_key: str, subject_attributes: Attributes, gamma +): + return BanditEvaluation( + flag_key, + subject_key, + subject_attributes, + None, + None, + 0.0, + 0.0, + gamma, + ) + + @dataclass class BanditEvaluator: sharder: Sharder @@ -35,7 +89,13 @@ def evaluate_bandit( subject_attributes: Attributes, actions_with_contexts: List[ActionContext], bandit_model: BanditModelData, - ): + ) -> BanditEvaluation: + # handle the edge case that there are no actions + if not actions_with_contexts: + return null_evaluation( + flag_key, subject_key, subject_attributes, bandit_model.gamma + ) + action_scores = self.score_actions( subject_attributes, actions_with_contexts, bandit_model ) @@ -139,17 +199,20 @@ def score_action( ) -> float: score = coefficients.intercept score += score_numeric_attributes( - subject_attributes.numeric_attributes, coefficients.subject_numeric_coefficients + coefficients.subject_numeric_coefficients, + subject_attributes.numeric_attributes, ) score += score_categorical_attributes( - subject_attributes.categorical_attributes, coefficients.subject_categorical_coefficients, + subject_attributes.categorical_attributes, ) score += score_numeric_attributes( - action_attributes.numeric_attributes, coefficients.action_numeric_coefficients + coefficients.action_numeric_coefficients, + action_attributes.numeric_attributes, ) - score += score_numeric_attributes( - action_attributes.numeric_attributes, coefficients.action_numeric_coefficients + score += score_categorical_attributes( + coefficients.action_categorical_coefficients, + action_attributes.categorical_attributes, ) return score @@ -159,6 +222,7 @@ def score_numeric_attributes( attributes: Dict[str, float], ) -> float: score = 0.0 + print(coefficients) for coefficient in coefficients: if coefficient.attribute_key in attributes: score += coefficient.coefficient * attributes[coefficient.attribute_key] diff --git a/eppo_client/client.py b/eppo_client/client.py index 557c36f..29a2612 100644 --- a/eppo_client/client.py +++ b/eppo_client/client.py @@ -3,12 +3,12 @@ import json from typing import Any, Dict, List, Optional from eppo_client.assignment_logger import AssignmentLogger -from eppo_client.bandit import BanditEvaluator +from eppo_client.bandit import BanditEvaluator, BanditResult, ActionContext, Attributes from eppo_client.configuration_requestor import ( ExperimentConfigurationRequestor, ) from eppo_client.constants import POLL_INTERVAL_MILLIS, POLL_JITTER_MILLIS -from eppo_client.models import ActionContext, Attributes, VariationType +from eppo_client.models import VariationType from eppo_client.poller import Poller from eppo_client.sharders import MD5Sharder from eppo_client.types import SubjectAttributes, ValueType @@ -223,31 +223,58 @@ def get_assignment_detail( def get_bandit_action( self, - bandit_key: str, + flag_key: str, subject_key: str, subject_attributes: Attributes, actions_with_contexts: List[ActionContext], - ) -> Optional[str]: - + default: str, + ) -> BanditResult: + """ + Determines the bandit action for a given subject based on the provided bandit key and subject attributes. + + This method performs the following steps: + 1. Retrieves the experiment assignment for the given bandit key and subject. + 2. Checks if the assignment matches the bandit key. If not, it means the subject is not allocated in the bandit, + and the method returns a BanditResult with the assignment. + 3. If the subject is part of the bandit, it fetches the bandit model data. + 4. Evaluates the bandit action using the bandit evaluator. + 5. Logs the bandit action event. + 6. Returns the BanditResult containing the selected action key and the assignment. + + Args: + flag_key (str): The feature flag key that contains the bandit as one of the variations. + subject_key (str): The key identifying the subject. + subject_attributes (Attributes): The attributes of the subject. + actions_with_contexts (List[ActionContext]): The list of actions with their contexts. + + Returns: + BanditResult: The result containing either the bandit action if the subject is part of the bandit, + or the assignment if they are not. The BanditResult includes: + - action (str): The key of the selected action if the subject is part of the bandit. + - assignment (str): The assignment key indicating the subject's variation. + """ # get experiment assignment assignment = self.get_string_assignment( - bandit_key, subject_key, subject_attributes.categorical_attributes, None + flag_key, subject_key, subject_attributes.categorical_attributes, default ) + print(flag_key, assignment) + # if the assignment is not the bandit key, then the subject is not allocated in the bandit - if assignment != bandit_key: - return None + if assignment not in self.get_bandit_keys(): + return BanditResult(None, assignment) - bandit_data = self.__config_requestor.get_bandit_model(bandit_key) + # for now, assume that the assignment is equal to the bandit key + bandit_data = self.__config_requestor.get_bandit_model(assignment) if not bandit_data: logger.warning( - "[Eppo SDK] No assigned action. Bandit not found: " + bandit_key + f"[Eppo SDK] No assigned action. Bandit not found for flag: {flag_key}" ) - return None + return BanditResult(None, assignment) - action = self.__bandit_evaluator.get_bandit_action( - bandit_key, + evaluation = self.__bandit_evaluator.evaluate_bandit( + flag_key, subject_key, subject_attributes, actions_with_contexts, @@ -256,27 +283,32 @@ def get_bandit_action( # log bandit action bandit_event = { - "banditKey": bandit_key, + "flagKey": flag_key, + "banditKey": bandit_data.bandit_key, "subject": subject_key, - "action": action.action_key if action else None, - "actionProbability": action.weight if action else None, - "modelVersion": bandit_data.model_version if action else None, + "action": evaluation.action_key if evaluation else None, + "actionProbability": evaluation.action_weight if evaluation else None, + "modelVersion": bandit_data.model_version if evaluation else None, "timestamp": datetime.datetime.utcnow().isoformat(), "subjectNumericAttributes": ( - subject_attributes.numeric_attributes if action else None + subject_attributes.numeric_attributes if evaluation else None ), "subjectCategoricalAttributes": ( - subject_attributes.categorical_attributes if action else None + subject_attributes.categorical_attributes if evaluation else None + ), + "actionNumericAttributes": ( + evaluation.action_attributes.numeric_attributes if evaluation else None ), - "actionNumericAttributes": (action.numeric_attributes if action else None), "actionCategoricalAttributes": ( - action.categorical_attributes if action else None + evaluation.action_attributes.categorical_attributes + if evaluation + else None ), "metaData": {"sdkLanguage": "python", "sdkVersion": __version__}, } self.__assignment_logger.log_bandit_action(bandit_event) - return action.action_key if action else None + return BanditResult(evaluation.action_key if evaluation else None, assignment) def get_flag_keys(self): """ @@ -287,6 +319,13 @@ def get_flag_keys(self): """ return self.__config_requestor.get_flag_keys() + def get_bandit_keys(self): + """ + Returns a list of all bandit keys that have been initialized. + This can be useful to debug the initialization process. + """ + return self.__config_requestor.get_bandit_keys() + def is_initialized(self): """ Returns True if the client has successfully initialized diff --git a/eppo_client/configuration_requestor.py b/eppo_client/configuration_requestor.py index ba4b2d4..3dfc170 100644 --- a/eppo_client/configuration_requestor.py +++ b/eppo_client/configuration_requestor.py @@ -8,6 +8,7 @@ UFC_ENDPOINT = "/flag-config/v1/config" +BANDIT_ENDPOINT = "/flag-config/v1/bandits" class ExperimentConfigurationRequestor: @@ -25,7 +26,7 @@ def __init__( def get_configuration(self, flag_key: str) -> Optional[Flag]: if self.__http_client.is_unauthorized(): raise ValueError("Unauthorized: please check your API key") - return self.__config_store.get_configuration(flag_key) + return self.__flag_config_store.get_configuration(flag_key) def get_bandit_model(self, bandit_key: str) -> Optional[BanditData]: if self.__http_client.is_unauthorized(): @@ -38,18 +39,41 @@ def get_flag_keys(self): def get_bandit_keys(self): return self.__bandit_config_store.get_keys() - def fetch_and_store_configurations(self) -> Dict[str, Flag]: + def fetch_flags(self): + return self.__http_client.get(UFC_ENDPOINT) + + def fetch_bandits(self): + return self.__http_client.get(BANDIT_ENDPOINT) + + def store_flags(self, flag_data) -> Dict[str, Flag]: + flag_config_dict = cast(dict, flag_data.get("flags", {})) + flag_configs = {key: Flag(**config) for key, config in flag_config_dict.items()} + self.__flag_config_store.set_configurations(flag_configs) + return flag_configs + + def store_bandits(self, bandit_data) -> Dict[str, BanditData]: + print(bandit_data) + bandit_configs = { + config["banditKey"]: BanditData(**config) + for config in cast(dict, bandit_data.get("bandits", [])) + } + print(bandit_configs) + self.__bandit_config_store.set_configurations(bandit_configs) + return bandit_configs + + def fetch_and_store_configurations(self): + print("fetch and store...") try: - configs_dict = cast( - dict, self.__http_client.get(UFC_ENDPOINT).get("flags", {}) - ) - configs = {key: Flag(**config) for key, config in configs_dict.items()} - self.__config_store.set_configurations(configs) + flag_data = self.fetch_flags() + self.store_flags(flag_data) + + if flag_data.get("bandits", {}): + bandit_data = self.fetch_bandits() + self.store_bandits(bandit_data) self.__is_initialized = True - return configs except Exception as e: - logger.error("Error retrieving flag configurations: " + str(e)) - return {} + logger.error("Error retrieving configurations: " + str(e)) + print("... done") def is_initialized(self): return self.__is_initialized diff --git a/eppo_client/configuration_store.py b/eppo_client/configuration_store.py index bdd57eb..87b7be5 100644 --- a/eppo_client/configuration_store.py +++ b/eppo_client/configuration_store.py @@ -20,4 +20,4 @@ def set_configurations(self, configs: Dict[str, T]): def get_keys(self): with self.__lock.reader(): - return list(self.__cache.keys()) + return set(self.__cache.keys()) diff --git a/eppo_client/models.py b/eppo_client/models.py index e119753..ea4e75d 100644 --- a/eppo_client/models.py +++ b/eppo_client/models.py @@ -54,16 +54,6 @@ class Flag(SdkBaseModel): total_shards: int = 10_000 -class Attributes(SdkBaseModel): - numeric_attributes: Dict[str, float] - categorical_attributes: Dict[str, str] - - -class ActionContext(SdkBaseModel): - action_key: str - attributes: Attributes - - class BanditVariation(SdkBaseModel): key: str flag_key: str @@ -77,6 +67,11 @@ class BanditNumericAttributeCoefficient(SdkBaseModel): missing_value_coefficient: float +class ValueCoefficient(SdkBaseModel): + value: str + coefficient: float + + class BanditCategoricalAttributeCoefficient(SdkBaseModel): attribute_key: str missing_value_coefficient: float diff --git a/test/bandit_test.py b/test/bandit_test.py index 3fb56cb..5c5e822 100644 --- a/test/bandit_test.py +++ b/test/bandit_test.py @@ -1,14 +1,21 @@ import pytest +from eppo_client.sharders import MD5Sharder, DeterministicSharder + from eppo_client.bandit import ( score_numeric_attributes, score_categorical_attributes, - weight_actions, + BanditEvaluator, ) from eppo_client.models import ( + BanditCoefficients, + BanditModelData, BanditNumericAttributeCoefficient, BanditCategoricalAttributeCoefficient, ) +from eppo_client.bandit import BanditEvaluator, ActionContext, Attributes + +bandit_evaluator = BanditEvaluator(MD5Sharder(), 10_000) def test_score_numeric_attributes_all_present(): @@ -158,14 +165,17 @@ def test_weight_actions_single_action(): gamma = 0.1 probability_floor = 0.1 expected_weights = [("action1", 1.0)] - assert weight_actions(action_scores, gamma, probability_floor) == expected_weights + assert ( + bandit_evaluator.weight_actions(action_scores, gamma, probability_floor) + == expected_weights + ) def test_weight_actions_multiple_actions(): action_scores = [("action1", 1.0), ("action2", 0.5)] gamma = 0.1 probability_floor = 0.1 - weights = weight_actions(action_scores, gamma, probability_floor) + weights = bandit_evaluator.weight_actions(action_scores, gamma, probability_floor) assert len(weights) == 2 assert any(action == "action1" and weight > 0.5 for action, weight in weights) assert any(action == "action2" and weight <= 0.5 for action, weight in weights) @@ -175,7 +185,7 @@ def test_weight_actions_probability_floor(): action_scores = [("action1", 1.0), ("action2", 0.5), ("action3", 0.2)] gamma = 0.1 probability_floor = 0.3 - weights = weight_actions(action_scores, gamma, probability_floor) + weights = bandit_evaluator.weight_actions(action_scores, gamma, probability_floor) assert len(weights) == 3 for action, weight in weights: assert weight >= 0.1 @@ -185,7 +195,7 @@ def test_weight_actions_gamma_effect(): action_scores = [("action1", 1.0), ("action2", 0.5)] gamma = 1.0 probability_floor = 0.1 - weights = weight_actions(action_scores, gamma, probability_floor) + weights = bandit_evaluator.weight_actions(action_scores, gamma, probability_floor) assert len(weights) == 2 assert any(action == "action1" and weight > 0.5 for action, weight in weights) assert any(action == "action2" and weight <= 0.5 for action, weight in weights) @@ -195,7 +205,116 @@ def test_weight_actions_all_equal_scores(): action_scores = [("action1", 1.0), ("action2", 1.0), ("action3", 1.0)] gamma = 0.1 probability_floor = 0.1 - weights = weight_actions(action_scores, gamma, probability_floor) + weights = bandit_evaluator.weight_actions(action_scores, gamma, probability_floor) assert len(weights) == 3 for _, weight in weights: assert weight == pytest.approx(1.0 / 3, rel=1e-2) + + +def test_evaluate_bandit(): + # Mock data + flag_key = "test_flag" + subject_key = "test_subject" + subject_attributes = Attributes( + numeric_attributes={"age": 25.0}, categorical_attributes={"location": "US"} + ) + action_contexts = [ + ActionContext( + action_key="action1", + attributes=Attributes( + numeric_attributes={"price": 10.0}, + categorical_attributes={"category": "A"}, + ), + ), + ActionContext( + action_key="action2", + attributes=Attributes( + numeric_attributes={"price": 20.0}, + categorical_attributes={"category": "B"}, + ), + ), + ] + coefficients = { + "action1": BanditCoefficients( + action_key="action1", + intercept=0.5, + subject_numeric_coefficients=[ + BanditNumericAttributeCoefficient( + attribute_key="age", coefficient=0.1, missing_value_coefficient=0.0 + ) + ], + subject_categorical_coefficients=[ + BanditCategoricalAttributeCoefficient( + attribute_key="location", + missing_value_coefficient=0.0, + value_coefficients={"US": 0.2}, + ) + ], + action_numeric_coefficients=[ + BanditNumericAttributeCoefficient( + attribute_key="price", + coefficient=0.05, + missing_value_coefficient=0.0, + ) + ], + action_categorical_coefficients=[ + BanditCategoricalAttributeCoefficient( + attribute_key="category", + missing_value_coefficient=0.0, + value_coefficients={"A": 0.3}, + ) + ], + ), + "action2": BanditCoefficients( + action_key="action2", + intercept=0.3, + subject_numeric_coefficients=[ + BanditNumericAttributeCoefficient( + attribute_key="age", coefficient=0.1, missing_value_coefficient=0.0 + ) + ], + subject_categorical_coefficients=[ + BanditCategoricalAttributeCoefficient( + attribute_key="location", + missing_value_coefficient=0.0, + value_coefficients={"US": 0.2}, + ) + ], + action_numeric_coefficients=[ + BanditNumericAttributeCoefficient( + attribute_key="price", + coefficient=0.05, + missing_value_coefficient=0.0, + ) + ], + action_categorical_coefficients=[ + BanditCategoricalAttributeCoefficient( + attribute_key="category", + missing_value_coefficient=0.0, + value_coefficients={"B": 0.3}, + ) + ], + ), + } + bandit_model = BanditModelData( + gamma=0.1, + default_action_score=0.0, + action_probability_floor=0.1, + coefficients=coefficients, + ) + + evaluator = BanditEvaluator(sharder=DeterministicSharder({})) + + # Evaluate bandit + evaluation = evaluator.evaluate_bandit( + flag_key, subject_key, subject_attributes, action_contexts, bandit_model + ) + + # Assertions + assert evaluation.flag_key == flag_key + assert evaluation.subject_key == subject_key + assert evaluation.subject_attributes == subject_attributes + assert evaluation.action_key == "action1" + assert evaluation.gamma == bandit_model.gamma + assert evaluation.action_score == 4.0 + assert pytest.approx(evaluation.action_weight, rel=1e-2) == 0.4926 diff --git a/test/client_bandit_test.py b/test/client_bandit_test.py new file mode 100644 index 0000000..f50c93e --- /dev/null +++ b/test/client_bandit_test.py @@ -0,0 +1,115 @@ +# Note: contains tests for client.py related to bandits to avoid +# making client_test.py too long. + + +import json +import os +from time import sleep +from typing import Dict +from eppo_client.bandit import BanditResult, ActionContext, Attributes + +import httpretty # type: ignore +import pytest + +from eppo_client.assignment_logger import AssignmentLogger +from eppo_client.configuration_requestor import BANDIT_ENDPOINT, UFC_ENDPOINT +from eppo_client import init, get_instance +from eppo_client.config import Config + +TEST_DIR = "test/test-data/ufc/bandit-tests" +FLAG_CONFIG_FILE = "test/test-data/ufc/bandit-flags-v1.json" +BANDIT_CONFIG_FILE = "test/test-data/ufc/bandit-models-v1.json" +test_data = [] +for file_name in [file for file in os.listdir(TEST_DIR)]: + with open("{}/{}".format(TEST_DIR, file_name)) as test_case_json: + test_case_dict = json.load(test_case_json) + test_data.append(test_case_dict) + +MOCK_BASE_URL = "http://localhost:4001/api" + +DEFAULT_SUBJECT_ATTRIBUTES = Attributes( + numeric_attributes={"age": 30}, categorical_attributes={"country": "UK"} +) + + +class MockAssignmentLogger(AssignmentLogger): + def log_assignment(self, assignment_event: Dict): + print(f"Assignment Event: {assignment_event}") + + def log_bandit_action(self, bandit_event: Dict): + print(f"Bandit Event: {bandit_event}") + + +@pytest.fixture(scope="session", autouse=True) +def init_fixture(): + httpretty.enable() + with open(FLAG_CONFIG_FILE) as mock_ufc_response: + ufc_json = json.load(mock_ufc_response) + + with open(BANDIT_CONFIG_FILE) as mock_bandit_response: + bandit_json = json.load(mock_bandit_response) + + httpretty.register_uri( + httpretty.GET, + MOCK_BASE_URL + UFC_ENDPOINT, + body=json.dumps(ufc_json), + ) + httpretty.register_uri( + httpretty.GET, + MOCK_BASE_URL + BANDIT_ENDPOINT, + body=json.dumps(bandit_json), + ) + client = init( + Config( + base_url=MOCK_BASE_URL, + api_key="dummy", + assignment_logger=AssignmentLogger(), + ) + ) + sleep(0.1) # wait for initialization + print(client.get_flag_keys()) + print(client.get_bandit_keys()) + yield + client._shutdown() + httpretty.disable() + httpretty.reset() + + +def test_is_initialized(): + client = get_instance() + assert client.is_initialized(), "Client should be initialized" + + +def test_get_bandit_action_bandit_does_not_exist(): + client = get_instance() + result = client.get_bandit_action( + "nonexistent_bandit", + "subject_key", + DEFAULT_SUBJECT_ATTRIBUTES, + [], + "default_variation", + ) + print(result) + assert result == BanditResult(None, "default_variation") + + +def test_get_bandit_action_flag_without_bandit(): + client = get_instance() + result = client.get_bandit_action( + "a_flag", "subject_key", DEFAULT_SUBJECT_ATTRIBUTES, [], "default_variation" + ) + assert result == BanditResult(None, "default_variation") + + +def test_get_bandit_action_with_subject_attributes(): + # tests that allocation filtering based on subject attributes works correctly + client = get_instance() + result = client.get_bandit_action( + "banner_bandit_flag_uk_only", + "subject_key", + DEFAULT_SUBJECT_ATTRIBUTES, + [ActionContext.create("adidas", {}, {}), ActionContext.create("nike", {}, {})], + "default_variation", + ) + assert result.assignment == "banner_bandit" + assert result.action in ["adidas", "nike"] diff --git a/test/configuration_store_test.py b/test/configuration_store_test.py index 72eef3f..3bda8d0 100644 --- a/test/configuration_store_test.py +++ b/test/configuration_store_test.py @@ -26,6 +26,14 @@ def test_get_configuration_known_key(): assert store.get_configuration("flag") == mock_flag +def test_get_keys(): + store.set_configurations({"flag1": mock_flag, "flag2": mock_flag}) + keys = store.get_keys() + assert "flag1" in keys + assert "flag2" in keys + assert len(keys) == 2 + + def test_evicts_old_entries_when_max_size_exceeded(): store.set_configurations({"item_to_be_evicted": mock_flag}) assert store.get_configuration("item_to_be_evicted") == mock_flag From d4b9d98bac5268447470e360f8a24b01549f31c5 Mon Sep 17 00:00:00 2001 From: Sven Schmit Date: Tue, 28 May 2024 16:00:41 -0700 Subject: [PATCH 04/10] more tests --- eppo_client/bandit.py | 31 ++++++++++++++++++++---- eppo_client/client.py | 18 +++++++------- test/client_bandit_test.py | 48 +++++++++++++++++++++++++++++++++++--- 3 files changed, 79 insertions(+), 18 deletions(-) diff --git a/eppo_client/bandit.py b/eppo_client/bandit.py index 34623ff..26adc73 100644 --- a/eppo_client/bandit.py +++ b/eppo_client/bandit.py @@ -27,6 +27,17 @@ def create( numeric_attributes: Dict[str, float], categorical_attributes: Dict[str, str], ): + """ + Create an instance of ActionContext. + + Args: + action_key (str): The key representing the action. + numeric_attributes (Dict[str, float]): A dictionary of numeric attributes. + categorical_attributes (Dict[str, str]): A dictionary of categorical attributes. + + Returns: + ActionContext: An instance of ActionContext with the provided action key and attributes. + """ return cls( action_key, Attributes( @@ -58,12 +69,12 @@ class BanditEvaluation: @dataclass class BanditResult: + variation: str action: str - assignment: str def null_evaluation( - flag_key: str, subject_key: str, subject_attributes: Attributes, gamma + flag_key: str, subject_key: str, subject_attributes: Attributes, gamma: float ): return BanditEvaluation( flag_key, @@ -217,15 +228,22 @@ def score_action( return score +def coalesce(value, default=0): + return value if value is not None else default + + def score_numeric_attributes( coefficients: List[BanditNumericAttributeCoefficient], attributes: Dict[str, float], ) -> float: score = 0.0 - print(coefficients) for coefficient in coefficients: + print(coefficient, attributes) if coefficient.attribute_key in attributes: - score += coefficient.coefficient * attributes[coefficient.attribute_key] + score += coefficient.coefficient * coalesce( + attributes[coefficient.attribute_key], + coefficient.missing_value_coefficient, + ) else: score += coefficient.missing_value_coefficient @@ -240,7 +258,10 @@ def score_categorical_attributes( for coefficient in coefficients: if coefficient.attribute_key in attributes: score += coefficient.value_coefficients.get( - attributes[coefficient.attribute_key], + coalesce( + attributes[coefficient.attribute_key], + coefficient.missing_value_coefficient, + ), coefficient.missing_value_coefficient, ) else: diff --git a/eppo_client/client.py b/eppo_client/client.py index 29a2612..9208745 100644 --- a/eppo_client/client.py +++ b/eppo_client/client.py @@ -254,24 +254,22 @@ def get_bandit_action( - assignment (str): The assignment key indicating the subject's variation. """ # get experiment assignment - assignment = self.get_string_assignment( + variation = self.get_string_assignment( flag_key, subject_key, subject_attributes.categorical_attributes, default ) - print(flag_key, assignment) + # if the variation is not the bandit key, then the subject is not allocated in the bandit + if variation not in self.get_bandit_keys(): + return BanditResult(variation, None) - # if the assignment is not the bandit key, then the subject is not allocated in the bandit - if assignment not in self.get_bandit_keys(): - return BanditResult(None, assignment) - - # for now, assume that the assignment is equal to the bandit key - bandit_data = self.__config_requestor.get_bandit_model(assignment) + # for now, assume that the variation is equal to the bandit key + bandit_data = self.__config_requestor.get_bandit_model(variation) if not bandit_data: logger.warning( f"[Eppo SDK] No assigned action. Bandit not found for flag: {flag_key}" ) - return BanditResult(None, assignment) + return BanditResult(variation, None) evaluation = self.__bandit_evaluator.evaluate_bandit( flag_key, @@ -308,7 +306,7 @@ def get_bandit_action( } self.__assignment_logger.log_bandit_action(bandit_event) - return BanditResult(evaluation.action_key if evaluation else None, assignment) + return BanditResult(variation, evaluation.action_key if evaluation else None) def get_flag_keys(self): """ diff --git a/test/client_bandit_test.py b/test/client_bandit_test.py index f50c93e..ce98584 100644 --- a/test/client_bandit_test.py +++ b/test/client_bandit_test.py @@ -25,6 +25,8 @@ test_case_dict = json.load(test_case_json) test_data.append(test_case_dict) +print(test_data) + MOCK_BASE_URL = "http://localhost:4001/api" DEFAULT_SUBJECT_ATTRIBUTES = Attributes( @@ -90,7 +92,7 @@ def test_get_bandit_action_bandit_does_not_exist(): "default_variation", ) print(result) - assert result == BanditResult(None, "default_variation") + assert result == BanditResult("default_variation", None) def test_get_bandit_action_flag_without_bandit(): @@ -98,7 +100,7 @@ def test_get_bandit_action_flag_without_bandit(): result = client.get_bandit_action( "a_flag", "subject_key", DEFAULT_SUBJECT_ATTRIBUTES, [], "default_variation" ) - assert result == BanditResult(None, "default_variation") + assert result == BanditResult("default_variation", None) def test_get_bandit_action_with_subject_attributes(): @@ -111,5 +113,45 @@ def test_get_bandit_action_with_subject_attributes(): [ActionContext.create("adidas", {}, {}), ActionContext.create("nike", {}, {})], "default_variation", ) - assert result.assignment == "banner_bandit" + assert result.variation == "banner_bandit" assert result.action in ["adidas", "nike"] + + +@pytest.mark.parametrize("test_case", test_data) +def test_bandit_generic_test_cases(test_case): + client = get_instance() + + flag = test_case["flag"] + default_value = test_case["defaultValue"] + + for subject in test_case["subjects"]: + result = client.get_bandit_action( + flag, + subject["subjectKey"], + Attributes( + numeric_attributes=subject["subjectAttributes"]["numeric_attributes"], + categorical_attributes=subject["subjectAttributes"][ + "categorical_attributes" + ], + ), + [ + ActionContext.create( + action["actionKey"], + action["numericAttributes"], + action["categoricalAttributes"], + ) + for action in subject["actions"] + ], + default_value, + ) + + expected_result = BanditResult( + subject["assignment"]["variation"], subject["assignment"]["action"] + ) + + assert ( + result.variation == expected_result.variation + ), f"Flag {flag} failed for subject {subject['subjectKey']}: expected assignment {expected_result.variation}, got {result.variation}" + assert ( + result.action == expected_result.action + ), f"Flag {flag} failed for subject {subject['subjectKey']}: expected action {expected_result.action}, got {result.action}" From c7be8fca09430a7794cae304a6e30ab856afc639 Mon Sep 17 00:00:00 2001 From: Sven Schmit Date: Tue, 28 May 2024 16:15:24 -0700 Subject: [PATCH 05/10] :broom: --- eppo_client/bandit.py | 8 +++++--- eppo_client/client.py | 17 ++++++++++++----- test/bandit_test.py | 3 ++- test/client_bandit_test.py | 14 ++++++++------ 4 files changed, 27 insertions(+), 15 deletions(-) diff --git a/eppo_client/bandit.py b/eppo_client/bandit.py index 26adc73..b554d16 100644 --- a/eppo_client/bandit.py +++ b/eppo_client/bandit.py @@ -70,7 +70,7 @@ class BanditEvaluation: @dataclass class BanditResult: variation: str - action: str + action: Optional[str] def null_evaluation( @@ -180,7 +180,7 @@ def weight_actions( weights.append((best_action, remaining_weight)) return weights - def select_action(self, flag_key, subject_key, action_weights) -> Tuple[str, float]: + def select_action(self, flag_key, subject_key, action_weights) -> Tuple[int, str]: # deterministic ordering sorted_action_weights = sorted( action_weights, @@ -200,7 +200,9 @@ def select_action(self, flag_key, subject_key, action_weights) -> Tuple[str, flo return idx, action_key # If no action is selected, return the last action (fallback) - return (len(sorted_action_weights) - 1, sorted_action_weights[-1][0]) + action_index = len(sorted_action_weights) - 1 + action_key = sorted_action_weights[action_index][0] + return action_index, action_key def score_action( diff --git a/eppo_client/client.py b/eppo_client/client.py index 9208745..64736a7 100644 --- a/eppo_client/client.py +++ b/eppo_client/client.py @@ -254,8 +254,9 @@ def get_bandit_action( - assignment (str): The assignment key indicating the subject's variation. """ # get experiment assignment + # ignoring type because Dict[str, str] satisfies Dict[str, str | ...] but mypy does not understand variation = self.get_string_assignment( - flag_key, subject_key, subject_attributes.categorical_attributes, default + flag_key, subject_key, subject_attributes.categorical_attributes, default # type: ignore ) # if the variation is not the bandit key, then the subject is not allocated in the bandit @@ -289,17 +290,23 @@ def get_bandit_action( "modelVersion": bandit_data.model_version if evaluation else None, "timestamp": datetime.datetime.utcnow().isoformat(), "subjectNumericAttributes": ( - subject_attributes.numeric_attributes if evaluation else None + subject_attributes.numeric_attributes + if evaluation.subject_attributes + else None ), "subjectCategoricalAttributes": ( - subject_attributes.categorical_attributes if evaluation else None + subject_attributes.categorical_attributes + if evaluation.subject_attributes + else None ), "actionNumericAttributes": ( - evaluation.action_attributes.numeric_attributes if evaluation else None + evaluation.action_attributes.numeric_attributes + if evaluation.action_attributes + else None ), "actionCategoricalAttributes": ( evaluation.action_attributes.categorical_attributes - if evaluation + if evaluation.action_attributes else None ), "metaData": {"sdkLanguage": "python", "sdkVersion": __version__}, diff --git a/test/bandit_test.py b/test/bandit_test.py index 5c5e822..959433c 100644 --- a/test/bandit_test.py +++ b/test/bandit_test.py @@ -3,6 +3,8 @@ from eppo_client.sharders import MD5Sharder, DeterministicSharder from eppo_client.bandit import ( + ActionContext, + Attributes, score_numeric_attributes, score_categorical_attributes, BanditEvaluator, @@ -13,7 +15,6 @@ BanditNumericAttributeCoefficient, BanditCategoricalAttributeCoefficient, ) -from eppo_client.bandit import BanditEvaluator, ActionContext, Attributes bandit_evaluator = BanditEvaluator(MD5Sharder(), 10_000) diff --git a/test/client_bandit_test.py b/test/client_bandit_test.py index ce98584..ae19636 100644 --- a/test/client_bandit_test.py +++ b/test/client_bandit_test.py @@ -149,9 +149,11 @@ def test_bandit_generic_test_cases(test_case): subject["assignment"]["variation"], subject["assignment"]["action"] ) - assert ( - result.variation == expected_result.variation - ), f"Flag {flag} failed for subject {subject['subjectKey']}: expected assignment {expected_result.variation}, got {result.variation}" - assert ( - result.action == expected_result.action - ), f"Flag {flag} failed for subject {subject['subjectKey']}: expected action {expected_result.action}, got {result.action}" + assert result.variation == expected_result.variation, ( + f"Flag {flag} failed for subject {subject['subjectKey']}:" + f"expected assignment {expected_result.variation}, got {result.variation}" + ) + assert result.action == expected_result.action, ( + f"Flag {flag} failed for subject {subject['subjectKey']}:" + f"expected action {expected_result.action}, got {result.action}" + ) From 11c18f9208e645f1d4b3cd1407b35f208d48d244 Mon Sep 17 00:00:00 2001 From: Sven Schmit Date: Tue, 28 May 2024 16:15:42 -0700 Subject: [PATCH 06/10] bump version --- eppo_client/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eppo_client/version.py b/eppo_client/version.py index 8d1c862..f5f41e5 100644 --- a/eppo_client/version.py +++ b/eppo_client/version.py @@ -1 +1 @@ -__version__ = "3.0.3" +__version__ = "3.1.0" From ccc4671d980fcbb3a857e9323c9c49126853e7ba Mon Sep 17 00:00:00 2001 From: Sven Schmit Date: Wed, 29 May 2024 13:43:59 -0700 Subject: [PATCH 07/10] address Giorgio's comments --- eppo_client/bandit.py | 20 +++++------ eppo_client/client.py | 2 +- eppo_client/configuration_requestor.py | 3 -- test/bandit_test.py | 46 ++++++++++++++++---------- 4 files changed, 38 insertions(+), 33 deletions(-) diff --git a/eppo_client/bandit.py b/eppo_client/bandit.py index b554d16..bd7bae6 100644 --- a/eppo_client/bandit.py +++ b/eppo_client/bandit.py @@ -111,7 +111,7 @@ def evaluate_bandit( subject_attributes, actions_with_contexts, bandit_model ) - action_weights = self.weight_actions( + action_weights = self.weigh_actions( action_scores, bandit_model.gamma, bandit_model.action_probability_floor, @@ -153,7 +153,7 @@ def score_actions( for action_context in actions_with_contexts ] - def weight_actions( + def weigh_actions( self, action_scores, gamma, probability_floor ) -> List[Tuple[str, float]]: number_of_actions = len(action_scores) @@ -240,12 +240,11 @@ def score_numeric_attributes( ) -> float: score = 0.0 for coefficient in coefficients: - print(coefficient, attributes) - if coefficient.attribute_key in attributes: - score += coefficient.coefficient * coalesce( - attributes[coefficient.attribute_key], - coefficient.missing_value_coefficient, - ) + if ( + coefficient.attribute_key in attributes + and attributes[coefficient.attribute_key] is not None + ): + score += coefficient.coefficient * attributes[coefficient.attribute_key] else: score += coefficient.missing_value_coefficient @@ -260,10 +259,7 @@ def score_categorical_attributes( for coefficient in coefficients: if coefficient.attribute_key in attributes: score += coefficient.value_coefficients.get( - coalesce( - attributes[coefficient.attribute_key], - coefficient.missing_value_coefficient, - ), + attributes[coefficient.attribute_key], coefficient.missing_value_coefficient, ) else: diff --git a/eppo_client/client.py b/eppo_client/client.py index 64736a7..f80ba9f 100644 --- a/eppo_client/client.py +++ b/eppo_client/client.py @@ -250,8 +250,8 @@ def get_bandit_action( Returns: BanditResult: The result containing either the bandit action if the subject is part of the bandit, or the assignment if they are not. The BanditResult includes: + - variation (str): The assignment key indicating the subject's variation. - action (str): The key of the selected action if the subject is part of the bandit. - - assignment (str): The assignment key indicating the subject's variation. """ # get experiment assignment # ignoring type because Dict[str, str] satisfies Dict[str, str | ...] but mypy does not understand diff --git a/eppo_client/configuration_requestor.py b/eppo_client/configuration_requestor.py index 3dfc170..3211475 100644 --- a/eppo_client/configuration_requestor.py +++ b/eppo_client/configuration_requestor.py @@ -52,7 +52,6 @@ def store_flags(self, flag_data) -> Dict[str, Flag]: return flag_configs def store_bandits(self, bandit_data) -> Dict[str, BanditData]: - print(bandit_data) bandit_configs = { config["banditKey"]: BanditData(**config) for config in cast(dict, bandit_data.get("bandits", [])) @@ -62,7 +61,6 @@ def store_bandits(self, bandit_data) -> Dict[str, BanditData]: return bandit_configs def fetch_and_store_configurations(self): - print("fetch and store...") try: flag_data = self.fetch_flags() self.store_flags(flag_data) @@ -73,7 +71,6 @@ def fetch_and_store_configurations(self): self.__is_initialized = True except Exception as e: logger.error("Error retrieving configurations: " + str(e)) - print("... done") def is_initialized(self): return self.__is_initialized diff --git a/test/bandit_test.py b/test/bandit_test.py index 959433c..d34887a 100644 --- a/test/bandit_test.py +++ b/test/bandit_test.py @@ -161,52 +161,64 @@ def test_score_categorical_attributes_mixed_coefficients(): assert score_categorical_attributes(coefficients, attributes) == expected_score -def test_weight_actions_single_action(): +def test_weigh_actions_single_action(): action_scores = [("action1", 1.0)] gamma = 0.1 probability_floor = 0.1 expected_weights = [("action1", 1.0)] assert ( - bandit_evaluator.weight_actions(action_scores, gamma, probability_floor) + bandit_evaluator.weigh_actions(action_scores, gamma, probability_floor) == expected_weights ) -def test_weight_actions_multiple_actions(): +def test_weigh_actions_multiple_actions(): action_scores = [("action1", 1.0), ("action2", 0.5)] - gamma = 0.1 + gamma = 10 probability_floor = 0.1 - weights = bandit_evaluator.weight_actions(action_scores, gamma, probability_floor) + weights = bandit_evaluator.weigh_actions(action_scores, gamma, probability_floor) assert len(weights) == 2 - assert any(action == "action1" and weight > 0.5 for action, weight in weights) - assert any(action == "action2" and weight <= 0.5 for action, weight in weights) + action_1_weight = next(weight for action, weight in weights if action == "action1") + assert action_1_weight == pytest.approx(6 / 7, rel=1e-6) + action_2_weight = next(weight for action, weight in weights if action == "action2") + assert action_2_weight == pytest.approx(1 / 7, rel=1e-6) def test_weight_actions_probability_floor(): action_scores = [("action1", 1.0), ("action2", 0.5), ("action3", 0.2)] - gamma = 0.1 + gamma = 10 probability_floor = 0.3 - weights = bandit_evaluator.weight_actions(action_scores, gamma, probability_floor) + weights = bandit_evaluator.weigh_actions(action_scores, gamma, probability_floor) assert len(weights) == 3 - for action, weight in weights: - assert weight >= 0.1 + for _, weight in weights: + assert weight == pytest.approx(0.1, rel=1e-6) or weight > 0.1 def test_weight_actions_gamma_effect(): action_scores = [("action1", 1.0), ("action2", 0.5)] - gamma = 1.0 + small_gamma = 1.0 + large_gamma = 10.0 probability_floor = 0.1 - weights = bandit_evaluator.weight_actions(action_scores, gamma, probability_floor) - assert len(weights) == 2 - assert any(action == "action1" and weight > 0.5 for action, weight in weights) - assert any(action == "action2" and weight <= 0.5 for action, weight in weights) + weights_small_gamma = bandit_evaluator.weigh_actions( + action_scores, small_gamma, probability_floor + ) + weights_large_gamma = bandit_evaluator.weigh_actions( + action_scores, large_gamma, probability_floor + ) + + assert next( + weight for action, weight in weights_small_gamma if action == "action1" + ) < next(weight for action, weight in weights_large_gamma if action == "action1") + assert next( + weight for action, weight in weights_small_gamma if action == "action2" + ) > next(weight for action, weight in weights_large_gamma if action == "action2") def test_weight_actions_all_equal_scores(): action_scores = [("action1", 1.0), ("action2", 1.0), ("action3", 1.0)] gamma = 0.1 probability_floor = 0.1 - weights = bandit_evaluator.weight_actions(action_scores, gamma, probability_floor) + weights = bandit_evaluator.weigh_actions(action_scores, gamma, probability_floor) assert len(weights) == 3 for _, weight in weights: assert weight == pytest.approx(1.0 / 3, rel=1e-2) From 064a018979c1782b9f2033efc3c620974f316bfb Mon Sep 17 00:00:00 2001 From: Sven Schmit Date: Wed, 29 May 2024 13:46:27 -0700 Subject: [PATCH 08/10] add comment to probability floor test --- test/bandit_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/bandit_test.py b/test/bandit_test.py index d34887a..a69516d 100644 --- a/test/bandit_test.py +++ b/test/bandit_test.py @@ -190,6 +190,8 @@ def test_weight_actions_probability_floor(): probability_floor = 0.3 weights = bandit_evaluator.weigh_actions(action_scores, gamma, probability_floor) assert len(weights) == 3 + + # note probability floor is normalized by number of actions: 0.3/3 = 0.1 for _, weight in weights: assert weight == pytest.approx(0.1, rel=1e-6) or weight > 0.1 From 99307986f1f0093b88027d6dc022a14d6c5309cf Mon Sep 17 00:00:00 2001 From: Sven Schmit Date: Wed, 29 May 2024 16:38:37 -0700 Subject: [PATCH 09/10] :broom: --- eppo_client/configuration_requestor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eppo_client/configuration_requestor.py b/eppo_client/configuration_requestor.py index 3211475..4dedc91 100644 --- a/eppo_client/configuration_requestor.py +++ b/eppo_client/configuration_requestor.py @@ -56,7 +56,6 @@ def store_bandits(self, bandit_data) -> Dict[str, BanditData]: config["banditKey"]: BanditData(**config) for config in cast(dict, bandit_data.get("bandits", [])) } - print(bandit_configs) self.__bandit_config_store.set_configurations(bandit_configs) return bandit_configs @@ -67,6 +66,7 @@ def fetch_and_store_configurations(self): if flag_data.get("bandits", {}): bandit_data = self.fetch_bandits() + print(bandit_data) self.store_bandits(bandit_data) self.__is_initialized = True except Exception as e: From 0e3d6b534d75b9e58f5bc2993939aef381ec4374 Mon Sep 17 00:00:00 2001 From: Sven Schmit Date: Wed, 29 May 2024 22:09:34 -0700 Subject: [PATCH 10/10] address Aaron's comments --- eppo_client/bandit.py | 27 +++++++++++++++++++++------ eppo_client/client.py | 30 ++++++++++++++++++++++++++---- test/client_bandit_test.py | 4 ---- 3 files changed, 47 insertions(+), 14 deletions(-) diff --git a/eppo_client/bandit.py b/eppo_client/bandit.py index bd7bae6..581a3bf 100644 --- a/eppo_client/bandit.py +++ b/eppo_client/bandit.py @@ -1,5 +1,7 @@ from dataclasses import dataclass +import logging from typing import Dict, List, Optional, Tuple + from eppo_client.models import ( BanditCategoricalAttributeCoefficient, BanditCoefficients, @@ -9,6 +11,13 @@ from eppo_client.sharders import Sharder +logger = logging.getLogger(__name__) + + +class BanditEvaluationError(Exception): + pass + + @dataclass class Attributes: numeric_attributes: Dict[str, float] @@ -72,6 +81,9 @@ class BanditResult: variation: str action: Optional[str] + def to_string(self) -> str: + return coalesce(self.action, self.variation) + def null_evaluation( flag_key: str, subject_key: str, subject_attributes: Attributes, gamma: float @@ -176,7 +188,7 @@ def weigh_actions( ] # remaining weight goes to best action - remaining_weight = 1.0 - sum(weight for _, weight in weights) + remaining_weight = max(0.0, 1.0 - sum(weight for _, weight in weights)) weights.append((best_action, remaining_weight)) return weights @@ -184,8 +196,11 @@ def select_action(self, flag_key, subject_key, action_weights) -> Tuple[int, str # deterministic ordering sorted_action_weights = sorted( action_weights, - key=lambda t: self.sharder.get_shard( - f"{flag_key}-{subject_key}-{t[0]}", self.total_shards + key=lambda t: ( + self.sharder.get_shard( + f"{flag_key}-{subject_key}-{t[0]}", self.total_shards + ), + t[0], # tie-break using action name ), ) @@ -200,9 +215,9 @@ def select_action(self, flag_key, subject_key, action_weights) -> Tuple[int, str return idx, action_key # If no action is selected, return the last action (fallback) - action_index = len(sorted_action_weights) - 1 - action_key = sorted_action_weights[action_index][0] - return action_index, action_key + raise BanditEvaluationError( + f"[Eppo SDK] No action selected for {flag_key} {subject_key}" + ) def score_action( diff --git a/eppo_client/client.py b/eppo_client/client.py index f80ba9f..b4f5790 100644 --- a/eppo_client/client.py +++ b/eppo_client/client.py @@ -253,6 +253,28 @@ def get_bandit_action( - variation (str): The assignment key indicating the subject's variation. - action (str): The key of the selected action if the subject is part of the bandit. """ + try: + return self.get_bandit_action_detail( + flag_key, + subject_key, + subject_attributes, + actions_with_contexts, + default, + ) + except Exception as e: + if self.__is_graceful_mode: + logger.error("[Eppo SDK] Error getting bandit action: " + str(e)) + return BanditResult(default, None) + raise e + + def get_bandit_action_detail( + self, + flag_key: str, + subject_key: str, + subject_attributes: Attributes, + actions_with_contexts: List[ActionContext], + default: str, + ) -> BanditResult: # get experiment assignment # ignoring type because Dict[str, str] satisfies Dict[str, str | ...] but mypy does not understand variation = self.get_string_assignment( @@ -292,22 +314,22 @@ def get_bandit_action( "subjectNumericAttributes": ( subject_attributes.numeric_attributes if evaluation.subject_attributes - else None + else {} ), "subjectCategoricalAttributes": ( subject_attributes.categorical_attributes if evaluation.subject_attributes - else None + else {} ), "actionNumericAttributes": ( evaluation.action_attributes.numeric_attributes if evaluation.action_attributes - else None + else {} ), "actionCategoricalAttributes": ( evaluation.action_attributes.categorical_attributes if evaluation.action_attributes - else None + else {} ), "metaData": {"sdkLanguage": "python", "sdkVersion": __version__}, } diff --git a/test/client_bandit_test.py b/test/client_bandit_test.py index ae19636..d13038d 100644 --- a/test/client_bandit_test.py +++ b/test/client_bandit_test.py @@ -25,7 +25,6 @@ test_case_dict = json.load(test_case_json) test_data.append(test_case_dict) -print(test_data) MOCK_BASE_URL = "http://localhost:4001/api" @@ -69,8 +68,6 @@ def init_fixture(): ) ) sleep(0.1) # wait for initialization - print(client.get_flag_keys()) - print(client.get_bandit_keys()) yield client._shutdown() httpretty.disable() @@ -91,7 +88,6 @@ def test_get_bandit_action_bandit_does_not_exist(): [], "default_variation", ) - print(result) assert result == BanditResult("default_variation", None)