Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def backward(ctx, grad_output):
total_input = total_input.view(-1, total_input.shape[-1])

if ctx.async_grad_allreduce and fp8_communication:
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication)
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2")
elif ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
Expand Down Expand Up @@ -566,7 +566,7 @@ def forward(ctx, input_, process_group, dim, fp8_communication=False):
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
if fp8_communication:
reduce_scatter_fp8(output, input_list, group=process_group)
reduce_scatter_fp8(output, input_list, group=process_group, fp8_format="e4m3")
else:
dist.reduce_scatter(output, input_list, group=process_group)

Expand All @@ -577,7 +577,12 @@ def backward(ctx, grad_output):
dim = ctx.dim
process_group = ctx.process_group
fp8_communication = ctx.fp8_communication
return _gather(grad_output, dim, process_group, fp8_communication=fp8_communication), None, None, None
return (
_gather(grad_output, dim, process_group, fp8_communication=fp8_communication, fp8_format="e5m2"),
None,
None,
None,
)


class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
Expand Down Expand Up @@ -618,7 +623,7 @@ def forward(
)

else:
input_parallel = _gather(input_, dim, process_group, fp8_communication)
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")

output = torch.matmul(input_parallel, weight)

Expand All @@ -641,7 +646,7 @@ def backward(ctx, grad_output):
bias = bias.view(bias.shape)

if not overlap:
input_parallel = _gather(input_, dim, process_group, fp8_communication)
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e5m2")

total_input = input_parallel
grad_input = grad_output.matmul(weight.T)
Expand Down Expand Up @@ -728,8 +733,13 @@ def backward(ctx, grad_output):
if ctx.grad_scale is not None:
grad_output = grad_output * ctx.grad_scale

# to_cast.append(grad_output.cpu().detach().numpy())
return _gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication), None, None, None, None
return (
_gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication, fp8_format="e5m2"),
None,
None,
None,
None,
)


class _ReduceForward(torch.autograd.Function):
Expand All @@ -743,7 +753,7 @@ class _ReduceForward(torch.autograd.Function):

@staticmethod
def forward(ctx, input_, process_group, fp8_communication=False):
return _reduce(input_, process_group, fp8_communication)
return _reduce(input_, process_group, fp8_communication, fp8_format="e4m3")

@staticmethod
def backward(ctx, grad_output):
Expand All @@ -768,7 +778,7 @@ def forward(ctx, input_, process_group, fp8_communication=False):
@staticmethod
def backward(ctx, grad_output):
fp8_communication = ctx.fp8_communication
return _reduce(grad_output, ctx.process_group, fp8_communication), None, None
return _reduce(grad_output, ctx.process_group, fp8_communication, fp8_format="e5m2"), None, None


class _GatherForwardSplitBackward(torch.autograd.Function):
Expand All @@ -786,7 +796,7 @@ def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=
ctx.dim = dim
ctx.grad_scale = grad_scale

return _gather(input_, dim, process_group, fp8_communication=fp8_communication)
return _gather(input_, dim, process_group, fp8_communication=fp8_communication, fp8_format="e4m3")

@staticmethod
def backward(ctx, grad_output):
Expand Down Expand Up @@ -851,13 +861,13 @@ def hook_parameter_in_backward(input, weight=None, bias=None):
return HookParameter.apply(input, weight, bias)


def _reduce(input_, process_group, fp8_communication=False):
def _reduce(input_, process_group, fp8_communication=False, fp8_format="e5m2"):
# skip if only one rank involved
if dist.get_world_size(process_group) == 1:
return input_
else:
if fp8_communication:
all_reduce_fp8(input_, group=process_group)
all_reduce_fp8(input_, group=process_group, fp8_format=fp8_format)
else:
dist.all_reduce(input_, group=process_group)
return input_
Expand Down