Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions python/mxnet/contrib/amp/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
'list_widest_type_cast', 'list_loss_output_functions', 'list_lp16_use_fp32_params',
'convert_symbol']

from types import MethodType
from array import array
import ctypes
import logging
Expand Down Expand Up @@ -341,21 +340,6 @@ def init_trainer(optimizer_or_trainer):
if isinstance(optimizer_or_trainer, trainer.Trainer):
optimizer_or_trainer._amp_loss_scaler = loss_scaler
optimizer_or_trainer._amp_original_scale = optimizer_or_trainer._scale
skip_update = optimizer_or_trainer._amp_loss_scaler.wait_and_update
optimizer_or_trainer._optimizer.old_update_multi_precision = \
optimizer_or_trainer._optimizer.update_multi_precision
def new_update_multi_precision(self, index, weight, grad, state):
if not skip_update():
self.old_update_multi_precision(index, weight, grad, state)
optimizer_or_trainer._optimizer.update_multi_precision = \
MethodType(new_update_multi_precision, optimizer_or_trainer._optimizer)
launch_check_overflow = optimizer_or_trainer._amp_loss_scaler.launch_check_overflow
optimizer_or_trainer._old_update = optimizer_or_trainer._update
def new_update(self, ignore_stale_grad=False):
launch_check_overflow(self._params)
self._old_update(ignore_stale_grad)
optimizer_or_trainer._update = MethodType(new_update, optimizer_or_trainer)

elif isinstance(optimizer_or_trainer, opt.Optimizer):
# TODO(ptredak): make it work with the optimizer
raise TypeError("AMP is currently only compatible with Gluon Trainer")
Expand Down
39 changes: 15 additions & 24 deletions python/mxnet/contrib/amp/loss_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,13 @@ def __init__(self):
self._max_loss_scale = 2.**24
self._scale_seq_len = 2000
self._unskipped = 0
self._has_overflow = False

@property
def loss_scale(self):
return self._loss_scale

def launch_check_overflow(self, params):
"""Launch overflow checking for gradients."""
self._wait_for_outputs = True
self._has_overflow = False
def has_overflow(self, params):
"""Check gradients for overflow."""
with ag.pause():
chunk_size = 200
valid_params = [p._grad[0] for p in params if p._grad is not None]
Expand All @@ -56,22 +53,16 @@ def launch_check_overflow(self, params):
multi_all_finite(*valid_params[idx:idx+chunk_size],
num_arrays=len(valid_params[idx:idx+chunk_size]),
init_output=False, out=gpu_output)
self.output = gpu_output

def wait_and_update(self):
"""Wait for the results of overflow checking and update the loss scale."""
if self._wait_for_outputs:
self._has_overflow = not bool(self.output.asnumpy())
self._loss_scale = self._next_loss_scale
if self._has_overflow:
self._next_loss_scale = self._loss_scale / 2.
self._unskipped = 0
logging.info("AMP: decreasing loss scale to %f", self._next_loss_scale)
else:
self._unskipped += 1
if self._unskipped == self._scale_seq_len:
self._unskipped = 0
self._next_loss_scale = min(self._max_loss_scale, self._loss_scale * 2.)
logging.info("AMP: increasing loss scale to %f", self._next_loss_scale)
self._wait_for_outputs = False
return self._has_overflow
has_overflow = not bool(gpu_output.asnumpy())
self._loss_scale = self._next_loss_scale
if has_overflow:
self._next_loss_scale = self._loss_scale / 2.
self._unskipped = 0
logging.info("AMP: decreasing loss scale to %f", self._next_loss_scale)
else:
self._unskipped += 1
if self._unskipped == self._scale_seq_len:
self._unskipped = 0
self._next_loss_scale = min(self._max_loss_scale, self._loss_scale * 2.)
logging.info("AMP: increasing loss scale to %f", self._next_loss_scale)
return has_overflow
5 changes: 5 additions & 0 deletions python/mxnet/gluon/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,11 @@ def update(self, batch_size, ignore_stale_grad=False):
self._update(ignore_stale_grad)

def _update(self, ignore_stale_grad=False):
loss_scaler = getattr(self, '_amp_loss_scaler', None)
if loss_scaler is not None:
if loss_scaler.has_overflow(self._params):
return # skip on overflow

updates = [[] for _ in self._updaters]

for i, param in enumerate(self._params):
Expand Down