From 5bd9894c770e6d92b0aa6994d0fc5217fb582cd3 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Fri, 21 Mar 2025 13:40:30 -0700 Subject: [PATCH] Enable amp with autocast Signed-off-by: Sahil Jain --- nemo_reinforcer/models/policy/hf_policy.py | 40 ++++++++++++---------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 1a5d23232e..cf932e98d6 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -73,9 +73,9 @@ def __init__( world_size = torch.distributed.get_world_size() model_name = self.cfg["model_name"] if self.cfg["precision"] == "float32": - dtype = torch.float32 + self.dtype = torch.float32 elif self.cfg["precision"] == "bfloat16": - dtype = torch.bfloat16 + self.dtype = torch.bfloat16 else: raise ValueError(f"Unknown precision: {self.cfg['precision']}") @@ -83,12 +83,12 @@ def __init__( self.model = AutoModelForCausalLM.from_pretrained( model_name, device_map="cpu", # load weights onto CPU initially - torch_dtype=dtype, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed + torch_dtype=torch.float32, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed ) self.reference_model = AutoModelForCausalLM.from_pretrained( model_name, device_map="cpu", # load weights onto CPU initially - torch_dtype=dtype, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed + torch_dtype=torch.float32, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed ) self.tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -272,16 +272,17 @@ def train( # For right-padded sequence, set 1s at the beginning of the sequence attention_mask[i, :length] = 1 - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - use_cache=False, - ) - # Get logprobs - if not hasattr(outputs, "logits"): - logits = self.model.lm_head(outputs.last_hidden_state) - else: - logits = outputs.logits + with torch.autocast(device_type="cuda", dtype=self.dtype): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + use_cache=False, + ) + # Get logprobs + if not hasattr(outputs, "logits"): + logits = self.model.lm_head(outputs.last_hidden_state) + else: + logits = outputs.logits loss, loss_metrics = loss_fn(logits, mb) @@ -358,11 +359,12 @@ def get_logprobs(self, data: BatchedDataDict) -> BatchedDataDict: attention_mask[i, :length] = 1 # Process with the model directly using right-padded inputs - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - use_cache=False, - ) + with torch.autocast(device_type="cuda", dtype=self.dtype): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + use_cache=False, + ) log_probs = torch.nn.functional.log_softmax( outputs.logits.to(torch.float32), dim=-1 )