From 72455f6bb018805dcd14a86bc18d61f6ed8d856c Mon Sep 17 00:00:00 2001 From: Benjamin Date: Fri, 18 Mar 2022 16:19:35 -0400 Subject: [PATCH] fix sclaer check for non fp16 mode in trainer --- src/transformers/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 23b97e780ee0..f6449d71cde8 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1997,7 +1997,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler) return loss_mb.reduce_mean().detach().to(self.args.device) - with self.autocast_smart_context_manager(enabled=self.scaler.is_enabled()): + with self.autocast_smart_context_manager(enabled=hasattr(self, "scaler") and self.scaler.is_enabled()): loss = self.compute_loss(model, inputs) if self.args.n_gpu > 1: @@ -2654,7 +2654,7 @@ def prediction_step( logits = smp_nested_concat(logits_mb) else: if has_labels: - with self.autocast_smart_context_manager(enabled=self.scaler.is_enabled()): + with self.autocast_smart_context_manager(enabled=hasattr(self, "scaler") and self.scaler.is_enabled()): loss, outputs = self.compute_loss(model, inputs, return_outputs=True) loss = loss.mean().detach() @@ -2664,7 +2664,7 @@ def prediction_step( logits = outputs[1:] else: loss = None - with self.autocast_smart_context_manager(enabled=self.scaler.is_enabled()): + with self.autocast_smart_context_manager(enabled=hasattr(self, "scaler") and self.scaler.is_enabled()): outputs = model(**inputs) if isinstance(outputs, dict): logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)