Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 48 additions & 34 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import OrderedDict
from contextlib import nullcontext
from functools import partial
from typing import Dict, Iterator, List, Optional, Union, Tuple, Set
from typing import Dict, Iterator, List, Optional, Set, Tuple, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -51,6 +51,7 @@ class ZeroDDP(ColoDDP):
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
scatter_after_inference (bool): If set to True, the model will be scattered after inference. This will save memory but slow down the consecutive inference.
mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16.
"""

def __init__(self,
Expand All @@ -59,7 +60,9 @@ def __init__(self,
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True) -> None:
scatter_after_inference: bool = True,
mixed_precision: torch.dtype = torch.float16) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32
Expand All @@ -71,6 +74,7 @@ def __init__(self,
self.param2name: Dict[nn.Parameter, str] = dict()
self.name2param: Dict[str, nn.Parameter] = dict()
self.scatter_after_inference = scatter_after_inference
self.mixed_precision = mixed_precision

self._logger = get_dist_logger()

Expand All @@ -96,34 +100,38 @@ def __init__(self,
param_name = m_name + '.' + p_name if m_name else p_name
self.name2param[param_name] = p_var
super().__init__(module, process_group=ColoProcessGroup())
self._non_persistent_buffers_set=self._get_non_persistent_buffers_set(module)
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
self._cast_buffers()

def _get_non_persistent_buffers_set(self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True):

r"""
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
"""

if memo is None:
memo = set()
self_non_persistent_set = set()
if module not in memo:
if remove_duplicate:
memo.add(module)
self_non_persistent_set = set(map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set))
for name, sub_module in module._modules.items():
if sub_module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, remove_duplicate)
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
return self_non_persistent_set

def _get_non_persistent_buffers_set(self,
module,
memo: Optional[Set[nn.Module]] = None,
prefix: str = '',
remove_duplicate: bool = True):
r"""
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
"""

if memo is None:
memo = set()
self_non_persistent_set = set()
if module not in memo:
if remove_duplicate:
memo.add(module)
self_non_persistent_set = set(
map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set))
for name, sub_module in module._modules.items():
if sub_module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix,
remove_duplicate)
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
return self_non_persistent_set

def _post_forward(self):
"""This function is only triggered for inference.
Expand All @@ -147,7 +155,7 @@ def forward(self, *args, **kwargs):
assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
), "You should run a completed iteration as your warmup iter"

args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
args, kwargs = _cast_float(args, self.mixed_precision), _cast_float(kwargs, self.mixed_precision)
self.module.zero_grad(set_to_none=True)
if not grad_flag:
outputs = self._inference_forward(*args, **kwargs)
Expand Down Expand Up @@ -566,14 +574,14 @@ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pi

# move ignored parameters to CUDA
if is_ddp_ignored(p):
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision)
continue

# create a fp32 parameter
fp32_data = p.data.float()
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
# create a fp16 parameter
p.data = p.data.half()
p.data = p.data.to(self.mixed_precision)

# register the fp16 parameter and fp32 parameter in the chunk manager
dp_world_size = p.process_group.dp_world_size()
Expand Down Expand Up @@ -609,7 +617,7 @@ def _cast_buffers(self):
buffer.materialize()
buffer.data = buffer.cuda()
if torch.is_floating_point(buffer):
buffer.data = buffer.half()
buffer.data = buffer.to(self.mixed_precision)

def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None:
"""Convert parameter to ColoParameter in-place.
Expand Down Expand Up @@ -732,6 +740,7 @@ def __init__(self,
hidden_dim: Optional[int] = None,
min_chunk_size_mb: float = 32,
memstats: Optional[MemStats] = None,
mixed_precision: torch.dtype = torch.float16,
verbose: bool = False) -> None:
"""
A torch.Module wrapper using ZeRO-DP and Gemini.
Expand Down Expand Up @@ -772,5 +781,10 @@ def __init__(self,
strict_ddp_flag=strict_ddp_mode,
verbose=verbose)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode,
scatter_after_inference)
super().__init__(module,
gemini_manager,
pin_memory,
force_outputs_fp32,
strict_ddp_mode,
scatter_after_inference,
mixed_precision=mixed_precision)
92 changes: 43 additions & 49 deletions colossalai/zero/gemini/gemini_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import math
import warnings
from enum import Enum
from typing import Any, Dict, Set, Tuple

import torch
import torch.distributed as dist
from torch.nn import Parameter
from torch.optim import Optimizer

from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
Expand All @@ -22,9 +21,26 @@
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}


class OptimState(Enum):
SCALED = 0
UNSCALED = 1
class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):

def __init__(self,
module: ZeroDDP,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32) -> None:
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
max_scale)
self.module = module

def check_local_overflow(self) -> bool:
return self.module.overflow_counter > 0

def pre_zero_grad(self) -> None:
self.module.overflow_counter = 0


class ZeroOptimizer(ColossalaiOptimizer):
Expand Down Expand Up @@ -79,7 +95,6 @@ def __init__(self,
self.module = module
self.gemini_manager = module.gemini_manager
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
self.optim_state = OptimState.UNSCALED
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
self.param_to_chunk32: Dict[Parameter, Chunk] = dict()
self.chunk16_set: Set[Chunk] = set()
Expand Down Expand Up @@ -107,15 +122,20 @@ def __init__(self,

self.__init__optimizer()

# Grad scaler
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
if module.mixed_precision is torch.float16:
self.mix_precision_mixin = GeminiFP16MixedPrecisionMixin(module,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
elif module.mixed_precision is torch.bfloat16:
self.mix_precision_mixin = BF16MixedPrecisionMixin()
else:
raise RuntimeError(f"Unsupported mixed precision type: {module.mixed_precision}")

self._logger = get_dist_logger()

self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
Expand Down Expand Up @@ -151,15 +171,6 @@ def _update_fp16_params(self):
for chunk16 in self.chunk16_set:
chunk16.optim_update()

def _check_overflow(self):
# clear previous overflow record
self._found_overflow.fill_(self.module.overflow_counter)

# all-reduce across global group
dist.all_reduce(self._found_overflow)

return self._found_overflow.item() > 0

def _clear_global_norm(self) -> None:
for c16 in self.chunk16_set:
c16.l2_norm = None
Expand Down Expand Up @@ -190,40 +201,25 @@ def _calc_global_norm(self) -> float:
return global_norm

def _get_combined_scale(self):
loss_scale = 1

if self.optim_state == OptimState.SCALED:
loss_scale = self.loss_scale
self.optim_state = OptimState.UNSCALED
div_scale = self.mix_precision_mixin.get_grad_div_scale()

combined_scale = loss_scale
if self.clipping_flag:
total_norm = self._calc_global_norm()
clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
if clip > 1:
combined_scale = clip * loss_scale
div_scale = clip * div_scale

if combined_scale == 1:
return -1
else:
return combined_scale

@property
def loss_scale(self):
return self.grad_scaler.scale.item()
return -1 if div_scale == 1.0 else div_scale

def zero_grad(self, *args, **kwargs):
self.module.overflow_counter = 0
self.mix_precision_mixin.pre_zero_grad()
return self.optim.zero_grad(set_to_none=True)

def step(self, *args, **kwargs):
self._maybe_move_fp32_params()
self._set_grad_ptr()

found_inf = self._check_overflow()
if found_inf:
self.optim_state = OptimState.UNSCALED # no need to unscale grad
self.grad_scaler.update(found_inf) # update gradient scaler
if self.mix_precision_mixin.should_skip_step():
if self.verbose:
self._logger.info(f'Found overflow. Skip step')
self._clear_global_norm() # clear recorded norm
Expand All @@ -234,7 +230,6 @@ def step(self, *args, **kwargs):
# get combined scale. combined scale = loss scale * clipping norm
# so that gradient = gradient / combined scale
combined_scale = self._get_combined_scale()
self.grad_scaler.update(found_inf)

ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
self._register_states()
Expand All @@ -246,16 +241,15 @@ def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: flo
raise NotImplementedError

def backward(self, loss: torch.Tensor):
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
loss = self.mix_precision_mixin.pre_backward(loss)
self.module.backward(loss)

def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
# This function is called except the last stage of pipeline parallel
# It receives the scaled grad from the previous rank
# No need to scale the grad again
# Need to unscale when optimizing
self.optim_state = OptimState.SCALED
grad = self.mix_precision_mixin.pre_backward_by_grad(grad)
self.module.backward_by_grad(tensor, grad)

def _maybe_move_fp32_params(self):
Expand Down
Loading