diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 6a5d6b593d..158c9824eb 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -58,7 +58,7 @@ class ClippedPGLossFn(LossFunction): - r_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t) is the probability ratio - A_t is the advantage estimate - ε is the clip parameter (ratio_eps) - - As proposed in the DAPO paper (https://arxiv.org/pdf/2503.14476), + - As proposed in the DAPO paper (https://arxiv.org/pdf/2503.14476), we allow setting a distinct minimum and maximum value for the clip parameter (set to the same value for PPO/GRPO/etc.) - ratio_eps_min: minimum value for the clip parameter - ratio_eps_max: maximum value for the clip parameter diff --git a/nemo_reinforcer/data/llm_message_utils.py b/nemo_reinforcer/data/llm_message_utils.py index 5ae8bee9a8..f6cb3c8079 100644 --- a/nemo_reinforcer/data/llm_message_utils.py +++ b/nemo_reinforcer/data/llm_message_utils.py @@ -353,14 +353,13 @@ def get_formatted_message_log( Returns: The message log with updated 'token_ids' and 'content' fields. """ - cu_message = [] + new_message_log = [] prev_formatted_message = "" template = task_data_spec.custom_template for i, message in enumerate(message_log): - cu_message.append(message.copy()) formatted_message = tokenizer.apply_chat_template( - cu_message, + message_log[: i + 1], chat_template=template, add_generation_prompt=False, tokenize=False, @@ -383,13 +382,17 @@ def get_formatted_message_log( message_chunk = message_chunk.rstrip("\n") if not message_chunk.endswith(tokenizer.eos_token): message_chunk += tokenizer.eos_token - message["token_ids"] = tokenizer( + + new_message = message.copy() + new_message["token_ids"] = tokenizer( message_chunk, return_tensors="pt", add_special_tokens=False )["input_ids"][0] - message["content"] = message_chunk + new_message["content"] = message_chunk + new_message_log.append(new_message) + prev_formatted_message = formatted_message - return message_log + return new_message_log def remap_dataset_keys(