From a5491eef8331800729e2973039140fa24e092d2b Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 30 Aug 2021 16:41:47 -0400 Subject: [PATCH 1/3] Override LRScheduler when using LRModifiers --- src/transformers/sparse.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/transformers/sparse.py b/src/transformers/sparse.py index 941f777827b7..2c5006ae73e2 100644 --- a/src/transformers/sparse.py +++ b/src/transformers/sparse.py @@ -102,6 +102,24 @@ def create_optimizer(self): self.optimizer, self.model, self.manager, steps_per_epoch=steps_per_epoch, loggers=self.loggers ) + def create_scheduler(self, num_training_steps: int): + """ + Override LR scheduler if the SparseML manager has LR modifiers, otherwise + set default scheduler + """ + if self.lr_scheduler is not None: + # scheduler already set + return + + if self.manager.learning_rate_modifiers: + # allow SparseML to manage LR and set a dummy scheduler + self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, lambda _: 1.0, -1 + ) + else: + # default scheduler + super().create_scheduler(num_training_steps) + def save_model(self, output_dir: Optional[str] = None): """ Save model during or after training. The sparsification recipe will also be saved. From a365a0789efb114e8a80ecfd741efdfc6f99f470 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 30 Aug 2021 17:09:56 -0400 Subject: [PATCH 2/3] Disable FP16 on QAT start --- src/transformers/sparse.py | 8 ++++++++ src/transformers/trainer.py | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/src/transformers/sparse.py b/src/transformers/sparse.py index 2c5006ae73e2..60f120e44cbb 100644 --- a/src/transformers/sparse.py +++ b/src/transformers/sparse.py @@ -120,6 +120,14 @@ def create_scheduler(self, num_training_steps: int): # default scheduler super().create_scheduler(num_training_steps) + def qat_active(self, epoch: int): + if not self.manager.quantization_modifiers: + return False + + qat_start = min([mod.start_epoch for mod in self.manager.quantization_modifiers]) + + return qat_start < epoch + 1 + def save_model(self, output_dir: Optional[str] = None): """ Save model during or after training. The sparsification recipe will also be saved. diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index fd1a03930734..be384a0927c6 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1215,6 +1215,10 @@ def train( break for epoch in range(epochs_trained, num_train_epochs): + if self.use_amp and hasattr(self, "qat_active") and callable(self.qat_active) and self.qat_active(epoch): + logger.info("entering QAT phase, disabling FP16 training") + self.use_amp = False + if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch) elif isinstance(train_dataloader.dataset, IterableDatasetShard): From b6dde37a7ab19ef297915d0b5016e561ff794784 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 30 Aug 2021 17:33:45 -0400 Subject: [PATCH 3/3] keep wrapped scaler object for training after disabling --- src/transformers/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index be384a0927c6..b299a96744b8 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1217,7 +1217,7 @@ def train( for epoch in range(epochs_trained, num_train_epochs): if self.use_amp and hasattr(self, "qat_active") and callable(self.qat_active) and self.qat_active(epoch): logger.info("entering QAT phase, disabling FP16 training") - self.use_amp = False + self.scaler._enabled = False if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch) @@ -1736,7 +1736,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, return loss_mb.reduce_mean().detach().to(self.args.device) if self.use_amp: - with autocast(): + with autocast(enabled=self.scaler.is_enabled()): loss = self.compute_loss(model, inputs) else: loss = self.compute_loss(model, inputs) @@ -2381,7 +2381,7 @@ def prediction_step( else: loss = None if self.use_amp: - with autocast(): + with autocast(enabled=self.scaler.is_enabled()): outputs = model(**inputs) else: outputs = model(**inputs)