diff --git a/eppo_client/__init__.py b/eppo_client/__init__.py index 884ae50..91ee517 100644 --- a/eppo_client/__init__.py +++ b/eppo_client/__init__.py @@ -6,7 +6,7 @@ ) from eppo_client.configuration_store import ConfigurationStore from eppo_client.http_client import HttpClient, SdkParams -from eppo_client.models import Flag +from eppo_client.models import BanditData, Flag from eppo_client.read_write_lock import ReadWriteLock from eppo_client.version import __version__ @@ -30,9 +30,12 @@ def init(config: Config) -> EppoClient: apiKey=config.api_key, sdkName="python", sdkVersion=__version__ ) http_client = HttpClient(base_url=config.base_url, sdk_params=sdk_params) - config_store: ConfigurationStore[Flag] = ConfigurationStore() + flag_config_store: ConfigurationStore[Flag] = ConfigurationStore() + bandit_config_store: ConfigurationStore[BanditData] = ConfigurationStore() config_requestor = ExperimentConfigurationRequestor( - http_client=http_client, config_store=config_store + http_client=http_client, + flag_config_store=flag_config_store, + bandit_config_store=bandit_config_store, ) assignment_logger = config.assignment_logger is_graceful_mode = config.is_graceful_mode diff --git a/eppo_client/assignment_logger.py b/eppo_client/assignment_logger.py index 309b26a..cc17e76 100644 --- a/eppo_client/assignment_logger.py +++ b/eppo_client/assignment_logger.py @@ -8,3 +8,6 @@ class AssignmentLogger(BaseModel): def log_assignment(self, assignment_event: Dict): pass + + def log_bandit_action(self, bandit_event: Dict): + pass diff --git a/eppo_client/bandit.py b/eppo_client/bandit.py new file mode 100644 index 0000000..581a3bf --- /dev/null +++ b/eppo_client/bandit.py @@ -0,0 +1,282 @@ +from dataclasses import dataclass +import logging +from typing import Dict, List, Optional, Tuple + +from eppo_client.models import ( + BanditCategoricalAttributeCoefficient, + BanditCoefficients, + BanditModelData, + BanditNumericAttributeCoefficient, +) +from eppo_client.sharders import Sharder + + +logger = logging.getLogger(__name__) + + +class BanditEvaluationError(Exception): + pass + + +@dataclass +class Attributes: + numeric_attributes: Dict[str, float] + categorical_attributes: Dict[str, str] + + +@dataclass +class ActionContext: + action_key: str + attributes: Attributes + + @classmethod + def create( + cls, + action_key: str, + numeric_attributes: Dict[str, float], + categorical_attributes: Dict[str, str], + ): + """ + Create an instance of ActionContext. + + Args: + action_key (str): The key representing the action. + numeric_attributes (Dict[str, float]): A dictionary of numeric attributes. + categorical_attributes (Dict[str, str]): A dictionary of categorical attributes. + + Returns: + ActionContext: An instance of ActionContext with the provided action key and attributes. + """ + return cls( + action_key, + Attributes( + numeric_attributes=numeric_attributes, + categorical_attributes=categorical_attributes, + ), + ) + + @property + def numeric_attributes(self): + return self.attributes.numeric_attributes + + @property + def categorical_attributes(self): + return self.attributes.categorical_attributes + + +@dataclass +class BanditEvaluation: + flag_key: str + subject_key: str + subject_attributes: Attributes + action_key: Optional[str] + action_attributes: Optional[Attributes] + action_score: float + action_weight: float + gamma: float + + +@dataclass +class BanditResult: + variation: str + action: Optional[str] + + def to_string(self) -> str: + return coalesce(self.action, self.variation) + + +def null_evaluation( + flag_key: str, subject_key: str, subject_attributes: Attributes, gamma: float +): + return BanditEvaluation( + flag_key, + subject_key, + subject_attributes, + None, + None, + 0.0, + 0.0, + gamma, + ) + + +@dataclass +class BanditEvaluator: + sharder: Sharder + total_shards: int = 10_000 + + def evaluate_bandit( + self, + flag_key: str, + subject_key: str, + subject_attributes: Attributes, + actions_with_contexts: List[ActionContext], + bandit_model: BanditModelData, + ) -> BanditEvaluation: + # handle the edge case that there are no actions + if not actions_with_contexts: + return null_evaluation( + flag_key, subject_key, subject_attributes, bandit_model.gamma + ) + + action_scores = self.score_actions( + subject_attributes, actions_with_contexts, bandit_model + ) + + action_weights = self.weigh_actions( + action_scores, + bandit_model.gamma, + bandit_model.action_probability_floor, + ) + + selected_idx, selected_action = self.select_action( + flag_key, subject_key, action_weights + ) + return BanditEvaluation( + flag_key, + subject_key, + subject_attributes, + selected_action, + actions_with_contexts[selected_idx].attributes, + action_scores[selected_idx][1], + action_weights[selected_idx][1], + bandit_model.gamma, + ) + + def score_actions( + self, + subject_attributes: Attributes, + actions_with_contexts: List[ActionContext], + bandit_model: BanditModelData, + ) -> List[Tuple[str, float]]: + return [ + ( + action_context.action_key, + ( + score_action( + subject_attributes, + action_context.attributes, + bandit_model.coefficients[action_context.action_key], + ) + if action_context.action_key in bandit_model.coefficients + else bandit_model.default_action_score + ), + ) + for action_context in actions_with_contexts + ] + + def weigh_actions( + self, action_scores, gamma, probability_floor + ) -> List[Tuple[str, float]]: + number_of_actions = len(action_scores) + best_action, best_score = max(action_scores, key=lambda t: t[1]) + + # adjust probability floor for number of actions to control the sum + min_probability = probability_floor / number_of_actions + + # weight all but the best action + weights = [ + ( + action_key, + max( + min_probability, + 1.0 / (number_of_actions + gamma * (best_score - score)), + ), + ) + for action_key, score in action_scores + if action_key != best_action + ] + + # remaining weight goes to best action + remaining_weight = max(0.0, 1.0 - sum(weight for _, weight in weights)) + weights.append((best_action, remaining_weight)) + return weights + + def select_action(self, flag_key, subject_key, action_weights) -> Tuple[int, str]: + # deterministic ordering + sorted_action_weights = sorted( + action_weights, + key=lambda t: ( + self.sharder.get_shard( + f"{flag_key}-{subject_key}-{t[0]}", self.total_shards + ), + t[0], # tie-break using action name + ), + ) + + # select action based on weights + shard = self.sharder.get_shard(f"{flag_key}-{subject_key}", self.total_shards) + cumulative_weight = 0.0 + shard_value = shard / self.total_shards + + for idx, (action_key, weight) in enumerate(sorted_action_weights): + cumulative_weight += weight + if cumulative_weight > shard_value: + return idx, action_key + + # If no action is selected, return the last action (fallback) + raise BanditEvaluationError( + f"[Eppo SDK] No action selected for {flag_key} {subject_key}" + ) + + +def score_action( + subject_attributes: Attributes, + action_attributes: Attributes, + coefficients: BanditCoefficients, +) -> float: + score = coefficients.intercept + score += score_numeric_attributes( + coefficients.subject_numeric_coefficients, + subject_attributes.numeric_attributes, + ) + score += score_categorical_attributes( + coefficients.subject_categorical_coefficients, + subject_attributes.categorical_attributes, + ) + score += score_numeric_attributes( + coefficients.action_numeric_coefficients, + action_attributes.numeric_attributes, + ) + score += score_categorical_attributes( + coefficients.action_categorical_coefficients, + action_attributes.categorical_attributes, + ) + return score + + +def coalesce(value, default=0): + return value if value is not None else default + + +def score_numeric_attributes( + coefficients: List[BanditNumericAttributeCoefficient], + attributes: Dict[str, float], +) -> float: + score = 0.0 + for coefficient in coefficients: + if ( + coefficient.attribute_key in attributes + and attributes[coefficient.attribute_key] is not None + ): + score += coefficient.coefficient * attributes[coefficient.attribute_key] + else: + score += coefficient.missing_value_coefficient + + return score + + +def score_categorical_attributes( + coefficients: List[BanditCategoricalAttributeCoefficient], + attributes: Dict[str, str], +) -> float: + score = 0.0 + for coefficient in coefficients: + if coefficient.attribute_key in attributes: + score += coefficient.value_coefficients.get( + attributes[coefficient.attribute_key], + coefficient.missing_value_coefficient, + ) + else: + score += coefficient.missing_value_coefficient + return score diff --git a/eppo_client/client.py b/eppo_client/client.py index 2200bc6..b4f5790 100644 --- a/eppo_client/client.py +++ b/eppo_client/client.py @@ -1,8 +1,9 @@ import datetime import logging import json -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from eppo_client.assignment_logger import AssignmentLogger +from eppo_client.bandit import BanditEvaluator, BanditResult, ActionContext, Attributes from eppo_client.configuration_requestor import ( ExperimentConfigurationRequestor, ) @@ -36,6 +37,7 @@ def __init__( ) self.__poller.start() self.__evaluator = Evaluator(sharder=MD5Sharder()) + self.__bandit_evaluator = BanditEvaluator(sharder=MD5Sharder()) def get_string_assignment( self, @@ -219,6 +221,122 @@ def get_assignment_detail( logger.error("[Eppo SDK] Error logging assignment event: " + str(e)) return result + def get_bandit_action( + self, + flag_key: str, + subject_key: str, + subject_attributes: Attributes, + actions_with_contexts: List[ActionContext], + default: str, + ) -> BanditResult: + """ + Determines the bandit action for a given subject based on the provided bandit key and subject attributes. + + This method performs the following steps: + 1. Retrieves the experiment assignment for the given bandit key and subject. + 2. Checks if the assignment matches the bandit key. If not, it means the subject is not allocated in the bandit, + and the method returns a BanditResult with the assignment. + 3. If the subject is part of the bandit, it fetches the bandit model data. + 4. Evaluates the bandit action using the bandit evaluator. + 5. Logs the bandit action event. + 6. Returns the BanditResult containing the selected action key and the assignment. + + Args: + flag_key (str): The feature flag key that contains the bandit as one of the variations. + subject_key (str): The key identifying the subject. + subject_attributes (Attributes): The attributes of the subject. + actions_with_contexts (List[ActionContext]): The list of actions with their contexts. + + Returns: + BanditResult: The result containing either the bandit action if the subject is part of the bandit, + or the assignment if they are not. The BanditResult includes: + - variation (str): The assignment key indicating the subject's variation. + - action (str): The key of the selected action if the subject is part of the bandit. + """ + try: + return self.get_bandit_action_detail( + flag_key, + subject_key, + subject_attributes, + actions_with_contexts, + default, + ) + except Exception as e: + if self.__is_graceful_mode: + logger.error("[Eppo SDK] Error getting bandit action: " + str(e)) + return BanditResult(default, None) + raise e + + def get_bandit_action_detail( + self, + flag_key: str, + subject_key: str, + subject_attributes: Attributes, + actions_with_contexts: List[ActionContext], + default: str, + ) -> BanditResult: + # get experiment assignment + # ignoring type because Dict[str, str] satisfies Dict[str, str | ...] but mypy does not understand + variation = self.get_string_assignment( + flag_key, subject_key, subject_attributes.categorical_attributes, default # type: ignore + ) + + # if the variation is not the bandit key, then the subject is not allocated in the bandit + if variation not in self.get_bandit_keys(): + return BanditResult(variation, None) + + # for now, assume that the variation is equal to the bandit key + bandit_data = self.__config_requestor.get_bandit_model(variation) + + if not bandit_data: + logger.warning( + f"[Eppo SDK] No assigned action. Bandit not found for flag: {flag_key}" + ) + return BanditResult(variation, None) + + evaluation = self.__bandit_evaluator.evaluate_bandit( + flag_key, + subject_key, + subject_attributes, + actions_with_contexts, + bandit_data.model_data, + ) + + # log bandit action + bandit_event = { + "flagKey": flag_key, + "banditKey": bandit_data.bandit_key, + "subject": subject_key, + "action": evaluation.action_key if evaluation else None, + "actionProbability": evaluation.action_weight if evaluation else None, + "modelVersion": bandit_data.model_version if evaluation else None, + "timestamp": datetime.datetime.utcnow().isoformat(), + "subjectNumericAttributes": ( + subject_attributes.numeric_attributes + if evaluation.subject_attributes + else {} + ), + "subjectCategoricalAttributes": ( + subject_attributes.categorical_attributes + if evaluation.subject_attributes + else {} + ), + "actionNumericAttributes": ( + evaluation.action_attributes.numeric_attributes + if evaluation.action_attributes + else {} + ), + "actionCategoricalAttributes": ( + evaluation.action_attributes.categorical_attributes + if evaluation.action_attributes + else {} + ), + "metaData": {"sdkLanguage": "python", "sdkVersion": __version__}, + } + self.__assignment_logger.log_bandit_action(bandit_event) + + return BanditResult(variation, evaluation.action_key if evaluation else None) + def get_flag_keys(self): """ Returns a list of all flag keys that have been initialized. @@ -228,6 +346,13 @@ def get_flag_keys(self): """ return self.__config_requestor.get_flag_keys() + def get_bandit_keys(self): + """ + Returns a list of all bandit keys that have been initialized. + This can be useful to debug the initialization process. + """ + return self.__config_requestor.get_bandit_keys() + def is_initialized(self): """ Returns True if the client has successfully initialized diff --git a/eppo_client/configuration_requestor.py b/eppo_client/configuration_requestor.py index 9b1d5dd..4dedc91 100644 --- a/eppo_client/configuration_requestor.py +++ b/eppo_client/configuration_requestor.py @@ -2,44 +2,75 @@ from typing import Dict, Optional, cast from eppo_client.configuration_store import ConfigurationStore from eppo_client.http_client import HttpClient -from eppo_client.models import Flag +from eppo_client.models import BanditData, Flag logger = logging.getLogger(__name__) UFC_ENDPOINT = "/flag-config/v1/config" +BANDIT_ENDPOINT = "/flag-config/v1/bandits" class ExperimentConfigurationRequestor: def __init__( self, http_client: HttpClient, - config_store: ConfigurationStore[Flag], + flag_config_store: ConfigurationStore[Flag], + bandit_config_store: ConfigurationStore[BanditData], ): self.__http_client = http_client - self.__config_store = config_store + self.__flag_config_store = flag_config_store + self.__bandit_config_store = bandit_config_store self.__is_initialized = False def get_configuration(self, flag_key: str) -> Optional[Flag]: if self.__http_client.is_unauthorized(): raise ValueError("Unauthorized: please check your API key") - return self.__config_store.get_configuration(flag_key) + return self.__flag_config_store.get_configuration(flag_key) + + def get_bandit_model(self, bandit_key: str) -> Optional[BanditData]: + if self.__http_client.is_unauthorized(): + raise ValueError("Unauthorized: please check your API key") + return self.__bandit_config_store.get_configuration(bandit_key) def get_flag_keys(self): - return self.__config_store.get_keys() + return self.__flag_config_store.get_keys() + + def get_bandit_keys(self): + return self.__bandit_config_store.get_keys() + + def fetch_flags(self): + return self.__http_client.get(UFC_ENDPOINT) - def fetch_and_store_configurations(self) -> Dict[str, Flag]: + def fetch_bandits(self): + return self.__http_client.get(BANDIT_ENDPOINT) + + def store_flags(self, flag_data) -> Dict[str, Flag]: + flag_config_dict = cast(dict, flag_data.get("flags", {})) + flag_configs = {key: Flag(**config) for key, config in flag_config_dict.items()} + self.__flag_config_store.set_configurations(flag_configs) + return flag_configs + + def store_bandits(self, bandit_data) -> Dict[str, BanditData]: + bandit_configs = { + config["banditKey"]: BanditData(**config) + for config in cast(dict, bandit_data.get("bandits", [])) + } + self.__bandit_config_store.set_configurations(bandit_configs) + return bandit_configs + + def fetch_and_store_configurations(self): try: - configs_dict = cast( - dict, self.__http_client.get(UFC_ENDPOINT).get("flags", {}) - ) - configs = {key: Flag(**config) for key, config in configs_dict.items()} - self.__config_store.set_configurations(configs) + flag_data = self.fetch_flags() + self.store_flags(flag_data) + + if flag_data.get("bandits", {}): + bandit_data = self.fetch_bandits() + print(bandit_data) + self.store_bandits(bandit_data) self.__is_initialized = True - return configs except Exception as e: - logger.error("Error retrieving flag configurations: " + str(e)) - return {} + logger.error("Error retrieving configurations: " + str(e)) def is_initialized(self): return self.__is_initialized diff --git a/eppo_client/configuration_store.py b/eppo_client/configuration_store.py index bdd57eb..87b7be5 100644 --- a/eppo_client/configuration_store.py +++ b/eppo_client/configuration_store.py @@ -20,4 +20,4 @@ def set_configurations(self, configs: Dict[str, T]): def get_keys(self): with self.__lock.reader(): - return list(self.__cache.keys()) + return set(self.__cache.keys()) diff --git a/eppo_client/models.py b/eppo_client/models.py index 601bc86..ea4e75d 100644 --- a/eppo_client/models.py +++ b/eppo_client/models.py @@ -4,7 +4,7 @@ from eppo_client.base_model import SdkBaseModel from eppo_client.rules import Rule -from eppo_client.types import ValueType +from eppo_client.types import Action, ValueType class VariationType(Enum): @@ -52,3 +52,51 @@ class Flag(SdkBaseModel): variations: Dict[str, Variation] allocations: List[Allocation] total_shards: int = 10_000 + + +class BanditVariation(SdkBaseModel): + key: str + flag_key: str + variation_key: str + variation_value: str + + +class BanditNumericAttributeCoefficient(SdkBaseModel): + attribute_key: str + coefficient: float + missing_value_coefficient: float + + +class ValueCoefficient(SdkBaseModel): + value: str + coefficient: float + + +class BanditCategoricalAttributeCoefficient(SdkBaseModel): + attribute_key: str + missing_value_coefficient: float + value_coefficients: Dict[str, float] + + +class BanditCoefficients(SdkBaseModel): + action_key: str + intercept: float + subject_numeric_coefficients: List[BanditNumericAttributeCoefficient] + subject_categorical_coefficients: List[BanditCategoricalAttributeCoefficient] + action_numeric_coefficients: List[BanditNumericAttributeCoefficient] + action_categorical_coefficients: List[BanditCategoricalAttributeCoefficient] + + +class BanditModelData(SdkBaseModel): + gamma: float + default_action_score: float + action_probability_floor: float + coefficients: Dict[Action, BanditCoefficients] + + +class BanditData(SdkBaseModel): + bandit_key: str + model_name: str + updated_at: datetime + model_version: str + model_data: BanditModelData diff --git a/eppo_client/types.py b/eppo_client/types.py index dcc4a76..07f126a 100644 --- a/eppo_client/types.py +++ b/eppo_client/types.py @@ -4,3 +4,4 @@ AttributeType = Union[str, int, float, bool] ConditionValueType = Union[AttributeType, List[AttributeType]] SubjectAttributes = Dict[str, AttributeType] +Action = str diff --git a/eppo_client/version.py b/eppo_client/version.py index 8d1c862..f5f41e5 100644 --- a/eppo_client/version.py +++ b/eppo_client/version.py @@ -1 +1 @@ -__version__ = "3.0.3" +__version__ = "3.1.0" diff --git a/test/bandit_test.py b/test/bandit_test.py new file mode 100644 index 0000000..a69516d --- /dev/null +++ b/test/bandit_test.py @@ -0,0 +1,335 @@ +import pytest + +from eppo_client.sharders import MD5Sharder, DeterministicSharder + +from eppo_client.bandit import ( + ActionContext, + Attributes, + score_numeric_attributes, + score_categorical_attributes, + BanditEvaluator, +) +from eppo_client.models import ( + BanditCoefficients, + BanditModelData, + BanditNumericAttributeCoefficient, + BanditCategoricalAttributeCoefficient, +) + +bandit_evaluator = BanditEvaluator(MD5Sharder(), 10_000) + + +def test_score_numeric_attributes_all_present(): + coefficients = [ + BanditNumericAttributeCoefficient( + attribute_key="age", coefficient=2.0, missing_value_coefficient=0.5 + ), + BanditNumericAttributeCoefficient( + attribute_key="height", coefficient=1.5, missing_value_coefficient=0.3 + ), + ] + attributes = {"age": 30, "height": 170} + expected_score = 30 * 2.0 + 170 * 1.5 + assert score_numeric_attributes(coefficients, attributes) == expected_score + + +def test_score_numeric_attributes_some_missing(): + coefficients = [ + BanditNumericAttributeCoefficient( + attribute_key="age", coefficient=2.0, missing_value_coefficient=0.5 + ), + BanditNumericAttributeCoefficient( + attribute_key="height", coefficient=1.5, missing_value_coefficient=0.3 + ), + ] + attributes = {"age": 30} + expected_score = 30 * 2.0 + 0.3 # height is missing, use missing_value_coefficient + assert score_numeric_attributes(coefficients, attributes) == expected_score + + +def test_score_numeric_attributes_all_missing(): + coefficients = [ + BanditNumericAttributeCoefficient( + attribute_key="age", coefficient=2.0, missing_value_coefficient=0.5 + ), + BanditNumericAttributeCoefficient( + attribute_key="height", coefficient=1.5, missing_value_coefficient=0.3 + ), + ] + attributes = {} + expected_score = 0.5 + 0.3 # both are missing, use missing_value_coefficients + assert score_numeric_attributes(coefficients, attributes) == expected_score + + +def test_score_numeric_attributes_empty_coefficients(): + coefficients = [] + attributes = {"age": 30, "height": 170} + expected_score = 0.0 # no coefficients to apply + assert score_numeric_attributes(coefficients, attributes) == expected_score + + +def test_score_numeric_attributes_negative_coefficients(): + coefficients = [ + BanditNumericAttributeCoefficient( + attribute_key="age", coefficient=-2.0, missing_value_coefficient=0.5 + ), + BanditNumericAttributeCoefficient( + attribute_key="height", coefficient=-1.5, missing_value_coefficient=0.3 + ), + ] + attributes = {"age": 30, "height": 170} + expected_score = 30 * -2.0 + 170 * -1.5 + assert score_numeric_attributes(coefficients, attributes) == expected_score + + +def test_score_categorical_attributes_some_missing(): + coefficients = [ + BanditCategoricalAttributeCoefficient( + attribute_key="color", + missing_value_coefficient=0.2, + value_coefficients={"red": 1.0, "blue": 0.5}, + ), + BanditCategoricalAttributeCoefficient( + attribute_key="size", + missing_value_coefficient=0.3, + value_coefficients={"large": 2.0, "small": 1.0}, + ), + ] + attributes = {"color": "red"} + expected_score = 1.0 + 0.3 # size is missing, use missing_value_coefficient + assert score_categorical_attributes(coefficients, attributes) == expected_score + + +def test_score_categorical_attributes_all_missing(): + coefficients = [ + BanditCategoricalAttributeCoefficient( + attribute_key="color", + missing_value_coefficient=0.2, + value_coefficients={"red": 1.0, "blue": 0.5}, + ), + BanditCategoricalAttributeCoefficient( + attribute_key="size", + missing_value_coefficient=0.3, + value_coefficients={"large": 2.0, "small": 1.0}, + ), + ] + attributes = {} + expected_score = 0.2 + 0.3 # both are missing, use missing_value_coefficients + assert score_categorical_attributes(coefficients, attributes) == expected_score + + +def test_score_categorical_attributes_empty_coefficients(): + coefficients = [] + attributes = {"color": "red", "size": "large"} + expected_score = 0.0 # no coefficients to apply + assert score_categorical_attributes(coefficients, attributes) == expected_score + + +def test_score_categorical_attributes_negative_coefficients(): + coefficients = [ + BanditCategoricalAttributeCoefficient( + attribute_key="color", + missing_value_coefficient=0.2, + value_coefficients={"red": -1.0, "blue": -0.5}, + ), + BanditCategoricalAttributeCoefficient( + attribute_key="size", + missing_value_coefficient=0.3, + value_coefficients={"large": -2.0, "small": -1.0}, + ), + ] + attributes = {"color": "red", "size": "large"} + expected_score = -1.0 + -2.0 + assert score_categorical_attributes(coefficients, attributes) == expected_score + + +def test_score_categorical_attributes_mixed_coefficients(): + coefficients = [ + BanditCategoricalAttributeCoefficient( + attribute_key="color", + missing_value_coefficient=0.2, + value_coefficients={"red": 1.0, "blue": -0.5}, + ), + BanditCategoricalAttributeCoefficient( + attribute_key="size", + missing_value_coefficient=0.3, + value_coefficients={"large": -2.0, "small": 1.0}, + ), + ] + attributes = {"color": "blue", "size": "small"} + expected_score = -0.5 + 1.0 + assert score_categorical_attributes(coefficients, attributes) == expected_score + + +def test_weigh_actions_single_action(): + action_scores = [("action1", 1.0)] + gamma = 0.1 + probability_floor = 0.1 + expected_weights = [("action1", 1.0)] + assert ( + bandit_evaluator.weigh_actions(action_scores, gamma, probability_floor) + == expected_weights + ) + + +def test_weigh_actions_multiple_actions(): + action_scores = [("action1", 1.0), ("action2", 0.5)] + gamma = 10 + probability_floor = 0.1 + weights = bandit_evaluator.weigh_actions(action_scores, gamma, probability_floor) + assert len(weights) == 2 + action_1_weight = next(weight for action, weight in weights if action == "action1") + assert action_1_weight == pytest.approx(6 / 7, rel=1e-6) + action_2_weight = next(weight for action, weight in weights if action == "action2") + assert action_2_weight == pytest.approx(1 / 7, rel=1e-6) + + +def test_weight_actions_probability_floor(): + action_scores = [("action1", 1.0), ("action2", 0.5), ("action3", 0.2)] + gamma = 10 + probability_floor = 0.3 + weights = bandit_evaluator.weigh_actions(action_scores, gamma, probability_floor) + assert len(weights) == 3 + + # note probability floor is normalized by number of actions: 0.3/3 = 0.1 + for _, weight in weights: + assert weight == pytest.approx(0.1, rel=1e-6) or weight > 0.1 + + +def test_weight_actions_gamma_effect(): + action_scores = [("action1", 1.0), ("action2", 0.5)] + small_gamma = 1.0 + large_gamma = 10.0 + probability_floor = 0.1 + weights_small_gamma = bandit_evaluator.weigh_actions( + action_scores, small_gamma, probability_floor + ) + weights_large_gamma = bandit_evaluator.weigh_actions( + action_scores, large_gamma, probability_floor + ) + + assert next( + weight for action, weight in weights_small_gamma if action == "action1" + ) < next(weight for action, weight in weights_large_gamma if action == "action1") + assert next( + weight for action, weight in weights_small_gamma if action == "action2" + ) > next(weight for action, weight in weights_large_gamma if action == "action2") + + +def test_weight_actions_all_equal_scores(): + action_scores = [("action1", 1.0), ("action2", 1.0), ("action3", 1.0)] + gamma = 0.1 + probability_floor = 0.1 + weights = bandit_evaluator.weigh_actions(action_scores, gamma, probability_floor) + assert len(weights) == 3 + for _, weight in weights: + assert weight == pytest.approx(1.0 / 3, rel=1e-2) + + +def test_evaluate_bandit(): + # Mock data + flag_key = "test_flag" + subject_key = "test_subject" + subject_attributes = Attributes( + numeric_attributes={"age": 25.0}, categorical_attributes={"location": "US"} + ) + action_contexts = [ + ActionContext( + action_key="action1", + attributes=Attributes( + numeric_attributes={"price": 10.0}, + categorical_attributes={"category": "A"}, + ), + ), + ActionContext( + action_key="action2", + attributes=Attributes( + numeric_attributes={"price": 20.0}, + categorical_attributes={"category": "B"}, + ), + ), + ] + coefficients = { + "action1": BanditCoefficients( + action_key="action1", + intercept=0.5, + subject_numeric_coefficients=[ + BanditNumericAttributeCoefficient( + attribute_key="age", coefficient=0.1, missing_value_coefficient=0.0 + ) + ], + subject_categorical_coefficients=[ + BanditCategoricalAttributeCoefficient( + attribute_key="location", + missing_value_coefficient=0.0, + value_coefficients={"US": 0.2}, + ) + ], + action_numeric_coefficients=[ + BanditNumericAttributeCoefficient( + attribute_key="price", + coefficient=0.05, + missing_value_coefficient=0.0, + ) + ], + action_categorical_coefficients=[ + BanditCategoricalAttributeCoefficient( + attribute_key="category", + missing_value_coefficient=0.0, + value_coefficients={"A": 0.3}, + ) + ], + ), + "action2": BanditCoefficients( + action_key="action2", + intercept=0.3, + subject_numeric_coefficients=[ + BanditNumericAttributeCoefficient( + attribute_key="age", coefficient=0.1, missing_value_coefficient=0.0 + ) + ], + subject_categorical_coefficients=[ + BanditCategoricalAttributeCoefficient( + attribute_key="location", + missing_value_coefficient=0.0, + value_coefficients={"US": 0.2}, + ) + ], + action_numeric_coefficients=[ + BanditNumericAttributeCoefficient( + attribute_key="price", + coefficient=0.05, + missing_value_coefficient=0.0, + ) + ], + action_categorical_coefficients=[ + BanditCategoricalAttributeCoefficient( + attribute_key="category", + missing_value_coefficient=0.0, + value_coefficients={"B": 0.3}, + ) + ], + ), + } + bandit_model = BanditModelData( + gamma=0.1, + default_action_score=0.0, + action_probability_floor=0.1, + coefficients=coefficients, + ) + + evaluator = BanditEvaluator(sharder=DeterministicSharder({})) + + # Evaluate bandit + evaluation = evaluator.evaluate_bandit( + flag_key, subject_key, subject_attributes, action_contexts, bandit_model + ) + + # Assertions + assert evaluation.flag_key == flag_key + assert evaluation.subject_key == subject_key + assert evaluation.subject_attributes == subject_attributes + assert evaluation.action_key == "action1" + assert evaluation.gamma == bandit_model.gamma + assert evaluation.action_score == 4.0 + assert pytest.approx(evaluation.action_weight, rel=1e-2) == 0.4926 diff --git a/test/client_bandit_test.py b/test/client_bandit_test.py new file mode 100644 index 0000000..d13038d --- /dev/null +++ b/test/client_bandit_test.py @@ -0,0 +1,155 @@ +# Note: contains tests for client.py related to bandits to avoid +# making client_test.py too long. + + +import json +import os +from time import sleep +from typing import Dict +from eppo_client.bandit import BanditResult, ActionContext, Attributes + +import httpretty # type: ignore +import pytest + +from eppo_client.assignment_logger import AssignmentLogger +from eppo_client.configuration_requestor import BANDIT_ENDPOINT, UFC_ENDPOINT +from eppo_client import init, get_instance +from eppo_client.config import Config + +TEST_DIR = "test/test-data/ufc/bandit-tests" +FLAG_CONFIG_FILE = "test/test-data/ufc/bandit-flags-v1.json" +BANDIT_CONFIG_FILE = "test/test-data/ufc/bandit-models-v1.json" +test_data = [] +for file_name in [file for file in os.listdir(TEST_DIR)]: + with open("{}/{}".format(TEST_DIR, file_name)) as test_case_json: + test_case_dict = json.load(test_case_json) + test_data.append(test_case_dict) + + +MOCK_BASE_URL = "http://localhost:4001/api" + +DEFAULT_SUBJECT_ATTRIBUTES = Attributes( + numeric_attributes={"age": 30}, categorical_attributes={"country": "UK"} +) + + +class MockAssignmentLogger(AssignmentLogger): + def log_assignment(self, assignment_event: Dict): + print(f"Assignment Event: {assignment_event}") + + def log_bandit_action(self, bandit_event: Dict): + print(f"Bandit Event: {bandit_event}") + + +@pytest.fixture(scope="session", autouse=True) +def init_fixture(): + httpretty.enable() + with open(FLAG_CONFIG_FILE) as mock_ufc_response: + ufc_json = json.load(mock_ufc_response) + + with open(BANDIT_CONFIG_FILE) as mock_bandit_response: + bandit_json = json.load(mock_bandit_response) + + httpretty.register_uri( + httpretty.GET, + MOCK_BASE_URL + UFC_ENDPOINT, + body=json.dumps(ufc_json), + ) + httpretty.register_uri( + httpretty.GET, + MOCK_BASE_URL + BANDIT_ENDPOINT, + body=json.dumps(bandit_json), + ) + client = init( + Config( + base_url=MOCK_BASE_URL, + api_key="dummy", + assignment_logger=AssignmentLogger(), + ) + ) + sleep(0.1) # wait for initialization + yield + client._shutdown() + httpretty.disable() + httpretty.reset() + + +def test_is_initialized(): + client = get_instance() + assert client.is_initialized(), "Client should be initialized" + + +def test_get_bandit_action_bandit_does_not_exist(): + client = get_instance() + result = client.get_bandit_action( + "nonexistent_bandit", + "subject_key", + DEFAULT_SUBJECT_ATTRIBUTES, + [], + "default_variation", + ) + assert result == BanditResult("default_variation", None) + + +def test_get_bandit_action_flag_without_bandit(): + client = get_instance() + result = client.get_bandit_action( + "a_flag", "subject_key", DEFAULT_SUBJECT_ATTRIBUTES, [], "default_variation" + ) + assert result == BanditResult("default_variation", None) + + +def test_get_bandit_action_with_subject_attributes(): + # tests that allocation filtering based on subject attributes works correctly + client = get_instance() + result = client.get_bandit_action( + "banner_bandit_flag_uk_only", + "subject_key", + DEFAULT_SUBJECT_ATTRIBUTES, + [ActionContext.create("adidas", {}, {}), ActionContext.create("nike", {}, {})], + "default_variation", + ) + assert result.variation == "banner_bandit" + assert result.action in ["adidas", "nike"] + + +@pytest.mark.parametrize("test_case", test_data) +def test_bandit_generic_test_cases(test_case): + client = get_instance() + + flag = test_case["flag"] + default_value = test_case["defaultValue"] + + for subject in test_case["subjects"]: + result = client.get_bandit_action( + flag, + subject["subjectKey"], + Attributes( + numeric_attributes=subject["subjectAttributes"]["numeric_attributes"], + categorical_attributes=subject["subjectAttributes"][ + "categorical_attributes" + ], + ), + [ + ActionContext.create( + action["actionKey"], + action["numericAttributes"], + action["categoricalAttributes"], + ) + for action in subject["actions"] + ], + default_value, + ) + + expected_result = BanditResult( + subject["assignment"]["variation"], subject["assignment"]["action"] + ) + + assert result.variation == expected_result.variation, ( + f"Flag {flag} failed for subject {subject['subjectKey']}:" + f"expected assignment {expected_result.variation}, got {result.variation}" + ) + assert result.action == expected_result.action, ( + f"Flag {flag} failed for subject {subject['subjectKey']}:" + f"expected action {expected_result.action}, got {result.action}" + ) diff --git a/test/configuration_store_test.py b/test/configuration_store_test.py index 72eef3f..3bda8d0 100644 --- a/test/configuration_store_test.py +++ b/test/configuration_store_test.py @@ -26,6 +26,14 @@ def test_get_configuration_known_key(): assert store.get_configuration("flag") == mock_flag +def test_get_keys(): + store.set_configurations({"flag1": mock_flag, "flag2": mock_flag}) + keys = store.get_keys() + assert "flag1" in keys + assert "flag2" in keys + assert len(keys) == 2 + + def test_evicts_old_entries_when_max_size_exceeded(): store.set_configurations({"item_to_be_evicted": mock_flag}) assert store.get_configuration("item_to_be_evicted") == mock_flag