Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Cannot load trainer with AMP #16858

@eric-haibin-lin

Description

@eric-haibin-lin

test.py:

import mxnet as mx
import os
import logging
from fp16_utils import LAMB2

net = mx.gluon.nn.Dense(10, in_units=10)
net.initialize(mx.init.Uniform(0.1), ctx=mx.gpu())

from mxnet.contrib import amp
amp.init()

trainer = mx.gluon.Trainer(net.collect_params(), 'lamb2')
trainer.load_states('/fsx/tf-amp-fp32-grad-248-2-14-restart/0002999.states.00')
amp.init_trainer(trainer)
print(trainer)
print(trainer._optimizer)
print(trainer._optimizer.new_update_multi_precision)

fp16_utils.py

from mxnet.optimizer import Optimizer, register
from mxnet.ndarray import zeros, ones_like, NDArray
from mxnet.ndarray import square, power, sqrt, maximum, minimum, clip, where
from mxnet.engine import bulk
import os
import logging

@register
class LAMB2(Optimizer):
    """The LAMB optimizer proposed in
    """

    def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
                 lower_bound=1e-3, upper_bound=10.0, bias_correction=False, verbose=False, **kwargs):
        super(LAMB2, self).__init__(learning_rate=learning_rate, **kwargs)
        self.beta1 = beta1
        self.beta2 = beta2
        self.epsilon = epsilon
        self.lower_bound = lower_bound
        self.upper_bound = upper_bound
        self.bias_correction = bias_correction
        if os.environ.get('EPS_AFTER_SQRT', False):
            self._eps_after_sqrt = True
            logging.info('self._eps_after_sqrt = ' + str(self._eps_after_sqrt))
        else:
            self._eps_after_sqrt = False
        self._bulk = int(os.environ.get('LAMB_BULK', 0))
        logging.info(" bulk = " + str(self._bulk))
        self._verbose = verbose
        if int(os.environ.get('USE_BOUND', False)):
            logging.info("using upper lower bound")
            self._use_bound = True
        else:
            self._use_bound = False
        if int(os.environ.get('USE_PROJ', False)):
            logging.info("use projection")
            self._use_proj = True
        else:
            self._use_proj = False
        if int(os.environ.get('FORCE_WD', False)):
            logging.info("force wd")
            self._force_wd = True
        else:
            self._force_wd = False
        if int(os.environ.get('ADJUST_BOUND', False)):
            logging.info("adjusting bound ")
            self._adjust_bound = True
        else:
            self._adjust_bound = False
        if int(os.environ.get('SCALE_NORM', False)):
            logging.info("scale by per layer norm")
            self._scale_norm = True
        else:
            self._scale_norm = False
        if int(os.environ.get('L1_NORM', False)):
            logging.info("use l1 norm for scaling ")
            self._l1_norm = True
        else:
            self._l1_norm = False
        logging.info('attrs = {}'.format(str(self.__dict__)))
        self._logged_missing_key = False

    def update(self, index, weight, grad, state):
        if self._verbose:
            logging.info('rescale gradient factor = {}'.format(str(self.rescale_grad)))
        assert(isinstance(weight, NDArray))
        assert(isinstance(grad, NDArray))
        self._update_count(index)
        lr = self._get_lr(index)
        wd = self._get_wd(index)
        t = self._index_update_count[index]
        self._use_bound = self._use_bound if hasattr(self, '_use_bound') else False
        self._adjust_bound = self._adjust_bound if hasattr(self, '_adjust_bound') else False
        self._force_wd = self._force_wd if hasattr(self, '_force_wd') else False
        self._use_proj = self._use_proj if hasattr(self, '_use_proj') else False
        try:
            name = self.idx2name[index]
        except KeyError as e:
            if not self._logged_missing_key:
                warnings.warn(str(e))
            self._logged_missing_key = True
            name = ''

        with bulk(self._bulk):
            # preprocess grad
            grad *= self.rescale_grad
            if self._scale_norm:
                if self._l1_norm:
                    grad /= grad.norm(ord=1)
                else:
                    grad /= grad.norm()
            if self.clip_gradient is not None:
                grad = clip(grad, -self.clip_gradient, self.clip_gradient)

            mean, var = state
            mean *= self.beta1
            mean += (1. - self.beta1) * grad
            var *= self.beta2
            var += (1. - self.beta2) * square(grad)

            r1 = weight.norm()
            if not self.bias_correction:
                r1 = minimum(maximum(r1, self.lower_bound), self.upper_bound)
                sqrt_var = sqrt(var)
                sqrt_var += self.epsilon
                g = mean / sqrt_var
                g += wd * weight
            else:
                # apply bias correction
                if self._use_bound:
                    upper_bound = self.upper_bound
                    r1 = minimum(maximum(r1, self.lower_bound), upper_bound)
                mean_hat = mean / (1. - power(self.beta1, t))
                var_hat = var / (1. - power(self.beta2, t))
                if self._eps_after_sqrt:
                    sqrt(var_hat, out=var_hat)
                    var_hat += self.epsilon
                else:
                    var_hat += self.epsilon
                    sqrt(var_hat, out=var_hat)
                mean_hat /= var_hat
                if not self._use_proj or self._force_wd:
                    mean_hat += wd * weight
                g = mean_hat

            r2 = g.norm()

            # calculate lamb_trust_ratio
            ratio = r1 / r2
            # becomes NaN if ratio == NaN or 0, otherwise 0
            nan_or_zero = 1 - ratio / ratio
            r = where(nan_or_zero, ones_like(ratio), ratio)
            lr *= r

            # update weight
            g *= lr
            weight[:] -= g

error message:

Traceback (most recent call last):
  File "test.py", line 13, in <module>
    trainer.load_states('/fsx/tf-amp-fp32-grad-248-2-14-restart/0002999.states.00')
  File "/home/ec2-user/.local/lib/python3.7/site-packages/mxnet/gluon/trainer.py", line 491, in load_states
    updater.set_states(states)
  File "/home/ec2-user/.local/lib/python3.7/site-packages/mxnet/optimizer/optimizer.py", line 1746, in set_states
    states = pickle.loads(states)
AttributeError: 'LAMB2' object has no attribute 'new_update_multi_precision'

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions