From 5aac164745fefdbcf499bf27d41c17e4785ff511 Mon Sep 17 00:00:00 2001 From: Alex Nikulkov Date: Tue, 10 Jan 2023 15:44:44 -0800 Subject: [PATCH] add support for distributed Offline Eval (#708) Summary: Pull Request resolved: https://github.com/facebookresearch/ReAgent/pull/708 Adding support for distributed Offline Eval. This requires maintaining local buffers in each trainer instance and syncing them across all trainers periodically. The sync happens under one of 2 conditions: 1. When the "critical" weight of data has been consumed (will be set approximately equal to the size of 1-hr partition) 2. At the end of the training epoch (if data has been consumed since last sync) Also, updating the FREE pipeline to remove the restriction on number of nodes for Offline Eval runs Differential Revision: D42407669 fbshipit-source-id: 634c94a594bedbd98d175d0c41371a717bab0306 --- reagent/evaluation/cb/base_evaluator.py | 47 ++++++++++++------ reagent/evaluation/cb/policy_evaluator.py | 49 ++++++++++++++----- .../test/evaluation/cb/test_integration.py | 6 ++- .../evaluation/cb/test_policy_evaluator.py | 39 ++++++++++++--- reagent/training/cb/base_trainer.py | 13 +++-- 5 files changed, 114 insertions(+), 40 deletions(-) diff --git a/reagent/evaluation/cb/base_evaluator.py b/reagent/evaluation/cb/base_evaluator.py index 4f73af5c2..403c35b77 100644 --- a/reagent/evaluation/cb/base_evaluator.py +++ b/reagent/evaluation/cb/base_evaluator.py @@ -5,6 +5,7 @@ import torch from reagent.core.types import CBInput +from reagent.core.utils import get_rank from reagent.evaluation.cb.utils import zero_out_skipped_obs_weights from torch.utils.tensorboard import SummaryWriter @@ -16,12 +17,14 @@ class BaseOfflineEval(torch.nn.Module, ABC): """ Base class for Contextual Bandit Offline Evaluation algorithms. All algorihtms support evaluation of non-stationary policies, as required for exploration-exploitation. - - IMPORTANT: current implementation doesn't support distributed training, only use it with a single training instance. """ sum_weight: torch.Tensor all_data_sum_weight: torch.Tensor + sum_weight_local: torch.Tensor + all_data_sum_weight_local: torch.Tensor + sum_weight_since_update_local: torch.Tensor + num_eval_model_updates: torch.Tensor def __init__( self, @@ -37,8 +40,12 @@ def __init__( self.summary_writer = summary_writer self.register_buffer("sum_weight", torch.zeros(1, dtype=torch.float)) self.register_buffer("all_data_sum_weight", torch.zeros(1, dtype=torch.float)) + self.register_buffer("sum_weight_local", torch.zeros(1, dtype=torch.float)) + self.register_buffer( + "all_data_sum_weight_local", torch.zeros(1, dtype=torch.float) + ) self.register_buffer( - "sum_weight_since_update", torch.zeros(1, dtype=torch.float) + "sum_weight_since_update_local", torch.zeros(1, dtype=torch.float) ) self.register_buffer("num_eval_model_updates", torch.zeros(1, dtype=torch.int)) @@ -86,6 +93,14 @@ def _process_used_data( """ pass + @abstractmethod + def _aggregate_across_instances(self) -> None: + """ + Aggregate local data across all instances of the evaluator. + Used for distributed training. + """ + pass + @abstractmethod def get_avg_reward(self) -> float: """ @@ -108,18 +123,20 @@ def attach_summary_writer(self, summary_writer: SummaryWriter) -> None: self.summary_writer = summary_writer def log_metrics(self, global_step: Optional[int] = None) -> None: - logger.info(self.get_formatted_result_string()) - summary_writer = self.summary_writer - if summary_writer is not None: - metric_dict = { - "avg_reward": self.get_avg_reward(), - "sum_weight": self.sum_weight.item(), - "all_data_sum_weight": self.all_data_sum_weight.item(), - "num_eval_model_updates": self.num_eval_model_updates.item(), - } - summary_writer.add_scalars( - "Offline_Eval", metric_dict, global_step=global_step - ) + if get_rank() == 0: + # only log from the main process + logger.info(self.get_formatted_result_string()) + summary_writer = self.summary_writer + if summary_writer is not None: + metric_dict = { + "avg_reward": self.get_avg_reward(), + "sum_weight": self.sum_weight.item(), + "all_data_sum_weight": self.all_data_sum_weight.item(), + "num_eval_model_updates": self.num_eval_model_updates.item(), + } + summary_writer.add_scalars( + "Offline_Eval", metric_dict, global_step=global_step + ) def get_formatted_result_string(self) -> str: return f"Avg reward {self.get_avg_reward():0.3f} based on {int(self.sum_weight.item())} processed observations (out of {int(self.all_data_sum_weight.item())} observations). The eval model has been updated {self.num_eval_model_updates.item()} times" diff --git a/reagent/evaluation/cb/policy_evaluator.py b/reagent/evaluation/cb/policy_evaluator.py index 1ba94ff14..ddaee634d 100644 --- a/reagent/evaluation/cb/policy_evaluator.py +++ b/reagent/evaluation/cb/policy_evaluator.py @@ -1,8 +1,14 @@ +import logging + import torch +from pytorch_lightning.utilities.distributed import ReduceOp, sync_ddp_if_available from reagent.core.types import CBInput from reagent.evaluation.cb.base_evaluator import BaseOfflineEval +logger = logging.getLogger(__name__) + + EPSILON = 1e-9 @@ -11,18 +17,22 @@ class PolicyEvaluator(BaseOfflineEval): An offline evaluator for Contextual Bandits, based on the paper https://arxiv.org/pdf/1003.0146.pdf (Algorithm 3) """ - avg_reward_weighted: torch.Tensor + sum_reward_weighted: torch.Tensor + sum_reward_weighted_local: torch.Tensor def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.register_buffer("avg_reward_weighted", torch.zeros(1, dtype=torch.float)) + self.register_buffer("sum_reward_weighted", torch.zeros(1, dtype=torch.float)) + self.register_buffer( + "sum_reward_weighted_local", torch.zeros(1, dtype=torch.float) + ) @torch.no_grad() def _process_all_data(self, batch: CBInput) -> None: if batch.weight is not None: - self.all_data_sum_weight += batch.weight.sum() + self.all_data_sum_weight_local += batch.weight.sum() else: - self.all_data_sum_weight += len(batch) + self.all_data_sum_weight_local += len(batch) @torch.no_grad() def _process_used_data(self, batch: CBInput) -> None: @@ -33,13 +43,30 @@ def _process_used_data(self, batch: CBInput) -> None: """ assert batch.reward is not None assert batch.weight is not None - batch_sum_weight = batch.weight.sum() assert batch.weight.shape == batch.reward.shape - self.avg_reward_weighted = ( - self.avg_reward_weighted * self.sum_weight - + (batch.weight * batch.reward).sum() - ) / (self.sum_weight + batch_sum_weight + EPSILON) - self.sum_weight += batch_sum_weight + self.sum_reward_weighted_local += (batch.weight * batch.reward).sum() + self.sum_weight_local += batch.weight.sum() + + def _aggregate_across_instances(self) -> None: + # sum local values across all trainers, add to the global value + # clone the tensors to avoid modifying them inplace + self.sum_reward_weighted += sync_ddp_if_available( + self.sum_reward_weighted_local.clone(), reduce_op=ReduceOp.SUM + ) + self.sum_weight += sync_ddp_if_available( + self.sum_weight_local.clone(), reduce_op=ReduceOp.SUM + ) + self.all_data_sum_weight += sync_ddp_if_available( + self.all_data_sum_weight_local.clone(), reduce_op=ReduceOp.SUM + ) + # reset local values to zero + self.sum_reward_weighted_local.zero_() + self.sum_weight_local.zero_() + self.all_data_sum_weight_local.zero_() def get_avg_reward(self) -> float: - return self.avg_reward_weighted.item() + assert ( + self.sum_weight_local.item() == 0.0 + ), f"Non-zero local weight {self.sum_weight_local.item()} in the evaluator. _aggregate_across_instances() Should have beed called to aggregate across all instances and zero-out the local values." + # return the average reward + return (self.sum_reward_weighted / (self.sum_weight + EPSILON)).item() diff --git a/reagent/test/evaluation/cb/test_integration.py b/reagent/test/evaluation/cb/test_integration.py index 9a93af35f..1fb53732e 100644 --- a/reagent/test/evaluation/cb/test_integration.py +++ b/reagent/test/evaluation/cb/test_integration.py @@ -258,7 +258,11 @@ def test_eval_during_training(self): ) # check total weight (number of observations). Should be 3 - self.assertAlmostEqual(self.eval_module.sum_weight.item(), 3.0, places=4) + self.assertAlmostEqual( + (self.eval_module.sum_weight + self.eval_module.sum_weight_local).item(), + 3.0, + places=4, + ) # metrics should have been logged once, at the end of epoch # TODO: test logging logic triggered by eval_model_update_critical_weight diff --git a/reagent/test/evaluation/cb/test_policy_evaluator.py b/reagent/test/evaluation/cb/test_policy_evaluator.py index 3c7e9ae86..9d351ad09 100644 --- a/reagent/test/evaluation/cb/test_policy_evaluator.py +++ b/reagent/test/evaluation/cb/test_policy_evaluator.py @@ -55,21 +55,34 @@ def test_process_all_data(self): self.eval_module._process_all_data(self.batch) state_dict_after = copy.deepcopy(self.eval_module.state_dict()) - # all_data_sum_weight got updated properly + # all_data_sum_weight_local got updated properly self.assertAlmostEqual( - state_dict_after["all_data_sum_weight"].item() - - state_dict_before["all_data_sum_weight"].item(), + state_dict_after["all_data_sum_weight_local"].item() + - state_dict_before["all_data_sum_weight_local"].item(), len(self.batch), ) + # all_data_sum_weight didn't change (bcs we haven't aggregated across instances yet) + self.assertAlmostEqual( + state_dict_after["all_data_sum_weight"].item(), + state_dict_before["all_data_sum_weight"].item(), + ) - # sum_weight and avg_reward_weighted didn't change + # sum_weight and sum_reward_weighted didn't change (as well as local values) self.assertAlmostEqual( state_dict_after["sum_weight"].item(), state_dict_before["sum_weight"].item(), ) self.assertAlmostEqual( - state_dict_after["avg_reward_weighted"].item(), - state_dict_before["avg_reward_weighted"].item(), + state_dict_after["sum_weight_local"].item(), + state_dict_before["sum_weight_local"].item(), + ) + self.assertAlmostEqual( + state_dict_after["sum_reward_weighted"].item(), + state_dict_before["sum_reward_weighted"].item(), + ) + self.assertAlmostEqual( + state_dict_after["sum_reward_weighted_local"].item(), + state_dict_before["sum_reward_weighted_local"].item(), ) def test_process_used_data_reject_all(self): @@ -88,14 +101,22 @@ def test_process_used_data_accept_some(self): policy_network = LinearRegressionUCB(2) eval_module = PolicyEvaluator(policy_network) state_dict_before = copy.deepcopy(eval_module.state_dict()) + weight_value = 2.0 batch = replace( self.batch, - weight=torch.tensor([[0.0], [1.0]]), + weight=torch.tensor([[0.0], [weight_value]]), ) eval_module._process_used_data(batch) + eval_module._aggregate_across_instances() state_dict_after = copy.deepcopy(eval_module.state_dict()) self.assertFalse(_compare_state_dicts(state_dict_before, state_dict_after)) - self.assertEqual(eval_module.sum_weight.item(), 1.0) + self.assertEqual(eval_module.sum_weight_local.item(), 0.0) + self.assertEqual(eval_module.sum_weight.item(), weight_value) + self.assertEqual( + eval_module.sum_reward_weighted.item(), + weight_value * self.batch.reward[1, 0].item(), + ) + self.assertEqual(eval_module.sum_reward_weighted_local.item(), 0.0) self.assertEqual(eval_module.get_avg_reward(), self.batch.reward[1, 0].item()) def test_update_eval_model(self): @@ -129,6 +150,7 @@ def test_update_eval_model(self): def test_ingest_batch(self): model_actions = torch.tensor([[1], [1]], dtype=torch.long) _ = self.eval_module.ingest_batch(self.batch, model_actions) + self.eval_module._aggregate_across_instances() # correct average reward self.assertEqual( self.eval_module.get_avg_reward(), self.batch.reward[1, 0].item() @@ -137,6 +159,7 @@ def test_ingest_batch(self): def test_formatted_output(self): model_actions = torch.tensor([[1], [1]], dtype=torch.long) _ = self.eval_module.ingest_batch(self.batch, model_actions) + self.eval_module._aggregate_across_instances() output = self.eval_module.get_formatted_result_string() self.assertIsInstance(output, str) diff --git a/reagent/training/cb/base_trainer.py b/reagent/training/cb/base_trainer.py index c836bda1f..c39b275ca 100644 --- a/reagent/training/cb/base_trainer.py +++ b/reagent/training/cb/base_trainer.py @@ -61,17 +61,17 @@ def training_step(self, batch: CBInput, batch_idx: int, optimizer_idx: int = 0): # update the model if we've processed enough samples eval_model_update_critical_weight = self.eval_model_update_critical_weight if eval_model_update_critical_weight is not None: - # TODO: support distributed training by aggregating sum_weight across trainers. if ( - eval_module.sum_weight_since_update.item() + eval_module.sum_weight_since_update_local.item() >= eval_model_update_critical_weight ): logger.info( - f"Updating the evaluated model after {eval_module.sum_weight_since_update.item()} observations" + f"Updating the evaluated model after {eval_module.sum_weight_since_update_local.item()} observations" ) eval_module.update_eval_model(self.scorer) - eval_module.sum_weight_since_update.zero_() + eval_module.sum_weight_since_update_local.zero_() eval_module.num_eval_model_updates += 1 + eval_module._aggregate_across_instances() eval_module.log_metrics(global_step=self.global_step) with torch.no_grad(): eval_scores = eval_module.eval_model(batch.context_arm_features) @@ -89,7 +89,7 @@ def training_step(self, batch: CBInput, batch_idx: int, optimizer_idx: int = 0): else: model_actions = torch.argmax(eval_scores, dim=1).reshape(-1, 1) new_batch = eval_module.ingest_batch(batch, model_actions) - eval_module.sum_weight_since_update += ( + eval_module.sum_weight_since_update_local += ( batch.weight.sum() if batch.weight is not None else len(batch) ) else: @@ -99,4 +99,7 @@ def training_step(self, batch: CBInput, batch_idx: int, optimizer_idx: int = 0): def on_train_epoch_end(self): eval_module = self.eval_module # assign to local var to keep pyre happy if eval_module is not None: + if eval_module.sum_weight_since_update_local.item() > 0: + # only aggregate if we've processed new data since last aggregation. + eval_module._aggregate_across_instances() eval_module.log_metrics(global_step=self.global_step)