diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 1b8dacb632cc..8665b27acfa9 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2415,8 +2415,12 @@ def _wrap_model(self, model: nn.Module, training: bool = True, dataloader: DataL return model return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) - # Multi-gpu training, 8bit models does not support DP - if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False): + # Multi-gpu training, quantized models do not support DP + if ( + self.args.n_gpu > 1 + and not getattr(model, "is_loaded_in_8bit", False) + and not getattr(model, "is_loaded_in_4bit", False) + ): model = nn.DataParallel(model) # Note: in torch.distributed mode, there's no point in wrapping the model