diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index cf908ee70a..16b048845a 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -77,18 +77,18 @@ def __init__( self.model = AutoModelForCausalLM.from_pretrained( model_name, device_map="cpu", # load weights onto CPU initially - torch_dtype=torch.bfloat16, # use half precision to save memory + torch_dtype=torch.float32, # use full precision until https://github.com/NVIDIA/reinforcer/issues/13 is fixed ) self.reference_model = AutoModelForCausalLM.from_pretrained( model_name, device_map="cpu", # load weights onto CPU initially - torch_dtype=torch.bfloat16, # use half precision to save memory + torch_dtype=torch.float32, # use full precision until https://github.com/NVIDIA/reinforcer/issues/13 is fixed ) - self.tokenizer = tokenizer = AutoTokenizer.from_pretrained(model_name) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) # If no pad token is defined, you might need: - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token # ------------------------------------------------ # 3) Move to GPU + Composable FSDP @@ -99,23 +99,10 @@ def do_fsdp(model): # Create a device mesh with 'world_size' GPUs in a 1D arrangement. mesh = init_device_mesh("cuda", (world_size,)) - # Mixed precision training - # https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision - param_dtype = torch.bfloat16 # use lower precision for model parameters - reduce_dtype = torch.float32 # use higher precision for gradient reduction - buffer_dtype = torch.float32 # use higher precision for optimizer states - - mp = MixedPrecision( - param_dtype=param_dtype, - reduce_dtype=reduce_dtype, - buffer_dtype=buffer_dtype, - ) - return FullyShardedDataParallel( model, device_mesh=mesh, auto_wrap_policy=size_based_auto_wrap_policy, - mixed_precision=mp, ) self.model.to("cuda")