diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 8ba68270e514..8a2415fab5cb 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -29,7 +29,7 @@ from colossalai.shardformer import ShardConfig from colossalai.shardformer.policies.base_policy import Policy from colossalai.tensor.moe_tensor.api import is_moe_tensor -from colossalai.zero.low_level import LowLevelOptStrategy, LowLevelZeroOptimizer, MoeZeroStrategy +from colossalai.zero.low_level import LowLevelZeroOptimizer class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer): @@ -68,38 +68,19 @@ def __init__( if use_pipeline: init_pipeline_optimizer(optimizer, model) - assert ( - len(optimizer.param_groups) == 1 - ), "Currently only one parameter group is supported, and we will support multiple groups later." - zero_params = list(filter(lambda x: not is_moe_tensor(x), model.parameters())) - moe_params = list(filter(lambda x: is_moe_tensor(x), model.parameters())) - - optimizer.param_groups.clear() - optimizer.add_param_group({"params": zero_params}) - optimizer.add_param_group({"params": moe_params}) - strategies = [ - LowLevelOptStrategy( - param_group=optimizer.param_groups[0], - process_group=dp_process_group, - reduce_bucket_size=reduce_bucket_size, - communication_dtype=communication_dtype, - overlap_communication=overlap_communication, - partition_grad=partition_grad, - cpu_offload=cpu_offload, - ), - MoeZeroStrategy( - param_group=optimizer.param_groups[1], - process_group=moe_extra_dp_process_group, - reduce_bucket_size=reduce_bucket_size, - communication_dtype=communication_dtype, - overlap_communication=overlap_communication, - partition_grad=partition_grad, - cpu_offload=cpu_offload, - ), - ] + pg_param_list = { + dp_process_group: [], + moe_extra_dp_process_group: [], + } + for param in model.parameters(): + if is_moe_tensor(param): + pg_param_list[moe_extra_dp_process_group].append(param) + else: + pg_param_list[dp_process_group].append(param) + super().__init__( optimizer=optimizer, - group_strategies=strategies, + pg_param_list=pg_param_list, initial_scale=initial_scale, min_scale=min_scale, growth_factor=growth_factor, @@ -109,6 +90,11 @@ def __init__( max_scale=max_scale, clip_grad_norm=clip_grad_norm, verbose=verbose, + reduce_bucket_size=reduce_bucket_size, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + partition_grad=partition_grad, + cpu_offload=cpu_offload, forced_dtype=forced_dtype, ) diff --git a/colossalai/zero/low_level/__init__.py b/colossalai/zero/low_level/__init__.py index 7e4702dfd38c..270a6a6a4786 100644 --- a/colossalai/zero/low_level/__init__.py +++ b/colossalai/zero/low_level/__init__.py @@ -1,4 +1,3 @@ from .low_level_optim import LowLevelZeroOptimizer -from .low_level_strategy import LowLevelOptStrategy, LowLevelOptStrategyBase, MoeZeroStrategy -__all__ = ["LowLevelZeroOptimizer", "LowLevelOptStrategy", "MoeZeroStrategy", "LowLevelOptStrategyBase"] +__all__ = ["LowLevelZeroOptimizer"] diff --git a/colossalai/zero/low_level/bookkeeping/__init__.py b/colossalai/zero/low_level/bookkeeping/__init__.py index 427973772f9c..07f6cdb2d701 100644 --- a/colossalai/zero/low_level/bookkeeping/__init__.py +++ b/colossalai/zero/low_level/bookkeeping/__init__.py @@ -1,6 +1,5 @@ from .bucket_store import BucketStore from .gradient_store import GradientStore -from .parameter_store import ParameterStore from .tensor_bucket import TensorBucket -__all__ = ["GradientStore", "ParameterStore", "BucketStore", "TensorBucket"] +__all__ = ["GradientStore", "BucketStore", "TensorBucket"] diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index d6898f74e7bd..5b1776062c48 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -5,6 +5,8 @@ from torch._utils import _flatten_dense_tensors from torch.distributed import ProcessGroup +from colossalai.accelerator.api import get_accelerator + from .base_store import BaseStore @@ -13,10 +15,13 @@ def __init__( self, torch_pg: ProcessGroup, reduce_bucket_size: int, + overlap_comm: bool = False, ): super().__init__(torch_pg) self.reduce_bucket_size = reduce_bucket_size self.reset_all() + if overlap_comm: + self.comm_stream = get_accelerator().Stream() def reset_all(self) -> None: # init diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py index fc28b77959c7..e8c469146eba 100644 --- a/colossalai/zero/low_level/bookkeeping/gradient_store.py +++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py @@ -6,7 +6,7 @@ class GradientStore(BaseStore): - def __init__(self, *args, partition_grad: bool = False, require_grad_sync: bool = True): + def __init__(self, *args, partition_grad: bool = False): super().__init__(*args) """ self._grads_of_params mapping the parameter and its gradient slices @@ -20,8 +20,6 @@ def __init__(self, *args, partition_grad: bool = False, require_grad_sync: bool self._grads_of_params = dict() # stage 2 self._partition_grads = partition_grad - # grad accumulation - self.require_grad_sync = require_grad_sync self._working_index = 0 if partition_grad else self._local_rank # for zero2, it's `param_id: [grad_local_rank]` self.grad_to_param_mapping = dict() @@ -107,8 +105,7 @@ def get_working_grad_by_param_id(self, param_id) -> Tensor: for group in self._grads_of_params.values(): if param_id in group.keys(): return group[param_id][self._working_index] - - raise KeyError(f"Working gradient for param_id {param_id} not found.") + return None def reset_grads_by_group_id(self, group_id: int): self._grads_of_params[group_id] = dict() diff --git a/colossalai/zero/low_level/bookkeeping/parameter_store.py b/colossalai/zero/low_level/bookkeeping/parameter_store.py deleted file mode 100644 index c03231f5fd1f..000000000000 --- a/colossalai/zero/low_level/bookkeeping/parameter_store.py +++ /dev/null @@ -1,60 +0,0 @@ -from typing import Dict - -from torch import Tensor -from torch.distributed import ProcessGroup - -from .base_store import BaseStore - - -class ParameterStore(BaseStore): - def __init__(self, torch_pg: ProcessGroup): - super().__init__(torch_pg) - - # record the padding size of each param - self._padding_map = dict() - - # mapping working param and master param - self.master_to_working_param = dict() - self.working_to_master_param = dict() - - def record_param_padding_size(self, param: Tensor, padding_size: int): - """Record the padding size of a param - - Args: - param (Tensor): The parameter - padding_size (int): The padding size of the parameter - """ - - self._padding_map[id(param)] = padding_size - - def get_param_padding_size(self, param: Tensor) -> int: - """Return the padding size of the parameter - - Args: - param (Tensor): The parameter - - Returns: - int: the padding size of the parameter - """ - - return self._padding_map[id(param)] - - def link_master_and_working_param(self, master_param: Tensor, working_param: Tensor): - """Mapping master parameter and working parameter - - Args: - master_param (Tensor): The parameter copy in optimizer - working_param (Tensor): The parameter of the model - """ - - self.master_to_working_param[id(master_param)] = working_param - self.working_to_master_param[id(working_param)] = master_param - - def get_padding_map(self) -> Dict[int, Tensor]: - """Return the padding map - - Returns: - Dict[int, Tensor]: The padding map - """ - - return self._padding_map diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index bcbc7561dcd6..bcfdb44478d3 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -1,11 +1,15 @@ # this code is inspired by the DeepSpeed library and implemented with our own design from scratch import copy -from collections import defaultdict from contextlib import contextmanager -from typing import Dict, List, Optional +from functools import partial +from typing import Dict, Iterator, List, Optional, Tuple +from weakref import proxy import torch +import torch.distributed as dist import torch.nn as nn +from torch import Tensor, inf +from torch.distributed import ProcessGroup from torch.optim import Optimizer from colossalai.accelerator import get_accelerator @@ -16,15 +20,16 @@ ) from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger -from colossalai.zero.low_level.low_level_strategy import LowLevelOptStrategy, LowLevelOptStrategyBase -from ._utils import calculate_global_norm_from_list, has_inf_or_nan +from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor +from .bookkeeping import BucketStore, GradientStore class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): def __init__( self, - group_strategies: List[LowLevelOptStrategyBase], + num_working_param_groups: int, + grad_stores: Dict[nn.Parameter, GradientStore], initial_scale: float = 2**16, min_scale: float = 1, growth_factor: float = 2, @@ -34,23 +39,33 @@ def __init__( max_scale: float = 2**32, ) -> None: super().__init__( - initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale + initial_scale, + min_scale, + growth_factor, + backoff_factor, + growth_interval, + hysteresis, + max_scale, ) - self.group_strategies = group_strategies + self.num_working_param_groups = num_working_param_groups + self.grad_stores = grad_stores def check_local_overflow(self) -> bool: - for strategy in self.group_strategies: - for avg_grad in strategy.working_grads: - if avg_grad is not None and has_inf_or_nan(avg_grad): - return True + for store in self.grad_stores.values(): + for group_id in range(self.num_working_param_groups): + for avg_grad in store.get_working_grads_by_group_id(group_id): + if avg_grad is not None and has_inf_or_nan(avg_grad): + return True return False class LowLevelZeroOptimizer(OptimizerWrapper): + """Optimizer used for ZeRO-1 and ZeRO-2.""" + def __init__( self, optimizer: Optimizer, - group_strategies: List[LowLevelOptStrategyBase] = None, + pg_param_list: Dict[ProcessGroup, List[nn.Parameter]] = None, initial_scale: int = 2**16, # grad scaler config min_scale: int = 1, growth_factor: float = 2.0, @@ -60,17 +75,56 @@ def __init__( max_scale: int = 2**24, clip_grad_norm: float = 0.0, # grad clipping verbose: bool = False, + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = False, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload forced_dtype: Optional[torch.dtype] = None, - **strategy_kwargs, + master_weights: bool = True, # master weights ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) + self._dtype = self.optim.param_groups[0]["params"][0].dtype self._logger = get_dist_logger() self._verbose = verbose + if pg_param_list is None: + pg_param_list = {dist.group.WORLD: []} + for group in self.optim.param_groups: + pg_param_list[dist.group.WORLD].extend(group["params"]) + + self.pg_param_list = pg_param_list + param_to_pg = {} + for grp, param_list in pg_param_list.items(): + for p in param_list: + assert isinstance(p, nn.Parameter) + param_to_pg[p] = grp + self.param_to_pg = param_to_pg + + # stage 2 + self._partition_grads = partition_grad + + self._cpu_offload = cpu_offload + + # grad accumulation + self.require_grad_sync = True + + # working and master params for mixed precision training + self._working_param_groups = dict() + self._master_param_groups_of_current_rank = dict() + + # communication params + self._overlap_communication = overlap_communication + self._reduce_bucket_size = reduce_bucket_size + self._communication_dtype = communication_dtype + # gradient clipping self._clip_grad_norm = clip_grad_norm + # master weights copy + self._master_weights = master_weights + if forced_dtype: for group in self.optim.param_groups: group_params = group["params"] @@ -81,23 +135,62 @@ def __init__( # check argument conflict self._sanity_checks() - if len(self.optim.param_groups) == 1 and group_strategies is None: - group_strategies = [LowLevelOptStrategy(param_group=self.optim.param_groups[0], **strategy_kwargs)] - elif len(self.optim.param_groups) > 1 and group_strategies is None: - raise ValueError("group_strategies must be provided when the optimizer has multiple param groups") - - self.workingparam2strategy: Dict[torch.nn.Parameter, LowLevelOptStrategyBase] = {} - for grp, strategy in zip(self.optim.param_groups, group_strategies): - assert grp["params"] is strategy.param_group["params"], "param groups should be in the same order" - for param in strategy.working_param_group: - self.workingparam2strategy[param] = strategy - self._group_strategies = group_strategies + self.require_grad_sync = True + + # ParameterStore will manage the tensor buffers used for zero + # it will not manage the tensors used by mixed precision training + + # record the padding size of each param + self._padding_map = dict() + + # mapping working param and master param + self.master_to_working_param = dict() + self.working_to_master_param = dict() + + # NOTE need to gurantee the order of process group is the same accross all ranks + self.grad_stores = {pg: GradientStore(pg, partition_grad=self._partition_grads) for pg in self.pg_param_list} + # param id to grad store, have to use id(param) as key since it is used in stores + self.pid2grad_store = {id(param): self.grad_stores[param_to_pg[param]] for param in param_to_pg} + self.bucket_stores = { + pg: BucketStore(pg, reduce_bucket_size, overlap_comm=self._overlap_communication) + for pg in self.pg_param_list + } + # param id to bucket store, have to use id(param) as key since it is used in stores + self.pid2bucket_store = {id(param): self.bucket_stores[param_to_pg[param]] for param in param_to_pg} + + # iterate over the param group in the optimizer + # partition these param groups for data parallel training + # and add buffers to parameter store for future access + for group_id, param_group in enumerate(self.optim.param_groups): + group_params = list() + for param in param_group["params"]: + if param.requires_grad: + group_params.append(param) + + # add the working params to working_param_groups for bookkeeping + self._working_param_groups[group_id] = group_params + + master_param_current_rank = self._create_master_param_current_rank(group_params) + self._master_param_groups_of_current_rank[group_id] = master_param_current_rank + + # need to replace the params in the `params` field in the optimizer + # so that when the optimizer calls step(), it only updates the tensors + # managed by this data parallel rank + param_group["params"] = master_param_current_rank + + # reduction hook is only used if overlapping communication + # or stage 2 is used + # if it is stage 1 without overlapping, no hook will be attached + self.grad_handles = [] + if self._overlap_communication or self._partition_grads: + self._attach_reduction_hook() # initialize mixed precision mixin self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None if self._dtype is torch.float16: self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin( - self._group_strategies, + self.num_param_groups, + self.grad_stores, initial_scale=initial_scale, min_scale=min_scale, growth_factor=growth_factor, @@ -109,54 +202,264 @@ def __init__( elif self._dtype is torch.bfloat16: self.mixed_precision_mixin = BF16MixedPrecisionMixin() + def __del__(self): + for hook in self.grad_handles: + hook.remove() + + @property + def dtype(self): + return self._dtype + + @property + def num_param_groups(self): + return len(self._working_param_groups) + def _sanity_checks(self): assert get_accelerator().name in ["cuda", "npu"], "device is required" - inv = defaultdict(list) for param_group in self.optim.param_groups: group_params = param_group["params"] for param in group_params: - inv[param].append(param_group) - assert ( - param.dtype == self._dtype - ), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" + if not hasattr(param, "skip_zero_check") or param.skip_zero_check is False: + assert ( + param.dtype == self._dtype + ), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" + + def _create_master_param_current_rank(self, param_list): + # split each param evenly by world size + params_current_rank = [] + device = "cpu" if self._cpu_offload else get_accelerator().get_current_device() + + for param in param_list: + padding_size = ( + self.pid2bucket_store[id(param)].world_size + - param.numel() % self.pid2bucket_store[id(param)].world_size + ) % self.pid2bucket_store[id(param)].world_size + self.record_param_padding_size(param, padding_size) + + with torch.no_grad(): + if padding_size > 0: + padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) + # reset working params' ptr when no master weights + if self._master_weights == False: + param.data = padding_param[: param.numel()].view(param.shape) + else: + padding_param = param.data.view(-1) + + splited_params = padding_param.split( + padding_param.numel() // self.pid2bucket_store[id(param)].world_size + ) + splited_params = splited_params[self.pid2bucket_store[id(param)].local_rank] + + # use fp32 when master_weights is True + if self._master_weights is True: + splited_param_current_rank = splited_params.detach().float().to(device) + else: + splited_param_current_rank = splited_params + + params_current_rank.append(splited_param_current_rank) + self.link_master_and_working_param(splited_param_current_rank, param) + + return params_current_rank + + ########################### + # Backward Reduction Hook # + ########################### + + def _attach_reduction_hook(self): + # we iterate over the working params + # on each param, we register a hook to its AccumulateGrad object + self_weakref = proxy(self) + + def _grad_handler(param, group_id): + # if run with no_sync context, would not sync grad when backward + if self_weakref.require_grad_sync: + self_weakref._add_to_bucket(param, group_id) + + for group_id in range(self.num_param_groups): + param_group = self._working_param_groups[group_id] + for param in param_group: + if param.requires_grad: + self.grad_handles.append( + param.register_post_accumulate_grad_hook(partial(_grad_handler, group_id=group_id)) + ) + + ####################### + # Reduction Functions # + ####################### + + def _run_reduction(self): + for bucket_store in self.bucket_stores.values(): + if bucket_store.num_elements_in_bucket() <= 0: + continue + + bucket_store.build_grad_in_bucket() + + flat_grads = bucket_store.get_flatten_grad() + flat_grads /= bucket_store.world_size + + # ready to add other tensors to bucket + bucket_store.reset_num_elements_in_bucket() + + if self._overlap_communication: + stream = bucket_store.comm_stream + # in case of the memory being reused in the default stream + flat_grads.record_stream(stream) + # waiting for ops in the default stream finishing + stream.wait_stream(get_accelerator().current_stream()) + else: + stream = get_accelerator().current_stream() + + with get_accelerator().stream(stream): + group_id = bucket_store.current_group_id + + grad_dtype = flat_grads.dtype + if self._communication_dtype is not None: + flat_grads = flat_grads.to(self._communication_dtype) + + if not self._partition_grads: + dist.all_reduce(flat_grads, group=bucket_store.torch_pg) + if flat_grads.dtype != grad_dtype: + flat_grads = flat_grads.to(grad_dtype) + + flat_grads_per_rank = flat_grads.split(flat_grads.numel() // bucket_store.world_size) + grad_in_bucket = bucket_store.get_grad() + self._update_unpartitoned_grad(bucket_store, grad_in_bucket.values(), flat_grads_per_rank, group_id) + else: + flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size)) + recieved_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg) + + if recieved_grad.dtype != grad_dtype: + recieved_grad = recieved_grad.to(grad_dtype) + + grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.local_rank] + self._update_partitoned_grad(bucket_store, grad_in_bucket_current_rank, recieved_grad, group_id, 1) + + bucket_store.reset() + + def _update_unpartitoned_grad( + self, bucket_store: BucketStore, origin_grad_list: List, flat_grad_list: List, group_id: int + ) -> None: + for rank, grad_list in enumerate(origin_grad_list): + sync_tensor(flat_grad_list[rank], grad_list) + for grad in grad_list: + param_id = bucket_store.get_param_id_of_grad(grad) + self._add_grad(grad, bucket_store.world_size, group_id, param_id, rank) - for _, grps in inv.items(): - assert ( - len(grps) == 1 - ), "Parameters should only appear in one group, since we assume that each strategy only manages one param group" + def _update_partitoned_grad( + self, + bucket_store: BucketStore, + origin_grad_list: List, + flat_grad: torch.Tensor, + group_id: int, + partition_num: int, + ) -> None: + sync_tensor(flat_grad, origin_grad_list) + for grad in origin_grad_list: + param_id = bucket_store.get_param_id_of_grad(grad) + self._add_grad(grad, partition_num, group_id, param_id) + + def _add_grad( + self, + grad: torch.Tensor, + partition_num: int, + group_id: int, + param_id: int, + rank: int = 0, + ) -> None: + if len(self.pid2grad_store[param_id].get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num: + self.pid2grad_store[param_id].append_gradients_by_param_id(grad, group_id, param_id) + else: + self.pid2grad_store[param_id].add_gradients_by_param_id(grad, rank, group_id, param_id) + + def _add_to_bucket(self, param, group_id): + param_size = param.numel() + + # check if the bucket is full + # if full, will reduce the grads already in the bucket + # or got a grad of param from another group + # after reduction, the bucket will be empty + if ( + self.pid2bucket_store[id(param)].num_elements_in_bucket() + param_size > self._reduce_bucket_size + or group_id != self.pid2bucket_store[id(param)].current_group_id + ): + self._run_reduction() + + padding_size = self.get_param_padding_size(param) + self.pid2bucket_store[id(param)].add_param_grad(group_id, param, padding_size) + + ################################ + # torch.optim.Optimizer methods + ################################ def backward(self, loss, retain_graph=False): - for strategy in self._group_strategies: - strategy.pre_backward(loss, retain_graph) + assert not ( + self._partition_grads and not self.require_grad_sync + ), "ZeRO2(partition_grads) and no_sync are not compatible" if self.mixed_precision_mixin is not None: loss = self.mixed_precision_mixin.pre_backward(loss) loss.backward(retain_graph=retain_graph) - for strategy in self._group_strategies: - strategy.post_backward() + if not self.require_grad_sync: + return - # another way of doing this is to reassign tensor.grad, however this won't apply for zero-2 - # since the shape doesn't match - def get_param_grad(self, working_param): - strategy = self.workingparam2strategy[working_param] - return strategy.get_param_grad(working_param) + self._reduce_grad(self._partition_grads) + + # clear reduced grads + if self._overlap_communication: + get_accelerator().synchronize() + + def backward_by_grad(self, tensor, grad): + assert not ( + self._partition_grads and not self.require_grad_sync + ), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" - def _unscale_and_clip_grads(self, grad_groups_flat, total_norm): - # compute combined scale factor for this group - div_scale = 1.0 if self.mixed_precision_mixin is not None: - div_scale = self.mixed_precision_mixin.get_grad_div_scale() + grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) + torch.autograd.backward(tensor, grad) - if self._clip_grad_norm > 0.0: - # norm is in fact norm*scale - clip = ((total_norm / div_scale) + 1e-6) / self._clip_grad_norm - if clip > 1: - div_scale = clip * div_scale + if not self.require_grad_sync: + return + self._reduce_grad(self._partition_grads) - for grad in grad_groups_flat: - grad.data.mul_(1.0 / div_scale) + # clear reduced grads + if self._overlap_communication: + get_accelerator().synchronize() + + def zero_bucket_stores(self): + for bucket_store in self.bucket_stores.values(): + bucket_store.reset_all() + + def zero_grad_stores(self): + for grad_store in self.grad_stores.values(): + grad_store.reset_all_gradients() + + def zero_grad(self, set_to_none=True): + """ + Set parameter gradients to zero. If set_to_none = True, gradient + will be set to None to save memory. + + :param set_to_none: Whether set the gradient to None. Default value is True. + :type set_to_none: bool + """ + if self.mixed_precision_mixin is not None: + self.mixed_precision_mixin.pre_zero_grad() + for _, param_group in self._working_param_groups.items(): + for param in param_group: + if set_to_none: + param.grad = None + else: + if param.grad is not None: + param.grad.detach() + param.grad.zero_() + self.zero_grad_stores() + self.zero_bucket_stores() + + #################### + # Update Parameter # + #################### def step(self, closure=None): assert closure is None, "closure is not supported by step()" @@ -166,19 +469,52 @@ def step(self, closure=None): if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step(): if self._verbose: self._logger.info(f"Found overflow. Skip step") - for strategy in self._group_strategies: - strategy.zero_working_grad() - strategy.zero_grad() + self.zero_grad() return - # TODO @botbw can be further refactored + # record all grads for unscale and clip grad_partition_groups = [] norm_groups = [] - for strategy in self._group_strategies: - strategy.pre_step() - grad_partition_groups.extend(strategy.working_grads) - norm_groups.append(strategy.get_grad_norm()) - strategy.zero_working_grad() + + # sometimes not all params are 'really' working + # for instance, when layer drop, the dropped layer has no grad + # and should not be updated + real_working_params = dict() + real_master_params = dict() + + for group_id in range(self.num_param_groups): + master_params = self._master_param_groups_of_current_rank[group_id] + working_params = self._working_param_groups[group_id] + real_working_params[group_id] = [] + real_master_params[group_id] = [] + working_grads = [] + for working_param, master_param in zip(working_params, master_params): + # if a working param requires grad and has no grad + # it is not 'really' working, e.g. the droped layer + # else the splited grad should be attached to the splited param + grad_store = self.pid2grad_store[id(working_param)] + grads = grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param)) + grad_index = 0 if self._partition_grads else grad_store.local_rank + if len(grads) > 0: + real_working_params[group_id].append(working_param) + grad = grads[grad_index] + # no need to copy fp32 grad if master_weights is False + if self._master_weights: + grad = grad.to(master_param.dtype).to(master_param.device) + master_param.grad = grad + grad_partition_groups.append(grad) + real_master_params[group_id].append(master_param) + + # compute norm + norm_group = 0 + for grad_store in self.grad_stores.values(): + working_grads = grad_store.get_working_grads_by_group_id(group_id) + norm_group += self._compute_grad_norm(pg=grad_store.torch_pg, gradients=working_grads) + + norm_groups.append(norm_group) + + # update the params in the optimizer + self.optim.param_groups[group_id]["params"] = real_master_params[group_id] # unscale and clip grads global_norm = calculate_global_norm_from_list(norm_list=norm_groups) @@ -187,28 +523,130 @@ def step(self, closure=None): # update the parameters self.optim.step() - for strategy in self._group_strategies: - strategy.post_step() + # release the grad + grad_partition_groups = [] + for group_id in range(self.num_param_groups): + release_param_grad(self._master_param_groups_of_current_rank[group_id]) + + # update working partition updated by the current rank + device = get_accelerator().get_current_device() + for group_id in range(self.num_param_groups): + master_working_param = self.optim.param_groups[group_id]["params"] + for idx, master_param in enumerate(master_working_param): + working_param = real_working_params[group_id][idx] + pg = self.param_to_pg[working_param] + all_splited_param = [ + torch.zeros(master_param.shape, device=device, dtype=self._dtype) for _ in range(pg.size()) + ] + dist.all_gather( + all_splited_param, + master_param.to(device).to(self._dtype), + group=pg, + ) + working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) + self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] + + def _compute_grad_norm(self, pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float: + r""" + Compute and return the gradient norm for gradient clipping. - @property - def require_grad_sync(self) -> bool: - flag_set = set() - for strategy in self._group_strategies: - flag_set.add(strategy.require_grad_sync) - assert len(flag_set) == 1, "require_grad_sync should be the same for all strategies" - return flag_set.pop() + Args: + gradients (List[Tensor]): The gradients to compute norm + norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2. + + Returns: + float: The total norm of given gradients + """ + + if len(gradients) == 0: + return 0.0 + + norm_type = float(norm_type) + if norm_type == inf: + total_norm = max(grad.data.abs().max() for grad in gradients) + total_norm_cuda = torch.tensor( + [float(total_norm)], + device=get_accelerator().get_current_device(), + dtype=torch.float, + ) + dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=pg) + total_norm = total_norm_cuda.item() + + else: + total_norm_exponentiated = 0.0 + for grad in gradients: + grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type + total_norm_exponentiated += grad_norm_exponentiated + + # Sum across all model parallel GPUs. + total_norm_exponentiated_cuda = torch.tensor( + [float(total_norm_exponentiated)], + device=get_accelerator().get_current_device(), + dtype=torch.float, + ) + torch.distributed.all_reduce( + total_norm_exponentiated_cuda, + op=torch.distributed.ReduceOp.SUM, + group=pg, + ) + total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) + + return total_norm + + ############################# + # Mixed Precision Utilities # + ############################# + + def _unscale_and_clip_grads(self, grad_groups_flat, total_norm): + # compute combined scale factor for this group + div_scale = 1.0 + if self.mixed_precision_mixin is not None: + div_scale = self.mixed_precision_mixin.get_grad_div_scale() + + if self._clip_grad_norm > 0.0: + # norm is in fact norm*scale + clip = ((total_norm / div_scale) + 1e-6) / self._clip_grad_norm + if clip > 1: + div_scale = clip * div_scale + + for grad in grad_groups_flat: + grad.data.mul_(1.0 / div_scale) + + ############################ + # Gradient Synchronization # + ############################ + + # this method is used to sync gradient manually + def _sync_grad(self): + for group_id in range(self.num_param_groups): + param_group = self._working_param_groups[group_id] + for param in param_group: + if param.requires_grad and param.grad is not None: + self._add_to_bucket(param, group_id) + + self._run_reduction() + + def _reduce_grad(self, partition_grad): + # if not overlapping communication (no reduction hook is attached) when zero1 + # we need to manually reduce these gradients + if not partition_grad and not self._overlap_communication: + self._sync_grad() + else: + self._run_reduction() # this context comes from pytorch DDP @contextmanager def no_sync(self): old_require_grad_sync = self.require_grad_sync - for strategy in self._group_strategies: - strategy.require_grad_sync = False + self.require_grad_sync = False try: yield finally: - for strategy in self._group_strategies: - strategy.require_grad_sync = old_require_grad_sync + self.require_grad_sync = old_require_grad_sync + + ############## + # State Dict # + ############## def _pack_state(self, state: Dict) -> Dict: # comes from pytorch optimizer.state_dict() @@ -237,13 +675,24 @@ def state_dict(self) -> Dict: Returns: Dict: the pytorch form state_dict """ - state_dict = {} - for strategy in self._group_strategies: - partial_dict = strategy.state_dict(self.optim) - assert len(set(partial_dict.keys()) & set(state_dict.keys())) == 0, "state_dict key conflict" - state_dict.update(partial_dict) - state_dict = self._pack_state(state_dict) - return state_dict + zero_state = dict() + device = get_accelerator().get_current_device() + for param, state in self.optim.state.items(): + zero_state[param] = copy.deepcopy(state) + for k, v in state.items(): + if isinstance(v, torch.Tensor) and k != "step": + working_param = self.master_to_working_param[id(param)] + pg = self.param_to_pg[working_param] + gather_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())] + dist.all_gather(gather_tensor, v.to(device), group=pg) + param_state = ( + torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() + ) + zero_state[param][k] = param_state + + states_dict = self._pack_state(zero_state) + + return states_dict def load_state_dict(self, state_dict: Dict): """Load state dict, requires the state_dict be the pytorch form @@ -252,12 +701,75 @@ def load_state_dict(self, state_dict: Dict): state_dict (dict): A pytorch form state_dict """ zero_state_dict = copy.deepcopy(state_dict) - # cannot load state_dict into torch.optim.Optimizer strategy by strategy - # due to torch internal param group assertion - # thus load first and then scatter + idx2master = {} + cnt = 0 + for param_group in self.optim.param_groups: + for param in param_group["params"]: + idx2master[cnt] = param + cnt += 1 + for param_idx, state in zero_state_dict["state"].items(): + pg = self.param_to_pg[self.master_to_working_param[id(idx2master[param_idx])]] + for k, v in state.items(): + if isinstance(v, torch.Tensor) and k != "step": + padding_size = (pg.size() - v.numel() % pg.size()) % pg.size() + with torch.no_grad(): + v = v.flatten() + if padding_size > 0: + v = torch.nn.functional.pad(v, [0, padding_size]) + v_list = v.split(v.numel() // pg.size()) + zero_state_dict["state"][param_idx][k] = v_list[pg.rank()].detach().clone() + self.optim.load_state_dict(zero_state_dict) - for strategy in self._group_strategies: - strategy.scatter_optim_state(self.optim.state) + + def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]: + """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. + Only include the 'state' in state_dict. + + Args: + max_shard_size (int, optional): max size of state shard (in MB). Defaults to 1024. + + Yields: + Iterator[OrderedDict]: A generator of state dict shard + """ + ret_block = dict() + ret_block_size = 0 + + device = get_accelerator().get_current_device() + local_states = self.optim.state_dict()["state"] + + idx2master = {} + cnt = 0 + for param_group in self.optim.param_groups: + for param in param_group["params"]: + idx2master[cnt] = param + cnt += 1 + for param_idx, states in local_states.items(): + current_block_size = 0 + current_block = copy.deepcopy(states) + + master_param = idx2master[param_idx] + working_param = self.master_to_working_param[id(master_param)] + pg = self.param_to_pg[working_param] + + for k, v in states.items(): + if isinstance(v, torch.Tensor) and k != "step": + state_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())] + dist.all_gather(state_tensor, v.to(device), group=pg) + state_tensor = ( + torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() + ) + current_block_size += state_tensor.numel() + current_block[k] = state_tensor + + if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0: + yield ret_block, ret_block_size + ret_block = dict() + ret_block_size = 0 + + ret_block[param_idx] = current_block + ret_block_size += current_block_size + + yield ret_block, ret_block_size def update_master_params(self, model: nn.Module) -> None: """Update master params from working params @@ -265,31 +777,74 @@ def update_master_params(self, model: nn.Module) -> None: Args: model (nn.Module): The model to update master params """ - for working_param in model.parameters(): - strategy = self.workingparam2strategy[working_param] - master_param = strategy.working2master(working_param=working_param) - strategy.update_master_param(master_param) + for p in model.parameters(): + p_id = id(p) + pg = self.param_to_pg[p] + if p_id in self.working_to_master_param: + master_param = self.working_to_master_param[p_id] + padding_size = self.get_param_padding_size(p) + working_param = p.data.view(-1) + if padding_size > 0: + working_param = torch.nn.functional.pad(working_param, [0, padding_size]) + master_param.copy_(working_param.chunk(pg.size())[pg.rank()]) def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: - mapp = {} - for strategy in self._group_strategies: - partial_map = strategy.working2master_map - assert len(set(partial_map.keys()) & set(mapp.keys())) == 0, "working_to_master_map key conflict" - mapp.update(partial_map) - return mapp + return self.working_to_master_param def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: - mapp = {} - for strategy in self._group_strategies: - partial_map = strategy.master2working_map - assert len(set(partial_map.keys()) & set(mapp.keys())) == 0, "master_to_working_map key conflict" - mapp.update(partial_map) - return mapp + return self.master_to_working_param def get_param_padding_map(self) -> Dict[int, torch.Tensor]: - mapp = {} - for strategy in self._group_strategies: - partial_map = strategy.padding_map - assert len(set(partial_map.keys()) & set(mapp.keys())) == 0, "param_padding_map key conflict" - mapp.update(partial_map) - return mapp + return self._padding_map + + def record_param_padding_size(self, param: Tensor, padding_size: int): + """Record the padding size of a param + + Args: + param (Tensor): The parameter + padding_size (int): The padding size of the parameter + """ + + self._padding_map[id(param)] = padding_size + + def get_param_padding_size(self, param: Tensor) -> int: + """Return the padding size of the parameter + + Args: + param (Tensor): The parameter + + Returns: + int: the padding size of the parameter + """ + + return self._padding_map[id(param)] + + def link_master_and_working_param(self, master_param: Tensor, working_param: Tensor): + """Mapping master parameter and working parameter + + Args: + master_param (Tensor): The parameter copy in optimizer + working_param (Tensor): The parameter of the model + """ + + self.master_to_working_param[id(master_param)] = working_param + self.working_to_master_param[id(working_param)] = master_param + + def get_padding_map(self) -> Dict[int, Tensor]: + """Return the padding map + + Returns: + Dict[int, Tensor]: The padding map + """ + + return self._padding_map + + def get_param_grad(self, working_param: nn.Parameter) -> Tensor: + grad_store = self.pid2grad_store[id(working_param)] + partial_grad = grad_store.get_working_grad_by_param_id(id(working_param)) + if partial_grad is None: + return None + tensor_list = [torch.empty_like(partial_grad) for _ in range(grad_store.world_size)] + dist.all_gather(tensor_list, partial_grad, group=grad_store.torch_pg) + grad_flat = torch.cat(tensor_list, dim=0) + return grad_flat[: working_param.numel()].reshape_as(working_param) diff --git a/colossalai/zero/low_level/low_level_strategy.py b/colossalai/zero/low_level/low_level_strategy.py deleted file mode 100644 index c8be5e0f7084..000000000000 --- a/colossalai/zero/low_level/low_level_strategy.py +++ /dev/null @@ -1,570 +0,0 @@ -# this code is inspired by the DeepSpeed library and implemented with our own design from scratch -import weakref -from abc import ABC, abstractmethod -from copy import deepcopy -from functools import partial -from typing import Any, Dict, List, Optional - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup - -from colossalai.accelerator import get_accelerator -from colossalai.tensor.moe_tensor.api import is_moe_tensor - -from ._utils import flatten, release_param_grad, sync_tensor -from .bookkeeping import BucketStore, GradientStore, ParameterStore - - -class LowLevelOptStrategyBase(ABC): - """ - Base class for low-level optimization strategies, this is to reduce the - coupling between different param group and corresponding process group - - This class contains necessary stores/data for optimizer: - 1. params bucket - 2. grads bucket - 3. reduce buckets - and necessary methods to do communication - """ - - # the store before refactoring supports multiple param groups - # but currently only one is used - DEFAULT_STORE_GROUP_ID = 0 - - def __init__( - self, - param_group, - dp_process_group, - master_weights, - partition_grad, - cpu_offload, - overlap_communication, - reduce_bucket_size, - communication_dtype, - ): - # param_group that current strategy is working on - self.param_group = param_group - self._dtype = self.param_group["params"][0].dtype - - if dp_process_group is None: # if dp_process_group is none, convert to default explicitly - dp_process_group = dist.group.WORLD - - self.dp_process_group = dp_process_group - - # if dp_process_group is none, will use the default one - self._local_rank = dist.get_rank(group=self.dp_process_group) - self._world_size = dist.get_world_size(group=self.dp_process_group) - - # master weights copy - self._master_weights = master_weights - - self._cpu_offload = cpu_offload - - # stage 2 - self._partition_grad = partition_grad - - # ParameterStore will manage the tensor buffers used for zero - # it will not manage the tensors used by mixed precision training - self._param_store = ParameterStore(dp_process_group) - self._grad_store = GradientStore(dp_process_group, partition_grad=partition_grad) - self._bucket_store = BucketStore(dp_process_group, reduce_bucket_size=reduce_bucket_size) - - # working and master params for mixed precision training - group_params = [] - for param in param_group["params"]: - if param.requires_grad: - group_params.append(param) - master_param_current_rank = self._create_master_param_current_rank(group_params) - param_group["params"] = master_param_current_rank - self.working_param_group: List[torch.Tensor] = group_params - self.master_param_group: List[torch.Tensor] = master_param_current_rank - - # by default this shouldn't be manipulate - self.require_grad_sync = True - - # communication params - self._overlap_communication = overlap_communication - self._communication_dtype = communication_dtype - - # initialize communication stream for - # communication-computation overlapping - if self._overlap_communication: - self._comm_stream = get_accelerator().Stream() - - # reduction hook is only used if overlapping communication - # or stage 2 is used - # if it is stage 1 without overlapping, no hook will be attached - self.grad_handles = [] - if self._overlap_communication or self._partition_grad: - self_weak_proxy = weakref.proxy(self) - - def _grad_handler(grad, param): - # if run with no_sync context, would not sync grad when backward - if self_weak_proxy.require_grad_sync: - self_weak_proxy._add_to_bucket(param) - return grad - - # we iterate over the working params - # on each param, we register a hook to its AccumulateGrad object - param_group = self.working_param_group - for param in param_group: - if param.requires_grad: - self.grad_handles.append( - param.register_post_accumulate_grad_hook(partial(_grad_handler, param=param)) - ) - - def __del__(self): - for handle in self.grad_handles: - handle.remove() - - def _create_master_param_current_rank(self, param_list): - # split each param evenly by world size - params_current_rank = [] - device = "cpu" if self._cpu_offload else get_accelerator().get_current_device() - - for param in param_list: - padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size - self._param_store.record_param_padding_size(param, padding_size) - - with torch.no_grad(): - if padding_size > 0: - padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) - # reset working params' ptr when no master weights - if self._master_weights == False: - param.data = padding_param[: param.numel()].view(param.shape) - else: - padding_param = param.data.view(-1) - - splited_params = padding_param.split(padding_param.numel() // self._world_size) - splited_params = splited_params[self._local_rank] - - # use fp32 when master_weights is True - if self._master_weights is True: - splited_param_current_rank = splited_params.detach().float().to(device) - else: - splited_param_current_rank = splited_params - - params_current_rank.append(splited_param_current_rank) - self._param_store.link_master_and_working_param(splited_param_current_rank, param) - - return params_current_rank - - def _update_unpartitoned_grad(self, origin_grad_list: List, flat_grad_list: List, group_id: int) -> None: - for rank, grad_list in enumerate(origin_grad_list): - sync_tensor(flat_grad_list[rank], grad_list) - for grad in grad_list: - param_id = self._bucket_store.get_param_id_of_grad(grad) - self._add_grad(grad, self._world_size, group_id, param_id, rank) - - def _update_partitoned_grad( - self, origin_grad_list: List, flat_grad: torch.Tensor, group_id: int, partition_num: int - ) -> None: - sync_tensor(flat_grad, origin_grad_list) - for grad in origin_grad_list: - param_id = self._bucket_store.get_param_id_of_grad(grad) - self._add_grad(grad, partition_num, group_id, param_id) - - def _add_grad(self, grad: torch.Tensor, partition_num: int, group_id: int, param_id: int, rank: int = 0) -> None: - if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num: - self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) - else: - self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) - - def _add_to_bucket(self, param): - param_size = param.numel() - - # check if the bucket is full - # if full, will reduce the grads already in the bucket - # or got a grad of param from another group - # after reduction, the bucket will be empty - if ( - self._bucket_store.num_elements_in_bucket() + param_size > self._bucket_store.reduce_bucket_size - or LowLevelOptStrategy.DEFAULT_STORE_GROUP_ID != self._bucket_store.current_group_id - ): - self._run_reduction() - - padding_size = self._param_store.get_param_padding_size(param) - self._bucket_store.add_param_grad(LowLevelOptStrategy.DEFAULT_STORE_GROUP_ID, param, padding_size) - - def _reduce_grad(self): - # if not overlapping communication (no reduction hook is attached) when zero1 - # we need to manually reduce these gradients - if not self._partition_grad and not self._overlap_communication: - self._sync_grad() - else: - self._run_reduction() - - def _sync_grad(self): - param_group = self.working_param_group - for param in param_group: - if param.requires_grad and param.grad is not None: - self._add_to_bucket(param) - - self._run_reduction() - - def _run_reduction(self): - if self._bucket_store.num_elements_in_bucket() <= 0: - return - - self._bucket_store.build_grad_in_bucket() - - flat_grads = self._bucket_store.get_flatten_grad() - flat_grads /= self._world_size - - # ready to add other tensors to bucket - self._bucket_store.reset_num_elements_in_bucket() - - if self._overlap_communication: - stream = self._comm_stream - # in case of the memory being reused in the default stream - flat_grads.record_stream(stream) - # waiting for ops in the default stream finishing - stream.wait_stream(get_accelerator().current_stream()) - else: - stream = get_accelerator().current_stream() - - with get_accelerator().stream(stream): - group_id = self._bucket_store.current_group_id - assert group_id == LowLevelOptStrategy.DEFAULT_STORE_GROUP_ID, "after refactoring, group_id should be 0" - - grad_dtype = flat_grads.dtype - if self._communication_dtype is not None: - flat_grads = flat_grads.to(self._communication_dtype) - - if not self._partition_grad: - dist.all_reduce(flat_grads, group=self.dp_process_group) - if flat_grads.dtype != grad_dtype: - flat_grads = flat_grads.to(grad_dtype) - - flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) - grad_in_bucket = self._bucket_store.get_grad() - self._update_unpartitoned_grad(grad_in_bucket.values(), flat_grads_per_rank, group_id) - else: - flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) - recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_process_group) - - if recieved_grad.dtype != grad_dtype: - recieved_grad = recieved_grad.to(grad_dtype) - - grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] - self._update_partitoned_grad(grad_in_bucket_current_rank, recieved_grad, group_id, 1) - - self._bucket_store.reset() - - ###################################################################### - # interfaces for child classes to manipulate the params, grads and buckets (and their stores) - @property - def master_params(self): - return self.master_param_group - - @property - def working_params(self): - return self.working_param_group - - @property - def working_grads(self): - return self._grad_store.get_working_grads_by_group_id(LowLevelOptStrategyBase.DEFAULT_STORE_GROUP_ID) - - @property - def master2working_map(self): - return self._param_store.master_to_working_param - - @property - def working2master_map(self): - return self._param_store.working_to_master_param - - @property - def padding_map(self): - return self._param_store._padding_map - - def master2working(self, master_param): - return self._param_store.master_to_working_param[id(master_param)] - - def working2master(self, working_param): - return self._param_store.working_to_master_param[id(working_param)] - - def get_param_padding_size(self, param): - return self._param_store.get_param_padding_size(param) - - def get_working_param_grads(self, working_param): - return self._grad_store.get_partitioned_gradients_by_param_id( - LowLevelOptStrategy.DEFAULT_STORE_GROUP_ID, id(working_param) - ) - - def state_dict(self, optim: torch.optim.Optimizer) -> Dict: - zero_state = {} - device = get_accelerator().get_current_device() - for working_param, master_param in zip(self.working_param_group, self.master_param_group): - zero_state[master_param] = deepcopy(optim.state[master_param]) - for k, v in zero_state[master_param].items(): - if isinstance(v, torch.Tensor) and k != "step": - gather_tensor = [ - torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size) - ] - dist.all_gather(gather_tensor, v, group=self.dp_process_group) - param_state = ( - torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() - ) - zero_state[master_param][k] = param_state - return zero_state - - def update_master_param(self, master_param): - working_param = self.master2working(master_param) - padding_size = self.get_param_padding_size(working_param) - working_param = working_param.data.view(-1) - if padding_size > 0: - working_param = torch.nn.functional.pad(working_param, [0, padding_size]) - master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) - - def get_grad_norm(self, norm_type: int = 2) -> float: - r""" - Compute and return the gradient norm for gradient clipping. - - Args: - gradients (List[Tensor]): The gradients to compute norm - norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2. - - Returns: - float: The total norm of given gradients - """ - gradients = self.working_grads - - norm_type = float(norm_type) - if norm_type == torch.inf: - total_norm = max(grad.data.abs().max() for grad in gradients) - total_norm_cuda = torch.tensor( - [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float - ) - dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_process_group) - total_norm = total_norm_cuda.item() - - else: - total_norm_exponentiated = 0.0 - for grad in gradients: - grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type - total_norm_exponentiated += grad_norm_exponentiated - - # Sum across all model parallel GPUs. - total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float - ) - torch.distributed.all_reduce( - total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_process_group - ) - total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) - - return total_norm - - def zero_grad(self, set_to_none=True): - param_group = self.working_param_group - for param in param_group: - if set_to_none: - param.grad = None - else: - if param.grad is not None: - param.grad.detach() - param.grad.zero_() - - def zero_working_grad(self): - self._grad_store.reset_grads_by_group_id(LowLevelOptStrategy.DEFAULT_STORE_GROUP_ID) - - def scatter_optim_state(self, optim_state): - with torch.no_grad(): - param_group = self.param_group - for param in param_group["params"]: - state = optim_state - for k, v in state.items(): - if isinstance(v, torch.Tensor) and k != "step": - padding_size = (self._world_size - v.numel() % self._world_size) % self._world_size - v = v.flatten() - if padding_size > 0: - v = torch.nn.functional.pad(v, [0, padding_size]) - v_list = v.split(v.numel() // self._world_size) - state[k] = v_list[self._local_rank].detach().clone() - - def get_param_grad(self, param): - grad_maybe_partial = self.get_working_param_grads(param) - if len(grad_maybe_partial) == 0: - return None - if self._partition_grad: - tensor_list = [torch.empty_like(grad_maybe_partial[0]) for _ in range(self._world_size)] - dist.all_gather(tensor_list, grad_maybe_partial[0], group=self.dp_process_group) - grad_flat = torch.cat(tensor_list, dim=0) - else: - grad_flat = torch.cat(grad_maybe_partial, dim=0) - return grad_flat[: param.numel()].reshape_as(param) - - ###################################################################### - # interfaces for child classes to implement, which will be called at - # corresponding stage in LowLevelOptimizer - - @abstractmethod - def pre_backward(self, loss, retain_graph=False) -> None: - raise NotImplementedError - - @abstractmethod - def post_backward(self) -> None: - raise NotImplementedError - - @abstractmethod - def pre_backward_by_grad(self, tensor, grad) -> None: - raise NotImplementedError - - @abstractmethod - def post_backward_by_grad(self) -> None: - raise NotImplementedError - - @abstractmethod - def pre_step(self) -> None: - raise NotImplementedError - - @abstractmethod - def post_step(self) -> None: - raise NotImplementedError - - -class LowLevelOptStrategy(LowLevelOptStrategyBase): - def __init__( - self, - param_group: Dict[str, Any], # from optimizer.param_groups - dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm - reduce_bucket_size: int = 1024 * 1024, # communication - communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = False, - partition_grad: bool = False, # stage 2 flag - cpu_offload: bool = False, # cpu offload - master_weights: bool = True, # master weights - ): - super().__init__( - param_group=param_group, - dp_process_group=dp_process_group, - cpu_offload=cpu_offload, - partition_grad=partition_grad, - master_weights=master_weights, - reduce_bucket_size=reduce_bucket_size, - communication_dtype=communication_dtype, - overlap_communication=overlap_communication, - ) - - # temporary variables - self.__saved_master_params = None - self.__saved_working_params = None - - ###################################################################### - # pre-backward: sanity check - # post-backward: deal with grads - - def pre_backward(self, loss, retain_graph=False): - assert not ( - self._partition_grad and not self.require_grad_sync - ), "ZeRO2(partition_grad) and no_sync are not compatible" - - def post_backward(self): - if not self.require_grad_sync: - return - - self._reduce_grad() - - # clear reduced grads - if self._overlap_communication: - get_accelerator().synchronize() - - for param in self.working_param_group: - assert param.grad is None, "unreduced grad are not removed" - - def pre_backward_by_grad(self, tensor, grad): - assert not ( - self._partition_grad and not self.require_grad_sync - ), "ZeRO2(partition_grad) and no_sync are not compatible" - - def post_backward_by_grad(self): - self.post_backward() - - def pre_step(self) -> None: - # sometimes not all params are 'really' working - # for instance, when layer drop, the dropped layer has no grad - # and should not be updated - grad_index = 0 if self._partition_grad else self._local_rank - real_master_params, real_working_params = [], [] - for working_param, master_param in zip(self.working_param_group, self.master_param_group): - # if a working param requires grad and has no grad - # it is not 'really' working, e.g. the droped layer - # else the splited grad should be attached to the splited param - grads = self.get_working_param_grads(working_param) - if len(grads) > 0: - real_master_params.append(master_param) - real_working_params.append(working_param) - grad = grads[grad_index] - # no need to copy fp32 grad if master_weights is False - if self._master_weights: - grad = grad.to(master_param.dtype).to(master_param.device) - # TODO @botbw: in original code, grad_partition_groups is used - # however it seems it's the same as working_grads as long as - # we update the grads in store correctly - grads[grad_index] = master_param.grad = grad - - # update the params in the optimizer and the working partition - # @botbw: to me, it seems like the original author only wants to keep the "real_xxx_params" when do the optimizer - # computation, and add "non real_xxx_params" back after since we might still need them for checkpoint - # not sure if it's necessary since None grads don't really bring lots of overhead - self.__saved_working_params = self.working_param_group - self.__saved_master_params = self.master_param_group - self.working_param_group = real_working_params - self.master_param_group = self.param_group["params"] = real_master_params - - def post_step(self): - release_param_grad(self.master_param_group) - - # update working partition updated by the current rank - device = get_accelerator().get_current_device() - for working_param, master_param in zip( - self.working_param_group, self.master_param_group - ): # initial value of zhe two group are stored in tmp variables - all_splited_param = [ - torch.zeros(master_param.shape, device=device, dtype=self._dtype) for _ in range(self._world_size) - ] - dist.all_gather(all_splited_param, master_param.to(device).to(self._dtype), group=self.dp_process_group) - working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) - - # restore tmp values - self.working_param_group = self.__saved_working_params - self.master_param_group = self.__saved_master_params - self.__saved_master_params = self.__saved_working_params = None - self.param_group["params"] = self.master_param_group - - -class MoeZeroStrategy(LowLevelOptStrategy): - def __init__( - self, - param_group: Dict[str, Any], # from optimizer.param_groups - reduce_bucket_size: int = 1024 * 1024, # communication - communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = False, - partition_grad: bool = False, # stage 2 flag - cpu_offload: bool = False, # cpu offload - dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm - master_weights: bool = True, # master weights - ): - for param in param_group["params"]: - if not is_moe_tensor(param): - raise ValueError(f"Mixture-of-Experts parameters are required for MoeZeroStrategy {type(param)}") - - super().__init__( - param_group=param_group, - dp_process_group=dp_process_group, - cpu_offload=cpu_offload, - partition_grad=partition_grad, - master_weights=master_weights, - reduce_bucket_size=reduce_bucket_size, - communication_dtype=communication_dtype, - overlap_communication=overlap_communication, - ) - - # def get_param_grad(self, param): # TODO @botbw: discuss whether it's intuitive to return grad of divided of full moe tensor - # moe_partial_grad = super().get_param_grad(param) - # moe_grad_list = [torch.empty_like(moe_partial_grad) for _ in range(self._world_size)] - # dist.all_gather(moe_grad_list, moe_partial_grad, group=self.dp_process_group) - # moe_grad = torch.cat(moe_grad_list, dim=0).reshape(param.shape[0] * self._world_size, *param.shape[1:]) - # return moe_grad diff --git a/tests/test_moe/test_moe_zero_fwd_bwd_optim.py b/tests/test_moe/test_moe_zero_fwd_bwd_optim.py index e4f288bf956f..c0340eb96f70 100644 --- a/tests/test_moe/test_moe_zero_fwd_bwd_optim.py +++ b/tests/test_moe/test_moe_zero_fwd_bwd_optim.py @@ -14,7 +14,6 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from colossalai.zero import LowLevelZeroOptimizer -from colossalai.zero.low_level.low_level_strategy import LowLevelOptStrategy, MoeZeroStrategy from tests.test_moe.moe_utils import loose_close tokens, n_experts = 7, 4 @@ -56,77 +55,65 @@ def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch. ori_model = DDP(orig_model.cuda(), static_graph=True).cuda() - zero_model = deepcopy(orig_model) + zero_model = deepcopy(orig_model).to(dtype) zero_model = EPMixtralSparseMoeBlock.from_native_module(zero_model, ep_group=plugin.ep_group) zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) - zero_params = list(filter(lambda x: not is_moe_tensor(x), zero_model.parameters())) - moe_params = list(filter(lambda x: is_moe_tensor(x), zero_model.parameters())) - zero_optimizer.param_groups.clear() - zero_optimizer.add_param_group({"params": zero_params}) - zero_optimizer.add_param_group({"params": moe_params}) - strategies = [ - LowLevelOptStrategy( - param_group=zero_optimizer.param_groups[0], - dp_process_group=plugin.global_dp_group, - overlap_communication=False, - partition_grad=(stage == 2), - ), - MoeZeroStrategy( - param_group=zero_optimizer.param_groups[1], - dp_process_group=plugin.moe_dp_group, - overlap_communication=True, - partition_grad=(stage == 2), - ), - ] + pg_param_list = {plugin.global_dp_group: [], plugin.moe_dp_group: []} + for p in zero_model.parameters(): + if is_moe_tensor(p): + pg_param_list[plugin.moe_dp_group].append(p) + else: + pg_param_list[plugin.global_dp_group].append(p) + zero_optimizer = LowLevelZeroOptimizer( zero_optimizer, - strategies, + pg_param_list=pg_param_list, master_weights=master_weights, initial_scale=1, + overlap_communication=False, + partition_grad=True, ) ori_optimizer = torch.optim.SGD(ori_model.parameters(), lr=1) # create seed_all(1453 + rank) - input_data = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda() - # zero-dp forward - zero_output, zero_logits = zero_model(input_data.to(dtype)) - # torch-ddp forward - ori_output, ori_logits = ori_model(input_data.to(dtype)) - loose_close(zero_output, ori_output, dtype=dtype) + for _ in range(2): + # zero-dp forward + input_data = torch.rand(1, tokens, hidden_size).cuda() + zero_output, zero_logits = zero_model(input_data.to(dtype)) + + # torch-ddp forward + ori_output, ori_logits = ori_model(input_data.to(dtype)) + loose_close(zero_output, ori_output, dtype=dtype) + + # zero-dp backward + zero_optimizer.backward(zero_output.mean().float()) - # zero-dp backward - zero_optimizer.backward(zero_output.mean().float()) + # torch-ddp backward + ori_output.mean().backward() - # torch-ddp backward - ori_output.mean().float().backward() + # check grad + name_to_p = {n: p for n, p in ori_model.module.named_parameters()} + for n, p in zero_model.named_parameters(): + zero_grad = zero_optimizer.get_param_grad(p) + if name_to_p[n].grad is None: + assert zero_grad is None + continue - # check grad - name_to_p = {n: p for n, p in ori_model.module.named_parameters()} - for n, p in zero_model.named_parameters(): - zero_grad = zero_optimizer.get_param_grad(p) - if p.grad is None: - """ - For fixed input seed, the test input may cause a certain expert not to be routed to, - so its gradient is None instead of a tensor, which may lead to a potential bug. - """ - # TODO(haze188) fix later - p.grad = torch.zeros_like(p) - continue - loose_close(zero_grad, name_to_p[n].grad, dtype=dtype) + loose_close(zero_grad, name_to_p[n].grad, dtype=dtype) - # zero-dp step - zero_optimizer.step() + # zero-dp step + zero_optimizer.step() - # original model step - ori_optimizer.step() + # original model step + ori_optimizer.step() - # check updated param - for n, p in zero_model.named_parameters(): - loose_close(p.data, name_to_p[n].data, dtype=dtype) + # check updated param + for n, p in zero_model.named_parameters(): + loose_close(p.data, name_to_p[n].data, dtype=dtype) def run_dist(rank, world_size, port): @@ -142,4 +129,4 @@ def test_moe_zero_model(world_size): if __name__ == "__main__": - test_moe_zero_model(world_size=2) + test_moe_zero_model(world_size=4) diff --git a/tests/test_zero/test_low_level/test_mem_leak.py b/tests/test_zero/test_low_level/test_mem_leak.py new file mode 100644 index 000000000000..7fa59ccc50c8 --- /dev/null +++ b/tests/test_zero/test_low_level/test_mem_leak.py @@ -0,0 +1,61 @@ +import pytest +import torch +import torch.nn as nn + +import colossalai +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.zero import LowLevelZeroOptimizer + + +class MlpModel(nn.Module): + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(123, 253) + + def forward(self, x): + x = self.linear1(x) + return x + + +DEL_CALLED = False + + +class TestLowLevelZeroOptimizer(LowLevelZeroOptimizer): + def __del__(self): + super().__del__() + global DEL_CALLED + DEL_CALLED = True + + +def exam_mem_leak(world_size): + """ + In this test, we test whether del will be called after the optimizer + is out of scope. + """ + # create models + zero_model = MlpModel().cuda() + + # we only test stage 1 here + # in `check_sharded_param_consistency.py`, we will test whether + # level 1 and 2 will produce exactly the same results + zero_optimizer = TestLowLevelZeroOptimizer(torch.optim.SGD(zero_model.parameters(), lr=1)) + + del zero_optimizer + + assert DEL_CALLED + + +def run_dist(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") + + exam_mem_leak(world_size=world_size) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_zero_1_2(): + spawn(run_dist, 2) + + +if __name__ == "__main__": + test_zero_1_2() diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py index 23baf6617b9a..8df35bdaa968 100644 --- a/tests/test_zero/test_low_level/test_zero1_2.py +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -123,7 +123,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): seed_all(1453) # create models - torch_model = MlpModel().cuda() + torch_model = MlpModel().cuda().to(dtype) zero_model = copy.deepcopy(torch_model).to(dtype) torch_model = DDP(torch_model.cuda(), static_graph=True).cuda() @@ -145,39 +145,41 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) seed_all(1453 + local_rank) - # create - input_data = torch.rand(32, 123).cuda() - # zero-dp forward - zero_output = zero_model(input_data.to(dtype)) + for _ in range(2): + # create + input_data = torch.rand(32, 123).cuda().to(dtype) - # torch-ddp forward - torch_output = torch_model(input_data) - loose_close(zero_output, torch_output, dtype=dtype) + # zero-dp forward + zero_output = zero_model(input_data) - # zero-dp backward - zero_optimizer.backward(zero_output.mean().float()) + # torch-ddp forward + torch_output = torch_model(input_data) + loose_close(zero_output, torch_output, dtype=dtype) - # torch-ddp backward - torch_output.mean().backward() + # zero-dp backward + zero_optimizer.backward(zero_output.mean()) - # check grad - for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - zero_grad = zero_optimizer.get_param_grad(z1p) - if p.grad is None: - assert zero_grad is None - continue - loose_close(p.grad, zero_grad, dtype=dtype) + # torch-ddp backward + torch_output.mean().backward() - # zero-dp step - zero_optimizer.step() + # check grad + for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + zero_grad = zero_optimizer.get_param_grad(z1p) + if p.grad is None: + assert zero_grad is None + continue + loose_close(p.grad, zero_grad, dtype=dtype) - # torch ddp step - torch_optimizer.step() + # zero-dp step + zero_optimizer.step() - # check updated param - for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - loose_close(p, z1p, dtype=dtype) + # torch ddp step + torch_optimizer.step() + + # check updated param + for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + loose_close(p, z1p, dtype=dtype) def run_dist(rank, world_size, port):