diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 70e5617040..61dcd9a127 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -235,9 +235,6 @@ def __init__( self.reference_model_state_dict = get_cpu_state_dict( self.model.state_dict().items(), pin_memory=True ) - self.reference_model_buffers = get_cpu_state_dict( - self.model.named_buffers(), pin_memory=True - ) if init_optimizer: optimizer_cls = import_class_from_path(self.cfg["optimizer"]["name"]) @@ -768,32 +765,26 @@ def use_reference_model(self) -> Generator[None, None, None]: """ with torch.no_grad(): try: + # Save train model state_dict curr_state_dict = get_cpu_state_dict( self.model.state_dict().items(), pin_memory=True ) - curr_buffers = get_cpu_state_dict( - self.model.named_buffers(), pin_memory=True - ) + # Swap reference model state_dict to self.model for k, v in self.model.state_dict().items(): val = to_local_if_dtensor(v) val.copy_(self.reference_model_state_dict[k]) - for k, v in self.model.named_buffers(): - val = to_local_if_dtensor(v) - val.copy_(self.reference_model_buffers[k]) - + # - self.model is the original reference_model, now on CUDA + # - curr_state_dict is the train model, now on CPU yield finally: + # Restore train model state_dict for k, v in self.model.state_dict().items(): val = to_local_if_dtensor(v) val.copy_(curr_state_dict[k]) - for k, v in self.model.named_buffers(): - val = to_local_if_dtensor(v) - val.copy_(curr_buffers[k]) - def get_reference_policy_logprobs( self, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None ) -> BatchedDataDict[ReferenceLogprobOutputSpec]: