Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.distributed.distributed_c10d import _get_default_group

from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import (
Expand All @@ -34,8 +35,7 @@
SUPPORTED_PRECISION = ["fp16", "bf16"]
PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}

DP_AXIS = 0
TP_AXIS = 1
ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2

def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
Expand Down Expand Up @@ -304,8 +304,8 @@ class GeminiPlugin(DPPluginBase):
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
norm_type (float, optional): norm_type used for `clip_grad_norm`.
enable_tensor_parallelism (bool, optional): Whether to use tensor parallelism strategy, which is implemented in Shardformer. Default to False.
tp_size (int, optional): If 'enable_tensor_parallelism' is set to true, please configure 'tp_size' which determines the size of the tensor parallel process group. Default to 1.
tp_size (int, optional): If 'tp_size' is set to be greater than 1, it means using tensor parallelism strategy, which is implemented in Shardformer, 'tp_size' determines the size of the tensor parallel process group. Default to 1.
extra_dp_size (int, optional): If 'extra_dp_size' is set to be greater than 1, it means creating another group to run with a ddp-like strategy. Default to 1.
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
Currently all the optimization methods include fused normalization, flash attention and JIT.
Defaults to False.
Expand Down Expand Up @@ -347,8 +347,8 @@ def __init__(
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
enable_tensor_parallelism: bool = False,
tp_size: int = 1,
extra_dp_size:int = 1,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
Expand Down Expand Up @@ -393,7 +393,7 @@ def __init__(
max_norm=max_norm,
norm_type=norm_type,
)
self.enable_tensor_parallelism = enable_tensor_parallelism
self.enable_tensor_parallelism = tp_size > 1
self.enable_all_optimization = enable_all_optimization
self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention
Expand All @@ -402,12 +402,17 @@ def __init__(
self.enable_sequence_overlap = enable_sequence_overlap
self.verbose = verbose

self.tp_size = tp_size if self.enable_tensor_parallelism else 1
self.dp_size = dist.get_world_size() // self.tp_size
assert self.dp_size > 1, f"The size of the DP group should be greater than 1. Please reduce the TP group size."
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.tp_size)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.tp_size = tp_size
self.extra_dp_size = extra_dp_size
world_size = dist.get_world_size()
self.zero_size = world_size // (self.tp_size * self.extra_dp_size)
assert world_size == (self.tp_size * self.extra_dp_size) * self.zero_size, f"The global group size can't be evenly divided by the subgroup size."

self.pg_mesh = ProcessGroupMesh(self.zero_size, self.extra_dp_size, self.tp_size)
self.zero_group = self.pg_mesh.get_group_along_axis(ZERO_AXIS) if self.zero_size < world_size else _get_default_group()
self.extra_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) if self.extra_dp_size > 1 else None
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if self.tp_size > 1 else None

self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
enable_tensor_parallelism=self.enable_tensor_parallelism,
Expand Down Expand Up @@ -458,7 +463,7 @@ def configure(
shardformer = ShardFormer(self.shard_config)
model, _ = shardformer.optimize(model)

model = GeminiDDP(model, **self.gemini_config, process_group=self.dp_group, verbose=self.verbose)
model = GeminiDDP(model, **self.gemini_config, zero_group=self.zero_group, extra_dp_group=self.extra_dp_group, verbose=self.verbose)

if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(
Expand Down
20 changes: 15 additions & 5 deletions colossalai/zero/gemini/chunk/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,13 @@ class Chunk:
def __init__(
self,
chunk_size: int,
process_group: ProcessGroup,
zero_group: ProcessGroup,
dtype: torch.dtype,
init_device: Optional[torch.device] = None,
cpu_shard_init: bool = False,
keep_gathered: bool = False,
pin_memory: bool = False,
extra_dp_group: ProcessGroup = None,
) -> None:
"""
Chunk: A container owning a piece of contiguous memory space for tensors
Expand All @@ -76,7 +77,7 @@ def __init__(

Args:
chunk_size (int): the number of elements in the chunk
process_group (ProcessGroup): the process group of this chunk
zero_group (ProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, During the chunk construction process, where the tensor is stored.
The default value is None, which is the current GPU
Expand All @@ -90,9 +91,11 @@ def __init__(
self.chunk_size = chunk_size
self.utilized_size = 0

self.torch_pg = process_group
self.torch_pg = zero_group
self.pg_size = dist.get_world_size(self.torch_pg)
self.pg_rank = dist.get_rank(self.torch_pg)
self.extra_dp_group = extra_dp_group
self.extra_dp_size = dist.get_world_size(self.extra_dp_group) if self.extra_dp_group is not None else 1

# the chunk size should be divisible by the dp degree
if not keep_gathered:
Expand Down Expand Up @@ -384,14 +387,20 @@ def reduce(self):
# just move cuda_global_chunk to cuda_shard
# the communication is not necessary
self.__scatter()
if self.extra_dp_group is not None:
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
elif self.keep_gathered:
# we use all-reduce here
dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg)
if self.extra_dp_group is not None:
dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group)
else:
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())

input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
if self.extra_dp_group is not None:
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)

free_storage(self.cuda_global_chunk)
self.is_gathered = False
Expand Down Expand Up @@ -608,10 +617,11 @@ def init_grad_chunk(self) -> "Chunk":
# grad chunk is not initialized
grad_chunk = Chunk(
chunk_size=self.chunk_size,
process_group=self.torch_pg,
zero_group=self.torch_pg,
dtype=self.dtype,
keep_gathered=self.keep_gathered,
pin_memory=self.pin_memory,
extra_dp_group=self.extra_dp_group,
)
grad_chunk.num_tensors = self.num_tensors
grad_chunk.utilized_size = self.utilized_size
Expand Down Expand Up @@ -640,4 +650,4 @@ def init_grad_chunk(self) -> "Chunk":
self.grad_chunk.l2_norm = None
alloc_storage(self.grad_chunk.cuda_global_chunk)

return self.grad_chunk
return self.grad_chunk
10 changes: 6 additions & 4 deletions colossalai/zero/gemini/chunk/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def register_tensor(
tensor: torch.Tensor,
group_type: str,
config_key: int,
process_group: ProcessGroup,
zero_group: ProcessGroup,
extra_dp_group: ProcessGroup = None,
cpu_offload: bool = False,
pin_memory: bool = False,
) -> None:
Expand Down Expand Up @@ -76,15 +77,16 @@ def register_tensor(

if tensor.numel() > chunk_size:
chunk_size = tensor.numel()
dp_size = dist.get_world_size(process_group)
dp_size = dist.get_world_size(zero_group)
chunk_size = chunk_size + (-chunk_size % dp_size)

chunk = Chunk(
chunk_size=chunk_size,
process_group=process_group,
zero_group=zero_group,
dtype=tensor.dtype,
cpu_shard_init=cpu_offload,
pin_memory=pin_memory,
extra_dp_group=extra_dp_group,
**chunk_kwargs,
)

Expand Down Expand Up @@ -288,4 +290,4 @@ def rearrange_accumulated_grad_chunk(self, chunk: Chunk) -> Chunk:
# Release accumulated_grad
free_storage(accumulated_grad)

return grad_chunk
return grad_chunk
26 changes: 17 additions & 9 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,10 @@ def __init__(
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True,
mixed_precision: torch.dtype = torch.float16,
process_group: Optional[ProcessGroup] = None,
zero_group: Optional[ProcessGroup] = None,
memstats: Optional[MemStats] = None, # genimi memory stats
master_weights: bool = True,
extra_dp_group: Optional[ProcessGroup] = None,
verbose: bool = False,
) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
Expand All @@ -105,7 +106,7 @@ def __init__(
search_range_m=search_range_m,
min_chunk_size_m=min_chunk_size_m,
strict_ddp_flag=strict_ddp_mode,
process_group=process_group,
process_group=zero_group,
verbose=verbose,
)
self.gemini_manager = GeminiManager(
Expand All @@ -128,7 +129,8 @@ def __init__(
self.name2param: Dict[str, nn.Parameter] = dict()
self.scatter_after_inference = scatter_after_inference
self.mixed_precision = mixed_precision
self.dp_process_group = process_group or _get_default_group()
self.zero_group = zero_group or _get_default_group()
self.extra_dp_group = extra_dp_group

self.reuse_fp16_chunk = master_weights
self.master_weights = master_weights
Expand Down Expand Up @@ -377,8 +379,12 @@ def grad_handle(self, p, grad):
self.chunk_manager.release_chunk(chunk)
if grad_chunk.is_gathered:
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
if self.extra_dp_group is not None:
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
else:
grad_chunk.cuda_shard.div_(chunk.pg_size)
if self.extra_dp_group is not None:
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
# check overflow elements
self.overflow_counter += grad_chunk.has_inf_or_nan
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
Expand Down Expand Up @@ -733,7 +739,7 @@ def load_parameter(chunk_slice, data):
unexpected_keys.append(key)

def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
dp_world_size = dist.get_world_size(self.dp_process_group)
zero_world_size = dist.get_world_size(self.zero_group)
for p in param_order.generate():
self._preprocess_param(p)
assert type(p) is ColoParameter
Expand All @@ -753,8 +759,9 @@ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pi
self.chunk_manager.register_tensor(
tensor=p,
group_type="fp16_param",
config_key=dp_world_size,
process_group=self.dp_process_group,
config_key=zero_world_size,
zero_group=self.zero_group,
extra_dp_group=self.extra_dp_group,
cpu_offload=cpu_offload,
pin_memory=pin_memory,
)
Expand All @@ -767,8 +774,9 @@ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pi
self.chunk_manager.register_tensor(
tensor=fp32_p,
group_type="fp32_param",
config_key=dp_world_size,
process_group=self.dp_process_group,
config_key=zero_world_size,
zero_group=self.zero_group,
extra_dp_group=self.extra_dp_group,
cpu_offload=cpu_offload,
pin_memory=pin_memory,
)
Expand Down Expand Up @@ -881,4 +889,4 @@ def state_dict_shard(
if block is not None:
yield block, block_size

yield sharder.current_block, sharder.current_block_size
yield sharder.current_block, sharder.current_block_size
24 changes: 16 additions & 8 deletions tests/test_booster/test_plugin/test_gemini_plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from contextlib import nullcontext
from typing import Optional
import pytest

import torch
import torch.distributed as dist
Expand All @@ -17,14 +18,15 @@
from tests.kit.model_zoo import model_zoo


def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism) -> Optional[str]:
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size) -> Optional[str]:
try:
if init_method == "lazy":
ctx = LazyInitContext()
else:
ctx = nullcontext()
enable_all_optimization = True if enable_tensor_parallelism else False
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, enable_tensor_parallelism=enable_tensor_parallelism, enable_all_optimization=enable_all_optimization)
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
enable_all_optimization = True if tp_size > 1 else False
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
booster = Booster(plugin=plugin)
with ctx:
model = model_fn()
Expand Down Expand Up @@ -62,8 +64,9 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tenso

@parameterize("subset", ["torchvision", "transformers", "diffusers"])
@parameterize("init_method", ["none"])
@parameterize("enable_tensor_parallelism", [True, False])
def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_parallelism: bool = True, early_stop: bool = True):
@parameterize("zero_size", [2])
@parameterize("tp_size", [2])
def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool = True, zero_size: int = 1, tp_size: int = 1):
"""check gemini plugin over model zoo

Args:
Expand Down Expand Up @@ -125,9 +128,9 @@ def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_pa

# TODO debug blip2 when using tp, something wrong with shift_logits's shape
if "transformers_blip2" in name:
enable_tensor_parallelism = False
tp_size = 1

err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism)
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size)
torch.cuda.empty_cache()
if err is None:
passed_models.append(name)
Expand All @@ -153,6 +156,11 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
def test_gemini_plugin(early_stop: bool = True):
spawn(run_dist, 4, early_stop=early_stop)

@pytest.mark.largedist
@rerun_if_address_is_in_use()
def test_gemini_plugin_3d(early_stop: bool = True):
spawn(run_dist, 8, early_stop=early_stop)


if __name__ == "__main__":
test_gemini_plugin(early_stop=False)
test_gemini_plugin(early_stop=False)
Loading