Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 21 additions & 19 deletions nemo_reinforcer/models/policy/hf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,22 +73,22 @@ def __init__(
world_size = torch.distributed.get_world_size()
model_name = self.cfg["model_name"]
if self.cfg["precision"] == "float32":
dtype = torch.float32
self.dtype = torch.float32
elif self.cfg["precision"] == "bfloat16":
dtype = torch.bfloat16
self.dtype = torch.bfloat16
else:
raise ValueError(f"Unknown precision: {self.cfg['precision']}")

print(f"[Rank {rank}] Loading model {model_name} on CPU...")
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="cpu", # load weights onto CPU initially
torch_dtype=dtype, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
torch_dtype=torch.float32, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
)
self.reference_model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="cpu", # load weights onto CPU initially
torch_dtype=dtype, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
torch_dtype=torch.float32, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
)

self.tokenizer = AutoTokenizer.from_pretrained(model_name)
Expand Down Expand Up @@ -272,16 +272,17 @@ def train(
# For right-padded sequence, set 1s at the beginning of the sequence
attention_mask[i, :length] = 1

outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
use_cache=False,
)
# Get logprobs
if not hasattr(outputs, "logits"):
logits = self.model.lm_head(outputs.last_hidden_state)
else:
logits = outputs.logits
with torch.autocast(device_type="cuda", dtype=self.dtype):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
use_cache=False,
)
# Get logprobs
if not hasattr(outputs, "logits"):
logits = self.model.lm_head(outputs.last_hidden_state)
else:
logits = outputs.logits

loss, loss_metrics = loss_fn(logits, mb)

Expand Down Expand Up @@ -358,11 +359,12 @@ def get_logprobs(self, data: BatchedDataDict) -> BatchedDataDict:
attention_mask[i, :length] = 1

# Process with the model directly using right-padded inputs
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
use_cache=False,
)
with torch.autocast(device_type="cuda", dtype=self.dtype):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
use_cache=False,
)
log_probs = torch.nn.functional.log_softmax(
outputs.logits.to(torch.float32), dim=-1
)
Expand Down