This repository was archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.7k
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Cannot load trainer with AMP #16858
Copy link
Copy link
Closed
Labels
Description
test.py:
import mxnet as mx
import os
import logging
from fp16_utils import LAMB2
net = mx.gluon.nn.Dense(10, in_units=10)
net.initialize(mx.init.Uniform(0.1), ctx=mx.gpu())
from mxnet.contrib import amp
amp.init()
trainer = mx.gluon.Trainer(net.collect_params(), 'lamb2')
trainer.load_states('/fsx/tf-amp-fp32-grad-248-2-14-restart/0002999.states.00')
amp.init_trainer(trainer)
print(trainer)
print(trainer._optimizer)
print(trainer._optimizer.new_update_multi_precision)
fp16_utils.py
from mxnet.optimizer import Optimizer, register
from mxnet.ndarray import zeros, ones_like, NDArray
from mxnet.ndarray import square, power, sqrt, maximum, minimum, clip, where
from mxnet.engine import bulk
import os
import logging
@register
class LAMB2(Optimizer):
"""The LAMB optimizer proposed in
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
lower_bound=1e-3, upper_bound=10.0, bias_correction=False, verbose=False, **kwargs):
super(LAMB2, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.lower_bound = lower_bound
self.upper_bound = upper_bound
self.bias_correction = bias_correction
if os.environ.get('EPS_AFTER_SQRT', False):
self._eps_after_sqrt = True
logging.info('self._eps_after_sqrt = ' + str(self._eps_after_sqrt))
else:
self._eps_after_sqrt = False
self._bulk = int(os.environ.get('LAMB_BULK', 0))
logging.info(" bulk = " + str(self._bulk))
self._verbose = verbose
if int(os.environ.get('USE_BOUND', False)):
logging.info("using upper lower bound")
self._use_bound = True
else:
self._use_bound = False
if int(os.environ.get('USE_PROJ', False)):
logging.info("use projection")
self._use_proj = True
else:
self._use_proj = False
if int(os.environ.get('FORCE_WD', False)):
logging.info("force wd")
self._force_wd = True
else:
self._force_wd = False
if int(os.environ.get('ADJUST_BOUND', False)):
logging.info("adjusting bound ")
self._adjust_bound = True
else:
self._adjust_bound = False
if int(os.environ.get('SCALE_NORM', False)):
logging.info("scale by per layer norm")
self._scale_norm = True
else:
self._scale_norm = False
if int(os.environ.get('L1_NORM', False)):
logging.info("use l1 norm for scaling ")
self._l1_norm = True
else:
self._l1_norm = False
logging.info('attrs = {}'.format(str(self.__dict__)))
self._logged_missing_key = False
def update(self, index, weight, grad, state):
if self._verbose:
logging.info('rescale gradient factor = {}'.format(str(self.rescale_grad)))
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
t = self._index_update_count[index]
self._use_bound = self._use_bound if hasattr(self, '_use_bound') else False
self._adjust_bound = self._adjust_bound if hasattr(self, '_adjust_bound') else False
self._force_wd = self._force_wd if hasattr(self, '_force_wd') else False
self._use_proj = self._use_proj if hasattr(self, '_use_proj') else False
try:
name = self.idx2name[index]
except KeyError as e:
if not self._logged_missing_key:
warnings.warn(str(e))
self._logged_missing_key = True
name = ''
with bulk(self._bulk):
# preprocess grad
grad *= self.rescale_grad
if self._scale_norm:
if self._l1_norm:
grad /= grad.norm(ord=1)
else:
grad /= grad.norm()
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
mean, var = state
mean *= self.beta1
mean += (1. - self.beta1) * grad
var *= self.beta2
var += (1. - self.beta2) * square(grad)
r1 = weight.norm()
if not self.bias_correction:
r1 = minimum(maximum(r1, self.lower_bound), self.upper_bound)
sqrt_var = sqrt(var)
sqrt_var += self.epsilon
g = mean / sqrt_var
g += wd * weight
else:
# apply bias correction
if self._use_bound:
upper_bound = self.upper_bound
r1 = minimum(maximum(r1, self.lower_bound), upper_bound)
mean_hat = mean / (1. - power(self.beta1, t))
var_hat = var / (1. - power(self.beta2, t))
if self._eps_after_sqrt:
sqrt(var_hat, out=var_hat)
var_hat += self.epsilon
else:
var_hat += self.epsilon
sqrt(var_hat, out=var_hat)
mean_hat /= var_hat
if not self._use_proj or self._force_wd:
mean_hat += wd * weight
g = mean_hat
r2 = g.norm()
# calculate lamb_trust_ratio
ratio = r1 / r2
# becomes NaN if ratio == NaN or 0, otherwise 0
nan_or_zero = 1 - ratio / ratio
r = where(nan_or_zero, ones_like(ratio), ratio)
lr *= r
# update weight
g *= lr
weight[:] -= g
error message:
Traceback (most recent call last):
File "test.py", line 13, in <module>
trainer.load_states('/fsx/tf-amp-fp32-grad-248-2-14-restart/0002999.states.00')
File "/home/ec2-user/.local/lib/python3.7/site-packages/mxnet/gluon/trainer.py", line 491, in load_states
updater.set_states(states)
File "/home/ec2-user/.local/lib/python3.7/site-packages/mxnet/optimizer/optimizer.py", line 1746, in set_states
states = pickle.loads(states)
AttributeError: 'LAMB2' object has no attribute 'new_update_multi_precision'
minseong-alchera