Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions transformer_engine/pytorch/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ def get_amax_reduce_handle_fwd(cls) -> Union[bool, None]:
"""Return AMAX reduction wait handle of forward prop."""
return cls.amax_reduce_handle_fwd

@classmethod
def set_amax_reduce_handle_fwd(cls, async_handle: Optional[torch.distributed.Work]) -> None:
"""Return AMAX reduction wait handle of forward prop."""
cls.amax_reduce_handle_fwd = async_handle

@classmethod
def setup_amax_forward_global_reduce_func(cls, f: Callable) -> None:
"""Sets up the function to call during autocast exit."""
Expand Down Expand Up @@ -407,6 +412,9 @@ def fp8_autocast_enter(
) -> None:
"""Set state and tracking variables for entry into FP8 region."""
if cls.FP8_AUTOCAST_DEPTH == 0:
if cls.amax_reduce_handle_fwd is not None:
cls.amax_reduce_handle_fwd.wait()
cls.amax_reduce_handle_fwd = None
if callable(cls.amax_forward_global_reduce_func):
cls.amax_reduce_handle_fwd = cls.amax_forward_global_reduce_func() # pylint: disable=not-callable
cls.delete_key_from_amax_buffer(forward=True)
Expand Down
44 changes: 28 additions & 16 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _prepare_backward(
fp8_meta: Dict[str, Any],
tp_group: dist_group_type,
tp_size: int,
is_training: bool,
name: str = ""
) -> Generator[None, None, None]:
"""Checks and prep for BWD."""
Expand All @@ -79,24 +80,29 @@ def _prepare_backward(
_amax_reduce_handle_bwd = None

# Update amax and scale; Skip all setup for global amax reduction
if fp8_meta["recipe"].reduce_amax and get_distributed_world_size(fp8_meta["fp8_group"]) > 1:
# From previous iteration
FP8GlobalStateManager.copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False)
FP8GlobalStateManager.set_amax_buffer_key_deletion(fp8_meta, forward=False)
if is_training:
if fp8_meta["recipe"].reduce_amax and get_distributed_world_size(fp8_meta["fp8_group"]) > 1:
# From previous iteration
FP8GlobalStateManager.copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False)
FP8GlobalStateManager.set_amax_buffer_key_deletion(fp8_meta, forward=False)

# Get new backward key.
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0)
# Get new backward key.
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0)

FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False)
else:
amax_and_scale_update(fp8_meta, False)
FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False)
else:
amax_and_scale_update(fp8_meta, False)

with torch.cuda.nvtx.range(name + " backward"):
yield

if (fp8 and fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(fp8_meta["fp8_group"]) > 1):
if (
fp8
and is_training
and fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(fp8_meta["fp8_group"]) > 1
):
if fp8_meta["first_module"]:
_amax_reduce_handle_bwd = FP8GlobalStateManager.global_amax_reduction(
fp8_meta,
Expand Down Expand Up @@ -567,11 +573,21 @@ def prepare_forward(
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
if (self.fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1):

# Get reduced amax
async_handle = FP8GlobalStateManager.get_amax_reduce_handle_fwd()
if async_handle is not None:
async_handle.wait()
FP8GlobalStateManager.set_amax_reduce_handle_fwd(None)
FP8GlobalStateManager.copy_amax_from_global_buffer(self.fp8_meta, forward=True)

# Update scaling factor
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
)

FP8GlobalStateManager.set_amax_buffer_key_deletion(self.fp8_meta, forward=True)

else:
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
Expand All @@ -583,10 +599,6 @@ def prepare_forward(
and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1):
self.fp8_meta["first_module"] = FP8GlobalStateManager.is_first_fp8_module()
if self.fp8_meta["first_module"]:
# Wait for the prior AMAX reduction to finish
amax_reduce_handle_fwd = FP8GlobalStateManager.get_amax_reduce_handle_fwd()
if amax_reduce_handle_fwd is not None:
amax_reduce_handle_fwd.wait()
self.fp8_meta["autocast_id_fwd"] = (
FP8GlobalStateManager.new_fp8_context_id())
FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
Expand Down
11 changes: 10 additions & 1 deletion transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def forward(
return_layernorm_output: bool,
return_layernorm_output_gathered: bool,
is_grad_enabled: bool,
is_training: bool,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
Expand Down Expand Up @@ -319,6 +320,7 @@ def forward(
ctx.ub_name = ub_name
ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization
ctx.is_training = is_training

# Row Parallel Linear
if parallel_mode == "row" and sequence_parallel:
Expand All @@ -343,7 +345,12 @@ def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormLinear"
ctx.fp8,
ctx.fp8_meta,
ctx.tp_group,
ctx.tp_size,
ctx.is_training,
name="_LayerNormLinear",
):
(
inputmat,
Expand Down Expand Up @@ -645,6 +652,7 @@ def backward(
None,
None,
None,
None,
)


Expand Down Expand Up @@ -1103,6 +1111,7 @@ def forward(
self.return_layernorm_output,
self.return_layernorm_output_gathered,
torch.is_grad_enabled(),
self.training,
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
Expand Down
11 changes: 10 additions & 1 deletion transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def forward(
bias_gelu_nvfusion: bool,
set_parallel_mode: bool,
is_grad_enabled: bool,
is_training: bool,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
Expand Down Expand Up @@ -536,6 +537,7 @@ def forward(
ctx.ub_overlap_ag = ub_overlap_ag
ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization
ctx.is_training = is_training

# Row Parallel Linear
if ub_overlap_rs:
Expand All @@ -562,7 +564,12 @@ def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormMLP"
ctx.fp8,
ctx.fp8_meta,
ctx.tp_group,
ctx.tp_size,
ctx.is_training,
name="_LayerNormMLP",
):
(
inputmat,
Expand Down Expand Up @@ -1087,6 +1094,7 @@ def backward(
None,
None,
None,
None,
)


Expand Down Expand Up @@ -1484,6 +1492,7 @@ def forward(
self.bias_gelu_nvfusion,
self.set_parallel_mode,
torch.is_grad_enabled(),
self.training,
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
Expand Down
11 changes: 10 additions & 1 deletion transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def forward(
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
is_grad_enabled: bool,
is_training: bool,
primary_weights_in_fp8: bool,
ub_overlap_rs: bool,
ub_overlap_ag: bool,
Expand Down Expand Up @@ -313,6 +314,7 @@ def forward(
ctx.ub_name = ub_name
ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad
ctx.is_training = is_training

# Row Parallel Linear
if ub_overlap_rs:
Expand All @@ -331,7 +333,12 @@ def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear"
ctx.fp8,
ctx.fp8_meta,
ctx.tp_group,
ctx.tp_size,
ctx.is_training,
name="_Linear",
):
(
inputmat,
Expand Down Expand Up @@ -542,6 +549,7 @@ def backward(
None,
None,
None,
None,
)


Expand Down Expand Up @@ -915,6 +923,7 @@ def forward(
self.activation_dtype,
self.parallel_mode,
torch.is_grad_enabled(),
self.training,
self.primary_weights_in_fp8,
self.ub_overlap_rs,
self.ub_overlap_ag,
Expand Down