diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 5a1a40460e16..c654e66abc59 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -864,19 +864,20 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False): class GatheredParameters: - def __init__(self, param, modifier_rank=None, fwd_module=None, enabled=True): - """A context that collects a parameter that was partitioned via a - :class:`deepspeed.zero.Init` context. The parameter is partitioned + def __init__(self, params, modifier_rank=None, fwd_module=None, enabled=True): + """A context that collects parameters that were partitioned via a + :class:`deepspeed.zero.Init` context. The parameters are partitioned again upon exit. Args: - param (``torch.nn.Parameter``): The parameter to collect. + params (``torch.nn.Parameter``): A single parameter or a list of parameters to collect. + It's assumed that all parameters are zero params. modifier_rank (int, optional): If specified, this rank's parameter will be - broadcasted after the context. This argument is required if ``param`` is - modified all processes should have a consistent view of the data. Defaults + broadcasted on exit from the context. This argument is required if ``params`` are + modified, so that all processes have a consistent view of the data. Defaults to ``None``. - fwd_module (``torch.nn.Module``, optional): If specified, ``param`` will be - registered as an external parameter of ``fwd_module``. See :meth:`deepspeed.zero.register_external_parameter`. + fwd_module (``torch.nn.Module``, optional): If specified, ``params`` will be + registered as external parameters of ``fwd_module``. See :meth:`deepspeed.zero.register_external_parameter`. enabled (bool, optional): If ``False``, this context is a no-op. Defaults to ``True``. Examples @@ -911,41 +912,79 @@ def forward(self, input): fwd_module=self): y = self.layer2(x, self.layer1.weight) return y + + + #. Pretrained model loading + + .. code-block:: python + + with deepspeed.zero.Init(): + model = MyModel() + + state_dict = torch.load(model_path, map_location="cpu") + + def load(module: nn.Module, prefix=""): + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): + if torch.distributed.get_rank() == 0: + module._load_from_state_dict(state_dict, prefix) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(model, prefix="") + + If this approach is not used, then the full model will first get copied to each GPU. For models + bigger than the memory of a single gpu this method is required. """ self.enabled = enabled if not enabled: return - # This is a no-op, just return. - if not is_zero_param(param): + if not isinstance(params, list): + params = [params] + + # enable if at least one is zero-param, otherwise a noop + if not any(is_zero_param(p) for p in params): self.enabled = False return - self.param = param + self.params = params self.src_rank = None if modifier_rank is not None: - if self.param.ds_process_group == torch.distributed.group.WORLD: + if self.params[0].ds_process_group == torch.distributed.group.WORLD: self.src_rank = modifier_rank else: # A group was specified; convert DP rank to global rank - self.src_rank = _get_global_rank(self.param.ds_process_group, + self.src_rank = _get_global_rank(self.params[0].ds_process_group, modifier_rank) self.fwd_module = fwd_module if self.fwd_module is not None: # is a no-op if already registered - register_external_parameter(self.fwd_module, self.param) + for p in self.params: + register_external_parameter(self.fwd_module, p) def __enter__(self): if not self.enabled: return - self.param.all_gather() + self.params[0].all_gather(param_list=self.params) def __exit__(self, *exc): if not self.enabled: return - if self.src_rank is not None: - torch.distributed.broadcast(self.param, + if self.src_rank is None: + return + + handles = [ + torch.distributed.broadcast(p, self.src_rank, - group=self.param.ds_process_group) - self.param.partition(has_been_updated=self.src_rank is not None) + group=p.ds_process_group, + async_op=True) for p in self.params + ] + for h in handles: + h.wait() + self.params[0].partition(param_list=self.params, has_been_updated=True)