From 015eafb83537233bb5b4b08a6fc9a290bf4ff86d Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Mon, 13 Apr 2026 14:14:57 +0000 Subject: [PATCH] avoid wrap 4bit-quantized model into DP Signed-off-by: Liu, Kaixuan --- src/transformers/trainer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 235189fe8320..b9fc93376cbb 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2414,8 +2414,12 @@ def _wrap_model(self, model: nn.Module, training: bool = True, dataloader: DataL return model return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) - # Multi-gpu training, 8bit models does not support DP - if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False): + # Multi-gpu training, quantized models do not support DP + if ( + self.args.n_gpu > 1 + and not getattr(model, "is_loaded_in_8bit", False) + and not getattr(model, "is_loaded_in_4bit", False) + ): model = nn.DataParallel(model) # Note: in torch.distributed mode, there's no point in wrapping the model