diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index af908e48e4b8..8b9e14aa6fe0 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3649,7 +3649,7 @@ def training_step( return loss_mb.reduce_mean().detach().to(self.args.device) with self.compute_loss_context_manager(): - if self.model_accepts_loss_kwargs: + if not self.model_accepts_loss_kwargs: loss = self.compute_loss(model, inputs) else: loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)