diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 3401f121bca0..0ace7b23ae27 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -58,11 +58,13 @@ def initialize(args=None, mpu=None, dist_init_required=None, collate_fn=None, + config=None, config_params=None): """Initialize the DeepSpeed Engine. Arguments: - args: an object containing local_rank and deepspeed_config fields. This is optional if `config_params` is passed. + args: an object containing local_rank and deepspeed_config fields. + This is optional if `config` is passed. model: Required: nn.module class before apply any wrappers @@ -87,8 +89,10 @@ def initialize(args=None, mini-batch of Tensor(s). Used when using batched loading from a map-style dataset. - config_params: Optional: Instead of requiring args.deepspeed_config you can pass your deepspeed config - as a dictionary instead. + config: Optional: Instead of requiring args.deepspeed_config you can pass your deepspeed config + as an argument instead, as a path or a dictionary. + + config_params: Optional: Same as `config`, kept for backwards compatibility. Returns: A tuple of ``engine``, ``optimizer``, ``training_dataloader``, ``lr_scheduler`` @@ -122,6 +126,7 @@ def initialize(args=None, mpu=mpu, dist_init_required=dist_init_required, collate_fn=collate_fn, + config=config, config_params=config_params) else: assert mpu is None, "mpu must be None with pipeline parallelism" @@ -134,6 +139,7 @@ def initialize(args=None, mpu=model.mpu(), dist_init_required=dist_init_required, collate_fn=collate_fn, + config=config, config_params=config_params) return_items = [ diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py index faa60f20efa3..db40dcfba708 100644 --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -725,11 +725,11 @@ def reset(): size_offsets = [] -def _configure_using_config_file(deepspeed_config, mpu=None): +def _configure_using_config_file(config, mpu=None): global num_layers, PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \ PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME - config = DeepSpeedConfig(deepspeed_config, mpu=mpu).activation_checkpointing_config + config = DeepSpeedConfig(config, mpu=mpu).activation_checkpointing_config if dist.get_rank() == 0: logger.info(config.repr()) PARTITION_ACTIVATIONS = config.partition_activations diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 3fa0b32a6032..d9e0e399b150 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -2,6 +2,8 @@ Copyright (c) Microsoft Corporation Licensed under the MIT license. """ +import os +from typing import Union import torch import json @@ -521,17 +523,19 @@ def write_config(self, filename): class DeepSpeedConfig(object): - def __init__(self, json_file, mpu=None, param_dict=None): + def __init__(self, config: Union[str, dict], mpu=None): super(DeepSpeedConfig, self).__init__() - - if param_dict is None: + if isinstance(config, dict): + self._param_dict = config + elif os.path.exists(config): self._param_dict = json.load( - open(json_file, + open(config, 'r'), object_pairs_hook=dict_raise_error_on_duplicate_keys) else: - self._param_dict = param_dict - + raise ValueError( + f"Expected a string path to an existing deepspeed config, or a dictionary. Received: {ds_config}" + ) try: self.global_rank = torch.distributed.get_rank() if mpu is None: diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 9c5f70ec2aac..87ccf1c99e5a 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -110,6 +110,7 @@ def __init__(self, mpu=None, dist_init_required=None, collate_fn=None, + config=None, config_params=None, dont_change_device=False): super(DeepSpeedEngine, self).__init__() @@ -127,13 +128,17 @@ def __init__(self, self.skipped_steps = 0 self.gradient_average = True self.warn_unscaled_loss = True - self.config_params = config_params + self.config = config self.loaded_checkpoint_mp_world_size = None self.loaded_checkpoint_dp_world_size = None self.enable_backward_allreduce = True self.progressive_layer_drop = None self.dist_backend = "nccl" + # Set config using config_params for backwards compat + if self.config is None and config_params is not None: + self.config = config_params + if dist_init_required is None: dist_init_required = not dist.is_initialized() @@ -515,9 +520,10 @@ def _configure_with_arguments(self, args, mpu): if hasattr(args, 'local_rank'): args.local_rank = self.local_rank - config_file = args.deepspeed_config if hasattr(args, - 'deepspeed_config') else None - self._config = DeepSpeedConfig(config_file, mpu, param_dict=self.config_params) + if self.config is None: + self.config = args.deepspeed_config if hasattr(args, + 'deepspeed_config') else None + self._config = DeepSpeedConfig(self.config, mpu) # Validate command line arguments def _do_args_sanity_check(self, args): @@ -538,7 +544,7 @@ def _do_args_sanity_check(self, args): assert env_local_rank == args.local_rank, \ f"Mismatch in local rank setting, args.local_rank={args.local_rank} but env['LOCAL_RANK']={env_local_rank}." - if self.config_params is None: + if self.config is None: assert hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None, \ 'DeepSpeed requires --deepspeed_config to specify configuration file' diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index e831911efd62..d1e29afa0a5d 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -271,8 +271,7 @@ def __init__(self, mem_efficient_linear=True, remote_device=None, pin_memory=False, - deepspeed_config=None, - param_dict=None, + config=None, enabled=True): """A context to enable massive model construction for training with ZeRO-3. Models are automatically partitioned (or, sharded) across the @@ -293,10 +292,8 @@ def __init__(self, pin_memory (bool, optional): Potentially increase performance by using pinned memory for model weights. ``remote_device`` must be ``"cpu"``. Defaults to ``False``. - deepspeed_config (``json file``, optional): If provided, provides configuration + config (``json file`` or dict, optional): If provided, provides configuration for swapping fp16 params to NVMe. - param_dict (dict, optional): Instead of requiring a deepspeed_config you can pass your deepspeed config - as a dictionary instead for swapping fp16 params to NVMe. enabled (bool, optional): If ``False``, this context has no effect. Defaults to ``True``. @@ -386,7 +383,7 @@ def get_model(): #It is the device where parameters are fully instantiated using allgather self.local_device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) - self._validate_remote_device(remote_device, deepspeed_config, param_dict) + self._validate_remote_device(remote_device, config) #Remote device is the device where parameter partiitons are stored #It can be same as local_device or it could be CPU or NVMe. @@ -396,7 +393,7 @@ def get_model(): # Enable fp16 param swapping to NVMe if self.remote_device == OFFLOAD_NVME_DEVICE: - _ds_config = DeepSpeedConfig(deepspeed_config, param_dict=param_dict) + _ds_config = DeepSpeedConfig(config) self.param_swapper = AsyncPartitionedParameterSwapper(_ds_config) else: self.param_swapper = None @@ -410,9 +407,9 @@ def get_model(): self._convert_to_deepspeed_param(param) param.partition() - def _validate_remote_device(self, remote_device, ds_config, param_dict): + def _validate_remote_device(self, remote_device, ds_config): if ds_config is not None: - _ds_config = DeepSpeedConfig(ds_config, param_dict=param_dict) + _ds_config = DeepSpeedConfig(ds_config) if remote_device in [None, OFFLOAD_CPU_DEVICE]: if _ds_config.zero_config.offload_param is not None: offload_param_device = _ds_config.zero_config.offload_param[ diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 7de3a40fabeb..01004a0fa867 100755 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -229,7 +229,7 @@ def _helper(): def test_none_args(tmpdir): - config_dict = { + config = { "train_batch_size": 1, "optimizer": { "type": "Adam", @@ -245,7 +245,7 @@ def test_none_args(tmpdir): @distributed_test(world_size=1) def _helper(): model = SimpleModel(hidden_dim=10) - model, _, _, _ = deepspeed.initialize(args=None, model=model, config_params=config_dict) + model, _, _, _ = deepspeed.initialize(args=None, model=model, config=config) data_loader = random_dataloader(model=model, total_samples=5, hidden_dim=10, @@ -257,7 +257,7 @@ def _helper(): def test_no_args(tmpdir): - config_dict = { + config = { "train_batch_size": 1, "optimizer": { "type": "Adam", @@ -273,7 +273,7 @@ def test_no_args(tmpdir): @distributed_test(world_size=1) def _helper(): model = SimpleModel(hidden_dim=10) - model, _, _, _ = deepspeed.initialize(model=model, config_params=config_dict) + model, _, _, _ = deepspeed.initialize(model=model, config=config) data_loader = random_dataloader(model=model, total_samples=5, hidden_dim=10, @@ -285,7 +285,7 @@ def _helper(): def test_no_model(tmpdir): - config_dict = { + config = { "train_batch_size": 1, "optimizer": { "type": "Adam", @@ -302,7 +302,7 @@ def test_no_model(tmpdir): def _helper(): model = SimpleModel(hidden_dim=10) with pytest.raises(AssertionError): - model, _, _, _ = deepspeed.initialize(model=None, config_params=config_dict) + model, _, _, _ = deepspeed.initialize(model=None, config=config) with pytest.raises(AssertionError): - model, _, _, _ = deepspeed.initialize(model, config_params=config_dict) + model, _, _, _ = deepspeed.initialize(model, config=config) diff --git a/tests/unit/test_fp16.py b/tests/unit/test_fp16.py index dbd40c322be9..defd64318dd2 100755 --- a/tests/unit/test_fp16.py +++ b/tests/unit/test_fp16.py @@ -196,25 +196,19 @@ def _test_adamw_fp16_basic(args, model, hidden_dim): def test_dict_config_adamw_fp16_basic(): - config_dict = { - "train_batch_size": 1, - "steps_per_print": 1, - "fp16": { - "enabled": True - } - } + config = {"train_batch_size": 1, "steps_per_print": 1, "fp16": {"enabled": True}} args = create_deepspeed_args() hidden_dim = 10 model = SimpleModel(hidden_dim) @distributed_test(world_size=[1]) - def _test_adamw_fp16_basic(args, model, hidden_dim, config_dict): + def _test_adamw_fp16_basic(args, model, hidden_dim, config): optimizer = torch.optim.AdamW(params=model.parameters()) model, _, _, _ = deepspeed.initialize(args=args, model=model, optimizer=optimizer, - config_params=config_dict) + config=config) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, @@ -224,10 +218,7 @@ def _test_adamw_fp16_basic(args, model, hidden_dim, config_dict): model.backward(loss) model.step() - _test_adamw_fp16_basic(args=args, - model=model, - hidden_dim=hidden_dim, - config_dict=config_dict) + _test_adamw_fp16_basic(args=args, model=model, hidden_dim=hidden_dim, config=config) def test_adamw_fp16_empty_grad(tmpdir): diff --git a/tests/unit/test_zero_context.py b/tests/unit/test_zero_context.py index 5ccccb5c18a0..ac1341c685b3 100644 --- a/tests/unit/test_zero_context.py +++ b/tests/unit/test_zero_context.py @@ -65,7 +65,7 @@ def test_gather_update(): assert torch.equal(l.weight, torch.zeros_like(l.weight)) -config_dict = { +config = { "train_batch_size": 1, "steps_per_print": 1, "optimizer": { @@ -109,7 +109,7 @@ def forward(self, input): engine, optim, _, _ = deepspeed.initialize(args=args, model=net, model_parameters=net.parameters(), - config_params=config_dict) + config=config) with deepspeed.zero.GatheredParameters(net.linear1.weight): assert net.linear1.weight.numel() == net.dim**2 @@ -214,7 +214,7 @@ def test_ext_param_return(): engine, optim, _, _ = deepspeed.initialize(args=args, model=net, model_parameters=net.parameters(), - config_params=config_dict) + config=config) for _ in range(5): input = torch.rand(net.dim).to(engine.device).half() @@ -234,7 +234,7 @@ def test_ext_param_returnobj(): engine, optim, _, _ = deepspeed.initialize(args=args, model=net, model_parameters=net.parameters(), - config_params=config_dict) + config=config) for _ in range(5): input = torch.rand(net.dim).to(engine.device).half()