Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.distributed_c10d import _get_default_group
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.distributed.distributed_c10d import _get_default_group

from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import (
Expand Down Expand Up @@ -352,7 +352,7 @@ def __init__(
max_norm: float = 0.0,
norm_type: float = 2.0,
tp_size: int = 1,
extra_dp_size:int = 1,
extra_dp_size: int = 1,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
Expand Down Expand Up @@ -412,10 +412,14 @@ def __init__(
self.extra_dp_size = extra_dp_size
world_size = dist.get_world_size()
self.zero_size = world_size // (self.tp_size * self.extra_dp_size)
assert world_size == (self.tp_size * self.extra_dp_size) * self.zero_size, f"The global group size can't be evenly divided by the subgroup size."
assert (
world_size == (self.tp_size * self.extra_dp_size) * self.zero_size
), f"The global group size can't be evenly divided by the subgroup size."

self.pg_mesh = ProcessGroupMesh(self.zero_size, self.extra_dp_size, self.tp_size)
self.zero_group = self.pg_mesh.get_group_along_axis(ZERO_AXIS) if self.zero_size < world_size else _get_default_group()
self.zero_group = (
self.pg_mesh.get_group_along_axis(ZERO_AXIS) if self.zero_size < world_size else _get_default_group()
)
self.extra_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) if self.extra_dp_size > 1 else None
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if self.tp_size > 1 else None

Expand Down Expand Up @@ -469,7 +473,13 @@ def configure(
shardformer = ShardFormer(self.shard_config)
model, _ = shardformer.optimize(model)

model = GeminiDDP(model, **self.gemini_config, zero_group=self.zero_group, extra_dp_group=self.extra_dp_group, verbose=self.verbose)
model = GeminiDDP(
model,
**self.gemini_config,
zero_group=self.zero_group,
extra_dp_group=self.extra_dp_group,
verbose=self.verbose,
)

if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(
Expand Down
2 changes: 1 addition & 1 deletion colossalai/zero/gemini/chunk/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,4 +649,4 @@ def init_grad_chunk(self) -> "Chunk":
self.grad_chunk.l2_norm = None
alloc_storage(self.grad_chunk.cuda_global_chunk)

return self.grad_chunk
return self.grad_chunk
2 changes: 1 addition & 1 deletion colossalai/zero/gemini/chunk/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,4 +293,4 @@ def rearrange_accumulated_grad_chunk(self, chunk: Chunk) -> Chunk:
# Release accumulated_grad
free_storage(accumulated_grad)

return grad_chunk
return grad_chunk
2 changes: 1 addition & 1 deletion colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,4 +905,4 @@ def state_dict_shard(
if block is not None:
yield block, block_size

yield sharder.current_block, sharder.current_block_size
yield sharder.current_block, sharder.current_block_size
4 changes: 3 additions & 1 deletion examples/language/llama2/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ def empty_init():
model.config.num_hidden_layers,
model.config.hidden_size,
model.config.vocab_size,
args.grad_checkpoint, args.ignore_steps, dp_world_size=dp_size
args.grad_checkpoint,
args.ignore_steps,
dp_world_size=dp_size,
)

optimizer = HybridAdam(model.parameters())
Expand Down
8 changes: 6 additions & 2 deletions examples/language/llama2/performance_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,12 @@ def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
batch_size, seq_len = input_ids.shape

self.num_samples += batch_size
checkpoint_activations_factor = (3 + int(self.enable_grad_checkpoint))
self.flop_megatron += (24 * checkpoint_activations_factor * batch_size * seq_len * self.num_layers * (self.hidden_size**2)) * (1. + (seq_len / (6. * self.hidden_size)) + (self.vocab_size / (16. * self.num_layers * self.hidden_size)))
checkpoint_activations_factor = 3 + int(self.enable_grad_checkpoint)
self.flop_megatron += (
24 * checkpoint_activations_factor * batch_size * seq_len * self.num_layers * (self.hidden_size**2)
) * (
1.0 + (seq_len / (6.0 * self.hidden_size)) + (self.vocab_size / (16.0 * self.num_layers * self.hidden_size))
)
self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))

def on_fit_end(self) -> None:
Expand Down