Skip to content
This repository was archived by the owner on Nov 8, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions eppo_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
)
from eppo_client.configuration_store import ConfigurationStore
from eppo_client.http_client import HttpClient, SdkParams
from eppo_client.models import Flag
from eppo_client.models import BanditData, Flag
from eppo_client.read_write_lock import ReadWriteLock
from eppo_client.version import __version__

Expand All @@ -30,9 +30,12 @@ def init(config: Config) -> EppoClient:
apiKey=config.api_key, sdkName="python", sdkVersion=__version__
)
http_client = HttpClient(base_url=config.base_url, sdk_params=sdk_params)
config_store: ConfigurationStore[Flag] = ConfigurationStore()
flag_config_store: ConfigurationStore[Flag] = ConfigurationStore()
bandit_config_store: ConfigurationStore[BanditData] = ConfigurationStore()
config_requestor = ExperimentConfigurationRequestor(
http_client=http_client, config_store=config_store
http_client=http_client,
flag_config_store=flag_config_store,
bandit_config_store=bandit_config_store,
)
assignment_logger = config.assignment_logger
is_graceful_mode = config.is_graceful_mode
Expand Down
3 changes: 3 additions & 0 deletions eppo_client/assignment_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@ class AssignmentLogger(BaseModel):

def log_assignment(self, assignment_event: Dict):
pass

def log_bandit_action(self, bandit_event: Dict):
pass
Comment on lines +12 to +13
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aarsilv in particular: added the log_bandit_action method to the assignment logger for now, it seems like the simplest solution but happy to adjust

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's fine! I don't know python well but the main goal will be that only customers using bandits will need to implement / worry about this

282 changes: 282 additions & 0 deletions eppo_client/bandit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
from dataclasses import dataclass
import logging
from typing import Dict, List, Optional, Tuple

from eppo_client.models import (
BanditCategoricalAttributeCoefficient,
BanditCoefficients,
BanditModelData,
BanditNumericAttributeCoefficient,
)
from eppo_client.sharders import Sharder


logger = logging.getLogger(__name__)


class BanditEvaluationError(Exception):
pass


@dataclass
class Attributes:
numeric_attributes: Dict[str, float]
categorical_attributes: Dict[str, str]


@dataclass
class ActionContext:
action_key: str
attributes: Attributes

@classmethod
def create(
cls,
action_key: str,
numeric_attributes: Dict[str, float],
categorical_attributes: Dict[str, str],
):
"""
Create an instance of ActionContext.

Args:
action_key (str): The key representing the action.
numeric_attributes (Dict[str, float]): A dictionary of numeric attributes.
categorical_attributes (Dict[str, str]): A dictionary of categorical attributes.

Returns:
ActionContext: An instance of ActionContext with the provided action key and attributes.
"""
return cls(
action_key,
Attributes(
numeric_attributes=numeric_attributes,
categorical_attributes=categorical_attributes,
),
)

@property
def numeric_attributes(self):
return self.attributes.numeric_attributes

@property
def categorical_attributes(self):
return self.attributes.categorical_attributes


@dataclass
class BanditEvaluation:
flag_key: str
subject_key: str
subject_attributes: Attributes
action_key: Optional[str]
action_attributes: Optional[Attributes]
action_score: float
action_weight: float
gamma: float


@dataclass
class BanditResult:
variation: str
action: Optional[str]

def to_string(self) -> str:
return coalesce(self.action, self.variation)


def null_evaluation(
flag_key: str, subject_key: str, subject_attributes: Attributes, gamma: float
):
return BanditEvaluation(
flag_key,
subject_key,
subject_attributes,
None,
None,
0.0,
0.0,
gamma,
)


@dataclass
class BanditEvaluator:
sharder: Sharder
total_shards: int = 10_000

def evaluate_bandit(
self,
flag_key: str,
subject_key: str,
subject_attributes: Attributes,
actions_with_contexts: List[ActionContext],
bandit_model: BanditModelData,
) -> BanditEvaluation:
# handle the edge case that there are no actions
if not actions_with_contexts:
return null_evaluation(
flag_key, subject_key, subject_attributes, bandit_model.gamma
)

action_scores = self.score_actions(
subject_attributes, actions_with_contexts, bandit_model
)

action_weights = self.weigh_actions(
action_scores,
bandit_model.gamma,
bandit_model.action_probability_floor,
)

selected_idx, selected_action = self.select_action(
flag_key, subject_key, action_weights
)
return BanditEvaluation(
flag_key,
subject_key,
subject_attributes,
selected_action,
actions_with_contexts[selected_idx].attributes,
action_scores[selected_idx][1],
action_weights[selected_idx][1],
bandit_model.gamma,
)

def score_actions(
self,
subject_attributes: Attributes,
actions_with_contexts: List[ActionContext],
bandit_model: BanditModelData,
) -> List[Tuple[str, float]]:
return [
(
action_context.action_key,
(
score_action(
subject_attributes,
action_context.attributes,
bandit_model.coefficients[action_context.action_key],
)
if action_context.action_key in bandit_model.coefficients
else bandit_model.default_action_score
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

),
)
for action_context in actions_with_contexts
]

def weigh_actions(
self, action_scores, gamma, probability_floor
) -> List[Tuple[str, float]]:
number_of_actions = len(action_scores)
best_action, best_score = max(action_scores, key=lambda t: t[1])

# adjust probability floor for number of actions to control the sum
min_probability = probability_floor / number_of_actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way it's coded in Java is the probability floor is an absolute floor irrespective of the number of actions (src). Either way, it has its issues, but dividing it by the number of actions seems safer if you'd like that to become the standard.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup I think this is safer; otherwise there is no good way to set the probability floor generally. Suppose 1 bandit gets called with 1000s of actions, and the other one with 5; what probability floor should we be using if we don't normalize?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯


# weight all but the best action
weights = [
(
action_key,
max(
min_probability,
1.0 / (number_of_actions + gamma * (best_score - score)),
),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In java, we'd do the extra step of rounding to the shard space (e.g., the closest ten thousandth) to keep weights consistent across programming languages that may have different decimal number implementations under the hood. (src)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I figure they all use the same standard for doubles right? and even if not the precision should be much better than 1/10000 -- going to leave as is for now but happy to adjust in a future PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah you're right, all our core languages use IEEE 754 so should be good 🤞

for action_key, score in action_scores
if action_key != best_action
]

# remaining weight goes to best action
remaining_weight = max(0.0, 1.0 - sum(weight for _, weight in weights))
weights.append((best_action, remaining_weight))
return weights

def select_action(self, flag_key, subject_key, action_weights) -> Tuple[int, str]:
# deterministic ordering
sorted_action_weights = sorted(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To handle the (unlikely) edge case of two actions getting the same shard, we were tie-breaking with the action name (src)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also tie-breaking with action name (t[0]), just hashed so it's not alphabetical, maybe overkill?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, they can still land in the same bucket 🤔

action_weights,
key=lambda t: (
self.sharder.get_shard(
f"{flag_key}-{subject_key}-{t[0]}", self.total_shards
),
t[0], # tie-break using action name
),
)

# select action based on weights
shard = self.sharder.get_shard(f"{flag_key}-{subject_key}", self.total_shards)
cumulative_weight = 0.0
shard_value = shard / self.total_shards

for idx, (action_key, weight) in enumerate(sorted_action_weights):
cumulative_weight += weight
if cumulative_weight > shard_value:
return idx, action_key

# If no action is selected, return the last action (fallback)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be very unexpected given the top scoring action should fill the rest of the space--thinking we throw an error instead (src)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to throw an error at runtime in the SDK -- let's log an error though

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was assuming there would be some error handler at the top-level (unless graceful mode is off) that catches and logs an error and returns the default value.

Otherwise we could get in a situation where a bug is introduced and all bandits start selecting the last action all the time--seems worse than returning the default value

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You've convinced me, also did not add top level error handling -- fixed

raise BanditEvaluationError(
f"[Eppo SDK] No action selected for {flag_key} {subject_key}"
)


def score_action(
subject_attributes: Attributes,
action_attributes: Attributes,
coefficients: BanditCoefficients,
) -> float:
score = coefficients.intercept
score += score_numeric_attributes(
coefficients.subject_numeric_coefficients,
subject_attributes.numeric_attributes,
)
score += score_categorical_attributes(
coefficients.subject_categorical_coefficients,
subject_attributes.categorical_attributes,
)
score += score_numeric_attributes(
coefficients.action_numeric_coefficients,
action_attributes.numeric_attributes,
)
score += score_categorical_attributes(
coefficients.action_categorical_coefficients,
action_attributes.categorical_attributes,
)
return score


def coalesce(value, default=0):
return value if value is not None else default
Comment on lines +248 to +249
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JavaScript has null coalescing built in just sayin'

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😬



def score_numeric_attributes(
coefficients: List[BanditNumericAttributeCoefficient],
attributes: Dict[str, float],
) -> float:
score = 0.0
for coefficient in coefficients:
if (
coefficient.attribute_key in attributes
and attributes[coefficient.attribute_key] is not None
):
score += coefficient.coefficient * attributes[coefficient.attribute_key]
else:
score += coefficient.missing_value_coefficient

return score


def score_categorical_attributes(
coefficients: List[BanditCategoricalAttributeCoefficient],
attributes: Dict[str, str],
) -> float:
score = 0.0
for coefficient in coefficients:
if coefficient.attribute_key in attributes:
score += coefficient.value_coefficients.get(
attributes[coefficient.attribute_key],
coefficient.missing_value_coefficient,
)
else:
score += coefficient.missing_value_coefficient
return score
Loading