Skip to content
9 changes: 9 additions & 0 deletions deepspeed/pt/deepspeed_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,12 @@ def get_optimizer_legacy_fusion(param_dict):
return LEGACY_FUSION_DEFAULT


def get_zero_allow_untested_optimizer(param_dict):
return get_scalar_param(param_dict,
ZERO_ALLOW_UNTESTED_OPTIMIZER,
ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT)


def get_scheduler_name(param_dict):
if SCHEDULER in param_dict.keys() and \
TYPE in param_dict[SCHEDULER].keys():
Expand Down Expand Up @@ -271,6 +277,9 @@ def _initialize_params(self, param_dict):
self.optimizer_params = get_optimizer_params(param_dict)
self.optimizer_legacy_fusion = get_optimizer_legacy_fusion(param_dict)

self.zero_allow_untested_optimizer = get_zero_allow_untested_optimizer(
param_dict)

self.scheduler_name = get_scheduler_name(param_dict)
self.scheduler_params = get_scheduler_params(param_dict)

Expand Down
6 changes: 6 additions & 0 deletions deepspeed/pt/deepspeed_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@
SCHEDULER_PARAMS = "params"
MAX_GRAD_NORM = 'max_grad_norm'

#############################################
# Optimizer and lr scheduler
#############################################
ZERO_ALLOW_UNTESTED_OPTIMIZER = "zero_allow_untested_optimizer"
ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT = False

#############################################
# Torch distributed constants
#############################################
Expand Down
8 changes: 7 additions & 1 deletion deepspeed/pt/deepspeed_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ def scheduler_params(self):
def zero_optimization(self):
return self._config.zero_enabled

def zero_allow_untested_optimizer(self):
return self._config.zero_allow_untested_optimizer

def allgather_size(self):
return self._config.allgather_size

Expand Down Expand Up @@ -444,7 +447,10 @@ def _configure_optimizer(self, client_optimizer, model_parameters):
logging.info('DeepSpeed Basic Optimizer = {}'.format(basic_optimizer))

if self.zero_optimization():
if self.optimizer_name != ADAM_OPTIMIZER:
if self.optimizer_name() != ADAM_OPTIMIZER:
assert self.zero_allow_untested_optimizer(), \
'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.'

logging.warning(
"**** You are using ZeRO with an untested optimizer, proceed with caution *****"
)
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/simple_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,29 @@ def forward(self, x, y):
return self.cross_entropy_loss(hidden_dim, y)


class SimpleOptimizer(torch.optim.Optimizer):
def __init__(self, params, lr=0.11072018):
defaults = dict(lr=lr)
super(SimpleOptimizer, self).__init__(params, defaults)

def __setstate__(self, state):
super(SimpleOptimizer, self).__setstate__(state)

def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data
p.data.add_(-group['lr'], d_p)

return loss


def random_dataloader(model, total_samples, hidden_dim, device):
batch_size = model.train_micro_batch_size_per_gpu()
train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=torch.half)
Expand Down
28 changes: 27 additions & 1 deletion tests/unit/test_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import os
from common import distributed_test
from simple_model import SimpleModel, random_dataloader, args_from_dict
from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict


def test_lamb_fp16_basic(tmpdir):
Expand Down Expand Up @@ -289,3 +289,29 @@ def _test_zero_static_scale(args):
model.step()

_test_zero_static_scale(args)


def test_zero_allow_untested_optimizer(tmpdir):
config_dict = {
"train_batch_size": 4,
"steps_per_print": 1,
"fp16": {
"enabled": True,
},
"zero_optimization": True,
"zero_allow_untested_optimizer": False
}
args = args_from_dict(tmpdir, config_dict)

@distributed_test(world_size=[1])
def _test_zero_allow_untested_optimizer(args):
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True)
optimizer = SimpleOptimizer(model.parameters())
with pytest.raises(AssertionError):
model, optim, _,_ = deepspeed.initialize(args=args,
model=model,
optimizer=optimizer,
model_parameters=model.parameters())

_test_zero_allow_untested_optimizer(args)