Skip to content
This repository was archived by the owner on Nov 8, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions eppo_client/bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class BanditEvaluation:
action_score: float
action_weight: float
gamma: float
optimality_gap: float


@dataclass
Expand All @@ -89,14 +90,7 @@ def null_evaluation(
flag_key: str, subject_key: str, subject_attributes: Attributes, gamma: float
):
return BanditEvaluation(
flag_key,
subject_key,
subject_attributes,
None,
None,
0.0,
0.0,
gamma,
flag_key, subject_key, subject_attributes, None, None, 0.0, 0.0, gamma, 0.0

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 the formatting change tripped me up at first; I thought some parameters had been re-ordered. No, this just adds 0.0 as the optimality gap (last parameter)

)


Expand Down Expand Up @@ -129,9 +123,17 @@ def evaluate_bandit(
bandit_model.action_probability_floor,
)

selected_idx, selected_action = self.select_action(
flag_key, subject_key, action_weights
selected_action = self.select_action(flag_key, subject_key, action_weights)
selected_idx = next(
idx
for idx, action_context in enumerate(actions_with_contexts)
if action_context.action_key == selected_action
)

optimality_gap = (
max(score for _, score in action_scores) - action_scores[selected_idx][1]
)

return BanditEvaluation(
flag_key,
subject_key,
Expand All @@ -141,6 +143,7 @@ def evaluate_bandit(
action_scores[selected_idx][1],
action_weights[selected_idx][1],
bandit_model.gamma,
optimality_gap,
)

def score_actions(
Expand Down Expand Up @@ -192,7 +195,7 @@ def weigh_actions(
weights.append((best_action, remaining_weight))
return weights

def select_action(self, flag_key, subject_key, action_weights) -> Tuple[int, str]:
def select_action(self, flag_key, subject_key, action_weights) -> str:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this function incorrect before the change, or did you prefer having the selected_idx logic outside?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the index corresponded to the sorted list, rather than the original list, so returning it was useless and caused the bug.

# deterministic ordering
sorted_action_weights = sorted(
action_weights,
Expand All @@ -209,10 +212,10 @@ def select_action(self, flag_key, subject_key, action_weights) -> Tuple[int, str
cumulative_weight = 0.0
shard_value = shard / self.total_shards

for idx, (action_key, weight) in enumerate(sorted_action_weights):
for action_key, weight in sorted_action_weights:
cumulative_weight += weight
if cumulative_weight > shard_value:
return idx, action_key
return action_key

# If no action is selected, return the last action (fallback)
raise BanditEvaluationError(
Expand Down
1 change: 1 addition & 0 deletions eppo_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def get_bandit_action_detail(
"subject": subject_key,
"action": evaluation.action_key if evaluation else None,
"actionProbability": evaluation.action_weight if evaluation else None,
"optimalityGap": evaluation.optimality_gap if evaluation else None,
"modelVersion": bandit_data.model_version if evaluation else None,
"timestamp": datetime.datetime.utcnow().isoformat(),
"subjectNumericAttributes": (
Expand Down
2 changes: 1 addition & 1 deletion eppo_client/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.1.3"
__version__ = "3.1.4"
57 changes: 51 additions & 6 deletions test/client_bandit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import os
from time import sleep
from typing import Dict
from typing import Dict, List
from eppo_client.bandit import BanditResult, ActionContext, Attributes

import httpretty # type: ignore
Expand Down Expand Up @@ -34,11 +34,17 @@


class MockAssignmentLogger(AssignmentLogger):
assignment_events: List[Dict] = []
bandit_events: List[Dict] = []

def log_assignment(self, assignment_event: Dict):
print(f"Assignment Event: {assignment_event}")
self.assignment_events.append(assignment_event)

def log_bandit_action(self, bandit_event: Dict):
print(f"Bandit Event: {bandit_event}")
self.bandit_events.append(bandit_event)


mock_assignment_logger = MockAssignmentLogger()


@pytest.fixture(scope="session", autouse=True)
Expand All @@ -64,7 +70,7 @@ def init_fixture():
Config(
base_url=MOCK_BASE_URL,
api_key="dummy",
assignment_logger=AssignmentLogger(),
assignment_logger=mock_assignment_logger,
)
)
sleep(0.1) # wait for initialization
Expand Down Expand Up @@ -102,16 +108,55 @@ def test_get_bandit_action_flag_without_bandit():
def test_get_bandit_action_with_subject_attributes():
# tests that allocation filtering based on subject attributes works correctly
client = get_instance()
actions = [
ActionContext.create("adidas", {"discount": 0.1}, {"from": "germany"}),
ActionContext.create("nike", {"discount": 0.2}, {"from": "usa"}),
]
result = client.get_bandit_action(
"banner_bandit_flag_uk_only",
"subject_key",
"alice",
DEFAULT_SUBJECT_ATTRIBUTES,
[ActionContext.create("adidas", {}, {}), ActionContext.create("nike", {}, {})],
actions,
"default_variation",
)
assert result.variation == "banner_bandit"
assert result.action in ["adidas", "nike"]

# testing assignment logger
assignment_log_statement = mock_assignment_logger.assignment_events[-1]
assert assignment_log_statement["featureFlag"] == "banner_bandit_flag_uk_only"
assert assignment_log_statement["variation"] == "banner_bandit"
assert assignment_log_statement["subject"] == "alice"

# testing bandit logger
bandit_log_statement = mock_assignment_logger.bandit_events[-1]
assert bandit_log_statement["flagKey"] == "banner_bandit_flag_uk_only"
assert bandit_log_statement["banditKey"] == "banner_bandit"
assert bandit_log_statement["subject"] == "alice"
assert (
bandit_log_statement["subjectNumericAttributes"]
== DEFAULT_SUBJECT_ATTRIBUTES.numeric_attributes
)
assert (
bandit_log_statement["subjectCategoricalAttributes"]
== DEFAULT_SUBJECT_ATTRIBUTES.categorical_attributes
)
assert bandit_log_statement["action"] == result.action
assert bandit_log_statement["optimalityGap"] >= 0
assert bandit_log_statement["actionProbability"] >= 0

chosen_action = next(
action for action in actions if action.action_key == result.action
)
assert (
bandit_log_statement["actionNumericAttributes"]
== chosen_action.attributes.numeric_attributes
)
assert (
bandit_log_statement["actionCategoricalAttributes"]
== chosen_action.attributes.categorical_attributes
)


@pytest.mark.parametrize("test_case", test_data)
def test_bandit_generic_test_cases(test_case):
Expand Down
2 changes: 2 additions & 0 deletions test/eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def test_flag_target_on_id():
assert result.variation == Variation(key="control", value="control")
result = evaluator.evaluate_flag(flag, "user-3", {})
assert result.variation is None
result = evaluator.evaluate_flag(flag, "user-1", {"id": "do-not-overwrite-me"})
assert result.variation is None


def test_catch_all_allocation():
Expand Down