diff --git a/deepspeed/pt/deepspeed_config.py b/deepspeed/pt/deepspeed_config.py index 5310a03d3ff6..47828da2f10d 100755 --- a/deepspeed/pt/deepspeed_config.py +++ b/deepspeed/pt/deepspeed_config.py @@ -165,6 +165,12 @@ def get_optimizer_legacy_fusion(param_dict): return LEGACY_FUSION_DEFAULT +def get_zero_allow_untested_optimizer(param_dict): + return get_scalar_param(param_dict, + ZERO_ALLOW_UNTESTED_OPTIMIZER, + ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT) + + def get_scheduler_name(param_dict): if SCHEDULER in param_dict.keys() and \ TYPE in param_dict[SCHEDULER].keys(): @@ -271,6 +277,9 @@ def _initialize_params(self, param_dict): self.optimizer_params = get_optimizer_params(param_dict) self.optimizer_legacy_fusion = get_optimizer_legacy_fusion(param_dict) + self.zero_allow_untested_optimizer = get_zero_allow_untested_optimizer( + param_dict) + self.scheduler_name = get_scheduler_name(param_dict) self.scheduler_params = get_scheduler_params(param_dict) diff --git a/deepspeed/pt/deepspeed_constants.py b/deepspeed/pt/deepspeed_constants.py index 074d8128c116..429166b11953 100644 --- a/deepspeed/pt/deepspeed_constants.py +++ b/deepspeed/pt/deepspeed_constants.py @@ -31,6 +31,12 @@ SCHEDULER_PARAMS = "params" MAX_GRAD_NORM = 'max_grad_norm' +############################################# +# Optimizer and lr scheduler +############################################# +ZERO_ALLOW_UNTESTED_OPTIMIZER = "zero_allow_untested_optimizer" +ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT = False + ############################################# # Torch distributed constants ############################################# diff --git a/deepspeed/pt/deepspeed_light.py b/deepspeed/pt/deepspeed_light.py index a4b228f13160..81375eba78d8 100755 --- a/deepspeed/pt/deepspeed_light.py +++ b/deepspeed/pt/deepspeed_light.py @@ -271,6 +271,9 @@ def scheduler_params(self): def zero_optimization(self): return self._config.zero_enabled + def zero_allow_untested_optimizer(self): + return self._config.zero_allow_untested_optimizer + def allgather_size(self): return self._config.allgather_size @@ -444,7 +447,10 @@ def _configure_optimizer(self, client_optimizer, model_parameters): logging.info('DeepSpeed Basic Optimizer = {}'.format(basic_optimizer)) if self.zero_optimization(): - if self.optimizer_name != ADAM_OPTIMIZER: + if self.optimizer_name() != ADAM_OPTIMIZER: + assert self.zero_allow_untested_optimizer(), \ + 'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.' + logging.warning( "**** You are using ZeRO with an untested optimizer, proceed with caution *****" ) diff --git a/tests/unit/simple_model.py b/tests/unit/simple_model.py index 4561ae787c06..bdd9559102a5 100644 --- a/tests/unit/simple_model.py +++ b/tests/unit/simple_model.py @@ -18,6 +18,29 @@ def forward(self, x, y): return self.cross_entropy_loss(hidden_dim, y) +class SimpleOptimizer(torch.optim.Optimizer): + def __init__(self, params, lr=0.11072018): + defaults = dict(lr=lr) + super(SimpleOptimizer, self).__init__(params, defaults) + + def __setstate__(self, state): + super(SimpleOptimizer, self).__setstate__(state) + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + d_p = p.grad.data + p.data.add_(-group['lr'], d_p) + + return loss + + def random_dataloader(model, total_samples, hidden_dim, device): batch_size = model.train_micro_batch_size_per_gpu() train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=torch.half) diff --git a/tests/unit/test_fp16.py b/tests/unit/test_fp16.py index b2c91bff4c09..636e941c069e 100755 --- a/tests/unit/test_fp16.py +++ b/tests/unit/test_fp16.py @@ -5,7 +5,7 @@ import json import os from common import distributed_test -from simple_model import SimpleModel, random_dataloader, args_from_dict +from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict def test_lamb_fp16_basic(tmpdir): @@ -289,3 +289,29 @@ def _test_zero_static_scale(args): model.step() _test_zero_static_scale(args) + + +def test_zero_allow_untested_optimizer(tmpdir): + config_dict = { + "train_batch_size": 4, + "steps_per_print": 1, + "fp16": { + "enabled": True, + }, + "zero_optimization": True, + "zero_allow_untested_optimizer": False + } + args = args_from_dict(tmpdir, config_dict) + + @distributed_test(world_size=[1]) + def _test_zero_allow_untested_optimizer(args): + hidden_dim = 10 + model = SimpleModel(hidden_dim, empty_grad=True) + optimizer = SimpleOptimizer(model.parameters()) + with pytest.raises(AssertionError): + model, optim, _,_ = deepspeed.initialize(args=args, + model=model, + optimizer=optimizer, + model_parameters=model.parameters()) + + _test_zero_allow_untested_optimizer(args)