diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 3194467c2f9..084138ea34e 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1053,6 +1053,87 @@ def test_dpo_loss_js_div_f(self): ) self.assertTrue(torch.isfinite(losses).cpu().numpy().all()) + def test_dpo_trainer_use_num_logits_to_keep(self): + model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM" + tokenizer = AutoTokenizer.from_pretrained(model_id) + + model = AutoModelForCausalLM.from_pretrained(model_id) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + use_num_logits_to_keep=True, + rpo_alpha=0.5, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + # dpo train lora model with a lora config + trainer = DPOTrainer( + model=model, + ref_model=None, + args=training_args, + tokenizer=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + training_args.use_num_logits_to_keep = False + trainer2 = DPOTrainer( + model=model, + ref_model=None, + args=training_args, + tokenizer=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + # Fake batch + prompt_input_ids = torch.randint(1, 1000, (2, 10)) + chosen_input_ids = torch.randint(1, 1000, (2, 5)) + rejected_input_ids = torch.randint(1, 1000, (2, 7)) + prompt_attention_mask = torch.ones_like(prompt_input_ids) + chosen_attention_mask = torch.ones_like(chosen_input_ids) + rejected_attention_mask = torch.ones_like(rejected_input_ids) + + batch = { + "prompt_input_ids": prompt_input_ids.to(model.device), + "chosen_input_ids": chosen_input_ids.to(model.device), + "rejected_input_ids": rejected_input_ids.to(model.device), + "prompt_attention_mask": prompt_attention_mask.to(model.device), + "chosen_attention_mask": chosen_attention_mask.to(model.device), + "rejected_attention_mask": rejected_attention_mask.to(model.device), + } + + output = trainer.concatenated_forward(model, batch) + output2 = trainer2.concatenated_forward(model, batch) + + np.testing.assert_allclose(output["nll_loss"].item(), output2["nll_loss"].item(), atol=1e-5) + np.testing.assert_allclose( + output["mean_chosen_logits"].item(), output2["mean_chosen_logits"].item(), atol=1e-5 + ) + np.testing.assert_allclose( + output["mean_rejected_logits"].item(), output2["mean_rejected_logits"].item(), atol=1e-5 + ) + + for i in range(output["chosen_logps"].shape[0]): + np.testing.assert_allclose( + output["chosen_logps"][i].item(), output2["chosen_logps"][i].item(), atol=1e-5 + ) + np.testing.assert_allclose( + output["rejected_logps"][i].item(), output2["rejected_logps"][i].item(), atol=1e-5 + ) + + trainer.train() + @require_vision class DPOVisionTrainerTester(unittest.TestCase): diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index b84dfb47dd8..3e653fa3f84 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -132,6 +132,11 @@ class DPOConfig(TrainingArguments): α parameter from the [RPO](https://huggingface.co/papers/2404.19733) paper (v3), which controls the weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the DPO loss. The paper recommends `rpo_alpha=1.0`. + use_num_logits_to_keep (`bool`, *optional*, defaults to `False`): + If `True`, only a specified number of logits are computed in the forward pass of CausalLM. This can be useful + for saving memory and speeding up training by not computing the logits for all tokens, especially in scenarios + when working with very long prompts where labels are -ignored (-100). + [Read more](https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaForCausalLM) """ learning_rate: float = 1e-6 @@ -176,6 +181,7 @@ class DPOConfig(TrainingArguments): ref_model_mixup_alpha: float = 0.9 ref_model_sync_steps: int = 64 rpo_alpha: Optional[float] = None + use_num_logits_to_keep: bool = False def __post_init__(self): if self.max_target_length is not None: diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 0c9883387a6..38acd85994d 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -47,7 +47,7 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput -from transformers.utils import is_peft_available +from transformers.utils import check_min_version, is_peft_available from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt @@ -396,6 +396,7 @@ def make_inputs_require_grad(module, input, output): self.truncation_mode = args.truncation_mode self.max_completion_length = args.max_completion_length self.precompute_ref_log_probs = args.precompute_ref_log_probs + self.use_num_logits_to_keep = args.use_num_logits_to_keep # Since ref_logs are precomputed on the first call to get_train/eval_dataloader # keep track of first called to avoid computation of future calls @@ -529,6 +530,11 @@ def make_inputs_require_grad(module, input, output): ) self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) + + # num_logits_to_keep is supported since transformers v4.45.0 + if self.use_num_logits_to_keep: + check_min_version("4.45.0") + if self.loss_type == "bco_pair": self.running = RunningMoments(self.accelerator) @@ -1087,9 +1093,9 @@ def concatenated_forward(self, model: nn.Module, batch: Dict[str, Union[List, to # Get the first column idx that is all zeros and remove every column after that empty_cols = torch.sum(attention_mask, dim=0) == 0 first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else attention_mask.size(1) - input_ids = input_ids[:, : first_empty_col] - attention_mask = attention_mask[:, : first_empty_col] - loss_mask = loss_mask[:, : first_empty_col] + input_ids = input_ids[:, :first_empty_col] + attention_mask = attention_mask[:, :first_empty_col] + loss_mask = loss_mask[:, :first_empty_col] # Truncate right if self.args.max_length is not None: @@ -1097,6 +1103,15 @@ def concatenated_forward(self, model: nn.Module, batch: Dict[str, Union[List, to attention_mask = attention_mask[:, : self.args.max_length] loss_mask = loss_mask[:, : self.args.max_length] + if self.use_num_logits_to_keep: + # Compute num_logits_to_keep based on loss_mask pattern: + # [[0, 0, 0, x, x, x, x], + # [0, 0, 0, x, x, x, 0]] + # ^ start computing logits from here ([:, -(7-3+1):]) + first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() + num_logits_to_keep = loss_mask.shape[1] - first_compute_index + model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + 1 # +1 for the first label + outputs = model(input_ids=input_ids, attention_mask=attention_mask, **model_kwargs) # Offset the logits by one to align with the labels @@ -1104,6 +1119,16 @@ def concatenated_forward(self, model: nn.Module, batch: Dict[str, Union[List, to labels = input_ids[:, 1:].clone() loss_mask = loss_mask[:, 1:].bool() + if self.use_num_logits_to_keep: + # Align labels with logits + # logits: -, -, [x2, x3, x4, x5, x6] + # ^ --------- ^ after logits[:, :-1, :] + # labels: [y0, y1, y2, y3, y4, y5, y6] + # ^ --------- ^ with num_logits_to_keep=4, [:, -4:] + # loss_mask: [0, 0, 0, 1, 1, 1, 1] + labels = labels[:, -num_logits_to_keep:] + loss_mask = loss_mask[:, -num_logits_to_keep:] + if logits.shape[:2] != labels.shape[:2]: # for llava, the returned logits include the image tokens (placed before the text tokens) seq_len = labels.shape[1]