Skip to content

[FP8] performance degradation in speed and memory without compile #685

@leeeizhang

Description

@leeeizhang

The FP8 FFN performance degrades in both speed and GPU memory if it is not compiled.

Variances Torch FP16 AO FP8 (compile=False) AO FP8 (compile=True)
bs=32, seq=512, dim=512 0.92ms, 308MB 3.14ms, 594MB 0.95ms, 339MB
bs=32, seq=512, dim=1024 3.14ms, 664MB 7.17ms, 1.2GB 2.61ms, 724MB
bs=32, seq=512, dim=2048 11.38ms, 1.53GB 17.84ms, 2.6GB 7.84ms, 1.6GB
bs=32, seq=512, dim=4096 43.16ms, 3.9GB 49.25ms, 6.1GB 26.20ms, 4.1GB

Track Logs (torch.profile)

Image 1 Image 2
Compile=True (2.1ms) Compile=False (6.3ms)

Testbed

  • Torch: 2.5.0.dev20240814+cu121
  • TorchAO: 2024.8.15+cu121
  • CUDA Version: 12.1 (NVIDIA L20, SM89)

Codes to Reproduce Issues

"""
usage: $ python3 test.py --bs 32 --seq 512 --dim 1024 --compile 0
"""
import time
import argparse

import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training


class TorchFFN(nn.Module):
    def __init__(self, in_feature, hidden_feature, bias=True):
        super().__init__()
        self.fc1 = nn.Linear(in_feature, hidden_feature, bias)
        self.fc2 = nn.Linear(hidden_feature, in_feature, bias)
        self.gelu = nn.GELU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.fc2(x)
        return x


if __name__ == "__main__":
    torch.manual_seed(0)
    torch.cuda.set_device(0)

    parser = argparse.ArgumentParser()
    parser.add_argument('--bs', type=int, required=False, default=32)
    parser.add_argument('--seq', type=int, required=False, default=512)
    parser.add_argument('--dim', type=int, required=False, default=1024)
    parser.add_argument('--compile', type=int, required=False, default=0)
    args = parser.parse_args()

    # Test fp8 linear
    BS, SQ, DIM = args.bs, args.seq, args.dim
    x = torch.randn((BS, SQ, DIM), device="cuda")

    torch_fp16_ffn = TorchFFN(DIM, 4 * DIM).to("cuda")

    torch_fp8_ffn = TorchFFN(DIM, 4 * DIM).to("cuda")
    torch_fp8_ffn.load_state_dict(torch_fp16_ffn.state_dict())  # Align weights

    convert_to_float8_training(torch_fp8_ffn)
    if args.compile > 0:
        torch_fp8_ffn = torch.compile(torch_fp8_ffn)

    with torch.inference_mode():
        # Warmup
        for _ in range(10):
            _ = torch_fp8_ffn(x)
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                _ = torch_fp16_ffn(x)

        # Test torch fp8 speed
        s = time.time()
        for _ in range(1000):
            torch_fp8_y = torch_fp8_ffn(x)
            torch.cuda.synchronize()
        e = time.time()
        print(f"torch fp8: {e-s}ms")

        # Test torch fp16 speed
        s = time.time()
        for _ in range(1000):
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                torch_fp16_y = torch_fp16_ffn(x)
            torch.cuda.synchronize()
        e = time.time()
        print(f"torch fp16: {e-s}ms")

        # Profile memory
        torch.cuda.reset_peak_memory_stats("cuda:0"), torch.cuda.empty_cache()
        _ = torch_fp8_ffn(x)
        peak_memory = torch.cuda.max_memory_allocated("cuda:0")
        print(f"Torch FP8 Peak memory usage: {peak_memory / 1024 ** 2:.2f} MB")

        torch.cuda.reset_peak_memory_stats("cuda:0"), torch.cuda.empty_cache()
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            torch_fp16_y = torch_fp16_ffn(x)
        peak_memory = torch.cuda.max_memory_allocated("cuda:0")
        print(f"Torch FP16 Peak memory usage: {peak_memory / 1024 ** 2:.2f} MB")

    print(f"[torch-fp8 v.s. torch-fp16] mse loss: {nn.functional.mse_loss(torch_fp16_y, torch_fp8_y)}")

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions