Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions transformer_engine/pytorch/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
_amax_forward_global_reduce_func = None
_buffer_delete_key_fwd = None
_buffer_delete_key_bwd = None
_amax_reduce_handle_fwd = None
_is_fp8_available = None
_reason_for_no_fp8 = ""

Expand Down Expand Up @@ -73,6 +74,12 @@ def get_autocast_key(forward: bool = True) -> str:
return "autocast_id_bwd"


def get_amax_reduce_handle_fwd() -> Union[bool, None]:
"""Return AMAX reduction wait handle of forward prop."""
global _amax_reduce_handle_fwd
return _amax_reduce_handle_fwd


def get_global_fp8_buffer() -> Dict[str, List[torch.Tensor]]:
"""Returns global fp8 buffer."""
return _global_fp8_buffer
Expand Down Expand Up @@ -264,6 +271,7 @@ def fp8_autocast(
global _FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP, _FP8_AUTOCAST_DEPTH
global _IS_FIRST_FP8_MODULE, _FP8_AUTOCAST_COUNTER
global _global_fp8_buffer, _buffer_delete_key_fwd
global _amax_reduce_handle_fwd
fp8_state = (_FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP)
try:
_FP8_ENABLED = enabled
Expand All @@ -287,7 +295,7 @@ def fp8_autocast(

if _FP8_AUTOCAST_DEPTH == 0:
if callable(_amax_forward_global_reduce_func):
_amax_forward_global_reduce_func()
_amax_reduce_handle_fwd = _amax_forward_global_reduce_func()
delete_key_from_amax_buffer(forward=True)


Expand Down Expand Up @@ -521,16 +529,18 @@ def get_fp8_te_dtype(


def reduce_tensor_across_group_op_max(
tensor: torch.Tensor, group: dist_group_type
tensor: torch.Tensor, group: dist_group_type, async_op: bool
) -> None:
"""Reduce tensor across given group."""
if torch.distributed.is_initialized():
torch.distributed.all_reduce(
wait_handle = torch.distributed.all_reduce(
tensor,
op=torch.distributed.ReduceOp.MAX,
group=group,
async_op=False,
async_op=async_op,
)
return wait_handle
return None


def global_amax_reduction(
Expand All @@ -543,14 +553,19 @@ def global_amax_reduction(

# Key already deleted.
if amax_buffer_key not in _global_fp8_buffer:
return
return None

chunk_sizes = [x.numel() for x in _global_fp8_buffer[amax_buffer_key]]
contiguous_amax = torch.cat(_global_fp8_buffer[amax_buffer_key])

reduce_tensor_across_group_op_max(contiguous_amax, fp8_meta["fp8_group"])
wait_handle = reduce_tensor_across_group_op_max(
contiguous_amax,
fp8_meta["fp8_group"],
fp8_meta["async_amax_reduction"],
)

_global_fp8_buffer[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes))
return wait_handle


def delete_key_from_amax_buffer(forward: bool = True) -> None:
Expand Down
16 changes: 15 additions & 1 deletion transformer_engine/pytorch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
copy_forward_fp8_meta_tensors_for_recompute,
get_old_fp8_meta_tensors_for_recompute,
restore_fp8_meta_tensors,
get_amax_reduce_handle_fwd,
)
from .jit import (
bias_gelu_fused,
Expand Down Expand Up @@ -84,6 +85,7 @@
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
_cublas_workspace = None
_amax_reduce_handle_bwd = None


def get_cublas_workspace_size_bytes() -> None:
Expand All @@ -106,6 +108,11 @@ def get_workspace() -> torch.Tensor:
def _prepare_backward(fp8: bool, fp8_meta: Dict[str, Any], name: str = "") -> None:
"""Checks and prep for BWD."""
if fp8:
global _amax_reduce_handle_bwd
if _amax_reduce_handle_bwd is not None:
_amax_reduce_handle_bwd.wait()
_amax_reduce_handle_bwd = None

# Update amax and scale; Skip all setup for global amax reduction
if not fp8_meta["recipe"].reduce_amax:
amax_and_scale_update(fp8_meta, False)
Expand All @@ -125,7 +132,7 @@ def _prepare_backward(fp8: bool, fp8_meta: Dict[str, Any], name: str = "") -> N

if fp8 and fp8_meta["recipe"].reduce_amax:
if fp8_meta["first_module"]:
global_amax_reduction(fp8_meta, forward=False)
_amax_reduce_handle_bwd = global_amax_reduction(fp8_meta, forward=False)
delete_key_from_amax_buffer(forward=False)


Expand Down Expand Up @@ -184,6 +191,9 @@ def __init__(self) -> None:
self.sequence_parallel = False
self.fp8_weight_shapes = []
self.fp8_meta["autocast_id_fwd_stack"] = []
self.fp8_meta["async_amax_reduction"] = bool(
int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "1"))
)

def set_meta_tensor(self, fwd: bool) -> None:
"""Init scales and amaxes for fwd | bwd."""
Expand Down Expand Up @@ -497,6 +507,10 @@ def prepare_forward(
if self.fp8_meta["recipe"].reduce_amax:
self.fp8_meta["first_module"] = is_first_fp8_module()
if self.fp8_meta["first_module"]:
# Wait for the prior AMAX reduction to finish
amax_reduce_handle_fwd = get_amax_reduce_handle_fwd()
if amax_reduce_handle_fwd is not None:
amax_reduce_handle_fwd.wait()
self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id()
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
else:
Expand Down