-
Notifications
You must be signed in to change notification settings - Fork 641
Description
Hi,
we are looking into training some transformer models with FP8 and we see a lot of overhead on the CPU side when te.Linear layers are scheduled in the forward pass of the net.
I'm using the following:
- H100 GPUs with 12.2, V12.2.140
- TE version git+https://github.com/NVIDIA/TransformerEngine.git@cf6fc898286e4ad347ff88925c88663324e2b87d
- PyTorch 2.1.0 with cuDNN 8906
Concretely, running a toy model we see the FP8 model being slightly faster at around 300ms per iteration vs the BF16 model with 320ms per iteration. We're always using te.Linear layers, regardless of whether we're doing FP8 or BF16.
However, looking at the profiles we see that the forward pass of the FP8 model (wall duration roughly 140ms) is much slower than the forward pass on the BF16 model (wall duration roughly 77ms). The GPU is also idle a lot of the time for the FP8 forward pass. GPU utilization is near 100% for the backward pass for both models.
Looking at the CPU side it seems like scheduling a te.Linear layer in FP8 takes more than 2x more time compared to scheduling the te.Linear layer in BF16.
Attached a screenshot of part of the forward pass of the FP8 model:


I think this is related to #445 which observed similar behavior.
Do you have any suggestions about how to optimize this?
Code to reproduce:
Call with
python fp8_minimal_example.py --dtype bf16
python fp8_minimal_example.py --dtype fp8
Add --profile to generate a PyTorch profile.
import argparse
import torch
import torch.nn as nn
from torch.profiler import profile, ProfilerActivity
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
class TEBlock(nn.Module):
def __init__(self, hidden_size: int, mlp_ratio: float):
super().__init__()
linear = te.Linear
# timestep modulation predicts several parameters conditioned on the timestep
self.timestep_modulation = linear(hidden_size, 6 * hidden_size, bias=True)
# simulate self attention layer for getting qkv embedding
self.self_attn = linear(hidden_size, 3 * hidden_size, bias=False)
# simulate cross attention layer for getting qkv embedding
self.cross_attn_q = linear(hidden_size, hidden_size, bias=False)
self.cross_attn_kv = linear(hidden_size, 2 * hidden_size, bias=False)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = nn.Sequential(
linear(hidden_size, mlp_hidden_dim),
linear(mlp_hidden_dim, hidden_size),
)
def forward(self, x):
# simulating predicting parameters for timestep modulation
shift, scale, _, _, _, _ = self.timestep_modulation(x).chunk(6, dim=-1)
x = shift * x / scale
# simulating self attention
sa_q, _, _ = self.self_attn(x).chunk(3, dim=-1)
x = x + sa_q
# simulating cross attention
ca_q = self.cross_attn_q(x)
ca_k, _ = self.cross_attn_kv(x).chunk(2, dim=-1)
x = x + ca_q + ca_k
# run MLP
x = x + self.mlp(x)
return x
class TEModel(nn.Module):
def __init__(
self,
num_blocks: int,
hidden_size: int,
mlp_ratio: int,
):
super().__init__()
self.blocks = torch.nn.ModuleList()
for _ in range(num_blocks):
self.blocks.append(TEBlock(hidden_size, mlp_ratio))
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
for block_idx, block in enumerate(self.blocks):
with torch.autograd.profiler.record_function(f"block_{block_idx}"):
x = block(x)
return x
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Command line arguments")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
parser.add_argument("--dtype", type=str, default="bf16", help="Data type")
parser.add_argument("--hidden_size", type=int, default=2048, help="Hidden size")
parser.add_argument("--depth", type=int, default=32, help="Depth")
parser.add_argument("--seq_length", type=int, default=1024, help="Sequence length")
parser.add_argument("--mlp_ratio", type=float, default=4, help="mlp_ratio")
parser.add_argument("--profile", action="store_true", help="Run PyTorch profiler")
args = parser.parse_args()
if args.dtype == "bf16":
dtype = torch.bfloat16
cast_type = "bf16"
elif args.dtype == "fp8":
dtype = torch.float32
cast_type = "fp8"
else:
print("Invalid data type, must be either bf16 or fp8")
exit(0)
# Define FP8 recipe
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(
fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max"
)
# Generate random model input and target for MSE loss
model_input = (
torch.rand(args.batch_size, args.seq_length, args.hidden_size)
.cuda()
.to(dtype=dtype)
)
target = (
torch.rand(args.batch_size, args.seq_length, args.hidden_size)
.cuda()
.to(dtype=dtype)
)
criterion = torch.nn.MSELoss()
# Define the model and optimizer
model = TEModel(args.depth, args.hidden_size, args.mlp_ratio)
model.to(dtype=torch.float32).cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
# Define autocast kwargs
if cast_type == "fp8":
autocast_args = {"enabled": True, "fp8_recipe": fp8_recipe}
autocast = te.fp8_autocast
elif cast_type == "bf16":
autocast_args = {
"device_type": "cuda",
"enabled": True,
"dtype": torch.bfloat16,
}
autocast = torch.autocast
# Run PyTorch profile
if args.profile:
with autocast(**autocast_args):
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(
skip_first=5, wait=10, warmup=5, active=3
),
) as prof:
for _ in range(25):
with autocast(**autocast_args):
output = model(model_input)
loss = criterion(output, target)
loss.backward()
prof.step()
profile_name = cast_type + "_bs_" + str(args.batch_size)
profile_name = f"_profile_{profile_name}.json"
prof.export_chrome_trace(profile_name)
print(f"Saved profile as {profile_name}")
# Time model iterations
else:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
timing_iters = 50
# warmup iterations
for _ in range(10):
with autocast(**autocast_args):
output = model(model_input)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# estimate memory usage
free, total = torch.cuda.mem_get_info()
memory = (total - free) / 1024**2
# benchmark
torch.cuda.synchronize()
start.record()
for _ in range(timing_iters):
with autocast(**autocast_args):
output = model(model_input)
loss = criterion(output, target)
loss.backward()
optimizer.step()
torch.cuda.synchronize()
end.record()
mean_time = start.elapsed_time(end) / timing_iters
print(f"Mean time {mean_time} ms per iteration ({memory} GB used)")

