diff --git a/reagent/evaluation/cb/base_evaluator.py b/reagent/evaluation/cb/base_evaluator.py index 4f73af5c2..403c35b77 100644 --- a/reagent/evaluation/cb/base_evaluator.py +++ b/reagent/evaluation/cb/base_evaluator.py @@ -5,6 +5,7 @@ import torch from reagent.core.types import CBInput +from reagent.core.utils import get_rank from reagent.evaluation.cb.utils import zero_out_skipped_obs_weights from torch.utils.tensorboard import SummaryWriter @@ -16,12 +17,14 @@ class BaseOfflineEval(torch.nn.Module, ABC): """ Base class for Contextual Bandit Offline Evaluation algorithms. All algorihtms support evaluation of non-stationary policies, as required for exploration-exploitation. - - IMPORTANT: current implementation doesn't support distributed training, only use it with a single training instance. """ sum_weight: torch.Tensor all_data_sum_weight: torch.Tensor + sum_weight_local: torch.Tensor + all_data_sum_weight_local: torch.Tensor + sum_weight_since_update_local: torch.Tensor + num_eval_model_updates: torch.Tensor def __init__( self, @@ -37,8 +40,12 @@ def __init__( self.summary_writer = summary_writer self.register_buffer("sum_weight", torch.zeros(1, dtype=torch.float)) self.register_buffer("all_data_sum_weight", torch.zeros(1, dtype=torch.float)) + self.register_buffer("sum_weight_local", torch.zeros(1, dtype=torch.float)) + self.register_buffer( + "all_data_sum_weight_local", torch.zeros(1, dtype=torch.float) + ) self.register_buffer( - "sum_weight_since_update", torch.zeros(1, dtype=torch.float) + "sum_weight_since_update_local", torch.zeros(1, dtype=torch.float) ) self.register_buffer("num_eval_model_updates", torch.zeros(1, dtype=torch.int)) @@ -86,6 +93,14 @@ def _process_used_data( """ pass + @abstractmethod + def _aggregate_across_instances(self) -> None: + """ + Aggregate local data across all instances of the evaluator. + Used for distributed training. + """ + pass + @abstractmethod def get_avg_reward(self) -> float: """ @@ -108,18 +123,20 @@ def attach_summary_writer(self, summary_writer: SummaryWriter) -> None: self.summary_writer = summary_writer def log_metrics(self, global_step: Optional[int] = None) -> None: - logger.info(self.get_formatted_result_string()) - summary_writer = self.summary_writer - if summary_writer is not None: - metric_dict = { - "avg_reward": self.get_avg_reward(), - "sum_weight": self.sum_weight.item(), - "all_data_sum_weight": self.all_data_sum_weight.item(), - "num_eval_model_updates": self.num_eval_model_updates.item(), - } - summary_writer.add_scalars( - "Offline_Eval", metric_dict, global_step=global_step - ) + if get_rank() == 0: + # only log from the main process + logger.info(self.get_formatted_result_string()) + summary_writer = self.summary_writer + if summary_writer is not None: + metric_dict = { + "avg_reward": self.get_avg_reward(), + "sum_weight": self.sum_weight.item(), + "all_data_sum_weight": self.all_data_sum_weight.item(), + "num_eval_model_updates": self.num_eval_model_updates.item(), + } + summary_writer.add_scalars( + "Offline_Eval", metric_dict, global_step=global_step + ) def get_formatted_result_string(self) -> str: return f"Avg reward {self.get_avg_reward():0.3f} based on {int(self.sum_weight.item())} processed observations (out of {int(self.all_data_sum_weight.item())} observations). The eval model has been updated {self.num_eval_model_updates.item()} times" diff --git a/reagent/evaluation/cb/policy_evaluator.py b/reagent/evaluation/cb/policy_evaluator.py index 1ba94ff14..ddaee634d 100644 --- a/reagent/evaluation/cb/policy_evaluator.py +++ b/reagent/evaluation/cb/policy_evaluator.py @@ -1,8 +1,14 @@ +import logging + import torch +from pytorch_lightning.utilities.distributed import ReduceOp, sync_ddp_if_available from reagent.core.types import CBInput from reagent.evaluation.cb.base_evaluator import BaseOfflineEval +logger = logging.getLogger(__name__) + + EPSILON = 1e-9 @@ -11,18 +17,22 @@ class PolicyEvaluator(BaseOfflineEval): An offline evaluator for Contextual Bandits, based on the paper https://arxiv.org/pdf/1003.0146.pdf (Algorithm 3) """ - avg_reward_weighted: torch.Tensor + sum_reward_weighted: torch.Tensor + sum_reward_weighted_local: torch.Tensor def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.register_buffer("avg_reward_weighted", torch.zeros(1, dtype=torch.float)) + self.register_buffer("sum_reward_weighted", torch.zeros(1, dtype=torch.float)) + self.register_buffer( + "sum_reward_weighted_local", torch.zeros(1, dtype=torch.float) + ) @torch.no_grad() def _process_all_data(self, batch: CBInput) -> None: if batch.weight is not None: - self.all_data_sum_weight += batch.weight.sum() + self.all_data_sum_weight_local += batch.weight.sum() else: - self.all_data_sum_weight += len(batch) + self.all_data_sum_weight_local += len(batch) @torch.no_grad() def _process_used_data(self, batch: CBInput) -> None: @@ -33,13 +43,30 @@ def _process_used_data(self, batch: CBInput) -> None: """ assert batch.reward is not None assert batch.weight is not None - batch_sum_weight = batch.weight.sum() assert batch.weight.shape == batch.reward.shape - self.avg_reward_weighted = ( - self.avg_reward_weighted * self.sum_weight - + (batch.weight * batch.reward).sum() - ) / (self.sum_weight + batch_sum_weight + EPSILON) - self.sum_weight += batch_sum_weight + self.sum_reward_weighted_local += (batch.weight * batch.reward).sum() + self.sum_weight_local += batch.weight.sum() + + def _aggregate_across_instances(self) -> None: + # sum local values across all trainers, add to the global value + # clone the tensors to avoid modifying them inplace + self.sum_reward_weighted += sync_ddp_if_available( + self.sum_reward_weighted_local.clone(), reduce_op=ReduceOp.SUM + ) + self.sum_weight += sync_ddp_if_available( + self.sum_weight_local.clone(), reduce_op=ReduceOp.SUM + ) + self.all_data_sum_weight += sync_ddp_if_available( + self.all_data_sum_weight_local.clone(), reduce_op=ReduceOp.SUM + ) + # reset local values to zero + self.sum_reward_weighted_local.zero_() + self.sum_weight_local.zero_() + self.all_data_sum_weight_local.zero_() def get_avg_reward(self) -> float: - return self.avg_reward_weighted.item() + assert ( + self.sum_weight_local.item() == 0.0 + ), f"Non-zero local weight {self.sum_weight_local.item()} in the evaluator. _aggregate_across_instances() Should have beed called to aggregate across all instances and zero-out the local values." + # return the average reward + return (self.sum_reward_weighted / (self.sum_weight + EPSILON)).item() diff --git a/reagent/test/evaluation/cb/test_integration.py b/reagent/test/evaluation/cb/test_integration.py index 9a93af35f..1fb53732e 100644 --- a/reagent/test/evaluation/cb/test_integration.py +++ b/reagent/test/evaluation/cb/test_integration.py @@ -258,7 +258,11 @@ def test_eval_during_training(self): ) # check total weight (number of observations). Should be 3 - self.assertAlmostEqual(self.eval_module.sum_weight.item(), 3.0, places=4) + self.assertAlmostEqual( + (self.eval_module.sum_weight + self.eval_module.sum_weight_local).item(), + 3.0, + places=4, + ) # metrics should have been logged once, at the end of epoch # TODO: test logging logic triggered by eval_model_update_critical_weight diff --git a/reagent/test/evaluation/cb/test_policy_evaluator.py b/reagent/test/evaluation/cb/test_policy_evaluator.py index 3c7e9ae86..9d351ad09 100644 --- a/reagent/test/evaluation/cb/test_policy_evaluator.py +++ b/reagent/test/evaluation/cb/test_policy_evaluator.py @@ -55,21 +55,34 @@ def test_process_all_data(self): self.eval_module._process_all_data(self.batch) state_dict_after = copy.deepcopy(self.eval_module.state_dict()) - # all_data_sum_weight got updated properly + # all_data_sum_weight_local got updated properly self.assertAlmostEqual( - state_dict_after["all_data_sum_weight"].item() - - state_dict_before["all_data_sum_weight"].item(), + state_dict_after["all_data_sum_weight_local"].item() + - state_dict_before["all_data_sum_weight_local"].item(), len(self.batch), ) + # all_data_sum_weight didn't change (bcs we haven't aggregated across instances yet) + self.assertAlmostEqual( + state_dict_after["all_data_sum_weight"].item(), + state_dict_before["all_data_sum_weight"].item(), + ) - # sum_weight and avg_reward_weighted didn't change + # sum_weight and sum_reward_weighted didn't change (as well as local values) self.assertAlmostEqual( state_dict_after["sum_weight"].item(), state_dict_before["sum_weight"].item(), ) self.assertAlmostEqual( - state_dict_after["avg_reward_weighted"].item(), - state_dict_before["avg_reward_weighted"].item(), + state_dict_after["sum_weight_local"].item(), + state_dict_before["sum_weight_local"].item(), + ) + self.assertAlmostEqual( + state_dict_after["sum_reward_weighted"].item(), + state_dict_before["sum_reward_weighted"].item(), + ) + self.assertAlmostEqual( + state_dict_after["sum_reward_weighted_local"].item(), + state_dict_before["sum_reward_weighted_local"].item(), ) def test_process_used_data_reject_all(self): @@ -88,14 +101,22 @@ def test_process_used_data_accept_some(self): policy_network = LinearRegressionUCB(2) eval_module = PolicyEvaluator(policy_network) state_dict_before = copy.deepcopy(eval_module.state_dict()) + weight_value = 2.0 batch = replace( self.batch, - weight=torch.tensor([[0.0], [1.0]]), + weight=torch.tensor([[0.0], [weight_value]]), ) eval_module._process_used_data(batch) + eval_module._aggregate_across_instances() state_dict_after = copy.deepcopy(eval_module.state_dict()) self.assertFalse(_compare_state_dicts(state_dict_before, state_dict_after)) - self.assertEqual(eval_module.sum_weight.item(), 1.0) + self.assertEqual(eval_module.sum_weight_local.item(), 0.0) + self.assertEqual(eval_module.sum_weight.item(), weight_value) + self.assertEqual( + eval_module.sum_reward_weighted.item(), + weight_value * self.batch.reward[1, 0].item(), + ) + self.assertEqual(eval_module.sum_reward_weighted_local.item(), 0.0) self.assertEqual(eval_module.get_avg_reward(), self.batch.reward[1, 0].item()) def test_update_eval_model(self): @@ -129,6 +150,7 @@ def test_update_eval_model(self): def test_ingest_batch(self): model_actions = torch.tensor([[1], [1]], dtype=torch.long) _ = self.eval_module.ingest_batch(self.batch, model_actions) + self.eval_module._aggregate_across_instances() # correct average reward self.assertEqual( self.eval_module.get_avg_reward(), self.batch.reward[1, 0].item() @@ -137,6 +159,7 @@ def test_ingest_batch(self): def test_formatted_output(self): model_actions = torch.tensor([[1], [1]], dtype=torch.long) _ = self.eval_module.ingest_batch(self.batch, model_actions) + self.eval_module._aggregate_across_instances() output = self.eval_module.get_formatted_result_string() self.assertIsInstance(output, str) diff --git a/reagent/training/cb/base_trainer.py b/reagent/training/cb/base_trainer.py index c836bda1f..c39b275ca 100644 --- a/reagent/training/cb/base_trainer.py +++ b/reagent/training/cb/base_trainer.py @@ -61,17 +61,17 @@ def training_step(self, batch: CBInput, batch_idx: int, optimizer_idx: int = 0): # update the model if we've processed enough samples eval_model_update_critical_weight = self.eval_model_update_critical_weight if eval_model_update_critical_weight is not None: - # TODO: support distributed training by aggregating sum_weight across trainers. if ( - eval_module.sum_weight_since_update.item() + eval_module.sum_weight_since_update_local.item() >= eval_model_update_critical_weight ): logger.info( - f"Updating the evaluated model after {eval_module.sum_weight_since_update.item()} observations" + f"Updating the evaluated model after {eval_module.sum_weight_since_update_local.item()} observations" ) eval_module.update_eval_model(self.scorer) - eval_module.sum_weight_since_update.zero_() + eval_module.sum_weight_since_update_local.zero_() eval_module.num_eval_model_updates += 1 + eval_module._aggregate_across_instances() eval_module.log_metrics(global_step=self.global_step) with torch.no_grad(): eval_scores = eval_module.eval_model(batch.context_arm_features) @@ -89,7 +89,7 @@ def training_step(self, batch: CBInput, batch_idx: int, optimizer_idx: int = 0): else: model_actions = torch.argmax(eval_scores, dim=1).reshape(-1, 1) new_batch = eval_module.ingest_batch(batch, model_actions) - eval_module.sum_weight_since_update += ( + eval_module.sum_weight_since_update_local += ( batch.weight.sum() if batch.weight is not None else len(batch) ) else: @@ -99,4 +99,7 @@ def training_step(self, batch: CBInput, batch_idx: int, optimizer_idx: int = 0): def on_train_epoch_end(self): eval_module = self.eval_module # assign to local var to keep pyre happy if eval_module is not None: + if eval_module.sum_weight_since_update_local.item() > 0: + # only aggregate if we've processed new data since last aggregation. + eval_module._aggregate_across_instances() eval_module.log_metrics(global_step=self.global_step)