Skip to content

CPU Overhead of te.Linear FP8 Layers #761

@tohinz

Description

@tohinz

Hi,

we are looking into training some transformer models with FP8 and we see a lot of overhead on the CPU side when te.Linear layers are scheduled in the forward pass of the net.

I'm using the following:

  1. H100 GPUs with 12.2, V12.2.140
  2. TE version git+https://github.com/NVIDIA/TransformerEngine.git@cf6fc898286e4ad347ff88925c88663324e2b87d
  3. PyTorch 2.1.0 with cuDNN 8906

Concretely, running a toy model we see the FP8 model being slightly faster at around 300ms per iteration vs the BF16 model with 320ms per iteration. We're always using te.Linear layers, regardless of whether we're doing FP8 or BF16.

However, looking at the profiles we see that the forward pass of the FP8 model (wall duration roughly 140ms) is much slower than the forward pass on the BF16 model (wall duration roughly 77ms). The GPU is also idle a lot of the time for the FP8 forward pass. GPU utilization is near 100% for the backward pass for both models.

Looking at the CPU side it seems like scheduling a te.Linear layer in FP8 takes more than 2x more time compared to scheduling the te.Linear layer in BF16.

Attached a screenshot of part of the forward pass of the FP8 model:
fwd_fp8
fwd_fp8_block

On the BF16 model:
fwd_bf16
fwd_bf16_block

I think this is related to #445 which observed similar behavior.
Do you have any suggestions about how to optimize this?

Code to reproduce:
Call with

python fp8_minimal_example.py --dtype bf16
python fp8_minimal_example.py --dtype fp8

Add --profile to generate a PyTorch profile.

import argparse
import torch
import torch.nn as nn
from torch.profiler import profile, ProfilerActivity

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling


class TEBlock(nn.Module):
    def __init__(self, hidden_size: int, mlp_ratio: float):
        super().__init__()
        linear = te.Linear

        # timestep modulation predicts several parameters conditioned on the timestep
        self.timestep_modulation = linear(hidden_size, 6 * hidden_size, bias=True)

        # simulate self attention layer for getting qkv embedding
        self.self_attn = linear(hidden_size, 3 * hidden_size, bias=False)

        # simulate cross attention layer for getting qkv embedding
        self.cross_attn_q = linear(hidden_size, hidden_size, bias=False)
        self.cross_attn_kv = linear(hidden_size, 2 * hidden_size, bias=False)

        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        self.mlp = nn.Sequential(
            linear(hidden_size, mlp_hidden_dim),
            linear(mlp_hidden_dim, hidden_size),
        )

    def forward(self, x):
        # simulating predicting parameters for timestep modulation
        shift, scale, _, _, _, _ = self.timestep_modulation(x).chunk(6, dim=-1)
        x = shift * x / scale

        # simulating self attention
        sa_q, _, _ = self.self_attn(x).chunk(3, dim=-1)
        x = x + sa_q

        # simulating cross attention
        ca_q = self.cross_attn_q(x)
        ca_k, _ = self.cross_attn_kv(x).chunk(2, dim=-1)
        x = x + ca_q + ca_k

        # run MLP
        x = x + self.mlp(x)
        return x


class TEModel(nn.Module):
    def __init__(
        self,
        num_blocks: int,
        hidden_size: int,
        mlp_ratio: int,
    ):
        super().__init__()

        self.blocks = torch.nn.ModuleList()

        for _ in range(num_blocks):
            self.blocks.append(TEBlock(hidden_size, mlp_ratio))

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        for block_idx, block in enumerate(self.blocks):
            with torch.autograd.profiler.record_function(f"block_{block_idx}"):
                x = block(x)
        return x


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Command line arguments")
    parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
    parser.add_argument("--dtype", type=str, default="bf16", help="Data type")
    parser.add_argument("--hidden_size", type=int, default=2048, help="Hidden size")
    parser.add_argument("--depth", type=int, default=32, help="Depth")
    parser.add_argument("--seq_length", type=int, default=1024, help="Sequence length")
    parser.add_argument("--mlp_ratio", type=float, default=4, help="mlp_ratio")
    parser.add_argument("--profile", action="store_true", help="Run PyTorch profiler")
    args = parser.parse_args()

    if args.dtype == "bf16":
        dtype = torch.bfloat16
        cast_type = "bf16"
    elif args.dtype == "fp8":
        dtype = torch.float32
        cast_type = "fp8"
    else:
        print("Invalid data type, must be either bf16 or fp8")
        exit(0)

    # Define FP8 recipe
    fp8_format = Format.HYBRID
    fp8_recipe = DelayedScaling(
        fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max"
    )

    # Generate random model input and target for MSE loss
    model_input = (
        torch.rand(args.batch_size, args.seq_length, args.hidden_size)
        .cuda()
        .to(dtype=dtype)
    )
    target = (
        torch.rand(args.batch_size, args.seq_length, args.hidden_size)
        .cuda()
        .to(dtype=dtype)
    )
    criterion = torch.nn.MSELoss()

    # Define the model and optimizer
    model = TEModel(args.depth, args.hidden_size, args.mlp_ratio)
    model.to(dtype=torch.float32).cuda()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

    # Define autocast kwargs
    if cast_type == "fp8":
        autocast_args = {"enabled": True, "fp8_recipe": fp8_recipe}
        autocast = te.fp8_autocast
    elif cast_type == "bf16":
        autocast_args = {
            "device_type": "cuda",
            "enabled": True,
            "dtype": torch.bfloat16,
        }
        autocast = torch.autocast

    # Run PyTorch profile
    if args.profile:
        with autocast(**autocast_args):
            with profile(
                activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
                schedule=torch.profiler.schedule(
                    skip_first=5, wait=10, warmup=5, active=3
                ),
            ) as prof:
                for _ in range(25):
                    with autocast(**autocast_args):
                        output = model(model_input)
                    loss = criterion(output, target)
                    loss.backward()
                    prof.step()
        profile_name = cast_type + "_bs_" + str(args.batch_size)
        profile_name = f"_profile_{profile_name}.json"
        prof.export_chrome_trace(profile_name)
        print(f"Saved profile as {profile_name}")

    # Time model iterations
    else:
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        timing_iters = 50

        # warmup iterations
        for _ in range(10):
            with autocast(**autocast_args):
                output = model(model_input)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

        # estimate memory usage
        free, total = torch.cuda.mem_get_info()
        memory = (total - free) / 1024**2

        # benchmark
        torch.cuda.synchronize()
        start.record()
        for _ in range(timing_iters):
            with autocast(**autocast_args):
                output = model(model_input)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
        torch.cuda.synchronize()
        end.record()

        mean_time = start.elapsed_time(end) / timing_iters
        print(f"Mean time {mean_time} ms per iteration ({memory} GB used)")

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions