Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions colossalai/gemini/chunk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,18 @@
from colossalai.utils import is_ddp_ignored


def safe_div(a, b):
if a == 0:
return 0
return a / b


def init_chunk_manager(model: nn.Module,
init_device: Optional[torch.device] = None,
hidden_dim: Optional[int] = None,
search_range_mb: Optional[float] = None,
min_chunk_size_mb: Optional[float] = None,
filter_exlarge_params: Optional[bool] = None) -> ChunkManager:

kwargs_dict = dict()

if hidden_dim:
Expand Down Expand Up @@ -50,7 +55,7 @@ def init_chunk_manager(model: nn.Module,
if dist.get_rank() == 0:
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
"used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size),
"total wasted percentage is {:.2f}%".format(100 * wasted_size / (total_size + wasted_size)),
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
sep='',
flush=True)
dist.barrier()
Expand Down
15 changes: 13 additions & 2 deletions colossalai/nn/optimizer/zero_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import warnings
from enum import Enum
from typing import Any, Dict, Set, Tuple

Expand Down Expand Up @@ -78,8 +79,16 @@ def __init__(self,
if self.clipping_flag:
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"

params_list = [p for p in module.parameters() if not is_ddp_ignored(p)]
for p, fp32_p in zip(params_list, module.fp32_params):
ddp_param_list = []
for name, param in module.named_parameters():
if is_ddp_ignored(param):
if param.requires_grad:
warnings.warn(f"Parameter `{name}` is ignored by DDP but requires gradient! "
"You should handle its optimizer update by yourself!")
else:
ddp_param_list.append(param)

for p, fp32_p in zip(ddp_param_list, module.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
if chunk_16 not in self.chunk16_set:
chunk_16.l2_norm_flag = self.clipping_flag
Expand Down Expand Up @@ -290,6 +299,8 @@ def get_range_pair(local_chunk: Chunk, local_param: Parameter):
fake_params_list = list()

for param in group['params']:
if is_ddp_ignored(param):
continue
chunk16 = self.chunk_manager.get_chunk(param)
range_pair = get_range_pair(chunk16, param)
if range_pair[0] >= range_pair[1]:
Expand Down