Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def __init__(
cpu_offload: bool = False,
master_weights: bool = True,
verbose: bool = False,
fp8_communication: bool = False,
) -> None:
super().__init__()
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
Expand All @@ -315,6 +316,7 @@ def __init__(
partition_grad=(stage == 2),
cpu_offload=cpu_offload,
master_weights=master_weights,
fp8_communication=fp8_communication,
)
self.lora_enabled = False
self.verbose = verbose
Expand Down
13 changes: 9 additions & 4 deletions colossalai/zero/low_level/bookkeeping/tensor_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8


class TensorBucket:
def __init__(self, size):
Expand Down Expand Up @@ -61,11 +63,14 @@ def unflatten_and_copy(self, flat_tensor):
for old, new in zip(self._bucket, unflattened_tensor_list):
old.copy_(new)

def all_gather(self, group=None):
def all_gather(self, group=None, fp8_communication: bool = False):
flat = self.flatten()
buffers = [torch.empty_like(flat) for _ in range(dist.get_world_size(group))]
dist.all_gather(buffers, flat, group=group)
unflat_buffers = [self.unflatten(buffer) for buffer in buffers]
buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype)
if fp8_communication:
all_gather_into_tensor_flat_fp8(buffer, flat, output_shape=buffer.shape, group=group)
else:
dist.all_gather_into_tensor(buffer, flat, group=group)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert this change

unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))]
# transpose the list of list
unflat_buffers = list(map(list, zip(*unflat_buffers)))
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
Expand Down
26 changes: 21 additions & 5 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8, all_reduce_fp8, reduce_scatter_fp8

from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
from .bookkeeping import BucketStore, GradientStore, TensorBucket
Expand Down Expand Up @@ -83,6 +84,7 @@ def __init__(
dp_process_group: Optional[ProcessGroup] = None,
forced_dtype: Optional[torch.dtype] = None,
master_weights: bool = True, # master weights
fp8_communication: bool = False,
):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)

Expand Down Expand Up @@ -123,6 +125,7 @@ def __init__(
self._overlap_communication = overlap_communication
self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype
self._fp8_communication = fp8_communication

# gradient clipping
self._clip_grad_norm = clip_grad_norm
Expand Down Expand Up @@ -323,7 +326,10 @@ def _run_reduction(self):
flat_grads = flat_grads.to(self._communication_dtype)

if not self._partition_grads:
dist.all_reduce(flat_grads, group=bucket_store.torch_pg)
if self._fp8_communication:
all_reduce_fp8(flat_grads, group=bucket_store.torch_pg)
else:
dist.all_reduce(flat_grads, group=bucket_store.torch_pg)
if flat_grads.dtype != grad_dtype:
flat_grads = flat_grads.to(grad_dtype)

Expand All @@ -333,7 +339,14 @@ def _run_reduction(self):
else:
flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size))
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg)
if self._fp8_communication:
reduce_scatter_fp8(
recieved_grad,
flat_grads_list,
group=bucket_store.torch_pg,
)
else:
dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg)

if recieved_grad.dtype != grad_dtype:
recieved_grad = recieved_grad.to(grad_dtype)
Expand Down Expand Up @@ -553,18 +566,21 @@ def step(self, closure=None):
buffer_tensor = torch.empty_like(
torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))])
)
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
if self._fp8_communication:
all_gather_into_tensor_flat_fp8(buffer_tensor, param_to_gather, pg, fp8_format="e4m3")
else:
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param))
continue
try:
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
except RuntimeError:
self.pg_to_tensor_bucket[pg].all_gather(pg)
self.pg_to_tensor_bucket[pg].all_gather(pg, fp8_communication=self._fp8_communication)
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
if not tensor_bucket.is_empty():
tensor_bucket.all_gather(pg)
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)

def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""
Expand Down
23 changes: 18 additions & 5 deletions tests/test_zero/test_low_level/test_zero1_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def split_ddp_grad(grad, world_size):
return splited_grad


def exam_zero_1_2():
@parameterize("fp8_communication", [True, False])
def exam_zero_1_2(fp8_communication: bool):
"""
In this test, we want to test whether zero stage 1 and 2
deliver the same numerical results despite different communication
Expand All @@ -73,10 +74,18 @@ def exam_zero_1_2():
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
zero1_optimizer = LowLevelZeroOptimizer(
zero1_optimizer, overlap_communication=True, initial_scale=128, verbose=True
zero1_optimizer,
overlap_communication=True,
initial_scale=128,
verbose=True,
fp8_communication=fp8_communication,
)
zero2_optimizer = LowLevelZeroOptimizer(
zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=128
zero2_optimizer,
overlap_communication=True,
partition_grad=True,
initial_scale=128,
fp8_communication=fp8_communication,
)
# create data
seed_all(2001 + local_rank)
Expand All @@ -97,15 +106,19 @@ def exam_zero_1_2():
if g1 is None or g2 is None:
assert g1 is None and g2 is None
continue
assert torch.allclose(g1, g2)
if fp8_communication:
loose_close(g1, g2, dtype=torch.float16)
else:
assert torch.allclose(g1, g2)

# step
zero1_optimizer.step()
zero2_optimizer.step()

# check updated param
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
assert torch.allclose(z1p, z2p)
if not fp8_communication:
assert torch.allclose(z1p, z2p)


@parameterize("dtype", [torch.float16, torch.bfloat16])
Expand Down