From 72d464799e03febf63cfe242498a17eb4c49fe6c Mon Sep 17 00:00:00 2001 From: xiaoyang Date: Fri, 20 Sep 2024 16:01:28 +0800 Subject: [PATCH 01/19] Support num_logits_to_keep, which computes necessary logits in the forward pass. --- trl/trainer/dpo_config.py | 6 ++++++ trl/trainer/dpo_trainer.py | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index a517c3cbfef..9a634eda66d 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -128,6 +128,11 @@ class DPOConfig(TrainingArguments): α parameter from the [RPO](https://huggingface.co/papers/2404.19733) paper (v3), which controls the weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the DPO loss. The paper recommends `rpo_alpha=1.0`. + use_num_logits_to_keep (`bool`, *optional*, defaults to `False`): + If `True`, only a specified number of logits are used to compute the loss. This can be useful for saving memory + and speeding up training by not computing the loss for all logits, especially in scenarios where only the + top logits are relevant or when working with very long sequences. + [Read more](https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaForCausalLM) """ beta: float = 0.1 @@ -170,6 +175,7 @@ class DPOConfig(TrainingArguments): ref_model_mixup_alpha: float = 0.9 ref_model_sync_steps: int = 64 rpo_alpha: Optional[float] = None + use_num_logits_to_keep: bool = False def __post_init__(self): if self.max_target_length is not None: diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 3c9ff4624b2..5074364043b 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -451,6 +451,7 @@ def __init__( ref_adapter_name: Optional[str] = None, reference_free: bool = False, force_use_ref_model: bool = False, + use_num_logits_to_keep: bool = False, ): if not isinstance(model, str) and ref_model is model: raise ValueError( @@ -532,6 +533,12 @@ def __init__( ) args.force_use_ref_model = force_use_ref_model + if use_num_logits_to_keep: + warnings.warn( + "You passed `use_num_logits_to_keep` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`." + ) + args.use_num_logits_to_keep = use_num_logits_to_keep + if not is_peft_available() and peft_config is not None: raise ValueError( "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" @@ -763,6 +770,7 @@ def make_inputs_require_grad(module, input, output): self.truncation_mode = args.truncation_mode self.max_completion_length = args.max_completion_length self.precompute_ref_log_probs = args.precompute_ref_log_probs + self.use_num_logits_to_keep = args.use_num_logits_to_keep # Since ref_logs are precomputed on the first call to get_train/eval_dataloader # keep track of first called to avoid computation of future calls @@ -1062,6 +1070,7 @@ def concatenated_inputs( label_pad_token_id: int = -100, padding_value: int = 0, device: Optional[torch.device] = None, + use_num_logits_to_keep: bool = False, ) -> Dict[str, torch.LongTensor]: """Concatenate the chosen and rejected inputs into a single tensor. @@ -1082,6 +1091,11 @@ def concatenated_inputs( else: max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) + # Support num_logits_to_keep, which computes necessary logits in the forward pass. + # This saves memory for long prompts where labels are -100 (unused). + num_logits_to_keep = batch["chosen_labels"].shape[1] - (batch["chosen_labels"] != -100).nonzero(as_tuple=True)[1].min() + concatenated_batch["num_logits_to_keep"] = num_logits_to_keep.cpu() + for k in batch: if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): if "labels" in k or is_encoder_decoder: @@ -1126,6 +1140,8 @@ def concatenated_inputs( concatenated_batch["pixel_attention_mask"] = torch.cat( [batch["prompt_pixel_attention_mask"], batch["prompt_pixel_attention_mask"]], dim=0 ) + if use_num_logits_to_keep: + concatenated_batch["concatenated_labels"] = concatenated_batch["concatenated_labels"][:, -concatenated_batch["num_logits_to_keep"]:] return concatenated_batch def dpo_loss( @@ -1358,6 +1374,7 @@ def concatenated_forward( label_pad_token_id=self.label_pad_token_id, padding_value=self.padding_value, device=self.accelerator.device, + use_num_logits_to_keep=self.use_num_logits_to_keep ) len_chosen = batch["chosen_labels"].shape[0] @@ -1379,6 +1396,7 @@ def concatenated_forward( concatenated_batch["concatenated_input_ids"], attention_mask=concatenated_batch["concatenated_attention_mask"], use_cache=False, + num_logits_to_keep=concatenated_batch["num_logits_to_keep"], # save memory **model_kwargs, ) all_logits = outputs.logits From 13cb564e5936db861cbb1bb99bc8a492233e0cff Mon Sep 17 00:00:00 2001 From: xiaoyang Date: Fri, 20 Sep 2024 16:36:20 +0800 Subject: [PATCH 02/19] update doc --- trl/trainer/dpo_config.py | 6 +++--- trl/trainer/dpo_trainer.py | 17 ++++++----------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 9a634eda66d..7cbc2359959 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -129,9 +129,9 @@ class DPOConfig(TrainingArguments): weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the DPO loss. The paper recommends `rpo_alpha=1.0`. use_num_logits_to_keep (`bool`, *optional*, defaults to `False`): - If `True`, only a specified number of logits are used to compute the loss. This can be useful for saving memory - and speeding up training by not computing the loss for all logits, especially in scenarios where only the - top logits are relevant or when working with very long sequences. + If `True`, only a specified number of logits are computed in the forward pass. This can be useful for + saving memory and speeding up training by not computing the logits for all tokens, especially in scenarios + when working with very long prompts where labels are -100 (ignored). [Read more](https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaForCausalLM) """ diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 5074364043b..50ef92d7617 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -450,8 +450,7 @@ def __init__( model_adapter_name: Optional[str] = None, ref_adapter_name: Optional[str] = None, reference_free: bool = False, - force_use_ref_model: bool = False, - use_num_logits_to_keep: bool = False, + force_use_ref_model: bool = False ): if not isinstance(model, str) and ref_model is model: raise ValueError( @@ -533,12 +532,6 @@ def __init__( ) args.force_use_ref_model = force_use_ref_model - if use_num_logits_to_keep: - warnings.warn( - "You passed `use_num_logits_to_keep` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`." - ) - args.use_num_logits_to_keep = use_num_logits_to_keep - if not is_peft_available() and peft_config is not None: raise ValueError( "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" @@ -1080,6 +1073,7 @@ def concatenated_inputs( label_pad_token_id: The label pad token id. padding_value: The padding value to use for the concatenated inputs_ids. device: The device for the concatenated inputs. + use_num_logits_to_keep: Whether to computes necessary logits in the forward pass. Returns: A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. @@ -1092,9 +1086,10 @@ def concatenated_inputs( max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) # Support num_logits_to_keep, which computes necessary logits in the forward pass. - # This saves memory for long prompts where labels are -100 (unused). - num_logits_to_keep = batch["chosen_labels"].shape[1] - (batch["chosen_labels"] != -100).nonzero(as_tuple=True)[1].min() - concatenated_batch["num_logits_to_keep"] = num_logits_to_keep.cpu() + # This saves memory for long prompts where labels are -100 (ignored). + if use_num_logits_to_keep: + num_logits_to_keep = batch["chosen_labels"].shape[1] - (batch["chosen_labels"] != -100).nonzero(as_tuple=True)[1].min() + concatenated_batch["num_logits_to_keep"] = num_logits_to_keep.cpu() for k in batch: if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): From efb12cbb786e1b2e30c66477d349fdf62b7c169d Mon Sep 17 00:00:00 2001 From: xiaoyang Date: Thu, 26 Sep 2024 10:54:54 +0800 Subject: [PATCH 03/19] bug fix --- tests/test_dpo_trainer.py | 58 ++++++++++++++++++++++++++++++++++++++ trl/trainer/dpo_config.py | 6 ++-- trl/trainer/dpo_trainer.py | 51 +++++++++++++++++++++++---------- 3 files changed, 97 insertions(+), 18 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 7d6e5c46701..a5c0f774934 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1043,6 +1043,64 @@ def test_dpo_loss_js_div_f(self): assert torch.isfinite(losses).cpu().numpy().all() + def test_dpo_trainer_use_num_logits_to_keep(self): + model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM" + tokenizer = AutoTokenizer.from_pretrained(model_id) + + model = AutoModelForCausalLM.from_pretrained(model_id) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + generate_during_eval=True, + report_to="none", + use_num_logits_to_keep=True, + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + # dpo train lora model with a lora config + trainer = DPOTrainer( + model=model, + ref_model=None, + args=training_args, + tokenizer=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + # Fake batch + chosen_labels = torch.tensor([[-100, -100, -100, -100, -100, 4, 5, -100], + [-100, -100, -100, -100, 2, 4, 5, 6]]) + rejected_labels = torch.tensor([[-100, -100, -100, -100, 100, 7, 5, 909], + [-100, -100, -100, -100, 88, 4, 5, 6]]) + chosen_input_ids = torch.tensor([[700, 3, 5, 8, 9, 76, 4, 5, -100], + [700, 3, 5, 8, 9, 2, 4, 5, 6]]) + rejected_input_ids = torch.tensor([[700, 3, 5, 8, 9, 100, 7, 5, 909], + [700, 3, 5, 8, 9, 88, 4, 5, 6]]) + chosen_attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1]]) + rejected_attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1]]) + + batch = {"chosen_labels": chosen_labels, "rejected_labels": rejected_labels, + "chosen_input_ids": chosen_input_ids, "rejected_input_ids": rejected_input_ids, + "chosen_attention_mask": chosen_attention_mask, "rejected_attention_mask": rejected_attention_mask} + + _, _, chosen_logits, rejected_logits, _ = trainer.concatenated_forward(model, batch) + + assert 4 == chosen_logits.shape[1] + assert 4 == rejected_logits.shape[1] + trainer.train() + + class DPOVisionTrainerTester(unittest.TestCase): @parameterized.expand( [ diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 7cbc2359959..9d21bc0c75f 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -129,9 +129,9 @@ class DPOConfig(TrainingArguments): weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the DPO loss. The paper recommends `rpo_alpha=1.0`. use_num_logits_to_keep (`bool`, *optional*, defaults to `False`): - If `True`, only a specified number of logits are computed in the forward pass. This can be useful for - saving memory and speeding up training by not computing the logits for all tokens, especially in scenarios - when working with very long prompts where labels are -100 (ignored). + If `True`, only a specified number of logits are computed in the forward pass of CausalLM. This can be useful + for saving memory and speeding up training by not computing the logits for all tokens, especially in scenarios + when working with very long prompts where labels are -ignored (-100). [Read more](https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaForCausalLM) """ diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 50ef92d7617..78afcae29a1 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -57,6 +57,7 @@ peft_module_casting_to_bf16, trl_sanitze_kwargs_for_tagging, ) +from packaging.version import Version if is_peft_available(): @@ -892,6 +893,16 @@ def make_inputs_require_grad(module, input, output): ) self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) + + # num_logits_to_keep is supported since transformers v4.45.0 + if self.use_num_logits_to_keep: + import transformers + transformers_version = transformers.__version__ + if Version(transformers_version) < Version("4.45.0"): + raise ValueError( + f"num_logits_to_keep is only supported since transformers v4.45.0. Your current version is {transformers_version}." + ) + if self.loss_type == "bco_pair": self.running = RunningMoments(self.accelerator) @@ -1073,7 +1084,7 @@ def concatenated_inputs( label_pad_token_id: The label pad token id. padding_value: The padding value to use for the concatenated inputs_ids. device: The device for the concatenated inputs. - use_num_logits_to_keep: Whether to computes necessary logits in the forward pass. + use_num_logits_to_keep: Whether to computes necessary logits in the forward pass of CausalLM. Returns: A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. @@ -1085,12 +1096,6 @@ def concatenated_inputs( else: max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) - # Support num_logits_to_keep, which computes necessary logits in the forward pass. - # This saves memory for long prompts where labels are -100 (ignored). - if use_num_logits_to_keep: - num_logits_to_keep = batch["chosen_labels"].shape[1] - (batch["chosen_labels"] != -100).nonzero(as_tuple=True)[1].min() - concatenated_batch["num_logits_to_keep"] = num_logits_to_keep.cpu() - for k in batch: if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): if "labels" in k or is_encoder_decoder: @@ -1135,8 +1140,16 @@ def concatenated_inputs( concatenated_batch["pixel_attention_mask"] = torch.cat( [batch["prompt_pixel_attention_mask"], batch["prompt_pixel_attention_mask"]], dim=0 ) - if use_num_logits_to_keep: + + # Support num_logits_to_keep, which computes necessary logits in the forward pass. + # This saves memory for long prompts where labels are -100 (label_pad_token_id). + if use_num_logits_to_keep and not is_encoder_decoder: + concatenated_batch["num_logits_to_keep"] = 0 + min_compute_index = (concatenated_batch["concatenated_labels"] != label_pad_token_id).nonzero(as_tuple=True)[1].min() + num_logits_to_keep = concatenated_batch["concatenated_labels"].shape[1] - min_compute_index + concatenated_batch["num_logits_to_keep"] = num_logits_to_keep.cpu() concatenated_batch["concatenated_labels"] = concatenated_batch["concatenated_labels"][:, -concatenated_batch["num_logits_to_keep"]:] + return concatenated_batch def dpo_loss( @@ -1387,13 +1400,21 @@ def concatenated_forward( if self.aux_loss_enabled: model_kwargs["output_router_logits"] = True - outputs = model( - concatenated_batch["concatenated_input_ids"], - attention_mask=concatenated_batch["concatenated_attention_mask"], - use_cache=False, - num_logits_to_keep=concatenated_batch["num_logits_to_keep"], # save memory - **model_kwargs, - ) + if self.use_num_logits_to_keep and not self.is_encoder_decoder: + outputs = model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + use_cache=False, + num_logits_to_keep=concatenated_batch["num_logits_to_keep"], # save memory + **model_kwargs, + ) + else: + outputs = model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + use_cache=False, + **model_kwargs, + ) all_logits = outputs.logits if all_logits.shape[:2] != concatenated_batch["concatenated_labels"].shape[:2]: From 1b2186550b8d730cc3da03fba8c6a2d515a475a9 Mon Sep 17 00:00:00 2001 From: xiaoyang Date: Thu, 26 Sep 2024 10:59:22 +0800 Subject: [PATCH 04/19] update --- trl/trainer/dpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 78afcae29a1..b3433350585 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -451,7 +451,7 @@ def __init__( model_adapter_name: Optional[str] = None, ref_adapter_name: Optional[str] = None, reference_free: bool = False, - force_use_ref_model: bool = False + force_use_ref_model: bool = False, ): if not isinstance(model, str) and ref_model is model: raise ValueError( From 03b1b673a16cd99f1cc5cd8f15b37cc3c7a397b1 Mon Sep 17 00:00:00 2001 From: xiaoyang Date: Thu, 26 Sep 2024 11:31:04 +0800 Subject: [PATCH 05/19] check is model supports num_logits_to_keep --- trl/trainer/dpo_trainer.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index b3433350585..0ed6e69b25a 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1143,7 +1143,7 @@ def concatenated_inputs( # Support num_logits_to_keep, which computes necessary logits in the forward pass. # This saves memory for long prompts where labels are -100 (label_pad_token_id). - if use_num_logits_to_keep and not is_encoder_decoder: + if use_num_logits_to_keep: concatenated_batch["num_logits_to_keep"] = 0 min_compute_index = (concatenated_batch["concatenated_labels"] != label_pad_token_id).nonzero(as_tuple=True)[1].min() num_logits_to_keep = concatenated_batch["concatenated_labels"].shape[1] - min_compute_index @@ -1375,6 +1375,12 @@ def concatenated_forward( We do this to avoid doing two forward passes, because it's faster for FSDP. """ + model_support_num_logits_to_keep = "num_logits_to_keep" in set(inspect.signature(model.forward).parameters.keys()) + if self.use_num_logits_to_keep and not model_support_num_logits_to_keep: + self.use_num_logits_to_keep = False + warnings.warn( + "The model does not support num_logits_to_keep. The parameter use_num_logits_to_keep will be ignored." + ) concatenated_batch = self.concatenated_inputs( batch, is_encoder_decoder=self.is_encoder_decoder, @@ -1382,7 +1388,7 @@ def concatenated_forward( label_pad_token_id=self.label_pad_token_id, padding_value=self.padding_value, device=self.accelerator.device, - use_num_logits_to_keep=self.use_num_logits_to_keep + use_num_logits_to_keep=self.use_num_logits_to_keep, ) len_chosen = batch["chosen_labels"].shape[0] @@ -1400,7 +1406,7 @@ def concatenated_forward( if self.aux_loss_enabled: model_kwargs["output_router_logits"] = True - if self.use_num_logits_to_keep and not self.is_encoder_decoder: + if self.use_num_logits_to_keep: outputs = model( concatenated_batch["concatenated_input_ids"], attention_mask=concatenated_batch["concatenated_attention_mask"], From 6ca9cbe0ec4cafeb53bf10d84524f6858a3da5c6 Mon Sep 17 00:00:00 2001 From: xiaoyang Date: Fri, 27 Sep 2024 16:18:05 +0800 Subject: [PATCH 06/19] ruff format --- tests/test_dpo_trainer.py | 36 +++++++++++++++++++----------------- trl/trainer/dpo_trainer.py | 6 +++--- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index a5c0f774934..2f130362a30 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1042,7 +1042,6 @@ def test_dpo_loss_js_div_f(self): ) assert torch.isfinite(losses).cpu().numpy().all() - def test_dpo_trainer_use_num_logits_to_keep(self): model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM" tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -1077,22 +1076,25 @@ def test_dpo_trainer_use_num_logits_to_keep(self): ) # Fake batch - chosen_labels = torch.tensor([[-100, -100, -100, -100, -100, 4, 5, -100], - [-100, -100, -100, -100, 2, 4, 5, 6]]) - rejected_labels = torch.tensor([[-100, -100, -100, -100, 100, 7, 5, 909], - [-100, -100, -100, -100, 88, 4, 5, 6]]) - chosen_input_ids = torch.tensor([[700, 3, 5, 8, 9, 76, 4, 5, -100], - [700, 3, 5, 8, 9, 2, 4, 5, 6]]) - rejected_input_ids = torch.tensor([[700, 3, 5, 8, 9, 100, 7, 5, 909], - [700, 3, 5, 8, 9, 88, 4, 5, 6]]) - chosen_attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1, 1]]) - rejected_attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1, 1]]) - - batch = {"chosen_labels": chosen_labels, "rejected_labels": rejected_labels, - "chosen_input_ids": chosen_input_ids, "rejected_input_ids": rejected_input_ids, - "chosen_attention_mask": chosen_attention_mask, "rejected_attention_mask": rejected_attention_mask} + chosen_labels = torch.tensor( + [[-100, -100, -100, -100, -100, 4, 5, -100], [-100, -100, -100, -100, 2, 4, 5, 6]] + ) + rejected_labels = torch.tensor( + [[-100, -100, -100, -100, 100, 7, 5, 909], [-100, -100, -100, -100, 88, 4, 5, 6]] + ) + chosen_input_ids = torch.tensor([[700, 3, 5, 8, 9, 76, 4, 5, -100], [700, 3, 5, 8, 9, 2, 4, 5, 6]]) + rejected_input_ids = torch.tensor([[700, 3, 5, 8, 9, 100, 7, 5, 909], [700, 3, 5, 8, 9, 88, 4, 5, 6]]) + chosen_attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]]) + rejected_attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]]) + + batch = { + "chosen_labels": chosen_labels, + "rejected_labels": rejected_labels, + "chosen_input_ids": chosen_input_ids, + "rejected_input_ids": rejected_input_ids, + "chosen_attention_mask": chosen_attention_mask, + "rejected_attention_mask": rejected_attention_mask, + } _, _, chosen_logits, rejected_logits, _ = trainer.concatenated_forward(model, batch) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 0ed6e69b25a..baa9186e484 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -29,6 +29,7 @@ from accelerate.utils import is_deepspeed_available, tqdm from datasets import Dataset from huggingface_hub.utils._deprecation import _deprecate_arguments +from packaging.version import Version from torch.utils.data import DataLoader from transformers import ( AutoModelForCausalLM, @@ -57,7 +58,6 @@ peft_module_casting_to_bf16, trl_sanitze_kwargs_for_tagging, ) -from packaging.version import Version if is_peft_available(): @@ -893,7 +893,7 @@ def make_inputs_require_grad(module, input, output): ) self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) - + # num_logits_to_keep is supported since transformers v4.45.0 if self.use_num_logits_to_keep: import transformers @@ -902,7 +902,7 @@ def make_inputs_require_grad(module, input, output): raise ValueError( f"num_logits_to_keep is only supported since transformers v4.45.0. Your current version is {transformers_version}." ) - + if self.loss_type == "bco_pair": self.running = RunningMoments(self.accelerator) From f2ac776e01a15cbf422174f320d0e9738697d0ec Mon Sep 17 00:00:00 2001 From: xiaoyang Date: Sun, 29 Sep 2024 10:27:37 +0800 Subject: [PATCH 07/19] update test file --- tests/test_dpo_trainer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 2f130362a30..17367077486 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1058,8 +1058,6 @@ def test_dpo_trainer_use_num_logits_to_keep(self): learning_rate=9e-1, eval_strategy="steps", beta=0.1, - generate_during_eval=True, - report_to="none", use_num_logits_to_keep=True, ) From cbe58bbbb5808d4b6e6962f232a4cc4ec5053db4 Mon Sep 17 00:00:00 2001 From: xiaoyang Date: Sun, 29 Sep 2024 17:35:19 +0800 Subject: [PATCH 08/19] peft model support --- trl/trainer/dpo_trainer.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 9ecf26c3ca9..f982dc47677 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1381,14 +1381,7 @@ def concatenated_forward( We do this to avoid doing two forward passes, because it's faster for FSDP. """ - model_support_num_logits_to_keep = "num_logits_to_keep" in set( - inspect.signature(model.forward).parameters.keys() - ) - if self.use_num_logits_to_keep and not model_support_num_logits_to_keep: - self.use_num_logits_to_keep = False - warnings.warn( - "The model does not support num_logits_to_keep. The parameter use_num_logits_to_keep will be ignored." - ) + concatenated_batch = self.concatenated_inputs( batch, is_encoder_decoder=self.is_encoder_decoder, From 725ccf06de9ef2203b328f7bcc2c0b3f54e7a76a Mon Sep 17 00:00:00 2001 From: xiaoyang Date: Sun, 29 Sep 2024 19:52:30 +0800 Subject: [PATCH 09/19] test passed --- tests/test_dpo_trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index c749b0ea5fd..5d5b99f432c 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1082,10 +1082,10 @@ def test_dpo_trainer_use_num_logits_to_keep(self): rejected_labels = torch.tensor( [[-100, -100, -100, -100, 100, 7, 5, 909], [-100, -100, -100, -100, 88, 4, 5, 6]] ) - chosen_input_ids = torch.tensor([[700, 3, 5, 8, 9, 76, 4, 5, -100], [700, 3, 5, 8, 9, 2, 4, 5, 6]]) - rejected_input_ids = torch.tensor([[700, 3, 5, 8, 9, 100, 7, 5, 909], [700, 3, 5, 8, 9, 88, 4, 5, 6]]) - chosen_attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]]) - rejected_attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]]) + chosen_input_ids = torch.randint_like(chosen_labels, high=1000) + rejected_input_ids = torch.randint_like(rejected_labels, high=1000) + chosen_attention_mask = torch.ones_like(chosen_input_ids) + rejected_attention_mask = torch.ones_like(rejected_input_ids) batch = { "chosen_labels": chosen_labels, From d1ada5068cf99224c3d146c86c0edbeeb3d8e097 Mon Sep 17 00:00:00 2001 From: xiaoyang Date: Fri, 11 Oct 2024 18:06:56 +0800 Subject: [PATCH 10/19] update --- tests/test_dpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index e9d7e2c6cd5..356ce87685c 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1136,7 +1136,7 @@ def test_dpo_trainer_use_num_logits_to_keep(self): "rejected_attention_mask": rejected_attention_mask, } - _, _, chosen_logits, rejected_logits, _ = trainer.concatenated_forward(model, batch) + _, _, chosen_logits, rejected_logits, _, _ = trainer.concatenated_forward(model, batch) assert 4 == chosen_logits.shape[1] assert 4 == rejected_logits.shape[1] From 56241e064192df9822fd9b3e64ae67184225c7b8 Mon Sep 17 00:00:00 2001 From: xiaoyang Date: Wed, 23 Oct 2024 18:01:09 +0800 Subject: [PATCH 11/19] apply use_num_logits_to_keep --- trl/trainer/dpo_trainer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index d205c98bf77..b58bd3363e4 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1270,6 +1270,11 @@ def concatenated_forward(self, model: nn.Module, batch: Dict[str, Union[List, to attention_mask = attention_mask[:, : self.args.max_length] loss_mask = loss_mask[:, : self.args.max_length] + if self.use_num_logits_to_keep: + first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() + num_logits_to_keep = loss_mask.shape[1] - first_compute_index + model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + outputs = model(input_ids=input_ids, attention_mask=attention_mask, **model_kwargs) # Offset the logits by one to align with the labels @@ -1277,6 +1282,10 @@ def concatenated_forward(self, model: nn.Module, batch: Dict[str, Union[List, to labels = input_ids[:, 1:].clone() loss_mask = loss_mask[:, 1:].bool() + if self.use_num_logits_to_keep: + labels = labels[:, -num_logits_to_keep:] + loss_mask = loss_mask[:, -num_logits_to_keep:] + if logits.shape[:2] != labels.shape[:2]: # for llava, the returned logits include the image tokens (placed before the text tokens) seq_len = labels.shape[1] From dfb2fa9908f8813315a611d7afe5a81401e609c6 Mon Sep 17 00:00:00 2001 From: xiaoyang Date: Fri, 25 Oct 2024 11:37:35 +0800 Subject: [PATCH 12/19] fix num_logits_to_keep compute bug --- tests/test_dpo_trainer.py | 47 +++++++++++++++++++++++--------------- trl/trainer/dpo_trainer.py | 2 +- 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 5e4a562e6ae..32516da71d9 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1073,31 +1073,40 @@ def test_dpo_trainer_use_num_logits_to_keep(self): eval_dataset=dummy_dataset["test"], ) - # Fake batch - chosen_labels = torch.tensor( - [[-100, -100, -100, -100, -100, 4, 5, -100], [-100, -100, -100, -100, 2, 4, 5, 6]] - ) - rejected_labels = torch.tensor( - [[-100, -100, -100, -100, 100, 7, 5, 909], [-100, -100, -100, -100, 88, 4, 5, 6]] + training_args.use_num_logits_to_keep = False + trainer2 = DPOTrainer( + model=model, + ref_model=None, + args=training_args, + tokenizer=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], ) - chosen_input_ids = torch.randint_like(chosen_labels, high=1000) - rejected_input_ids = torch.randint_like(rejected_labels, high=1000) + + # Fake batch + prompt_input_ids = torch.randint(1, 1000, (2, 10)) + chosen_input_ids = torch.randint(1, 1000, (2, 5)) + rejected_input_ids = torch.randint(1, 1000, (2, 7)) + prompt_attention_mask = torch.ones_like(prompt_input_ids) chosen_attention_mask = torch.ones_like(chosen_input_ids) rejected_attention_mask = torch.ones_like(rejected_input_ids) batch = { - "chosen_labels": chosen_labels, - "rejected_labels": rejected_labels, - "chosen_input_ids": chosen_input_ids, - "rejected_input_ids": rejected_input_ids, - "chosen_attention_mask": chosen_attention_mask, - "rejected_attention_mask": rejected_attention_mask, + "prompt_input_ids": prompt_input_ids.to(model.device), + "chosen_input_ids": chosen_input_ids.to(model.device), + "rejected_input_ids": rejected_input_ids.to(model.device), + "prompt_attention_mask": prompt_attention_mask.to(model.device), + "chosen_attention_mask": chosen_attention_mask.to(model.device), + "rejected_attention_mask": rejected_attention_mask.to(model.device), } - _, _, chosen_logits, rejected_logits, _, _ = trainer.concatenated_forward(model, batch) + output = trainer.concatenated_forward(model, batch) + output2 = trainer2.concatenated_forward(model, batch) + + for i in range(output["chosen_logps"].shape[0]): + assert output["chosen_logps"][i].item() == pytest.approx(output2["chosen_logps"][i].item()) + assert output["rejected_logps"][i].item() == pytest.approx(output2["rejected_logps"][i].item()) - assert 4 == chosen_logits.shape[1] - assert 4 == rejected_logits.shape[1] trainer.train() @@ -1190,4 +1199,6 @@ def test_vdpo_trainer(self, model_id): if __name__ == "__main__": - unittest.main() + # unittest.main() + dpt = DPOTrainerTester() + dpt.test_dpo_trainer_use_num_logits_to_keep() diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index b58bd3363e4..b9080581a60 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1273,7 +1273,7 @@ def concatenated_forward(self, model: nn.Module, batch: Dict[str, Union[List, to if self.use_num_logits_to_keep: first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() num_logits_to_keep = loss_mask.shape[1] - first_compute_index - model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + 1 # +1 for the first label outputs = model(input_ids=input_ids, attention_mask=attention_mask, **model_kwargs) From 9e9e2ef4914454647500e4c5427b24a7881146cd Mon Sep 17 00:00:00 2001 From: xiaoyang Date: Fri, 25 Oct 2024 11:51:13 +0800 Subject: [PATCH 13/19] compare all outputs --- tests/test_dpo_trainer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 32516da71d9..6e17a9e394b 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1059,6 +1059,8 @@ def test_dpo_trainer_use_num_logits_to_keep(self): eval_strategy="steps", beta=0.1, use_num_logits_to_keep=True, + rpo_alpha=0.5, + report_to="none", ) dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") @@ -1103,6 +1105,9 @@ def test_dpo_trainer_use_num_logits_to_keep(self): output = trainer.concatenated_forward(model, batch) output2 = trainer2.concatenated_forward(model, batch) + assert output["nll_loss"].item() == pytest.approx(output2["nll_loss"].item()) + assert output["mean_chosen_logits"].item() == pytest.approx(output2["mean_chosen_logits"].item()) + for i in range(output["chosen_logps"].shape[0]): assert output["chosen_logps"][i].item() == pytest.approx(output2["chosen_logps"][i].item()) assert output["rejected_logps"][i].item() == pytest.approx(output2["rejected_logps"][i].item()) From 858d4043f1d67b8a50d31de9cdb94cb331e755d1 Mon Sep 17 00:00:00 2001 From: xiaoyang Date: Fri, 25 Oct 2024 11:53:10 +0800 Subject: [PATCH 14/19] pytest --- tests/test_dpo_trainer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 6e17a9e394b..4894e693b7b 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1204,6 +1204,4 @@ def test_vdpo_trainer(self, model_id): if __name__ == "__main__": - # unittest.main() - dpt = DPOTrainerTester() - dpt.test_dpo_trainer_use_num_logits_to_keep() + unittest.main() From dc7849a088b6dfde8df5f6acb4228ba09edb131c Mon Sep 17 00:00:00 2001 From: xiaoyang Date: Fri, 25 Oct 2024 13:25:14 +0800 Subject: [PATCH 15/19] pass test --- tests/test_dpo_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 4894e693b7b..344adc29121 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1107,6 +1107,7 @@ def test_dpo_trainer_use_num_logits_to_keep(self): assert output["nll_loss"].item() == pytest.approx(output2["nll_loss"].item()) assert output["mean_chosen_logits"].item() == pytest.approx(output2["mean_chosen_logits"].item()) + assert output["mean_rejected_logits"].item() == pytest.approx(output2["mean_rejected_logits"].item()) for i in range(output["chosen_logps"].shape[0]): assert output["chosen_logps"][i].item() == pytest.approx(output2["chosen_logps"][i].item()) From 70838a7e94643bb828eb7434fe8a48deb0998859 Mon Sep 17 00:00:00 2001 From: xiaoyang Date: Wed, 6 Nov 2024 16:13:26 +0800 Subject: [PATCH 16/19] use check_min_version --- trl/trainer/dpo_trainer.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 907335f3896..96affb8d012 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -31,7 +31,6 @@ from accelerate import PartialState from accelerate.utils import is_deepspeed_available, tqdm from datasets import Dataset -from packaging.version import Version from torch.utils.data import DataLoader from transformers import ( AutoModelForCausalLM, @@ -534,13 +533,9 @@ def make_inputs_require_grad(module, input, output): # num_logits_to_keep is supported since transformers v4.45.0 if self.use_num_logits_to_keep: - import transformers + from transformers.utils import check_min_version - transformers_version = transformers.__version__ - if Version(transformers_version) < Version("4.45.0"): - raise ValueError( - f"num_logits_to_keep is only supported since transformers v4.45.0. Your current version is {transformers_version}." - ) + check_min_version("4.45.0") if self.loss_type == "bco_pair": self.running = RunningMoments(self.accelerator) From c4b5c4fad33e0bb7d2ea05e68bc51d9479d2e969 Mon Sep 17 00:00:00 2001 From: xiaoyang Date: Wed, 6 Nov 2024 16:15:36 +0800 Subject: [PATCH 17/19] format --- trl/trainer/dpo_trainer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 96affb8d012..d267eb4b58e 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -47,7 +47,7 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput -from transformers.utils import is_peft_available +from transformers.utils import check_min_version, is_peft_available from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt @@ -533,8 +533,6 @@ def make_inputs_require_grad(module, input, output): # num_logits_to_keep is supported since transformers v4.45.0 if self.use_num_logits_to_keep: - from transformers.utils import check_min_version - check_min_version("4.45.0") if self.loss_type == "bco_pair": From 32ed085dabd96f0e8331d488e563f8e2b6ea2fcb Mon Sep 17 00:00:00 2001 From: xiaoyang Date: Wed, 6 Nov 2024 17:00:03 +0800 Subject: [PATCH 18/19] test_dpo_trainer_use_num_logits_to_keep passed --- tests/test_dpo_trainer.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 8e7e197040a..084138ea34e 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1116,13 +1116,21 @@ def test_dpo_trainer_use_num_logits_to_keep(self): output = trainer.concatenated_forward(model, batch) output2 = trainer2.concatenated_forward(model, batch) - assert output["nll_loss"].item() == pytest.approx(output2["nll_loss"].item()) - assert output["mean_chosen_logits"].item() == pytest.approx(output2["mean_chosen_logits"].item()) - assert output["mean_rejected_logits"].item() == pytest.approx(output2["mean_rejected_logits"].item()) + np.testing.assert_allclose(output["nll_loss"].item(), output2["nll_loss"].item(), atol=1e-5) + np.testing.assert_allclose( + output["mean_chosen_logits"].item(), output2["mean_chosen_logits"].item(), atol=1e-5 + ) + np.testing.assert_allclose( + output["mean_rejected_logits"].item(), output2["mean_rejected_logits"].item(), atol=1e-5 + ) for i in range(output["chosen_logps"].shape[0]): - assert output["chosen_logps"][i].item() == pytest.approx(output2["chosen_logps"][i].item()) - assert output["rejected_logps"][i].item() == pytest.approx(output2["rejected_logps"][i].item()) + np.testing.assert_allclose( + output["chosen_logps"][i].item(), output2["chosen_logps"][i].item(), atol=1e-5 + ) + np.testing.assert_allclose( + output["rejected_logps"][i].item(), output2["rejected_logps"][i].item(), atol=1e-5 + ) trainer.train() From f3c044ee11a9eb720f53dc5cc43f23badffceddd Mon Sep 17 00:00:00 2001 From: xiaoyang Date: Fri, 8 Nov 2024 11:32:22 +0800 Subject: [PATCH 19/19] add some comments --- trl/trainer/dpo_trainer.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 954af9b3c28..38acd85994d 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1093,9 +1093,9 @@ def concatenated_forward(self, model: nn.Module, batch: Dict[str, Union[List, to # Get the first column idx that is all zeros and remove every column after that empty_cols = torch.sum(attention_mask, dim=0) == 0 first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else attention_mask.size(1) - input_ids = input_ids[:, : first_empty_col] - attention_mask = attention_mask[:, : first_empty_col] - loss_mask = loss_mask[:, : first_empty_col] + input_ids = input_ids[:, :first_empty_col] + attention_mask = attention_mask[:, :first_empty_col] + loss_mask = loss_mask[:, :first_empty_col] # Truncate right if self.args.max_length is not None: @@ -1104,6 +1104,10 @@ def concatenated_forward(self, model: nn.Module, batch: Dict[str, Union[List, to loss_mask = loss_mask[:, : self.args.max_length] if self.use_num_logits_to_keep: + # Compute num_logits_to_keep based on loss_mask pattern: + # [[0, 0, 0, x, x, x, x], + # [0, 0, 0, x, x, x, 0]] + # ^ start computing logits from here ([:, -(7-3+1):]) first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() num_logits_to_keep = loss_mask.shape[1] - first_compute_index model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + 1 # +1 for the first label @@ -1116,6 +1120,12 @@ def concatenated_forward(self, model: nn.Module, batch: Dict[str, Union[List, to loss_mask = loss_mask[:, 1:].bool() if self.use_num_logits_to_keep: + # Align labels with logits + # logits: -, -, [x2, x3, x4, x5, x6] + # ^ --------- ^ after logits[:, :-1, :] + # labels: [y0, y1, y2, y3, y4, y5, y6] + # ^ --------- ^ with num_logits_to_keep=4, [:, -4:] + # loss_mask: [0, 0, 0, 1, 1, 1, 1] labels = labels[:, -num_logits_to_keep:] loss_mask = loss_mask[:, -num_logits_to_keep:]