diff --git a/deepspeed/ops/adam/cpu_adam.py b/deepspeed/ops/adam/cpu_adam.py index 2b1be7e53de2..7977d232b1fa 100755 --- a/deepspeed/ops/adam/cpu_adam.py +++ b/deepspeed/ops/adam/cpu_adam.py @@ -74,7 +74,7 @@ def __init__(self, self.opt_id = DeepSpeedCPUAdam.optimizer_id DeepSpeedCPUAdam.optimizer_id = DeepSpeedCPUAdam.optimizer_id + 1 - + self.adam_w_mode = adamw_mode self.ds_opt_adam = CPUAdamBuilder().load() self.ds_opt_adam.create_adam(self.opt_id, diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 4cc09a8e3bf1..11e1d4037c8e 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -40,6 +40,10 @@ # extra optimizer parameters for adam/adamw TORCH_ADAM_PARAM = "torch_adam" +# default to adamw logic for adam/adamw optimizers unless user explictly opts out +ADAM_W_MODE = "adam_w_mode" +ADAM_W_MODE_DEFAULT = True + class DeepSpeedConfigError(Exception): pass diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 7faddfe566ad..1462225ac2bd 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -22,7 +22,7 @@ from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \ ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, \ - TORCH_ADAM_PARAM + TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT from deepspeed.runtime.dataloader import DeepSpeedDataLoader from deepspeed.runtime.constants import \ @@ -640,26 +640,30 @@ def _configure_basic_optimizer(self, model_parameters): if self.optimizer_name() in [ADAM_OPTIMIZER, ADAMW_OPTIMIZER]: torch_adam = optimizer_parameters.pop(TORCH_ADAM_PARAM, False) - adam_w_mode = self.optimizer_name() == ADAMW_OPTIMIZER - # zero-offload torch-adam adam_w_mode optimizer - # T|F T T torch.optim.AdamW - # T|F T F torch.optim.Adam - # T F T|F DeepSpeedCPUAdam(adam_w_mode) - # F F T|F FusedAdam(adam_w_mode) + adam_w_mode = optimizer_parameters.pop(ADAM_W_MODE, ADAM_W_MODE_DEFAULT) + + # Optimizer name of Adam forces AdamW logic unless adam_w_mode is explictly set + effective_adam_w_mode = self.optimizer_name( + ) == ADAMW_OPTIMIZER or adam_w_mode + if torch_adam: - if adam_w_mode: - optimizer = torch.optim.AdamW(model_parameters, - **optimizer_parameters) - else: + if not effective_adam_w_mode: optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters) - elif self.zero_cpu_offload(): - optimizer = DeepSpeedCPUAdam(model_parameters, - **optimizer_parameters, - adamw_mode=adam_w_mode) + else: + optimizer = torch.optim.AdamW(model_parameters, + **optimizer_parameters) else: - optimizer_parameters['adam_w_mode'] = adam_w_mode - optimizer = FusedAdam(model_parameters, **optimizer_parameters) + if self.zero_cpu_offload(): + from deepspeed.ops.adam import DeepSpeedCPUAdam + optimizer = DeepSpeedCPUAdam(model_parameters, + **optimizer_parameters, + adamw_mode=effective_adam_w_mode) + else: + from deepspeed.ops.adam import FusedAdam + optimizer = FusedAdam(model_parameters, + **optimizer_parameters, + adam_w_mode=effective_adam_w_mode) elif self.optimizer_name() == LAMB_OPTIMIZER: from deepspeed.ops.lamb import FusedLamb diff --git a/tests/unit/test_adamw.py b/tests/unit/test_adamw.py new file mode 100644 index 000000000000..83e0b5436546 --- /dev/null +++ b/tests/unit/test_adamw.py @@ -0,0 +1,73 @@ +import deepspeed +import torch +import pytest + +from common import distributed_test +from deepspeed.ops.adam import FusedAdam +from deepspeed.ops.adam import DeepSpeedCPUAdam +from simple_model import SimpleModel, args_from_dict + +# yapf: disable +#'optimizer, zero_offload, torch_adam, adam_w_mode, resulting_optimizer +adam_configs = [["AdamW", False, False, False, (FusedAdam, True)], + ["AdamW", False, True, False, (torch.optim.AdamW, None)], + ["AdamW", True, False, False, (DeepSpeedCPUAdam, True)], + ["AdamW", True, True, False, (torch.optim.AdamW, None)], + ["AdamW", False, False, True, (FusedAdam, True)], + ["AdamW", False, True, True, (torch.optim.AdamW, None)], + ["AdamW", True, False, True, (DeepSpeedCPUAdam, True)], + ["AdamW", True, True, True, (torch.optim.AdamW, None)], + ["Adam", False, False, False, (FusedAdam, False)], + ["Adam", False, True, False, (torch.optim.Adam, None)], + ["Adam", True, False, False, (DeepSpeedCPUAdam, False)], + ["Adam", True, True, False, (torch.optim.Adam, None)], + ["Adam", False, False, True, (FusedAdam, True)], + ["Adam", False, True, True, (torch.optim.AdamW, None)], + ["Adam", True, False, True, (DeepSpeedCPUAdam, True)], + ["Adam", True, True, True, (torch.optim.AdamW, None)]] + +@pytest.mark.parametrize( + 'optimizer, zero_offload, torch_adam, adam_w_mode, resulting_optimizer', + adam_configs) +def test_adam_configs(tmpdir, + optimizer, + zero_offload, + torch_adam, + adam_w_mode, + resulting_optimizer): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": optimizer, + "params": { + "lr": 0.00015, + "torch_adam": torch_adam, + "adam_w_mode": adam_w_mode + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True + }, + "zero_optimization": { + "stage": 2, + "cpu_offload": zero_offload + } + } + args = args_from_dict(tmpdir, config_dict) + + @distributed_test(world_size=[1]) + def helper(args): + model = SimpleModel(10) + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + # get base optimizer under zero + ds_optimizer = model.optimizer.optimizer + opt_class, adam_w_mode = resulting_optimizer + assert isinstance(ds_optimizer, opt_class) + if adam_w_mode in [True, False]: + assert ds_optimizer.adam_w_mode == adam_w_mode + + helper(args)