diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 7b5aec2aa405..4196a10ba9f6 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -448,7 +448,7 @@ def configure( if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer( - optimizer, **zero_optim_kwargs, verbose=self.verbose + optimizer, **self.zero_optim_kwargs, verbose=self.verbose ) # inject update_master_params model.update_master_params = MethodType(optimizer.update_master_params, model) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index d366d1e339cd..b0210ac581d1 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -1,15 +1,11 @@ # this code is inspired by the DeepSpeed library and implemented with our own design from scratch import copy +from collections import defaultdict from contextlib import contextmanager -from functools import partial -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Dict, List, Optional import torch -import torch.distributed as dist import torch.nn as nn -from torch import Tensor, inf -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from torch.distributed import ProcessGroup from torch.optim import Optimizer from colossalai.accelerator import get_accelerator @@ -20,17 +16,15 @@ ) from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger -from colossalai.tensor.moe_tensor.api import is_moe_tensor +from colossalai.zero.low_level.low_level_strategy import LowLevelOptStrategy, LowLevelOptStrategyBase -from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor -from .bookkeeping import BucketStore, GradientStore, ParameterStore +from ._utils import calculate_global_norm_from_list, has_inf_or_nan class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): def __init__( self, - num_working_param_groups: int, - grad_store: GradientStore, + group_strategies: List[LowLevelOptStrategyBase], initial_scale: float = 2**16, min_scale: float = 1, growth_factor: float = 2, @@ -40,31 +34,23 @@ def __init__( max_scale: float = 2**32, ) -> None: super().__init__( - initial_scale, - min_scale, - growth_factor, - backoff_factor, - growth_interval, - hysteresis, - max_scale, + initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale ) - self.num_working_param_groups = num_working_param_groups - self.grad_store = grad_store + self.group_strategies = group_strategies def check_local_overflow(self) -> bool: - for group_id in range(self.num_working_param_groups): - for avg_grad in self.grad_store.get_working_grads_by_group_id(group_id): + for strategy in self.group_strategies: + for avg_grad in strategy.working_grads: if avg_grad is not None and has_inf_or_nan(avg_grad): return True return False class LowLevelZeroOptimizer(OptimizerWrapper): - """Optimizer used for ZeRO-1 and ZeRO-2.""" - def __init__( self, optimizer: Optimizer, + group_strategies: List[LowLevelOptStrategyBase] = None, initial_scale: int = 2**16, # grad scaler config min_scale: int = 1, growth_factor: float = 2.0, @@ -74,34 +60,17 @@ def __init__( max_scale: int = 2**24, clip_grad_norm: float = 0.0, # grad clipping verbose: bool = False, - reduce_bucket_size: int = 1024 * 1024, # communication - communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = False, - partition_grad: bool = False, # stage 2 flag - cpu_offload: bool = False, # cpu offload - dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm forced_dtype: Optional[torch.dtype] = None, - moe_extra_dp_process_group: Optional[ProcessGroup] = None, - master_weights: bool = True, # master weights + **strategy_kwargs, ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) - self._dtype = self.optim.param_groups[0]["params"][0].dtype self._logger = get_dist_logger() self._verbose = verbose - self._cpu_offload = cpu_offload - - # working and master params for mixed precision training - self._working_param_groups = dict() - self._master_param_groups_of_current_rank = dict() - # gradient clipping self._clip_grad_norm = clip_grad_norm - # master weights copy - self._master_weights = master_weights - if forced_dtype: for group in self.optim.param_groups: group_params = group["params"] @@ -112,79 +81,23 @@ def __init__( # check argument conflict self._sanity_checks() - # ParameterStore will manage the tensor buffers used for zero - # it will not manage the tensors used by mixed precision training - self._param_store = ParameterStore(dp_process_group) - self._grad_store = GradientStore(dp_process_group, partition_grad=partition_grad, require_grad_sync=True) - self._bucket_store = BucketStore( - dp_process_group, reduce_bucket_size, overlap_communication, communication_dtype, moe_extra_dp_process_group - ) - - # moe param should not be stored in working_groups - # because they have different parallel strategy - # so we need to store them separately in param_groups - # instead of working_groups - self.working_moe_params = list() - - # iterate over the param group in the optimizer - # partition these param groups for data parallel training - # and add buffers to parameter store for future access - for group_id, param_group in enumerate(self.optim.param_groups): - group_params = list() - for param in param_group["params"]: - if param.requires_grad: - if self._bucket_store.moe_extra_dp_pg is not None: - # skip moe param - if is_moe_tensor(param): - self.working_moe_params.append(param) - continue - group_params.append(param) - - # add the working params to working_param_groups for bookkeeping - self._working_param_groups[group_id] = group_params - - master_param_current_rank = self._create_master_param_current_rank(group_params) - self._master_param_groups_of_current_rank[group_id] = master_param_current_rank - - # need to replace the params in the `params` field in the optimizer - # so that when the optimizer calls step(), it only updates the tensors - # managed by this data parallel rank - param_group["params"] = master_param_current_rank - - # if there are moe params, store in addtional group in optim - if len(self.working_moe_params) > 0: - self._sync_master_param = False - param_group = dict() - # create fp32 master param - for key, value in self.optim.param_groups[0].items(): - if key != "params": - param_group[key] = value - self.master_moe_params = [] - for param in self.working_moe_params: - if self._master_weights: - self.master_moe_params.append(param.clone().to(torch.float32).detach()) - else: - self.master_moe_params.append(param.detach()) - # create mapping from master to working for optimizer io - self.moe_master_to_working_map = {} - for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): - self.moe_master_to_working_map[id(master_moe_param)] = working_moe_param - # add to optim - param_group["params"] = self.master_moe_params - self.optim.param_groups.append(param_group) + if len(self.optim.param_groups) == 1 and group_strategies is None: + group_strategies = [LowLevelOptStrategy(param_group=self.optim.param_groups[0], **strategy_kwargs)] + elif len(self.optim.param_groups) > 1 and group_strategies is None: + raise ValueError("group_strategies must be provided when the optimizer has multiple param groups") - # reduction hook is only used if overlapping communication - # or stage 2 is used - # if it is stage 1 without overlapping, no hook will be attached - if self._bucket_store._overlap_communication or self._grad_store._partition_grads: - self._attach_reduction_hook() + self.param2strategy: Dict[torch.nn.Parameter, LowLevelOptStrategyBase] = {} + for grp, strategy in zip(self.optim.param_groups, group_strategies): + assert grp["params"] is strategy.param_group["params"], "param groups should be in the same order" + for param in strategy.working_param_group: + self.param2strategy[param] = strategy + self._group_strategies = group_strategies # initialize mixed precision mixin self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None if self._dtype is torch.float16: self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin( - self.num_param_groups, - self._grad_store, + self._group_strategies, initial_scale=initial_scale, min_scale=min_scale, growth_factor=growth_factor, @@ -196,489 +109,86 @@ def __init__( elif self._dtype is torch.bfloat16: self.mixed_precision_mixin = BF16MixedPrecisionMixin() - def __del__(self): - self.remove_hooks() - - @property - def dtype(self): - return self._dtype - - @property - def num_param_groups(self): - return len(self._working_param_groups) - - def _sanity_checks(self): - assert get_accelerator().name in ["cuda", "npu"], "device is required" - for param_group in self.optim.param_groups: - group_params = param_group["params"] - for param in group_params: - if not hasattr(param, "skip_zero_check") or param.skip_zero_check is False: - assert ( - param.dtype == self._dtype - ), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" - - def _create_master_param_current_rank(self, param_list): - # split each param evenly by world size - params_current_rank = [] - device = "cpu" if self._cpu_offload else get_accelerator().get_current_device() - - for param in param_list: - padding_size = ( - self._bucket_store.zero_world_size - param.numel() % self._bucket_store.zero_world_size - ) % self._bucket_store.zero_world_size - self._param_store.record_param_padding_size(param, padding_size) - - with torch.no_grad(): - if padding_size > 0: - padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) - # reset working params' ptr when no master weights - if self._master_weights == False: - param.data = padding_param[: param.numel()].view(param.shape) - else: - padding_param = param.data.view(-1) - - if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(param): - splited_params = padding_param.split( - padding_param.numel() // self._bucket_store.moe_extra_dp_pg_size - ) - splited_params = splited_params[self._bucket_store.moe_extra_dp_pg_rank] - else: - splited_params = padding_param.split(padding_param.numel() // self._bucket_store.zero_world_size) - splited_params = splited_params[self._bucket_store.zero_local_rank] - - # use fp32 when master_weights is True - if self._master_weights is True: - splited_param_current_rank = splited_params.detach().float().to(device) - else: - splited_param_current_rank = splited_params - - # Send the splited view to the optimizer to match ZeRO 2 grad shape - params_current_rank.append(splited_param_current_rank) - self._param_store.link_master_and_working_param(splited_param_current_rank, param) - - return params_current_rank - - ########################### - # Backward Reduction Hook # - ########################### - - @staticmethod - def grad_handler( - param: nn.Parameter, - group_id: int, - bucket_store: BucketStore, - param_store: ParameterStore, - grad_store: GradientStore, - ): - # if run with no_sync context, would not sync grad when backward - if grad_store.require_grad_sync: - LowLevelZeroOptimizer.add_to_bucket(param, group_id, bucket_store, param_store, grad_store) - - def _attach_reduction_hook(self): - # we iterate over the working params - # on each param, we register a hook to its AccumulateGrad object - for group_id in range(self.num_param_groups): - param_group = self._working_param_groups[group_id] # TODO(haze188) refactor moe: moe-param hook for reduce - for param in param_group: - if param.requires_grad: - param._grad_handle = param.register_post_accumulate_grad_hook( - partial( - LowLevelZeroOptimizer.grad_handler, - group_id=group_id, - bucket_store=self._bucket_store, - param_store=self._param_store, - grad_store=self._grad_store, - ) - ) - - ####################### - # Reduction Functions # - ####################### - @staticmethod - def run_reduction(bucket_store: BucketStore, grad_store: GradientStore): - if bucket_store.num_elements_in_bucket() > 0: - bucket_store.build_grad_in_bucket() - if bucket_store.moe_extra_dp_pg is None: - flat_grads = bucket_store.get_flatten_grad() - flat_grads /= bucket_store.zero_world_size - else: - # record moe and non moe param - moe_list = [] - for param in bucket_store._param_list: - moe_list.append(is_moe_tensor(param)) - - # divide them into different groups - moe_grad_list = [] - non_moe_grad_list = [] - for grad_list in bucket_store._grad_in_bucket.values(): - non_moe_cur_grad = [] - moe_cur_grad = [] - for i in range(len(grad_list)): - if moe_list[i] == True: - moe_cur_grad.append(grad_list[i]) - else: - non_moe_cur_grad.append(grad_list[i]) - if len(moe_cur_grad) > 0: - moe_grad_list.append(moe_cur_grad) - if len(non_moe_cur_grad) > 0: - non_moe_grad_list.append(non_moe_cur_grad) - - if len(non_moe_grad_list) > 0: - non_moe_flat_grads = [] - for grad_list in non_moe_grad_list: - non_moe_flat_grads.append(_flatten_dense_tensors(grad_list)) - non_moe_flat_grads = _flatten_dense_tensors(non_moe_flat_grads) - non_moe_flat_grads /= bucket_store.zero_world_size - - if len(moe_grad_list) > 0: - moe_flat_grads = [] - for grad_list in moe_grad_list: - moe_flat_grads.append(_flatten_dense_tensors(grad_list)) - moe_flat_grads = _flatten_dense_tensors(moe_flat_grads) - - # ready to add other tensors to bucket - bucket_store.reset_num_elements_in_bucket() - - if bucket_store._overlap_communication: - stream = bucket_store.comm_stream - # in case of the memory being reused in the default stream - if bucket_store.moe_extra_dp_pg is None: - flat_grads.record_stream(stream) - else: - if len(non_moe_grad_list) > 0: - non_moe_flat_grads.record_stream(stream) - if len(moe_grad_list) > 0: - moe_flat_grads.record_stream(stream) - # waiting for ops in the default stream finishing - stream.wait_stream(get_accelerator().current_stream()) - else: - stream = get_accelerator().current_stream() - - with get_accelerator().stream(stream): - group_id = bucket_store.current_group_id - - if bucket_store.moe_extra_dp_pg is None: - grad_dtype = flat_grads.dtype - if bucket_store._communication_dtype is not None: - flat_grads = flat_grads.to(bucket_store._communication_dtype) - - if not grad_store._partition_grads: - if bucket_store.moe_extra_dp_pg is None: - dist.all_reduce(flat_grads, group=bucket_store.torch_pg) - if flat_grads.dtype != grad_dtype: - flat_grads = flat_grads.to(grad_dtype) - - flat_grads_per_rank = flat_grads.split(flat_grads.numel() // bucket_store.zero_world_size) - grad_in_bucket = bucket_store.get_grad() - LowLevelZeroOptimizer.update_unpartitoned_grad( - bucket_store, grad_store, grad_in_bucket.values(), flat_grads_per_rank, group_id - ) - - # sync extra zero group - else: - # sync non moe param in global dp group - - if len(non_moe_grad_list) > 0: - dist.all_reduce(non_moe_flat_grads, group=bucket_store.torch_pg) - flat_grads_per_rank = non_moe_flat_grads.split( - non_moe_flat_grads.numel() // bucket_store.zero_world_size - ) - LowLevelZeroOptimizer.update_unpartitoned_grad( - bucket_store, grad_store, non_moe_grad_list, flat_grads_per_rank, group_id - ) - - # sync moe param only in zero group - if len(moe_grad_list) > 0: - dist.all_reduce(moe_flat_grads, group=bucket_store.moe_extra_dp_pg) - flat_grads_per_rank = moe_flat_grads.split( - moe_flat_grads.numel() // bucket_store.zero_world_size - ) - LowLevelZeroOptimizer.update_unpartitoned_grad( - bucket_store, grad_store, moe_grad_list, flat_grads_per_rank, group_id - ) - - else: - if bucket_store.moe_extra_dp_pg is None: - flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.zero_world_size)) - received_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg) - if received_grad.dtype != grad_dtype: - received_grad = received_grad.to(grad_dtype) - - grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.zero_local_rank] - LowLevelZeroOptimizer.update_partitoned_grad( - bucket_store, grad_store, grad_in_bucket_current_rank, received_grad, group_id, 1 - ) - else: - # categorize moe and non moe param - grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.zero_local_rank] - moe_grad_in_bucket_current_rank = [] - non_moe_grad_in_bucket_current_rank = [] - for idx, grad in enumerate(grad_in_bucket_current_rank): - if moe_list[idx] == True: - moe_grad_in_bucket_current_rank.append(grad) - else: - non_moe_grad_in_bucket_current_rank.append(grad) - - if len(non_moe_grad_list) > 0: - flat_grads_list = list( - non_moe_flat_grads.split(len(non_moe_flat_grads) // bucket_store.zero_world_size) - ) - received_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg) - LowLevelZeroOptimizer.update_partitoned_grad( - bucket_store, - grad_store, - non_moe_grad_in_bucket_current_rank, - received_grad, - group_id, - 1, - ) - - if len(moe_grad_list) > 0: - flat_grads_list = list( - moe_flat_grads.split(len(moe_flat_grads) // bucket_store.moe_extra_dp_pg_size) - ) - received_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter( - received_grad, - flat_grads_list, - group=bucket_store.moe_extra_dp_pg, - ) - param_slice = bucket_store.zero_world_size // bucket_store.moe_extra_dp_pg_size - received_grad = list(received_grad.split(len(received_grad) // param_slice)) - for split_recieved_grad in received_grad: - split_recieved_grad = _unflatten_dense_tensors( - split_recieved_grad, moe_grad_in_bucket_current_rank - ) - for real_grad, grad in zip(split_recieved_grad, moe_grad_in_bucket_current_rank): - param_id = bucket_store.get_param_id_of_grad(grad) - LowLevelZeroOptimizer.add_grad( - grad_store, real_grad, param_slice, group_id, param_id - ) - - bucket_store.reset() - - @staticmethod - def update_unpartitoned_grad( - bucket_store: BucketStore, - grad_store: GradientStore, - origin_grad_list: List, - flat_grad_list: List, - group_id: int, - ) -> None: - for rank, grad_list in enumerate(origin_grad_list): - sync_tensor(flat_grad_list[rank], grad_list) - for grad in grad_list: - param_id = bucket_store.get_param_id_of_grad(grad) - LowLevelZeroOptimizer.add_grad(grad_store, grad, bucket_store.zero_world_size, group_id, param_id, rank) - - @staticmethod - def update_partitoned_grad( - bucket_store: BucketStore, - grad_store: GradientStore, - origin_grad_list: List, - flat_grad: torch.Tensor, - group_id: int, - partition_num: int, - ) -> None: - sync_tensor(flat_grad, origin_grad_list) - for grad in origin_grad_list: - param_id = bucket_store.get_param_id_of_grad(grad) - LowLevelZeroOptimizer.add_grad(grad_store, grad, partition_num, group_id, param_id) - - @staticmethod - def add_grad( - grad_store: GradientStore, - grad: torch.Tensor, - partition_num: int, - group_id: int, - param_id: int, - rank: int = 0, - ) -> None: - if len(grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num: - grad_store.append_gradients_by_param_id(grad, group_id, param_id) - else: - grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) - - @staticmethod - def add_to_bucket( - param: nn.Parameter, - group_id: int, - bucket_store: BucketStore, - param_store: ParameterStore, - grad_store: GradientStore, - ): - param_size = param.numel() - - # check if the bucket is full - # if full, will reduce the grads already in the bucket - # or got a grad of param from another group - # after reduction, the bucket will be empty - if ( - bucket_store.num_elements_in_bucket() + param_size > bucket_store.reduce_bucket_size - or group_id != bucket_store.current_group_id - ): - LowLevelZeroOptimizer.run_reduction(bucket_store, grad_store) - - padding_size = param_store.get_param_padding_size(param) - bucket_store.add_param_grad(group_id, param, padding_size) - - ################################ - # torch.optim.Optimizer methods - ################################ - def backward(self, loss, retain_graph=False): - assert not ( - self._grad_store._partition_grads and not self._grad_store.require_grad_sync - ), "ZeRO2(partition_grads) and no_sync are not compatible" + for strategy in self._group_strategies: + strategy.pre_backward(loss, retain_graph) if self.mixed_precision_mixin is not None: loss = self.mixed_precision_mixin.pre_backward(loss) loss.backward(retain_graph=retain_graph) - if not self._grad_store.require_grad_sync: - return - - self._reduce_grad(self._grad_store._partition_grads) - - # clear reduced grads - if self._bucket_store._overlap_communication: - get_accelerator().synchronize() - self.zero_grad() + for strategy in self._group_strategies: + strategy.post_backward() - def backward_by_grad(self, tensor, grad): - assert not ( - self._grad_store._partition_grads and not self._grad_store.require_grad_sync - ), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" + def state_dict(self) -> Dict: + """Return a state_dict same with DDP - if self.mixed_precision_mixin is not None: - grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) - torch.autograd.backward(tensor, grad) + Returns: + Dict: the pytorch form state_dict + """ + zero_state = dict() + device = get_accelerator().get_current_device() + for strategy in self._group_strategies: + param_group = strategy.param_group + for param in param_group: + state = self.optim.state[param] + zero_state[param] = copy.deepcopy(state) + for k, v in state.items(): + if isinstance(v, torch.Tensor) and k != "step": + param_state = strategy.allgather_optim_state(param, v) + zero_state[param][k] = param_state - if not self._grad_store.require_grad_sync: - return - self._reduce_grad(self._grad_store._partition_grads) + states_dict = self._pack_state(zero_state) - # clear reduced grads - if self._bucket_store._overlap_communication: - get_accelerator().synchronize() + return states_dict - self.zero_grad() + def load_state_dict(self, state_dict: Dict): + """Load state dict, requires the state_dict be the pytorch form - def zero_grad(self, set_to_none=True): + Args: + state_dict (dict): A pytorch form state_dict """ - Set parameter gradients to zero. If set_to_none = True, gradient - will be set to None to save memory. + zero_state_dict = copy.deepcopy(state_dict) + self.optim.load_state_dict(zero_state_dict) + for strategy in self._group_strategies: + strategy.scatter_optim_state(self.optim.state) - :param set_to_none: Whether set the gradient to None. Default value is True. - :type set_to_none: bool - """ - if self.mixed_precision_mixin is not None: - self.mixed_precision_mixin.pre_zero_grad() - for _, param_group in self._working_param_groups.items(): - for param in param_group: - if set_to_none: - param.grad = None - else: - if param.grad is not None: - param.grad.detach() - param.grad.zero_() - self._bucket_store.reset_all() + def update_master_params(self, model: nn.Module) -> None: + """Update master params from working params - #################### - # Update Parameter # - #################### + Args: + model (nn.Module): The model to update master params + """ + all_working_params = [] + for stategy in self._group_strategies: + all_working_params.extend(stategy.working_params) + stategy.update_master_params() + assert set(map(lambda x: id(x), all_working_params)) == set( + map(lambda x: id(x), model.parameters()) + ), "model parameters should be the same" def step(self, closure=None): assert closure is None, "closure is not supported by step()" - if not self._grad_store.require_grad_sync: + if not self.require_grad_sync: return if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step(): - self._grad_store.reset_all_gradients() if self._verbose: self._logger.info(f"Found overflow. Skip step") - self.zero_grad() + for strategy in self._group_strategies: + strategy.zero_working_grad() + strategy.zero_grad() return - # record all grads for unscale and clip + # TODO @botbw can be further refactored grad_partition_groups = [] norm_groups = [] - - # sometimes not all params are 'really' working - # for instance, when layer drop, the dropped layer has no grad - # and should not be updated - real_working_params = dict() - real_master_params = dict() - grad_index = 0 if self._grad_store._partition_grads else self._bucket_store.zero_local_rank - for group_id in range(self.num_param_groups): - master_params = self._master_param_groups_of_current_rank[group_id] - real_working_params[group_id] = [] - real_master_params[group_id] = [] - for splited_param in master_params: - working_param = self._param_store.master_to_working_param[id(splited_param)] - # if a working param requires grad and has no grad - # it is not 'really' working, e.g. the droped layer - # else the splited grad should be attached to the splited param - grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param)) - if len(grads) > 0: - # moe hybrid zero - if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor( - working_param - ): # TODO(@haze188) refactor: this code may be useless, never run - real_working_params[group_id].append(working_param) - if self._grad_store._partition_grads: - grad = grads - else: - param_slice = self._bucket_store.zero_world_size // self._bucket_store.moe_extra_dp_pg_size - grad = grads[ - self._bucket_store.moe_extra_dp_pg_rank - * param_slice : (self._bucket_store.moe_extra_dp_pg_rank + 1) - * param_slice - ] - grad = flatten(grad) - else: - real_working_params[group_id].append(working_param) - grad = grads[grad_index] - # no need to copy fp32 grad if master_weights is False - if self._master_weights: - grad = grad.to(splited_param.dtype).to(splited_param.device) - splited_param.grad = grad - grad_partition_groups.append(grad) - real_master_params[group_id].append(splited_param) - - # compute norm - working_grads = self._grad_store.get_working_grads_by_group_id(group_id) - norm_group = self._compute_grad_norm(gradients=working_grads) - norm_groups.append(norm_group) - - self._grad_store.reset_grads_by_group_id(group_id) - - # update the params in the optimizer - self.optim.param_groups[group_id]["params"] = real_master_params[group_id] - - # update param for moe ep - # move grad to master param and compute norm - - if len(self.working_moe_params) > 0: - moe_grads = [] - for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): - if master_moe_param.grad is not None: - raise RuntimeError("Moe param should not have grad here") - grad = working_moe_param.grad - # no need to copy fp32 grad if master_weights is False - if self._master_weights: - grad = grad.to(master_moe_param.dtype).to(master_moe_param.device) - master_moe_param.grad = grad - working_moe_param.grad = None - moe_grads.append(grad) - grad_partition_groups.append(grad) - norm_group = self._compute_grad_norm(gradients=moe_grads) - norm_groups.append(norm_group) - self.optim.param_groups[-1]["params"] = self.master_moe_params - del moe_grads + for strategy in self._group_strategies: + strategy.pre_step() + grad_partition_groups.extend(strategy.working_grads) + norm_groups.append(strategy.get_grad_norm()) + strategy.zero_working_grad() # unscale and clip grads global_norm = calculate_global_norm_from_list(norm_list=norm_groups) @@ -687,99 +197,30 @@ def step(self, closure=None): # update the parameters self.optim.step() - # release moe grad - if len(self.working_moe_params) > 0: - for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): - master_moe_param.grad = None - - working_moe_param.data = ( - master_moe_param.data.to(working_moe_param.device).to(working_moe_param.dtype).detach() - ) - - # release the grad - grad_partition_groups = [] - for group_id in range(self.num_param_groups): - release_param_grad(self._master_param_groups_of_current_rank[group_id]) - - # update working partition updated by the current rank - device = get_accelerator().get_current_device() - for group_id in range(self.num_param_groups): - master_working_param = self.optim.param_groups[group_id]["params"] - for idx, splited_param in enumerate(master_working_param): - working_param = real_working_params[group_id][idx] - if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(working_param): - all_splited_param = [ - torch.zeros(splited_param.shape, device=device, dtype=self._dtype) - for _ in range(self._bucket_store.moe_extra_dp_pg_size) - ] - dist.all_gather( - all_splited_param, - splited_param.to(device).to(self._dtype), - group=self._bucket_store.moe_extra_dp_pg, - ) - else: - all_splited_param = [ - torch.zeros(splited_param.shape, device=device, dtype=self._dtype) - for _ in range(self._bucket_store.zero_world_size) - ] - dist.all_gather( - all_splited_param, - splited_param.to(device).to(self._dtype), - group=self._bucket_store.torch_pg, - ) - working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) - self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] - - def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: - r""" - Compute and return the gradient norm for gradient clipping. - - Args: - gradients (List[Tensor]): The gradients to compute norm - norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2. - - Returns: - float: The total norm of given gradients - """ - - if len(gradients) == 0: - return 0.0 - - norm_type = float(norm_type) - if norm_type == inf: - total_norm = max(grad.data.abs().max() for grad in gradients) - total_norm_cuda = torch.tensor( - [float(total_norm)], - device=get_accelerator().get_current_device(), - dtype=torch.float, - ) - dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self._bucket_store.torch_pg) - total_norm = total_norm_cuda.item() - - else: - total_norm_exponentiated = 0.0 - for grad in gradients: - grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type - total_norm_exponentiated += grad_norm_exponentiated + for strategy in self._group_strategies: + strategy.post_step() - # Sum across all model parallel GPUs. - total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], - device=get_accelerator().get_current_device(), - dtype=torch.float, - ) - torch.distributed.all_reduce( - total_norm_exponentiated_cuda, - op=torch.distributed.ReduceOp.SUM, - group=self._bucket_store.torch_pg, - ) - total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) + @property + def require_grad_sync(self) -> bool: + flag_set = set() + for strategy in self._group_strategies: + flag_set.add(strategy.require_grad_sync) + assert len(flag_set) == 1, "require_grad_sync should be the same for all strategies" + return flag_set.pop() - return total_norm + # this context comes from pytorch DDP + @contextmanager + def no_sync(self): + old_require_grad_sync = self.require_grad_sync + for strategy in self._group_strategies: + strategy.require_grad_sync = False + try: + yield + finally: + for strategy in self._group_strategies: + strategy.require_grad_sync = old_require_grad_sync - ############################# - # Mixed Precision Utilities # - ############################# + ################################################################################## def _unscale_and_clip_grads(self, grad_groups_flat, total_norm): # compute combined scale factor for this group @@ -796,47 +237,21 @@ def _unscale_and_clip_grads(self, grad_groups_flat, total_norm): for grad in grad_groups_flat: grad.data.mul_(1.0 / div_scale) - ############################ - # Gradient Synchronization # - ############################ - - # this method is used to sync gradient manually - def _sync_grad(self): - for group_id in range(self.num_param_groups): - param_group = self._working_param_groups[group_id] - for param in param_group: - if param.requires_grad and param.grad is not None: - LowLevelZeroOptimizer.add_to_bucket( - param, - group_id, - self._bucket_store, - self._param_store, - self._grad_store, - ) - - LowLevelZeroOptimizer.run_reduction(self._bucket_store, self._grad_store) - - def _reduce_grad(self, partition_grad): - # if not overlapping communication (no reduction hook is attached) when zero1 - # we need to manually reduce these gradients - if not partition_grad and not self._bucket_store._overlap_communication: - self._sync_grad() - else: - LowLevelZeroOptimizer.run_reduction(self._bucket_store, self._grad_store) - - # this context comes from pytorch DDP - @contextmanager - def no_sync(self): - old_require_grad_sync = self._grad_store.require_grad_sync - self._grad_store.require_grad_sync = False - try: - yield - finally: - self._grad_store.require_grad_sync = old_require_grad_sync + def _sanity_checks(self): + assert get_accelerator().name in ["cuda", "npu"], "device is required" + inv = defaultdict(list) + for param_group in self.optim.param_groups: + group_params = param_group["params"] + for param in group_params: + inv[param].append(param_group) + assert ( + param.dtype == self._dtype + ), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" - ############## - # State Dict # - ############## + for _, grps in inv.items(): + assert ( + len(grps) == 1 + ), "Parameters should only appear in one group, since we assume that each strategy only manages one param group" def _pack_state(self, state: Dict) -> Dict: # comes from pytorch optimizer.state_dict() @@ -859,178 +274,8 @@ def pack_group(group): return {"state": packed_state, "param_groups": param_groups} - def state_dict(self) -> Dict: - """Return a state_dict same with DDP - - Returns: - Dict: the pytorch form state_dict - """ - zero_state = dict() - device = get_accelerator().get_current_device() - for param, state in self.optim.state.items(): - zero_state[param] = copy.deepcopy(state) - for k, v in state.items(): - if isinstance(v, torch.Tensor) and k != "step": - working_param = self._param_store.master_to_working_param[id(param)] - if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v): - gather_tensor = [ - torch.zeros(v.shape, device=device, dtype=v.dtype) - for _ in range(self._bucket_store.moe_extra_dp_pg_size) - ] - dist.all_gather(gather_tensor, v.to(device), group=self._bucket_store.moe_extra_dp_pg) - else: - gather_tensor = [ - torch.zeros(v.shape, device=device, dtype=v.dtype) - for _ in range(self._bucket_store.zero_world_size) - ] - dist.all_gather(gather_tensor, v.to(device), group=self._bucket_store.torch_pg) - param_state = ( - torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() - ) - zero_state[param][k] = param_state - - states_dict = self._pack_state(zero_state) - - return states_dict - - def load_state_dict(self, state_dict: Dict): - """Load state dict, requires the state_dict be the pytorch form - - Args: - state_dict (dict): A pytorch form state_dict - """ - zero_state_dict = copy.deepcopy(state_dict) - for param_idx, state in zero_state_dict["state"].items(): - for k, v in state.items(): - if isinstance(v, torch.Tensor) and k != "step": - padding_size = ( - self._bucket_store.zero_world_size - v.numel() % self._bucket_store.zero_world_size - ) % self._bucket_store.zero_world_size - with torch.no_grad(): - v = v.flatten() - if padding_size > 0: - v = torch.nn.functional.pad(v, [0, padding_size]) - if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v): - v_list = v.split(v.numel() // self._bucket_store.moe_extra_dp_pg_size) - zero_state_dict["state"][param_idx][k] = ( - v_list[self._bucket_store.moe_extra_dp_pg_rank].detach().clone() - ) - else: - v_list = v.split(v.numel() // self._bucket_store.zero_world_size) - zero_state_dict["state"][param_idx][k] = ( - v_list[self._bucket_store.zero_local_rank].detach().clone() - ) - - self.optim.load_state_dict(zero_state_dict) - - def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]: - """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. - Only include the 'state' in state_dict. - - Args: - max_shard_size (int, optional): max size of state shard (in MB). Defaults to 1024. - - Yields: - Iterator[OrderedDict]: A generator of state dict shard - """ - ret_block = dict() - ret_block_size = 0 - - device = get_accelerator().get_current_device() - local_states = self.optim.state_dict()["state"] - for param_idx, states in local_states.items(): - current_block_size = 0 - current_block = copy.deepcopy(states) - - # find the working param of current param_id - for group_id, pg in self._master_param_groups_of_current_rank.items(): - if (group_id + 1) * len(pg) < param_idx: - continue - master_param = pg[param_idx - (group_id) * len(pg)] - working_param = self._param_store.master_to_working_param[id(master_param)] - - for k, v in states.items(): - if isinstance(v, torch.Tensor) and k != "step": - if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v): - state_tensor = [ - torch.zeros(v.shape, device=device, dtype=v.dtype) - for _ in range(self._bucket_store.moe_extra_dp_pg_size) - ] - dist.all_gather(state_tensor, v.to(device), group=self._bucket_store.moe_extra_dp_pg) - else: - state_tensor = [ - torch.zeros(v.shape, device=device, dtype=v.dtype) - for _ in range(self._bucket_store.zero_world_size) - ] - dist.all_gather(state_tensor, v.to(device), group=self._bucket_store.torch_pg) - state_tensor = ( - torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() - ) - current_block_size += state_tensor.numel() - current_block[k] = state_tensor - - if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0: - yield ret_block, ret_block_size - ret_block = dict() - ret_block_size = 0 - - ret_block[param_idx] = current_block - ret_block_size += current_block_size - - yield ret_block, ret_block_size - - def update_master_params(self, model: nn.Module) -> None: - """Update master params from working params - - Args: - model (nn.Module): The model to update master params - """ - for p in model.parameters(): - p_id = id(p) - if p_id in self._param_store.working_to_master_param: - master_param = self._param_store.working_to_master_param[p_id] - padding_size = self._param_store.get_param_padding_size(p) - working_param = p.data.view(-1) - if padding_size > 0: - working_param = torch.nn.functional.pad(working_param, [0, padding_size]) - if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(p): - master_param.copy_( - working_param.chunk(self._bucket_store.moe_extra_dp_pg_size)[ - self._bucket_store.moe_extra_dp_pg_rank - ] - ) - else: - master_param.copy_( - working_param.chunk(self._bucket_store.zero_world_size)[self._bucket_store.zero_local_rank] - ) - if hasattr(self, "master_moe_params"): - for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): - master_moe_param.copy_(working_moe_param) - - def remove_hooks(self) -> None: - """remove the registered hooks - - Args: - plugin (LowLevelZeroPlugin): the plugin to bound this method. - """ - for group_id in range(self.num_param_groups): - param_group = self._working_param_groups[group_id] - for param in param_group: - if param.requires_grad: - assert hasattr(param, "_grad_handle") - param._grad_handle.remove() - delattr(param, "_grad_handle") - - def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: - return self._param_store.working_to_master_param - - def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: - if hasattr(self, "moe_master_to_working_map"): - return { - **self._param_store.master_to_working_param, - **self.moe_master_to_working_map, - } - return self._param_store.master_to_working_param - - def get_param_padding_map(self) -> Dict[int, torch.Tensor]: - return self._param_store.get_padding_map() + # another way of doing this is to reassign tensor.grad, however this won't apply for zero-2 + # since the shape doesn't match + def get_param_grad(self, param): + strategy = self.param2strategy[param] + return strategy.get_param_grad(param) diff --git a/colossalai/zero/low_level/low_level_strategy.py b/colossalai/zero/low_level/low_level_strategy.py new file mode 100644 index 000000000000..16effac9c80a --- /dev/null +++ b/colossalai/zero/low_level/low_level_strategy.py @@ -0,0 +1,533 @@ +# this code is inspired by the DeepSpeed library and implemented with our own design from scratch +from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Dict, List, Optional + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from colossalai.accelerator import get_accelerator +from colossalai.tensor.moe_tensor.api import is_moe_tensor + +from ._utils import flatten, release_param_grad, sync_tensor +from .bookkeeping import BucketStore, GradientStore, ParameterStore + + +class LowLevelOptStrategyBase(ABC): + """ + Base class for low-level optimization strategies, this is to reduce the + coupling between different param group and corresponding process group + + This class contains necessary stores/data for optimizer: + 1. params bucket + 2. grads bucket + 3. reduce buckets + and necessary methods to do communication + """ + + # the store before refactoring supports multiple param groups + # but currently only one is used + DEFAULT_STORE_GROUP_ID = 0 + + def __init__( + self, + param_group, + process_group, + master_weights, + partition_grad, + cpu_offload, + overlap_communication, + reduce_bucket_size, + communication_dtype, + ): + # param_group that current strategy is working on + self.param_group = param_group + self._dtype = self.param_group["params"][0].dtype + + if process_group is None: # if process_group is none, convert to default explicitly + process_group = dist.group.WORLD + + self.process_group = process_group + + # if process_group is none, will use the default one + self._local_rank = dist.get_rank(group=self.process_group) + self._world_size = dist.get_world_size(group=self.process_group) + + # master weights copy + self._master_weights = master_weights + + self._cpu_offload = cpu_offload + + # stage 2 + self._partition_grad = partition_grad + + # ParameterStore will manage the tensor buffers used for zero + # it will not manage the tensors used by mixed precision training + self._param_store = ParameterStore(process_group) + self._grad_store = GradientStore(process_group, partition_grad=partition_grad) + self._bucket_store = BucketStore(process_group) + + # working and master params for mixed precision training + group_params = [] + for param in param_group["params"]: + if param.requires_grad: + group_params.append(param) + master_param_current_rank = self._create_master_param_current_rank(group_params) + param_group["params"] = master_param_current_rank + self.working_param_group: List[torch.Tensor] = group_params + self.master_param_group: List[torch.Tensor] = master_param_current_rank + + # by default this shouldn't be manipulate + self.require_grad_sync = True + + # communication params + self._overlap_communication = overlap_communication + self._reduce_bucket_size = reduce_bucket_size + self._communication_dtype = communication_dtype + + # initialize communication stream for + # communication-computation overlapping + if self._overlap_communication: + self._comm_stream = get_accelerator().Stream() + + # reduction hook is only used if overlapping communication + # or stage 2 is used + # if it is stage 1 without overlapping, no hook will be attached + if self._overlap_communication or self._partition_grad: + # we iterate over the working params + # on each param, we register a hook to its AccumulateGrad object + param_group = self.working_param_group + for param in param_group: + if param.requires_grad: + + def _grad_handler(grad, param): + # if run with no_sync context, would not sync grad when backward + if self.require_grad_sync: + self._add_to_bucket(param) + return grad + + param.register_hook(partial(_grad_handler, param=param)) + + def _create_master_param_current_rank(self, param_list): + # split each param evenly by world size + params_current_rank = [] + device = "cpu" if self._cpu_offload else get_accelerator().get_current_device() + + for param in param_list: + padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size + self._param_store.record_param_padding_size(param, padding_size) + + with torch.no_grad(): + if padding_size > 0: + padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) + # reset working params' ptr when no master weights + if self._master_weights == False: + param.data = padding_param[: param.numel()].view(param.shape) + else: + padding_param = param.data.view(-1) + + splited_params = padding_param.split(padding_param.numel() // self._world_size) + splited_params = splited_params[self._local_rank] + + # use fp32 when master_weights is True + if self._master_weights is True: + splited_param_current_rank = splited_params.detach().float().to(device) + else: + splited_param_current_rank = splited_params + + params_current_rank.append(splited_param_current_rank) + self._param_store.link_master_and_working_param(splited_param_current_rank, param) + + return params_current_rank + + def _update_unpartitoned_grad(self, origin_grad_list: List, flat_grad_list: List, group_id: int) -> None: + for rank, grad_list in enumerate(origin_grad_list): + sync_tensor(flat_grad_list[rank], grad_list) + for grad in grad_list: + param_id = self._bucket_store.get_param_id_of_grad(grad) + self._add_grad(grad, self._world_size, group_id, param_id, rank) + + def _update_partitoned_grad( + self, origin_grad_list: List, flat_grad: torch.Tensor, group_id: int, partition_num: int + ) -> None: + sync_tensor(flat_grad, origin_grad_list) + for grad in origin_grad_list: + param_id = self._bucket_store.get_param_id_of_grad(grad) + self._add_grad(grad, partition_num, group_id, param_id) + + def _add_grad(self, grad: torch.Tensor, partition_num: int, group_id: int, param_id: int, rank: int = 0) -> None: + if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num: + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + else: + self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) + + def _add_to_bucket(self, param): + param_size = param.numel() + + # check if the bucket is full + # if full, will reduce the grads already in the bucket + # or got a grad of param from another group + # after reduction, the bucket will be empty + if ( + self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size + or LowLevelOptStrategy.DEFAULT_STORE_GROUP_ID != self._bucket_store.current_group_id + ): + self._run_reduction() + + padding_size = self._param_store.get_param_padding_size(param) + self._bucket_store.add_param_grad(LowLevelOptStrategy.DEFAULT_STORE_GROUP_ID, param, padding_size) + + def _reduce_grad(self): + # if not overlapping communication (no reduction hook is attached) when zero1 + # we need to manually reduce these gradients + if not self._partition_grad and not self._overlap_communication: + self._sync_grad() + else: + self._run_reduction() + + def _sync_grad(self): + param_group = self.working_param_group + for param in param_group: + if param.requires_grad and param.grad is not None: + self._add_to_bucket(param) + + self._run_reduction() + + def _run_reduction(self): + if self._bucket_store.num_elements_in_bucket() <= 0: + return + + self._bucket_store.build_grad_in_bucket() + + flat_grads = self._bucket_store.get_flatten_grad() + flat_grads /= self._world_size + + # ready to add other tensors to bucket + self._bucket_store.reset_num_elements_in_bucket() + + if self._overlap_communication: + stream = self._comm_stream + # in case of the memory being reused in the default stream + flat_grads.record_stream(stream) + # waiting for ops in the default stream finishing + stream.wait_stream(get_accelerator().current_stream()) + else: + stream = get_accelerator().current_stream() + + with get_accelerator().stream(stream): + group_id = self._bucket_store.current_group_id + assert group_id == LowLevelOptStrategy.DEFAULT_STORE_GROUP_ID, "after refactoring, group_id should be 0" + + grad_dtype = flat_grads.dtype + if self._communication_dtype is not None: + flat_grads = flat_grads.to(self._communication_dtype) + + if not self._partition_grad: + dist.all_reduce(flat_grads, group=self.process_group) + if flat_grads.dtype != grad_dtype: + flat_grads = flat_grads.to(grad_dtype) + + flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) + grad_in_bucket = self._bucket_store.get_grad() + self._update_unpartitoned_grad(grad_in_bucket.values(), flat_grads_per_rank, group_id) + else: + flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) + recieved_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.process_group) + + if recieved_grad.dtype != grad_dtype: + recieved_grad = recieved_grad.to(grad_dtype) + + grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] + self._update_partitoned_grad(grad_in_bucket_current_rank, recieved_grad, group_id, 1) + + self._bucket_store.reset() + + ###################################################################### + # interfaces for child classes to manipulate the params, grads and buckets (and their stores) + @property + def master_params(self): + return self.master_param_group + + @property + def working_params(self): + return self.working_param_group + + @property + def working_grads(self): + return self._grad_store.get_working_grads_by_group_id(LowLevelOptStrategyBase.DEFAULT_STORE_GROUP_ID) + + def get_param_padding_size(self, param): + return self._param_store.get_param_padding_size(param) + + def get_working_param_grads(self, working_param): + return self._grad_store.get_partitioned_gradients_by_param_id( + LowLevelOptStrategy.DEFAULT_STORE_GROUP_ID, id(working_param) + ) + + def update_master_params(self, working_param): + for working_param, master_param in zip(self.working_params, self.master_params): + padding_size = self.get_param_padding_size(working_param) + if padding_size > 0: + working_param = torch.nn.functional.pad(working_param, [0, padding_size]) + master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) + + def get_grad_norm(self, norm_type: int = 2) -> float: + r""" + Compute and return the gradient norm for gradient clipping. + + Args: + gradients (List[Tensor]): The gradients to compute norm + norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2. + + Returns: + float: The total norm of given gradients + """ + gradients = self.working_grads + + norm_type = float(norm_type) + if norm_type == torch.inf: + total_norm = max(grad.data.abs().max() for grad in gradients) + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float + ) + dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.process_group) + total_norm = total_norm_cuda.item() + + else: + total_norm_exponentiated = 0.0 + for grad in gradients: + grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type + total_norm_exponentiated += grad_norm_exponentiated + + # Sum across all model parallel GPUs. + total_norm_exponentiated_cuda = torch.tensor( + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float + ) + torch.distributed.all_reduce( + total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.process_group + ) + total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) + + return total_norm + + def zero_grad(self, set_to_none=True): + param_group = self.working_param_group + for param in param_group: + if set_to_none: + param.grad = None + else: + if param.grad is not None: + param.grad.detach() + param.grad.zero_() + + def zero_working_grad(self): + self._grad_store.reset_grads_by_group_id(LowLevelOptStrategy.DEFAULT_STORE_GROUP_ID) + + def allgather_optim_state(self, master_param, master_state) -> torch.Tensor: + device = get_accelerator().get_current_device() + working_param = self._param_store.master_to_working_param[id(master_param)] + gather_tensor = [ + torch.zeros(master_state.shape, device=device, dtype=master_state.dtype) for _ in range(self._world_size) + ] + dist.all_gather(gather_tensor, master_state, group=self.process_group) + param_state = torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() + return param_state + + def scatter_optim_state(self, optim_state): + with torch.no_grad(): + param_group = self.param_group + for param in param_group["params"]: + state = optim_state + for k, v in state.items(): + if isinstance(v, torch.Tensor) and k != "step": + padding_size = (self._world_size - v.numel() % self._world_size) % self._world_size + v = v.flatten() + if padding_size > 0: + v = torch.nn.functional.pad(v, [0, padding_size]) + v_list = v.split(v.numel() // self._world_size) + state[k] = v_list[self._local_rank].detach().clone() + + def get_param_grad(self, param): + grad_maybe_partial = self.get_working_param_grads(param) + if len(grad_maybe_partial) == 0: + return None + if self._partition_grad: + tensor_list = [torch.empty_like(grad_maybe_partial[0]) for _ in range(self._world_size)] + dist.all_gather(tensor_list, grad_maybe_partial[0], group=self.process_group) + grad_flat = torch.cat(tensor_list, dim=0) + else: + grad_flat = torch.cat(grad_maybe_partial, dim=0) + return grad_flat[: param.numel()].reshape_as(param) + + ###################################################################### + # interfaces for child classes to implement, which will be called at + # corresponding stage in LowLevelOptimizer + + @abstractmethod + def pre_backward(self, loss, retain_graph=False) -> None: + raise NotImplementedError + + @abstractmethod + def post_backward(self) -> None: + raise NotImplementedError + + @abstractmethod + def pre_backward_by_grad(self, tensor, grad) -> None: + raise NotImplementedError + + @abstractmethod + def post_backward_by_grad(self) -> None: + raise NotImplementedError + + @abstractmethod + def pre_step(self) -> None: + raise NotImplementedError + + @abstractmethod + def post_step(self) -> None: + raise NotImplementedError + + +class LowLevelOptStrategy(LowLevelOptStrategyBase): + def __init__( + self, + param_group: Dict[str, Any], # from optimizer.param_groups + process_group: Optional[ProcessGroup] = None, # the dp pg for comm + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = False, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload + master_weights: bool = True, # master weights + ): + super().__init__( + param_group=param_group, + process_group=process_group, + cpu_offload=cpu_offload, + partition_grad=partition_grad, + master_weights=master_weights, + reduce_bucket_size=reduce_bucket_size, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + ) + + # temporary variables + self.__saved_master_params = None + self.__saved_working_params = None + + ###################################################################### + # pre-backward: sanity check + # post-backward: deal with grads + + def pre_backward(self, loss, retain_graph=False): + assert not ( + self._partition_grad and not self.require_grad_sync + ), "ZeRO2(partition_grad) and no_sync are not compatible" + + def post_backward(self): + if not self.require_grad_sync: + return + + self._reduce_grad() + + # clear reduced grads + if self._overlap_communication: + get_accelerator().synchronize() + + for param in self.working_param_group: + assert param.grad is None, "unreduced grad are not removed" + + def pre_backward_by_grad(self, tensor, grad): + assert not ( + self._partition_grad and not self.require_grad_sync + ), "ZeRO2(partition_grad) and no_sync are not compatible" + + def post_backward_by_grad(self): + self.post_backward() + + def pre_step(self) -> None: + # sometimes not all params are 'really' working + # for instance, when layer drop, the dropped layer has no grad + # and should not be updated + grad_index = 0 if self._partition_grad else self._local_rank + real_master_params, real_working_params = [], [] + for working_param, master_param in zip(self.working_param_group, self.master_param_group): + # if a working param requires grad and has no grad + # it is not 'really' working, e.g. the droped layer + # else the splited grad should be attached to the splited param + grads = self.get_working_param_grads(working_param) + if len(grads) > 0: + real_master_params.append(master_param) + real_working_params.append(working_param) + grad = grads[grad_index] + # no need to copy fp32 grad if master_weights is False + if self._master_weights: + grad = grad.to(master_param.dtype).to(master_param.device) + # TODO @botbw: in original code, grad_partition_groups is used + # however it seems it's the same as working_grads as long as + # we update the grads in store correctly + grads[grad_index] = master_param.grad = grad + + # update the params in the optimizer and the working partition + # @botbw: to me, it seems like the original author only wants to keep the "real_xxx_params" when do the optimizer + # computation, and add "non real_xxx_params" back after since we might still need them for checkpoint + # not sure if it's necessary since None grads don't really bring lots of overhead + self.__saved_working_params = self.working_param_group + self.__saved_master_params = self.master_param_group + self.working_param_group = real_working_params + self.master_param_group = self.param_group["params"] = real_master_params + + def post_step(self): + release_param_grad(self.master_param_group) + + # update working partition updated by the current rank + device = get_accelerator().get_current_device() + for working_param, master_param in zip(self.working_param_group, self.master_param_group): + all_splited_param = [ + torch.zeros(master_param.shape, device=device, dtype=self._dtype) for _ in range(self._world_size) + ] + dist.all_gather(all_splited_param, master_param.to(device).to(self._dtype), group=self.process_group) + working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) + + # restore saved values + self.working_param_group = self.__saved_working_params + self.master_param_group = self.__saved_master_params + self.__saved_master_params = self.__saved_working_params = None + self.param_group["params"] = self.master_param_group + + +class MoeZeroStrategy(LowLevelOptStrategy): + def __init__( + self, + param_group: Dict[str, Any], # from optimizer.param_groups + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = False, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload + process_group: Optional[ProcessGroup] = None, # the dp pg for comm + master_weights: bool = True, # master weights + ): + for param in param_group["params"]: + if not is_moe_tensor(param): + raise ValueError(f"Mixture-of-Experts parameters are required for MoeZeroStrategy {type(param)}") + + super().__init__( + param_group=param_group, + process_group=process_group, + cpu_offload=cpu_offload, + partition_grad=partition_grad, + master_weights=master_weights, + reduce_bucket_size=reduce_bucket_size, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + ) + + # def get_param_grad(self, param): # TODO @botbw: discuss whether it's intuitive to return grad of divided of full moe tensor + # moe_partial_grad = super().get_param_grad(param) + # moe_grad_list = [torch.empty_like(moe_partial_grad) for _ in range(self._world_size)] + # dist.all_gather(moe_grad_list, moe_partial_grad, group=self.process_group) + # moe_grad = torch.cat(moe_grad_list, dim=0).reshape(param.shape[0] * self._world_size, *param.shape[1:]) + # return moe_grad diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index 0811f28bc8d7..131932dcb3b3 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -115,7 +115,6 @@ def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> for (local_name, local_param), (ep_name, ep_param) in zip( local_model.named_parameters(), ep_model.named_parameters() ): - assert local_name in ep_name, print(f"{local_name} != {ep_name}") if "experts" not in local_name: if assert_grad_flag: assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}" diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py new file mode 100644 index 000000000000..c0722881bfcd --- /dev/null +++ b/tests/test_moe/test_moe_zero_fwd_bwd.py @@ -0,0 +1,107 @@ +import pytest +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai +from colossalai.moe.manager import MOE_MANAGER +from colossalai.tensor.moe_tensor.api import is_moe_tensor +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all +from colossalai.zero.low_level.low_level_optim import LowLevelZeroOptimizer +from colossalai.zero.low_level.low_level_strategy import LowLevelOptStrategy, MoeZeroStrategy +from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, sync_local_from_ep + + +def run_zero_test(local_rank): + dp_size = world_size = dist.get_world_size() + assert world_size >= 4, f"{world_size=}: at least 4 processes are required for this test (ep=2, moe_dp=2)" + criterion = torch.nn.CrossEntropyLoss() + + ep_size = 2 + extra_dp_size = world_size // ep_size + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(parallel="EP", mode="fixed", fixed_dp_size=extra_dp_size, fixed_ep_size=ep_size, fixed_pp_size=1) + + zero_model = MoeModel().bfloat16().cuda() + + dp_group = dist.group.WORLD + ep_group = MOE_MANAGER.parallel_info_dict[ep_size].ep_group + moe_extra_dp_group = MOE_MANAGER.parallel_info_dict[ep_size].dp_group + + zero_params = list(filter(lambda x: not is_moe_tensor(x), zero_model.parameters())) + moe_params = list(filter(lambda x: is_moe_tensor(x), zero_model.parameters())) + print(f"{len(zero_params)=}, {len(moe_params)=}") + lr = 1e-3 + zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=lr) + zero_optimizer.param_groups.clear() + zero_optimizer.add_param_group({"params": zero_params}) + zero_optimizer.add_param_group({"params": moe_params}) + + strategies = [ + LowLevelOptStrategy( + param_group=zero_optimizer.param_groups[0], + process_group=dp_group, + overlap_communication=False, + partition_grad=True, + ), + MoeZeroStrategy( + param_group=zero_optimizer.param_groups[1], + process_group=moe_extra_dp_group, + overlap_communication=True, + partition_grad=False, + ), + ] + zero_optimizer = LowLevelZeroOptimizer( + zero_optimizer, + strategies, + ) + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(parallel=None) + ddp_model = DDP(MoeModel().bfloat16().cuda(), static_graph=True) + delete_moe_info(ddp_model) + torch_optim = torch.optim.SGD(ddp_model.parameters(), lr=lr) + sync_local_from_ep(ddp_model, zero_model) + + seed_all(42 + local_rank) + data = torch.randn(16, 4).bfloat16().cuda() + label = torch.randint(0, 4, (16,)).cuda() + + ddp_model.train() + zero_model.train() + ddp_out = criterion(ddp_model(data), label).float() + zero_out = criterion(zero_model(data), label).float() + assert torch.allclose(ddp_out, zero_out) + print(f"{local_rank=} {ddp_out.mean()=}") + + ddp_out.backward() + zero_optimizer.backward(zero_out) + + for (zero_name, zero_param), (ddp_name, ddp_param) in zip( + zero_model.named_parameters(), ddp_model.named_parameters() + ): + torch_grad = ddp_param.grad + zero_grad = zero_optimizer.get_param_grad(zero_param) + if is_moe_tensor(zero_param): + moe_grad_list = [torch.empty_like(zero_grad) for _ in range(ep_size)] + dist.all_gather(moe_grad_list, zero_grad, group=ep_group) + zero_grad = torch.cat(moe_grad_list, dim=0) + loose_close(torch_grad, zero_grad, dtype=torch_grad.dtype) + + +def run_dist(rank, world_size, port, stage): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_zero_test(rank, stage=stage) + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@rerun_if_address_is_in_use() +def test_moe_zero_model(world_size): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_moe_zero_model(world_size=4) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py new file mode 100644 index 000000000000..3bbd90fd6aac --- /dev/null +++ b/tests/test_moe/test_moe_zero_optim.py @@ -0,0 +1,125 @@ +import pytest +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai +from colossalai.moe.manager import MOE_MANAGER +from colossalai.tensor.moe_tensor.api import is_moe_tensor +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all +from colossalai.zero.low_level.low_level_optim import LowLevelZeroOptimizer +from colossalai.zero.low_level.low_level_strategy import LowLevelOptStrategy, MoeZeroStrategy +from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, sync_local_from_ep + + +def run_zero_test(local_rank): + dp_size = world_size = dist.get_world_size() + assert world_size >= 4, f"{world_size=}: at least 4 processes are required for this test (ep=2, moe_dp=2)" + criterion = torch.nn.CrossEntropyLoss() + + ep_size = 2 + extra_dp_size = world_size // ep_size + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(parallel="EP", mode="fixed", fixed_dp_size=extra_dp_size, fixed_ep_size=ep_size, fixed_pp_size=1) + + zero_model = MoeModel().bfloat16().cuda() + + dp_group = dist.group.WORLD + ep_group = MOE_MANAGER.parallel_info_dict[ep_size].ep_group + moe_extra_dp_group = MOE_MANAGER.parallel_info_dict[ep_size].dp_group + + zero_params = list(filter(lambda x: not is_moe_tensor(x), zero_model.parameters())) + moe_params = list(filter(lambda x: is_moe_tensor(x), zero_model.parameters())) + print(f"{len(zero_params)=}, {len(moe_params)=}") + lr = 1e-3 + zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=lr) + zero_optimizer.param_groups.clear() + zero_optimizer.add_param_group({"params": zero_params}) + zero_optimizer.add_param_group({"params": moe_params}) + + strategies = [ + LowLevelOptStrategy( + param_group=zero_optimizer.param_groups[0], + process_group=dp_group, + overlap_communication=False, + partition_grad=True, + ), + MoeZeroStrategy( + param_group=zero_optimizer.param_groups[1], + process_group=moe_extra_dp_group, + overlap_communication=True, + partition_grad=False, + ), + ] + zero_optimizer = LowLevelZeroOptimizer( + zero_optimizer, + strategies, + ) + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(parallel=None) + ddp_model = DDP(MoeModel().bfloat16().cuda(), static_graph=True) + delete_moe_info(ddp_model) + torch_optim = torch.optim.SGD(ddp_model.parameters(), lr=lr) + sync_local_from_ep(ddp_model, zero_model) + + seed_all(42 + local_rank) + data = torch.randn(16, 4).bfloat16().cuda() + label = torch.randint(0, 4, (16,)).cuda() + + ddp_model.train() + zero_model.train() + ddp_out = criterion(ddp_model(data), label).float() + zero_out = criterion(zero_model(data), label).float() + assert torch.allclose(ddp_out, zero_out) + print(f"{local_rank=} {ddp_out.mean()=}") + + ddp_out.backward() + zero_optimizer.backward(zero_out) + + for (zero_name, zero_param), (ddp_name, ddp_param) in zip( + zero_model.named_parameters(), ddp_model.named_parameters() + ): + torch_grad = ddp_param.grad + zero_grad = zero_optimizer.get_param_grad(zero_param) + if is_moe_tensor(zero_param): + moe_grad_list = [torch.empty_like(zero_grad) for _ in range(ep_size)] + dist.all_gather(moe_grad_list, zero_grad, group=ep_group) + zero_grad = torch.cat(moe_grad_list, dim=0) + loose_close(torch_grad, zero_grad, dtype=torch_grad.dtype) + + torch_optim.step() + zero_optimizer.step() + + for (zero_name, zero_param), (ddp_name, ddp_param) in zip( + zero_model.named_parameters(), ddp_model.named_parameters() + ): + if is_moe_tensor(zero_param): + moe_param_list = [torch.empty_like(zero_param) for _ in range(ep_size)] + dist.all_gather(moe_param_list, zero_param, group=ep_group) + zero_param = torch.cat(moe_param_list, dim=0) + assert ddp_param.dtype == zero_param.dtype + ddp_param.numel() // dp_size + loose_close( + ddp_param, + zero_param, + dtype=ddp_param.dtype, + ) + + +def run_dist(rank, world_size, port, stage): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_zero_test(rank, stage=stage) + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@rerun_if_address_is_in_use() +def test_moe_zero_model(world_size): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_moe_zero_model(world_size=4) diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py index 06a29bd1dde2..23baf6617b9a 100644 --- a/tests/test_zero/test_low_level/test_zero1_2.py +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -91,10 +91,13 @@ def exam_zero_1_2(): zero2_optimizer.backward(zero2_output.mean().float()) # check grad - z1g_list = zero1_optimizer._grad_store.get_working_grads_by_group_id(0) - z2g_list = zero2_optimizer._grad_store.get_working_grads_by_group_id(0) - for z1g, z2g in zip(z1g_list, z2g_list): - assert torch.equal(z1g, z2g) + for p1, p2 in zip(zero1_model.parameters(), zero2_model.parameters()): + g1 = zero1_optimizer.get_param_grad(p1) + g2 = zero2_optimizer.get_param_grad(p2) + if g1 is None or g2 is None: + assert g1 is None and g2 is None + continue + assert torch.allclose(g1, g2) # step zero1_optimizer.step() @@ -102,7 +105,7 @@ def exam_zero_1_2(): # check updated param for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): - assert torch.equal(z1p.data, z2p.data) + assert torch.allclose(z1p, z2p) @parameterize("dtype", [torch.float16, torch.bfloat16]) @@ -160,11 +163,11 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): # check grad for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - if p.grad is not None: - zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(z1p)) - torch_grad_list = split_ddp_grad(p.grad, world_size) - for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list): - loose_close(zero_grad, torch_grad, dtype=dtype) + zero_grad = zero_optimizer.get_param_grad(z1p) + if p.grad is None: + assert zero_grad is None + continue + loose_close(p.grad, zero_grad, dtype=dtype) # zero-dp step zero_optimizer.step() @@ -174,7 +177,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): # check updated param for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - loose_close(p.data, z1p.data, dtype=dtype) + loose_close(p, z1p, dtype=dtype) def run_dist(rank, world_size, port):