diff --git a/python/mxnet/contrib/amp/amp.py b/python/mxnet/contrib/amp/amp.py index 688d73e23ffd..fa308559b7be 100755 --- a/python/mxnet/contrib/amp/amp.py +++ b/python/mxnet/contrib/amp/amp.py @@ -23,7 +23,6 @@ 'list_widest_type_cast', 'list_loss_output_functions', 'list_lp16_use_fp32_params', 'convert_symbol'] -from types import MethodType from array import array import ctypes import logging @@ -341,21 +340,6 @@ def init_trainer(optimizer_or_trainer): if isinstance(optimizer_or_trainer, trainer.Trainer): optimizer_or_trainer._amp_loss_scaler = loss_scaler optimizer_or_trainer._amp_original_scale = optimizer_or_trainer._scale - skip_update = optimizer_or_trainer._amp_loss_scaler.wait_and_update - optimizer_or_trainer._optimizer.old_update_multi_precision = \ - optimizer_or_trainer._optimizer.update_multi_precision - def new_update_multi_precision(self, index, weight, grad, state): - if not skip_update(): - self.old_update_multi_precision(index, weight, grad, state) - optimizer_or_trainer._optimizer.update_multi_precision = \ - MethodType(new_update_multi_precision, optimizer_or_trainer._optimizer) - launch_check_overflow = optimizer_or_trainer._amp_loss_scaler.launch_check_overflow - optimizer_or_trainer._old_update = optimizer_or_trainer._update - def new_update(self, ignore_stale_grad=False): - launch_check_overflow(self._params) - self._old_update(ignore_stale_grad) - optimizer_or_trainer._update = MethodType(new_update, optimizer_or_trainer) - elif isinstance(optimizer_or_trainer, opt.Optimizer): # TODO(ptredak): make it work with the optimizer raise TypeError("AMP is currently only compatible with Gluon Trainer") diff --git a/python/mxnet/contrib/amp/loss_scaler.py b/python/mxnet/contrib/amp/loss_scaler.py index a2600bcc2a49..3a177cebe67e 100755 --- a/python/mxnet/contrib/amp/loss_scaler.py +++ b/python/mxnet/contrib/amp/loss_scaler.py @@ -37,16 +37,13 @@ def __init__(self): self._max_loss_scale = 2.**24 self._scale_seq_len = 2000 self._unskipped = 0 - self._has_overflow = False @property def loss_scale(self): return self._loss_scale - def launch_check_overflow(self, params): - """Launch overflow checking for gradients.""" - self._wait_for_outputs = True - self._has_overflow = False + def has_overflow(self, params): + """Check gradients for overflow.""" with ag.pause(): chunk_size = 200 valid_params = [p._grad[0] for p in params if p._grad is not None] @@ -56,22 +53,16 @@ def launch_check_overflow(self, params): multi_all_finite(*valid_params[idx:idx+chunk_size], num_arrays=len(valid_params[idx:idx+chunk_size]), init_output=False, out=gpu_output) - self.output = gpu_output - - def wait_and_update(self): - """Wait for the results of overflow checking and update the loss scale.""" - if self._wait_for_outputs: - self._has_overflow = not bool(self.output.asnumpy()) - self._loss_scale = self._next_loss_scale - if self._has_overflow: - self._next_loss_scale = self._loss_scale / 2. - self._unskipped = 0 - logging.info("AMP: decreasing loss scale to %f", self._next_loss_scale) - else: - self._unskipped += 1 - if self._unskipped == self._scale_seq_len: - self._unskipped = 0 - self._next_loss_scale = min(self._max_loss_scale, self._loss_scale * 2.) - logging.info("AMP: increasing loss scale to %f", self._next_loss_scale) - self._wait_for_outputs = False - return self._has_overflow + has_overflow = not bool(gpu_output.asnumpy()) + self._loss_scale = self._next_loss_scale + if has_overflow: + self._next_loss_scale = self._loss_scale / 2. + self._unskipped = 0 + logging.info("AMP: decreasing loss scale to %f", self._next_loss_scale) + else: + self._unskipped += 1 + if self._unskipped == self._scale_seq_len: + self._unskipped = 0 + self._next_loss_scale = min(self._max_loss_scale, self._loss_scale * 2.) + logging.info("AMP: increasing loss scale to %f", self._next_loss_scale) + return has_overflow diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index fed3c440ac21..dd8551d0c37c 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -428,6 +428,11 @@ def update(self, batch_size, ignore_stale_grad=False): self._update(ignore_stale_grad) def _update(self, ignore_stale_grad=False): + loss_scaler = getattr(self, '_amp_loss_scaler', None) + if loss_scaler is not None: + if loss_scaler.has_overflow(self._params): + return # skip on overflow + updates = [[] for _ in self._updaters] for i, param in enumerate(self._params):