From b6d3da7ba4ef3fc42a61d4327e110a10564cbf62 Mon Sep 17 00:00:00 2001 From: Calogero Zarbo Date: Thu, 26 Mar 2020 11:42:18 +0100 Subject: [PATCH 01/11] added zero_allow_untested_optimizer flag helpers --- deepspeed/pt/deepspeed_config.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/deepspeed/pt/deepspeed_config.py b/deepspeed/pt/deepspeed_config.py index 5310a03d3ff6..47828da2f10d 100755 --- a/deepspeed/pt/deepspeed_config.py +++ b/deepspeed/pt/deepspeed_config.py @@ -165,6 +165,12 @@ def get_optimizer_legacy_fusion(param_dict): return LEGACY_FUSION_DEFAULT +def get_zero_allow_untested_optimizer(param_dict): + return get_scalar_param(param_dict, + ZERO_ALLOW_UNTESTED_OPTIMIZER, + ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT) + + def get_scheduler_name(param_dict): if SCHEDULER in param_dict.keys() and \ TYPE in param_dict[SCHEDULER].keys(): @@ -271,6 +277,9 @@ def _initialize_params(self, param_dict): self.optimizer_params = get_optimizer_params(param_dict) self.optimizer_legacy_fusion = get_optimizer_legacy_fusion(param_dict) + self.zero_allow_untested_optimizer = get_zero_allow_untested_optimizer( + param_dict) + self.scheduler_name = get_scheduler_name(param_dict) self.scheduler_params = get_scheduler_params(param_dict) From 1fe4d1bacfb9f3a94e769979593c5b759960be77 Mon Sep 17 00:00:00 2001 From: Calogero Zarbo Date: Thu, 26 Mar 2020 11:42:32 +0100 Subject: [PATCH 02/11] add zero_allow_untested_optimizer config constants --- deepspeed/pt/deepspeed_constants.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/deepspeed/pt/deepspeed_constants.py b/deepspeed/pt/deepspeed_constants.py index 074d8128c116..429166b11953 100644 --- a/deepspeed/pt/deepspeed_constants.py +++ b/deepspeed/pt/deepspeed_constants.py @@ -31,6 +31,12 @@ SCHEDULER_PARAMS = "params" MAX_GRAD_NORM = 'max_grad_norm' +############################################# +# Optimizer and lr scheduler +############################################# +ZERO_ALLOW_UNTESTED_OPTIMIZER = "zero_allow_untested_optimizer" +ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT = False + ############################################# # Torch distributed constants ############################################# From d838a8715999ab2859492294bed9c4af8aa94f9d Mon Sep 17 00:00:00 2001 From: Calogero Zarbo Date: Thu, 26 Mar 2020 11:43:20 +0100 Subject: [PATCH 03/11] zero_allow_untested_optimizer logic with assertion --- deepspeed/pt/deepspeed_light.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/deepspeed/pt/deepspeed_light.py b/deepspeed/pt/deepspeed_light.py index 740727c837e5..829efeebe1fc 100755 --- a/deepspeed/pt/deepspeed_light.py +++ b/deepspeed/pt/deepspeed_light.py @@ -271,6 +271,9 @@ def scheduler_params(self): def zero_optimization(self): return self._config.zero_enabled + def zero_allow_untested_optimizer(self): + return self._config.zero_allow_untested_optimizer + def allgather_size(self): return self._config.allgather_size @@ -443,6 +446,9 @@ def _configure_optimizer(self, client_optimizer, model_parameters): if self.zero_optimization(): if self.optimizer_name != ADAM_OPTIMIZER: + assert self.zero_allow_untested_optimizer(), \ + '{} is not a tested ZeRO Optimizer. To use it you need to add `"zero_allow_untested_optimizer": true` in the DeepSpeed json configuration file.'.format(self.optimizer_name()) + logging.warning( "**** You are using ZeRO with an untested optimizer, proceed with caution *****" ) From cb621b6f9b052421807e353d9d8089ba26d1f6db Mon Sep 17 00:00:00 2001 From: Calogero Zarbo Date: Thu, 26 Mar 2020 11:45:02 +0100 Subject: [PATCH 04/11] modified assertion message --- deepspeed/pt/deepspeed_light.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/pt/deepspeed_light.py b/deepspeed/pt/deepspeed_light.py index 829efeebe1fc..33ae85fe90f0 100755 --- a/deepspeed/pt/deepspeed_light.py +++ b/deepspeed/pt/deepspeed_light.py @@ -447,7 +447,7 @@ def _configure_optimizer(self, client_optimizer, model_parameters): if self.zero_optimization(): if self.optimizer_name != ADAM_OPTIMIZER: assert self.zero_allow_untested_optimizer(), \ - '{} is not a tested ZeRO Optimizer. To use it you need to add `"zero_allow_untested_optimizer": true` in the DeepSpeed json configuration file.'.format(self.optimizer_name()) + '{} is not a tested ZeRO Optimizer. Please add `"zero_allow_untested_optimizer": true` in the configuration file to use it.'.format(self.optimizer_name()) logging.warning( "**** You are using ZeRO with an untested optimizer, proceed with caution *****" From dbdc3f6a7eb5f6d1ce8e6b11f5b4a5d484f2797b Mon Sep 17 00:00:00 2001 From: Calogero Zarbo Date: Thu, 26 Mar 2020 12:18:50 +0100 Subject: [PATCH 05/11] changed assertion error message --- deepspeed/pt/deepspeed_light.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/pt/deepspeed_light.py b/deepspeed/pt/deepspeed_light.py index 33ae85fe90f0..8c644019b946 100755 --- a/deepspeed/pt/deepspeed_light.py +++ b/deepspeed/pt/deepspeed_light.py @@ -447,7 +447,7 @@ def _configure_optimizer(self, client_optimizer, model_parameters): if self.zero_optimization(): if self.optimizer_name != ADAM_OPTIMIZER: assert self.zero_allow_untested_optimizer(), \ - '{} is not a tested ZeRO Optimizer. Please add `"zero_allow_untested_optimizer": true` in the configuration file to use it.'.format(self.optimizer_name()) + 'You are using an untested ZeRO Optimizer. Please add `"zero_allow_untested_optimizer": true` in the configuration file to use it.' logging.warning( "**** You are using ZeRO with an untested optimizer, proceed with caution *****" From e84c884108e9aef39a17c7dc4cd249e72e3701a4 Mon Sep 17 00:00:00 2001 From: Calogero Zarbo Date: Thu, 26 Mar 2020 12:20:39 +0100 Subject: [PATCH 06/11] changed error message for utf-8 compatibility --- deepspeed/pt/deepspeed_light.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/pt/deepspeed_light.py b/deepspeed/pt/deepspeed_light.py index 8c644019b946..884326cfa5dc 100755 --- a/deepspeed/pt/deepspeed_light.py +++ b/deepspeed/pt/deepspeed_light.py @@ -447,7 +447,7 @@ def _configure_optimizer(self, client_optimizer, model_parameters): if self.zero_optimization(): if self.optimizer_name != ADAM_OPTIMIZER: assert self.zero_allow_untested_optimizer(), \ - 'You are using an untested ZeRO Optimizer. Please add `"zero_allow_untested_optimizer": true` in the configuration file to use it.' + 'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.' logging.warning( "**** You are using ZeRO with an untested optimizer, proceed with caution *****" From b3b66d33a9105739510279c91a31f074f7ee8c69 Mon Sep 17 00:00:00 2001 From: Calogero Zarbo Date: Thu, 26 Mar 2020 12:48:24 +0100 Subject: [PATCH 07/11] added lower to optimizer_name --- deepspeed/pt/deepspeed_light.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/pt/deepspeed_light.py b/deepspeed/pt/deepspeed_light.py index 884326cfa5dc..740a1aca28da 100755 --- a/deepspeed/pt/deepspeed_light.py +++ b/deepspeed/pt/deepspeed_light.py @@ -445,7 +445,7 @@ def _configure_optimizer(self, client_optimizer, model_parameters): logging.info('DeepSpeed Basic Optimizer = {}'.format(basic_optimizer)) if self.zero_optimization(): - if self.optimizer_name != ADAM_OPTIMIZER: + if self.optimizer_name.lower() != ADAM_OPTIMIZER.lower(): assert self.zero_allow_untested_optimizer(), \ 'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.' From 5a5e4ad1bf406a09b31ee6aadb845fc1dff09869 Mon Sep 17 00:00:00 2001 From: Calogero Zarbo Date: Thu, 26 Mar 2020 12:56:16 +0100 Subject: [PATCH 08/11] fixed optimizer_name and code style --- deepspeed/pt/deepspeed_light.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/pt/deepspeed_light.py b/deepspeed/pt/deepspeed_light.py index 740a1aca28da..22a0e9d6324c 100755 --- a/deepspeed/pt/deepspeed_light.py +++ b/deepspeed/pt/deepspeed_light.py @@ -445,7 +445,7 @@ def _configure_optimizer(self, client_optimizer, model_parameters): logging.info('DeepSpeed Basic Optimizer = {}'.format(basic_optimizer)) if self.zero_optimization(): - if self.optimizer_name.lower() != ADAM_OPTIMIZER.lower(): + if self.optimizer_name() != ADAM_OPTIMIZER: assert self.zero_allow_untested_optimizer(), \ 'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.' From 50c7091d8f7c0c73f6d8b6e8d241cd9714c1df1b Mon Sep 17 00:00:00 2001 From: Calogero Zarbo Date: Fri, 27 Mar 2020 11:20:24 +0100 Subject: [PATCH 09/11] Added unit test and CustomOptimizer helper class --- tests/unit/simple_model.py | 23 +++++++++++++++++++++++ tests/unit/test_fp16.py | 28 +++++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/tests/unit/simple_model.py b/tests/unit/simple_model.py index 4561ae787c06..bdd9559102a5 100644 --- a/tests/unit/simple_model.py +++ b/tests/unit/simple_model.py @@ -18,6 +18,29 @@ def forward(self, x, y): return self.cross_entropy_loss(hidden_dim, y) +class SimpleOptimizer(torch.optim.Optimizer): + def __init__(self, params, lr=0.11072018): + defaults = dict(lr=lr) + super(SimpleOptimizer, self).__init__(params, defaults) + + def __setstate__(self, state): + super(SimpleOptimizer, self).__setstate__(state) + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + d_p = p.grad.data + p.data.add_(-group['lr'], d_p) + + return loss + + def random_dataloader(model, total_samples, hidden_dim, device): batch_size = model.train_micro_batch_size_per_gpu() train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=torch.half) diff --git a/tests/unit/test_fp16.py b/tests/unit/test_fp16.py index b2c91bff4c09..e566cee306d4 100755 --- a/tests/unit/test_fp16.py +++ b/tests/unit/test_fp16.py @@ -5,7 +5,7 @@ import json import os from common import distributed_test -from simple_model import SimpleModel, random_dataloader, args_from_dict +from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict def test_lamb_fp16_basic(tmpdir): @@ -289,3 +289,29 @@ def _test_zero_static_scale(args): model.step() _test_zero_static_scale(args) + + +def test_zero_allow_untested_optimizer(tmpdir): + config_dict = { + "train_batch_size": 4, + "steps_per_print": 1, + "fp16": { + "enabled": True, + }, + "zero_optimization": True, + "zero_allow_untested_optimizer": False + } + args = args_from_dict(tmpdir, config_dict) + + @distributed_test(world_size=[1]) + def _test_zero_allow_untested_optimizer(args): + hidden_dim = 10 + optimizer = SimpleOptimizer() + model = SimpleModel(hidden_dim, empty_grad=True) + with pytest.raises(AssertionError): + model, optim, _,_ = deepspeed.initialize(args=args, + model=model, + optimizer=optimizer, + model_parameters=model.parameters()) + + _test_zero_allow_untested_optimizer(args) From fa22d1b578c05670ba1017361a12fc1c9d5c1487 Mon Sep 17 00:00:00 2001 From: Calogero Zarbo Date: Fri, 27 Mar 2020 11:24:47 +0100 Subject: [PATCH 10/11] fixed bug in sample optimizer --- tests/unit/simple_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/simple_model.py b/tests/unit/simple_model.py index bdd9559102a5..0f103a5ee41c 100644 --- a/tests/unit/simple_model.py +++ b/tests/unit/simple_model.py @@ -19,9 +19,9 @@ def forward(self, x, y): class SimpleOptimizer(torch.optim.Optimizer): - def __init__(self, params, lr=0.11072018): + def __init__(self, lr=0.11072018): defaults = dict(lr=lr) - super(SimpleOptimizer, self).__init__(params, defaults) + super(SimpleOptimizer, self).__init__(defaults) def __setstate__(self, state): super(SimpleOptimizer, self).__setstate__(state) From 632a819cfee9945f4812bd0e205e256ccd05ee84 Mon Sep 17 00:00:00 2001 From: Calogero Zarbo Date: Fri, 27 Mar 2020 11:29:14 +0100 Subject: [PATCH 11/11] fixed bug in SimpleOptimizer --- tests/unit/simple_model.py | 4 ++-- tests/unit/test_fp16.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/simple_model.py b/tests/unit/simple_model.py index 0f103a5ee41c..bdd9559102a5 100644 --- a/tests/unit/simple_model.py +++ b/tests/unit/simple_model.py @@ -19,9 +19,9 @@ def forward(self, x, y): class SimpleOptimizer(torch.optim.Optimizer): - def __init__(self, lr=0.11072018): + def __init__(self, params, lr=0.11072018): defaults = dict(lr=lr) - super(SimpleOptimizer, self).__init__(defaults) + super(SimpleOptimizer, self).__init__(params, defaults) def __setstate__(self, state): super(SimpleOptimizer, self).__setstate__(state) diff --git a/tests/unit/test_fp16.py b/tests/unit/test_fp16.py index e566cee306d4..636e941c069e 100755 --- a/tests/unit/test_fp16.py +++ b/tests/unit/test_fp16.py @@ -306,8 +306,8 @@ def test_zero_allow_untested_optimizer(tmpdir): @distributed_test(world_size=[1]) def _test_zero_allow_untested_optimizer(args): hidden_dim = 10 - optimizer = SimpleOptimizer() model = SimpleModel(hidden_dim, empty_grad=True) + optimizer = SimpleOptimizer(model.parameters()) with pytest.raises(AssertionError): model, optim, _,_ = deepspeed.initialize(args=args, model=model,