diff --git a/cifar/cifar10_deepspeed.py b/cifar/cifar10_deepspeed.py index d51d27b80..ed509b8e1 100755 --- a/cifar/cifar10_deepspeed.py +++ b/cifar/cifar10_deepspeed.py @@ -211,26 +211,11 @@ def forward(self, x): def create_moe_param_groups(model): - from deepspeed.moe.utils import is_moe_param - - params_with_weight_decay = {'params': [], 'name': 'weight_decay_params'} - moe_params_with_weight_decay = { - 'params': [], - 'moe': True, - 'name': 'weight_decay_moe_params' - } - - for module_ in model.modules(): - moe_params_with_weight_decay['params'].extend([ - p for n, p in list(module_._parameters.items()) - if p is not None and is_moe_param(p) - ]) - params_with_weight_decay['params'].extend([ - p for n, p in list(module_._parameters.items()) - if p is not None and not is_moe_param(p) - ]) - - return params_with_weight_decay, moe_params_with_weight_decay + from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer + + parameters = {'params': model.parameters(), 'name': 'parameters'} + + return split_params_into_different_moe_groups_for_optimizer(parameters) parameters = filter(lambda p: p.requires_grad, net.parameters())