Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions reagent/evaluation/cb/base_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
from reagent.core.types import CBInput
from reagent.core.utils import get_rank
from reagent.evaluation.cb.utils import zero_out_skipped_obs_weights
from torch.utils.tensorboard import SummaryWriter

Expand All @@ -16,12 +17,14 @@ class BaseOfflineEval(torch.nn.Module, ABC):
"""
Base class for Contextual Bandit Offline Evaluation algorithms. All algorihtms support evaluation of non-stationary
policies, as required for exploration-exploitation.

IMPORTANT: current implementation doesn't support distributed training, only use it with a single training instance.
"""

sum_weight: torch.Tensor
all_data_sum_weight: torch.Tensor
sum_weight_local: torch.Tensor
all_data_sum_weight_local: torch.Tensor
sum_weight_since_update_local: torch.Tensor
num_eval_model_updates: torch.Tensor

def __init__(
self,
Expand All @@ -37,8 +40,12 @@ def __init__(
self.summary_writer = summary_writer
self.register_buffer("sum_weight", torch.zeros(1, dtype=torch.float))
self.register_buffer("all_data_sum_weight", torch.zeros(1, dtype=torch.float))
self.register_buffer("sum_weight_local", torch.zeros(1, dtype=torch.float))
self.register_buffer(
"all_data_sum_weight_local", torch.zeros(1, dtype=torch.float)
)
self.register_buffer(
"sum_weight_since_update", torch.zeros(1, dtype=torch.float)
"sum_weight_since_update_local", torch.zeros(1, dtype=torch.float)
)
self.register_buffer("num_eval_model_updates", torch.zeros(1, dtype=torch.int))

Expand Down Expand Up @@ -86,6 +93,14 @@ def _process_used_data(
"""
pass

@abstractmethod
def _aggregate_across_instances(self) -> None:
"""
Aggregate local data across all instances of the evaluator.
Used for distributed training.
"""
pass

@abstractmethod
def get_avg_reward(self) -> float:
"""
Expand All @@ -108,18 +123,20 @@ def attach_summary_writer(self, summary_writer: SummaryWriter) -> None:
self.summary_writer = summary_writer

def log_metrics(self, global_step: Optional[int] = None) -> None:
logger.info(self.get_formatted_result_string())
summary_writer = self.summary_writer
if summary_writer is not None:
metric_dict = {
"avg_reward": self.get_avg_reward(),
"sum_weight": self.sum_weight.item(),
"all_data_sum_weight": self.all_data_sum_weight.item(),
"num_eval_model_updates": self.num_eval_model_updates.item(),
}
summary_writer.add_scalars(
"Offline_Eval", metric_dict, global_step=global_step
)
if get_rank() == 0:
# only log from the main process
logger.info(self.get_formatted_result_string())
summary_writer = self.summary_writer
if summary_writer is not None:
metric_dict = {
"avg_reward": self.get_avg_reward(),
"sum_weight": self.sum_weight.item(),
"all_data_sum_weight": self.all_data_sum_weight.item(),
"num_eval_model_updates": self.num_eval_model_updates.item(),
}
summary_writer.add_scalars(
"Offline_Eval", metric_dict, global_step=global_step
)

def get_formatted_result_string(self) -> str:
return f"Avg reward {self.get_avg_reward():0.3f} based on {int(self.sum_weight.item())} processed observations (out of {int(self.all_data_sum_weight.item())} observations). The eval model has been updated {self.num_eval_model_updates.item()} times"
49 changes: 38 additions & 11 deletions reagent/evaluation/cb/policy_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import logging

import torch
from pytorch_lightning.utilities.distributed import ReduceOp, sync_ddp_if_available
from reagent.core.types import CBInput
from reagent.evaluation.cb.base_evaluator import BaseOfflineEval


logger = logging.getLogger(__name__)


EPSILON = 1e-9


Expand All @@ -11,18 +17,22 @@ class PolicyEvaluator(BaseOfflineEval):
An offline evaluator for Contextual Bandits, based on the paper https://arxiv.org/pdf/1003.0146.pdf (Algorithm 3)
"""

avg_reward_weighted: torch.Tensor
sum_reward_weighted: torch.Tensor
sum_reward_weighted_local: torch.Tensor

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register_buffer("avg_reward_weighted", torch.zeros(1, dtype=torch.float))
self.register_buffer("sum_reward_weighted", torch.zeros(1, dtype=torch.float))
self.register_buffer(
"sum_reward_weighted_local", torch.zeros(1, dtype=torch.float)
)

@torch.no_grad()
def _process_all_data(self, batch: CBInput) -> None:
if batch.weight is not None:
self.all_data_sum_weight += batch.weight.sum()
self.all_data_sum_weight_local += batch.weight.sum()
else:
self.all_data_sum_weight += len(batch)
self.all_data_sum_weight_local += len(batch)

@torch.no_grad()
def _process_used_data(self, batch: CBInput) -> None:
Expand All @@ -33,13 +43,30 @@ def _process_used_data(self, batch: CBInput) -> None:
"""
assert batch.reward is not None
assert batch.weight is not None
batch_sum_weight = batch.weight.sum()
assert batch.weight.shape == batch.reward.shape
self.avg_reward_weighted = (
self.avg_reward_weighted * self.sum_weight
+ (batch.weight * batch.reward).sum()
) / (self.sum_weight + batch_sum_weight + EPSILON)
self.sum_weight += batch_sum_weight
self.sum_reward_weighted_local += (batch.weight * batch.reward).sum()
self.sum_weight_local += batch.weight.sum()

def _aggregate_across_instances(self) -> None:
# sum local values across all trainers, add to the global value
# clone the tensors to avoid modifying them inplace
self.sum_reward_weighted += sync_ddp_if_available(
self.sum_reward_weighted_local.clone(), reduce_op=ReduceOp.SUM
)
self.sum_weight += sync_ddp_if_available(
self.sum_weight_local.clone(), reduce_op=ReduceOp.SUM
)
self.all_data_sum_weight += sync_ddp_if_available(
self.all_data_sum_weight_local.clone(), reduce_op=ReduceOp.SUM
)
# reset local values to zero
self.sum_reward_weighted_local.zero_()
self.sum_weight_local.zero_()
self.all_data_sum_weight_local.zero_()

def get_avg_reward(self) -> float:
return self.avg_reward_weighted.item()
assert (
self.sum_weight_local.item() == 0.0
), f"Non-zero local weight {self.sum_weight_local.item()} in the evaluator. _aggregate_across_instances() Should have beed called to aggregate across all instances and zero-out the local values."
# return the average reward
return (self.sum_reward_weighted / (self.sum_weight + EPSILON)).item()
6 changes: 5 additions & 1 deletion reagent/test/evaluation/cb/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,11 @@ def test_eval_during_training(self):
)

# check total weight (number of observations). Should be 3
self.assertAlmostEqual(self.eval_module.sum_weight.item(), 3.0, places=4)
self.assertAlmostEqual(
(self.eval_module.sum_weight + self.eval_module.sum_weight_local).item(),
3.0,
places=4,
)

# metrics should have been logged once, at the end of epoch
# TODO: test logging logic triggered by eval_model_update_critical_weight
Expand Down
39 changes: 31 additions & 8 deletions reagent/test/evaluation/cb/test_policy_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,34 @@ def test_process_all_data(self):
self.eval_module._process_all_data(self.batch)
state_dict_after = copy.deepcopy(self.eval_module.state_dict())

# all_data_sum_weight got updated properly
# all_data_sum_weight_local got updated properly
self.assertAlmostEqual(
state_dict_after["all_data_sum_weight"].item()
- state_dict_before["all_data_sum_weight"].item(),
state_dict_after["all_data_sum_weight_local"].item()
- state_dict_before["all_data_sum_weight_local"].item(),
len(self.batch),
)
# all_data_sum_weight didn't change (bcs we haven't aggregated across instances yet)
self.assertAlmostEqual(
state_dict_after["all_data_sum_weight"].item(),
state_dict_before["all_data_sum_weight"].item(),
)

# sum_weight and avg_reward_weighted didn't change
# sum_weight and sum_reward_weighted didn't change (as well as local values)
self.assertAlmostEqual(
state_dict_after["sum_weight"].item(),
state_dict_before["sum_weight"].item(),
)
self.assertAlmostEqual(
state_dict_after["avg_reward_weighted"].item(),
state_dict_before["avg_reward_weighted"].item(),
state_dict_after["sum_weight_local"].item(),
state_dict_before["sum_weight_local"].item(),
)
self.assertAlmostEqual(
state_dict_after["sum_reward_weighted"].item(),
state_dict_before["sum_reward_weighted"].item(),
)
self.assertAlmostEqual(
state_dict_after["sum_reward_weighted_local"].item(),
state_dict_before["sum_reward_weighted_local"].item(),
)

def test_process_used_data_reject_all(self):
Expand All @@ -88,14 +101,22 @@ def test_process_used_data_accept_some(self):
policy_network = LinearRegressionUCB(2)
eval_module = PolicyEvaluator(policy_network)
state_dict_before = copy.deepcopy(eval_module.state_dict())
weight_value = 2.0
batch = replace(
self.batch,
weight=torch.tensor([[0.0], [1.0]]),
weight=torch.tensor([[0.0], [weight_value]]),
)
eval_module._process_used_data(batch)
eval_module._aggregate_across_instances()
state_dict_after = copy.deepcopy(eval_module.state_dict())
self.assertFalse(_compare_state_dicts(state_dict_before, state_dict_after))
self.assertEqual(eval_module.sum_weight.item(), 1.0)
self.assertEqual(eval_module.sum_weight_local.item(), 0.0)
self.assertEqual(eval_module.sum_weight.item(), weight_value)
self.assertEqual(
eval_module.sum_reward_weighted.item(),
weight_value * self.batch.reward[1, 0].item(),
)
self.assertEqual(eval_module.sum_reward_weighted_local.item(), 0.0)
self.assertEqual(eval_module.get_avg_reward(), self.batch.reward[1, 0].item())

def test_update_eval_model(self):
Expand Down Expand Up @@ -129,6 +150,7 @@ def test_update_eval_model(self):
def test_ingest_batch(self):
model_actions = torch.tensor([[1], [1]], dtype=torch.long)
_ = self.eval_module.ingest_batch(self.batch, model_actions)
self.eval_module._aggregate_across_instances()
# correct average reward
self.assertEqual(
self.eval_module.get_avg_reward(), self.batch.reward[1, 0].item()
Expand All @@ -137,6 +159,7 @@ def test_ingest_batch(self):
def test_formatted_output(self):
model_actions = torch.tensor([[1], [1]], dtype=torch.long)
_ = self.eval_module.ingest_batch(self.batch, model_actions)
self.eval_module._aggregate_across_instances()
output = self.eval_module.get_formatted_result_string()
self.assertIsInstance(output, str)

Expand Down
13 changes: 8 additions & 5 deletions reagent/training/cb/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,17 @@ def training_step(self, batch: CBInput, batch_idx: int, optimizer_idx: int = 0):
# update the model if we've processed enough samples
eval_model_update_critical_weight = self.eval_model_update_critical_weight
if eval_model_update_critical_weight is not None:
# TODO: support distributed training by aggregating sum_weight across trainers.
if (
eval_module.sum_weight_since_update.item()
eval_module.sum_weight_since_update_local.item()
>= eval_model_update_critical_weight
):
logger.info(
f"Updating the evaluated model after {eval_module.sum_weight_since_update.item()} observations"
f"Updating the evaluated model after {eval_module.sum_weight_since_update_local.item()} observations"
)
eval_module.update_eval_model(self.scorer)
eval_module.sum_weight_since_update.zero_()
eval_module.sum_weight_since_update_local.zero_()
eval_module.num_eval_model_updates += 1
eval_module._aggregate_across_instances()
eval_module.log_metrics(global_step=self.global_step)
with torch.no_grad():
eval_scores = eval_module.eval_model(batch.context_arm_features)
Expand All @@ -89,7 +89,7 @@ def training_step(self, batch: CBInput, batch_idx: int, optimizer_idx: int = 0):
else:
model_actions = torch.argmax(eval_scores, dim=1).reshape(-1, 1)
new_batch = eval_module.ingest_batch(batch, model_actions)
eval_module.sum_weight_since_update += (
eval_module.sum_weight_since_update_local += (
batch.weight.sum() if batch.weight is not None else len(batch)
)
else:
Expand All @@ -99,4 +99,7 @@ def training_step(self, batch: CBInput, batch_idx: int, optimizer_idx: int = 0):
def on_train_epoch_end(self):
eval_module = self.eval_module # assign to local var to keep pyre happy
if eval_module is not None:
if eval_module.sum_weight_since_update_local.item() > 0:
# only aggregate if we've processed new data since last aggregation.
eval_module._aggregate_across_instances()
eval_module.log_metrics(global_step=self.global_step)