The FP8 FFN performance degrades in both speed and GPU memory if it is not compiled.
| Variances |
Torch FP16 |
AO FP8 (compile=False) |
AO FP8 (compile=True) |
| bs=32, seq=512, dim=512 |
0.92ms, 308MB |
3.14ms, 594MB |
0.95ms, 339MB |
| bs=32, seq=512, dim=1024 |
3.14ms, 664MB |
7.17ms, 1.2GB |
2.61ms, 724MB |
| bs=32, seq=512, dim=2048 |
11.38ms, 1.53GB |
17.84ms, 2.6GB |
7.84ms, 1.6GB |
| bs=32, seq=512, dim=4096 |
43.16ms, 3.9GB |
49.25ms, 6.1GB |
26.20ms, 4.1GB |
Track Logs (torch.profile)
 |
 |
| Compile=True (2.1ms) |
Compile=False (6.3ms) |
Testbed
- Torch:
2.5.0.dev20240814+cu121
- TorchAO:
2024.8.15+cu121
- CUDA Version:
12.1 (NVIDIA L20, SM89)
Codes to Reproduce Issues
"""
usage: $ python3 test.py --bs 32 --seq 512 --dim 1024 --compile 0
"""
import time
import argparse
import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training
class TorchFFN(nn.Module):
def __init__(self, in_feature, hidden_feature, bias=True):
super().__init__()
self.fc1 = nn.Linear(in_feature, hidden_feature, bias)
self.fc2 = nn.Linear(hidden_feature, in_feature, bias)
self.gelu = nn.GELU()
def forward(self, x):
x = self.fc1(x)
x = self.gelu(x)
x = self.fc2(x)
return x
if __name__ == "__main__":
torch.manual_seed(0)
torch.cuda.set_device(0)
parser = argparse.ArgumentParser()
parser.add_argument('--bs', type=int, required=False, default=32)
parser.add_argument('--seq', type=int, required=False, default=512)
parser.add_argument('--dim', type=int, required=False, default=1024)
parser.add_argument('--compile', type=int, required=False, default=0)
args = parser.parse_args()
# Test fp8 linear
BS, SQ, DIM = args.bs, args.seq, args.dim
x = torch.randn((BS, SQ, DIM), device="cuda")
torch_fp16_ffn = TorchFFN(DIM, 4 * DIM).to("cuda")
torch_fp8_ffn = TorchFFN(DIM, 4 * DIM).to("cuda")
torch_fp8_ffn.load_state_dict(torch_fp16_ffn.state_dict()) # Align weights
convert_to_float8_training(torch_fp8_ffn)
if args.compile > 0:
torch_fp8_ffn = torch.compile(torch_fp8_ffn)
with torch.inference_mode():
# Warmup
for _ in range(10):
_ = torch_fp8_ffn(x)
with torch.autocast(device_type="cuda", dtype=torch.float16):
_ = torch_fp16_ffn(x)
# Test torch fp8 speed
s = time.time()
for _ in range(1000):
torch_fp8_y = torch_fp8_ffn(x)
torch.cuda.synchronize()
e = time.time()
print(f"torch fp8: {e-s}ms")
# Test torch fp16 speed
s = time.time()
for _ in range(1000):
with torch.autocast(device_type="cuda", dtype=torch.float16):
torch_fp16_y = torch_fp16_ffn(x)
torch.cuda.synchronize()
e = time.time()
print(f"torch fp16: {e-s}ms")
# Profile memory
torch.cuda.reset_peak_memory_stats("cuda:0"), torch.cuda.empty_cache()
_ = torch_fp8_ffn(x)
peak_memory = torch.cuda.max_memory_allocated("cuda:0")
print(f"Torch FP8 Peak memory usage: {peak_memory / 1024 ** 2:.2f} MB")
torch.cuda.reset_peak_memory_stats("cuda:0"), torch.cuda.empty_cache()
with torch.autocast(device_type="cuda", dtype=torch.float16):
torch_fp16_y = torch_fp16_ffn(x)
peak_memory = torch.cuda.max_memory_allocated("cuda:0")
print(f"Torch FP16 Peak memory usage: {peak_memory / 1024 ** 2:.2f} MB")
print(f"[torch-fp8 v.s. torch-fp16] mse loss: {nn.functional.mse_loss(torch_fp16_y, torch_fp8_y)}")
The FP8 FFN performance degrades in both speed and GPU memory if it is not compiled.
Track Logs (torch.profile)
Testbed
2.5.0.dev20240814+cu1212024.8.15+cu12112.1 (NVIDIA L20, SM89)Codes to Reproduce Issues