Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemo_reinforcer/algorithms/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class ClippedPGLossFn(LossFunction):
- r_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t) is the probability ratio
- A_t is the advantage estimate
- ε is the clip parameter (ratio_eps)
- As proposed in the DAPO paper (https://arxiv.org/pdf/2503.14476),
- As proposed in the DAPO paper (https://arxiv.org/pdf/2503.14476),
we allow setting a distinct minimum and maximum value for the clip parameter (set to the same value for PPO/GRPO/etc.)
- ratio_eps_min: minimum value for the clip parameter
- ratio_eps_max: maximum value for the clip parameter
Expand Down
15 changes: 9 additions & 6 deletions nemo_reinforcer/data/llm_message_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,14 +353,13 @@ def get_formatted_message_log(
Returns:
The message log with updated 'token_ids' and 'content' fields.
"""
cu_message = []
new_message_log = []
prev_formatted_message = ""
template = task_data_spec.custom_template

for i, message in enumerate(message_log):
cu_message.append(message.copy())
formatted_message = tokenizer.apply_chat_template(
cu_message,
message_log[: i + 1],
chat_template=template,
add_generation_prompt=False,
tokenize=False,
Expand All @@ -383,13 +382,17 @@ def get_formatted_message_log(
message_chunk = message_chunk.rstrip("\n")
if not message_chunk.endswith(tokenizer.eos_token):
message_chunk += tokenizer.eos_token
message["token_ids"] = tokenizer(

new_message = message.copy()
new_message["token_ids"] = tokenizer(
message_chunk, return_tensors="pt", add_special_tokens=False
)["input_ids"][0]
message["content"] = message_chunk
new_message["content"] = message_chunk
new_message_log.append(new_message)

prev_formatted_message = formatted_message

return message_log
return new_message_log


def remap_dataset_keys(
Expand Down