diff --git a/eppo_client/bandit.py b/eppo_client/bandit.py index 2d05e82..8f9be69 100644 --- a/eppo_client/bandit.py +++ b/eppo_client/bandit.py @@ -160,8 +160,12 @@ def weigh_actions( self, action_scores, gamma, probability_floor ) -> Dict[str, float]: number_of_actions = len(action_scores) - best_action = max(action_scores, key=action_scores.get) - best_score = action_scores[best_action] + # Find the max score + best_score = max(action_scores.values()) + # Get all the keys that have the same best score (if there's more than one) + best_action_keys = [k for k, v in action_scores.items() if v == best_score] + # Get the lowest lexicographically ordered key. + best_action = min(best_action_keys) # adjust probability floor for number of actions to control the sum min_probability = probability_floor / number_of_actions