Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
save_state_dict,
sharded_optimizer_loading_epilogue,
)
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
Expand Down Expand Up @@ -333,6 +334,7 @@ class LowLevelZeroPlugin(DPPluginBase):
verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
extra_dp_size (int, optional): The number of extra data parallel groups. Defaults to 1.
"""

def __init__(
Expand All @@ -358,11 +360,16 @@ def __init__(
cast_inputs: bool = True,
fp8_communication: bool = False,
use_fp8: bool = False,
extra_dp_size: int = 1,
) -> None:
super().__init__()
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training"
assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now"
if extra_dp_size > 1:
assert dist.get_world_size() % extra_dp_size == 0, "extra_dp_size should be a factor of world_size"
inner_dp_size = dist.get_world_size() // extra_dp_size
self.pg_mesh = ProcessGroupMesh(extra_dp_size, inner_dp_size)
self.stage = stage
self.precision = precision
self.zero_optim_kwargs = dict(
Expand All @@ -383,6 +390,9 @@ def __init__(
overlap_allgather=overlap_allgather,
fp8_communication=fp8_communication,
)
if extra_dp_size > 1:
self.zero_optim_kwargs["extra_dp_group"] = self.pg_mesh.get_group_along_axis(0)
self.zero_optim_kwargs["dp_process_group"] = self.pg_mesh.get_group_along_axis(1)
self.lora_enabled = False
self.verbose = verbose
self.logger = get_dist_logger()
Expand Down
42 changes: 41 additions & 1 deletion colossalai/zero/low_level/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
from typing import Optional
from typing import Optional, Tuple, Union

import numpy as np
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
Expand Down Expand Up @@ -209,3 +210,42 @@ def sync_tensor(flat_tensor, tensor_list):
# update the tensor data
for p, q in zip(tensor_list, updated_params):
p.data = q.data


def all_gather_into_flat_tensor_nd(
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
group: Union[dist.ProcessGroup, Tuple[dist.ProcessGroup, ...]],
async_op: bool = False,
):
if isinstance(group, dist.ProcessGroup):
group = (group,)
sizes = [dist.get_world_size(pg) for pg in group]
ranks = [dist.get_rank(pg) for pg in group]
for i, pg in list(enumerate(group))[::-1]:
if i == 0:
out = output_tensor
else:
prev_sizes = sizes[:i]
prev_ranks = ranks[:i]
chunks = output_tensor.chunk(np.prod(prev_sizes))
out = chunks[np.ravel_multi_index(prev_ranks, prev_sizes)]
handle = dist.all_gather_into_tensor(out, input_tensor, group=pg, async_op=async_op)
input_tensor = out
return handle


def get_nd_world_size(group) -> int:
if isinstance(group, tuple):
return int(np.prod([dist.get_world_size(pg) for pg in group]))
else:
return dist.get_world_size(group)


def get_nd_rank(group) -> int:
if isinstance(group, tuple):
return np.ravel_multi_index(
tuple(dist.get_rank(group=pg) for pg in group), [dist.get_world_size(pg) for pg in group]
)
else:
return dist.get_rank(group)
15 changes: 12 additions & 3 deletions colossalai/zero/low_level/bookkeeping/base_store.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from typing import Tuple, Union

import numpy as np
import torch.distributed as dist
from torch.distributed import ProcessGroup


class BaseStore:
def __init__(self, torch_pg: ProcessGroup):
self._world_size = dist.get_world_size(group=torch_pg)
self._local_rank = dist.get_rank(group=torch_pg)
def __init__(self, torch_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]]):
if isinstance(torch_pg, tuple):
self.sizes = [dist.get_world_size(group=pg) for pg in torch_pg]
self._world_size = int(np.prod(self.sizes))
self._local_rank = np.ravel_multi_index(tuple(dist.get_rank(group=pg) for pg in torch_pg), self.sizes)
else:
self._world_size = dist.get_world_size(group=torch_pg)
self._local_rank = dist.get_rank(group=torch_pg)
self.sizes = [self._world_size]
self.torch_pg = torch_pg

@property
Expand Down
14 changes: 11 additions & 3 deletions colossalai/zero/low_level/bookkeeping/tensor_bucket.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Optional

import numpy as np
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from colossalai.quantization.fp8 import all_gather_fp8
from colossalai.zero.low_level._utils import all_gather_into_flat_tensor_nd


class TensorBucket:
Expand Down Expand Up @@ -65,12 +67,18 @@ def unflatten_and_copy(self, flat_tensor):

def all_gather(self, group=None, fp8_communication: bool = False):
flat = self.flatten()
buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype)
if isinstance(group, tuple):
world_size = np.prod([dist.get_world_size(pg) for pg in group])
else:
world_size = dist.get_world_size(group)
buffer = torch.empty(flat.numel() * world_size, device=flat.device, dtype=flat.dtype)
if fp8_communication:
# TODO: fit fp8
all_gather_fp8(list(buffer.chunk(dist.get_world_size(group))), flat, group=group, fp8_format="e4m3")
else:
dist.all_gather_into_tensor(buffer, flat, group=group)
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))]
# dist.all_gather_into_tensor(buffer, flat, group=group)
all_gather_into_flat_tensor_nd(buffer, flat, group=group)
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(world_size)]
# transpose the list of list
unflat_buffers = list(map(list, zip(*unflat_buffers)))
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
Expand Down
Loading