What I cannot create, I do not understand.
零基础到完整推理系统的实战教程
适合人群: PyTorch 零基础,想理解 LLM 推理引擎的工作原理
学习路径:
- 从最简单的 10 行代码开始
- 每一步都能运行和验证
- 逐步添加优化,理解每个技术的作用
- 最终理解 nano-vllm和vllm 这样的推理引擎
学习方式:
- 边看边敲代码(强烈推荐)
- 修改参数,观察结果
- 出错了不要慌,文档有调试技巧
- 第零部分:PyTorch 快速入门
- 第一部分:最简单的推理(10 行代码)
- 第二部分:理解 Prefill 和 Decode
- 第三部分:加速优化 - KV Cache
- 第四部分:处理多个请求 - Batching
- 第五部分:资源管理 - Scheduler
- 第六部分:内存优化 - PagedAttention
- 第七部分:高级优化
- 第八部分:对比学习 nano-vllm
在开始构建推理引擎之前,我们需要掌握一些 PyTorch 基础。
# 安装 PyTorch(根据你的系统选择)
# CPU 版本
pip install torch
# GPU 版本(CUDA 11.8)
pip install torch --index-url https://download.pytorch.org/whl/cu118
# 验证安装
python -c "import torch; print(torch.__version__)"Tensor 是什么?
- 就是多维数组(类似 NumPy 的 ndarray)
- 可以在 GPU 上运行
import torch
# 创建 Tensor
x = torch.tensor([1, 2, 3])
print(x) # tensor([1, 2, 3])
print(x.shape) # torch.Size([3])
# 二维 Tensor(矩阵)
matrix = torch.tensor([[1, 2, 3],
[4, 5, 6]])
print(matrix.shape) # torch.Size([2, 3])
# ↑ ↑
# | └─ 3 列
# └──── 2 行
# 三维 Tensor
tensor_3d = torch.randn(2, 3, 4)
print(tensor_3d.shape) # torch.Size([2, 3, 4])
# ↑ ↑ ↑
# | | └─ 每个矩阵 4 列
# | └──── 每个 "层" 3 行
# └─────── 2 个矩阵常用创建方法:
# 随机数
x = torch.randn(2, 3) # 正态分布随机数
# 全零
x = torch.zeros(2, 3)
# 全一
x = torch.ones(2, 3)
# 指定范围
x = torch.arange(10) # [0, 1, 2, ..., 9]# 矩阵乘法的核心规则
A = torch.randn(2, 3) # [2, 3]
B = torch.randn(3, 4) # [3, 4]
C = A @ B # [2, 4]
# ↑ ↑
# | └─ B 的列数
# └──── A 的行数
# 关键:A 的列数必须等于 B 的行数
print(f"A: {A.shape}, B: {B.shape}, C: {C.shape}")
# A: torch.Size([2, 3]), B: torch.Size([3, 4]), C: torch.Size([2, 4])
# 错误示例
try:
A = torch.randn(2, 3)
B = torch.randn(5, 4) # 3 ≠ 5
C = A @ B # 报错!
except RuntimeError as e:
print(f"错误: {e}")可视化矩阵乘法:
# 例子
A = torch.tensor([[1, 2, 3],
[4, 5, 6]]) # [2, 3]
B = torch.tensor([[7, 8],
[9, 10],
[11, 12]]) # [3, 2]
C = A @ B
print(C)
# tensor([[ 58, 64],
# [139, 154]])
# 计算过程(第一个元素):
# C[0, 0] = A[0, 0]*B[0, 0] + A[0, 1]*B[1, 0] + A[0, 2]*B[2, 0]
# = 1*7 + 2*9 + 3*11
# = 7 + 18 + 33
# = 58x = torch.arange(12)
print(x.shape) # [12]
# 改成 3 行 4 列
y = x.view(3, 4)
print(y.shape) # [3, 4]
print(y)
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])
# 改成 2 行 6 列
z = x.view(2, 6)
print(z.shape) # [2, 6]
# 自动推断维度(-1)
w = x.view(4, -1) # 4 行,列数自动计算
print(w.shape) # [4, 3]x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
print(x.shape) # [2, 3]
y = x.T # 转置
print(y.shape) # [3, 2]
print(y)
# tensor([[1, 4],
# [2, 5],
# [3, 6]])
# 对于高维 tensor,指定转置的维度
x = torch.randn(2, 3, 4, 5)
y = x.transpose(2, 3) # 交换第 2 和第 3 维
print(y.shape) # [2, 3, 5, 4]x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
# 取单个元素
print(x[0, 0]) # tensor(1)
# 取一行
print(x[0]) # tensor([1, 2, 3])
# 取一列
print(x[:, 0]) # tensor([1, 4])
# 切片
print(x[:, 1:]) # 所有行,从第 1 列开始
# tensor([[2, 3],
# [5, 6]])# 检查是否有 GPU
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
# 将 tensor 移到 GPU
x = torch.randn(2, 3)
x_gpu = x.cuda() # 或 x.to('cuda')
print(x_gpu.device) # cuda:0
# GPU 上的计算
y_gpu = torch.randn(3, 4).cuda()
z_gpu = x_gpu @ y_gpu # 在 GPU 上计算
# 移回 CPU
z_cpu = z_gpu.cpu()
# 统一写法(自动选择设备)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
x = torch.randn(2, 3).to(device)nn.Module 是所有神经网络模块的基类。
import torch.nn as nn
# 最简单的模块:线性层(全连接层)
linear = nn.Linear(in_features=10, out_features=5)
# 内部有两个参数:
# - weight: [5, 10] (out_features × in_features)
# - bias: [5]
print(f"Weight shape: {linear.weight.shape}")
print(f"Bias shape: {linear.bias.shape}")
# 前向传播
x = torch.randn(2, 10) # batch=2, features=10
y = linear(x)
print(f"Input: {x.shape}, Output: {y.shape}")
# Input: torch.Size([2, 10]), Output: torch.Size([2, 5])
# 计算过程:
# y = x @ W.T + b
# = [2, 10] @ [10, 5] + [5]
# = [2, 5]自定义模块:
class MyModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
# 定义层
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
# 定义前向传播
x = self.fc1(x)
x = torch.relu(x) # 激活函数
x = self.fc2(x)
return x
# 使用
model = MyModel(10, 20, 5)
x = torch.randn(2, 10)
y = model(x)
print(y.shape) # [2, 5]
# 查看所有参数
for name, param in model.named_parameters():
print(f"{name}: {param.shape}")
# fc1.weight: torch.Size([20, 10])
# fc1.bias: torch.Size([20])
# fc2.weight: torch.Size([5, 20])
# fc2.bias: torch.Size([5])from transformers import AutoModelForCausalLM, AutoTokenizer
# 下载和加载模型
model_name = "gpt2" # 小模型,适合学习
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# 查看模型结构
print(model)
# 获取模型参数数量
num_params = sum(p.numel() for p in model.parameters())
print(f"参数量: {num_params / 1e6:.1f}M") # GPT-2: 124.4M
# Tokenize(文本 → 数字)
text = "Hello world"
input_ids = tokenizer.encode(text, return_tensors='pt')
print(f"Input IDs: {input_ids}")
# tensor([[15496, 995]])
# 模型推理
with torch.no_grad(): # 推理时不需要计算梯度
outputs = model(input_ids)
logits = outputs.logits # [1, 2, 50257]
# ↑ ↑ ↑
# | | └─ vocab size
# | └──── seq_len
# └─────── batch
print(f"Logits shape: {logits.shape}")# 练习 1: 矩阵乘法
# 创建两个矩阵 A [3, 4] 和 B [4, 5],计算 C = A @ B
A = torch.randn(3, 4)
B = torch.randn(4, 5)
C = A @ B
print(f"C shape: {C.shape}") # 应该是 [3, 5]
# 练习 2: Reshape
# 将一个 [2, 3, 4] 的 tensor reshape 成 [2, 12]
x = torch.randn(2, 3, 4)
y = x.view(2, -1) # -1 自动计算
print(f"y shape: {y.shape}") # 应该是 [2, 12]
# 练习 3: 实现一个简单的线性层(不用 nn.Linear)
def my_linear(x, weight, bias):
# x: [batch, in_features]
# weight: [out_features, in_features]
# bias: [out_features]
return x @ weight.T + bias
x = torch.randn(2, 10)
weight = torch.randn(5, 10)
bias = torch.randn(5)
output = my_linear(x, weight, bias)
print(f"Output shape: {output.shape}") # 应该是 [2, 5]
# 验证和 nn.Linear 一致
linear = nn.Linear(10, 5)
linear.weight.data = weight
linear.bias.data = bias
output_pytorch = linear(x)
print(f"Match: {torch.allclose(output, output_pytorch)}") # True恭喜!你已经掌握了构建推理引擎所需的 PyTorch 基础。
现在开始构建我们的第一个推理引擎!
用最少的代码,实现:
- 加载一个模型
- 输入一句话
- 生成一个新词
不关心性能,只求能跑通。
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# 1. 加载模型和 tokenizer
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval() # 设置为评估模式
# 2. 输入文本
text = "Hello, I am"
input_ids = tokenizer.encode(text, return_tensors='pt')
print(f"Input: {text}")
print(f"Token IDs: {input_ids}")
# 3. 模型推理
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits # [1, seq_len, vocab_size]
# 4. 获取最后一个 token 的预测
last_logits = logits[0, -1, :] # [vocab_size]
# 5. 选择概率最高的 token
next_token_id = torch.argmax(last_logits).item()
# 6. 解码
next_token = tokenizer.decode(next_token_id)
print(f"Next token: {next_token}")
# 完整输出
generated_text = text + next_token
print(f"Generated: {generated_text}")运行结果:
Input: Hello, I am
Token IDs: tensor([[15496, 11, 314, 716]])
Next token: a
Generated: Hello, I am a
让我们逐行理解:
# 步骤 1: 加载模型
model = AutoModelForCausalLM.from_pretrained("gpt2")
# 这会下载 GPT-2 模型(~500MB)
# 模型结构:12 层 Transformer Decoder
# 参数量:124M
# 步骤 2: Tokenize
text = "Hello, I am"
input_ids = tokenizer.encode(text, return_tensors='pt')
# 将文本转换为 token IDs
# "Hello" → 15496
# "," → 11
# " I" → 314
# " am" → 716
# 结果: tensor([[15496, 11, 314, 716]])
# 步骤 3: 模型前向传播
outputs = model(input_ids)
# 输入: [1, 4] (1 个序列,4 个 tokens)
# 输出: logits [1, 4, 50257]
# ↑ ↑ ↑
# | | └─ 50257 个词的得分
# | └──── 4 个位置(每个 token 都预测下一个词)
# └─────── batch size
# 步骤 4: 取最后一个 token 的预测
last_logits = logits[0, -1, :]
# 形状: [50257]
# 含义: 词表中每个词作为下一个词的得分
# 步骤 5: 选择得分最高的词
next_token_id = torch.argmax(last_logits).item()
# argmax 返回最大值的索引
# 例如: 如果 last_logits[257] = 5.3 是最大的
# 则 next_token_id = 257现在扩展到生成多个词:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval()
# 输入
text = "Once upon a time"
input_ids = tokenizer.encode(text, return_tensors='pt')
print(f"Input: {text}")
# 生成 10 个新词
max_new_tokens = 10
for i in range(max_new_tokens):
# 推理
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits
# 取最后一个 token 的预测
next_token_id = torch.argmax(logits[0, -1, :]).item()
# 追加到输入
input_ids = torch.cat([
input_ids,
torch.tensor([[next_token_id]])
], dim=1)
# 打印进度
generated = tokenizer.decode(input_ids[0])
print(f"Step {i+1}: {generated}")
# 最终结果
final_text = tokenizer.decode(input_ids[0])
print(f"\nFinal: {final_text}")运行结果:
Input: Once upon a time
Step 1: Once upon a time,
Step 2: Once upon a time, I
Step 3: Once upon a time, I was
Step 4: Once upon a time, I was a
Step 5: Once upon a time, I was a little
Step 6: Once upon a time, I was a little girl
Step 7: Once upon a time, I was a little girl.
Step 8: Once upon a time, I was a little girl. I
Step 9: Once upon a time, I was a little girl. I was
Step 10: Once upon a time, I was a little girl. I was born
Final: Once upon a time, I was a little girl. I was born
上面的代码总是选择概率最高的词(贪心策略),会导致生成内容重复。我们加入随机采样:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval()
def generate(prompt, max_new_tokens=20, temperature=1.0):
"""
生成文本
参数:
prompt: 输入文本
max_new_tokens: 最多生成多少个词
temperature: 温度参数
- < 1.0: 更确定性(倾向高概率词)
- = 1.0: 标准采样
- > 1.0: 更随机
"""
input_ids = tokenizer.encode(prompt, return_tensors='pt')
for _ in range(max_new_tokens):
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits[0, -1, :] # [vocab_size]
# 温度缩放
logits = logits / temperature
# 转换为概率
probs = torch.softmax(logits, dim=-1)
# 根据概率采样
next_token_id = torch.multinomial(probs, num_samples=1).item()
# 追加
input_ids = torch.cat([
input_ids,
torch.tensor([[next_token_id]])
], dim=1)
return tokenizer.decode(input_ids[0])
# 测试不同温度
prompt = "The secret of life is"
print(f"Prompt: {prompt}\n")
print("Temperature = 0.5 (确定性):")
print(generate(prompt, temperature=0.5))
print()
print("Temperature = 1.0 (标准):")
print(generate(prompt, temperature=1.0))
print()
print("Temperature = 1.5 (随机):")
print(generate(prompt, temperature=1.5))我们的 Version 0.1 能工作,但有严重的性能问题:
# 问题 1: 每次都重新计算所有 token
for i in range(100):
outputs = model(input_ids) # input_ids 越来越长
# Step 1: 处理 4 个 tokens
# Step 2: 处理 5 个 tokens (token 1-4 重复计算了)
# Step 3: 处理 6 个 tokens (token 1-5 重复计算了)
# ...
# Step 100: 处理 104 个 tokens (token 1-103 都重复计算了)
# 计算量:4 + 5 + 6 + ... + 104 = 5404 次 token 处理
# 实际只需要:4 + 100 = 104 次
# 浪费了 50 倍的计算!# 问题 2: 只能处理一个请求
# 如果有 10 个用户同时发送请求,只能一个一个处理
# GPU 利用率极低(~5%)下一部分我们将解决这些问题。
Version 0.1 实现了:
- ✅ 基本的文本生成
- ✅ 温度采样
存在的问题:
- ❌ 重复计算(慢)
- ❌ 无法批处理(吞吐量低)
在优化之前,我们需要深入理解推理的两个阶段。
Prefill = 处理输入 prompt 的阶段
# 输入
prompt = "Hello, I am a"
input_ids = [15496, 11, 314, 716, 257] # 5 个 tokens
# Prefill 阶段
outputs = model(input_ids)
# 一次性处理所有 5 个 tokens
# 输出: logits [1, 5, 50257]
# 目标: 为后续生成做准备为什么叫 Prefill?
- 因为会"预先填充"一些缓存(后面会讲 KV Cache)
- 类似"预加载"的概念
Decode = 逐个生成新 token 的阶段
# 已有的 tokens
current_ids = [15496, 11, 314, 716, 257]
# Decode 阶段(循环)
for step in range(max_new_tokens):
# 每次只处理最后一个 token(理想情况)
# 但我们的 Version 0.1 是处理所有 tokens(低效)
outputs = model(current_ids)
next_token = sample(outputs.logits[0, -1, :])
current_ids.append(next_token)| 特性 | Prefill | Decode |
|---|---|---|
| 处理的 tokens | 多个(整个 prompt) | 1 个(新生成的) |
| 计算模式 | 并行(所有 tokens 一起算) | 串行(一个一个生成) |
| 计算量 | 大(处理 N 个 tokens) | 小(处理 1 个 token) |
| 时间 | 一次性较长 | 每步很短 |
| 优化重点 | 利用并行性 | 减少重复计算 |
# 输入: "Hello, how are you"
# Token IDs: [15496, 11, 703, 389, 345]
# ============================================
# Prefill 阶段(一次性处理)
# ============================================
# 输入: [15496, 11, 703, 389, 345]
# ↓ ↓ ↓ ↓ ↓
# 模型: [模型处理所有 5 个 tokens]
# ↓ ↓ ↓ ↓ ↓
# 输出: [pred1, pred2, pred3, pred4, pred5]
# ↑
# 我们只需要最后一个
# 采样: pred5 → next_token = 5145 ("?")
# ============================================
# Decode 阶段(循环生成)
# ============================================
# Step 1:
# 输入: [..., 5145] (上一步生成的)
# 输出: pred6
# 采样: pred6 → 314 ("I")
# Step 2:
# 输入: [..., 314]
# 输出: pred7
# 采样: pred7 → 716 ("am")
# Step 3:
# 输入: [..., 716]
# 输出: pred8
# 采样: pred8 → 1327 ("fine")
# ...继续直到生成 EOS 或达到最大长度import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval()
def generate_v2(prompt, max_new_tokens=10, temperature=1.0):
"""
Version 2: 明确区分 Prefill 和 Decode
"""
print("=" * 50)
print("PREFILL 阶段")
print("=" * 50)
# Prefill: 处理 prompt
input_ids = tokenizer.encode(prompt, return_tensors='pt')
print(f"Input tokens: {input_ids.shape[1]}")
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits
print(f"Prefill output shape: {logits.shape}")
# 采样第一个新 token
probs = torch.softmax(logits[0, -1, :] / temperature, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1).item()
input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]])], dim=1)
print(f"First generated token: {tokenizer.decode(next_token_id)}")
print("\n" + "=" * 50)
print("DECODE 阶段")
print("=" * 50)
# Decode: 逐个生成
for i in range(max_new_tokens - 1):
with torch.no_grad():
# 注意:这里仍然处理所有 tokens(低效)
# 下一部分会优化
outputs = model(input_ids)
logits = outputs.logits
# 采样
probs = torch.softmax(logits[0, -1, :] / temperature, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1).item()
input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]])], dim=1)
print(f"Step {i+1}: {tokenizer.decode(input_ids[0])}")
return tokenizer.decode(input_ids[0])
# 测试
result = generate_v2("The meaning of life is", max_new_tokens=5)
print(f"\nFinal result: {result}")运行结果:
==================================================
PREFILL 阶段
==================================================
Input tokens: 5
Prefill output shape: torch.Size([1, 5, 50257])
First generated token: to
==================================================
DECODE 阶段
==================================================
Step 1: The meaning of life is to be
Step 2: The meaning of life is to be a
Step 3: The meaning of life is to be a part
Step 4: The meaning of life is to be a part of
Final result: The meaning of life is to be a part of
import time
def benchmark_generation(prompt, max_new_tokens=50):
"""测量 Prefill 和 Decode 的时间"""
input_ids = tokenizer.encode(prompt, return_tensors='pt')
# Prefill 阶段
start = time.time()
with torch.no_grad():
outputs = model(input_ids)
prefill_time = time.time() - start
# 第一个 token
next_token_id = torch.argmax(outputs.logits[0, -1, :]).item()
input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]])], dim=1)
# Decode 阶段
decode_times = []
for _ in range(max_new_tokens - 1):
start = time.time()
with torch.no_grad():
outputs = model(input_ids)
decode_times.append(time.time() - start)
next_token_id = torch.argmax(outputs.logits[0, -1, :]).item()
input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]])], dim=1)
print(f"Prefill time: {prefill_time*1000:.2f} ms")
print(f"Decode avg time: {sum(decode_times)/len(decode_times)*1000:.2f} ms")
print(f"Decode total time: {sum(decode_times)*1000:.2f} ms")
print(f"Total time: {(prefill_time + sum(decode_times))*1000:.2f} ms")
# 测试
benchmark_generation("Once upon a time", max_new_tokens=20)典型结果(GPU):
Prefill time: 394.24 ms
Decode avg time: 233.29 ms
Decode total time: 4432.52 ms
Total time: 4826.76 ms
问题:Decode 为什么越来越慢?
# Step 1: 处理 6 个 tokens (prompt 5 + generated 1)
# Step 2: 处理 7 个 tokens (prompt 5 + generated 2)
# Step 3: 处理 8 个 tokens (prompt 5 + generated 3)
# ...
# Step 20: 处理 25 个 tokens (prompt 5 + generated 20)
# 每一步都在重新计算之前的 tokens!下一部分我们用 KV Cache 解决这个问题。
回顾 Transformer 的 Attention 计算:
# 简化的 Attention 伪代码
def attention(input_ids):
# 计算 Q, K, V
Q = input_ids @ W_q # Query
K = input_ids @ W_k # Key
V = input_ids @ W_v # Value
# Attention 计算
scores = Q @ K.T
attn_weights = softmax(scores)
output = attn_weights @ V
return output问题:每次生成新 token 时,之前的 K 和 V 都重新计算了
# Step 1: input_ids = [t1, t2, t3]
K1 = compute_K([t1, t2, t3]) # 计算 K
V1 = compute_V([t1, t2, t3]) # 计算 V
# Step 2: input_ids = [t1, t2, t3, t4]
K2 = compute_K([t1, t2, t3, t4]) # t1,t2,t3 的 K 重复计算了!
V2 = compute_V([t1, t2, t3, t4]) # t1,t2,t3 的 V 重复计算了!
# Step 3: input_ids = [t1, t2, t3, t4, t5]
K3 = compute_K([t1, t2, t3, t4, t5]) # 又重复计算!
V3 = compute_V([t1, t2, t3, t4, t5]) # 又重复计算!解决方案:缓存已经计算过的 K 和 V
# Step 1:
K_cache = compute_K([t1, t2, t3]) # 计算并缓存
V_cache = compute_V([t1, t2, t3])
# Step 2:
K_new = compute_K([t4]) # 只计算新 token
V_new = compute_V([t4])
K_cache = concat(K_cache, K_new) # 追加到缓存
V_cache = concat(V_cache, V_new)
# Step 3:
K_new = compute_K([t5]) # 只计算新 token
V_new = compute_V([t5])
K_cache = concat(K_cache, K_new)
V_cache = concat(V_cache, V_new)Hugging Face Transformers 已经内置了 KV Cache 支持!
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval()
def generate_with_kv_cache(prompt, max_new_tokens=10):
"""
Version 3: 使用 KV Cache 加速
"""
input_ids = tokenizer.encode(prompt, return_tensors='pt')
# Prefill 阶段
with torch.no_grad():
outputs = model(input_ids, use_cache=True)
# outputs.past_key_values 就是 KV Cache
past_key_values = outputs.past_key_values
logits = outputs.logits
# 第一个新 token
next_token_id = torch.argmax(logits[0, -1, :]).item()
input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]])], dim=1)
# Decode 阶段(使用 KV Cache)
for i in range(max_new_tokens - 1):
# 关键:只输入最后一个 token!
last_token_id = torch.tensor([[input_ids[0, -1].item()]])
with torch.no_grad():
outputs = model(
last_token_id,
past_key_values=past_key_values, # 传入缓存
use_cache=True
)
# 更新缓存
past_key_values = outputs.past_key_values
logits = outputs.logits
# 采样
next_token_id = torch.argmax(logits[0, -1, :]).item()
input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]])], dim=1)
print(f"Step {i+1}: {tokenizer.decode(input_ids[0])}")
return tokenizer.decode(input_ids[0])
# 测试
result = generate_with_kv_cache("The secret of happiness is", max_new_tokens=10)
print(f"\nResult: {result}")import time
def benchmark_with_and_without_cache(prompt, max_new_tokens=50):
"""对比有无 KV Cache 的性能"""
print("=" * 60)
print("WITHOUT KV Cache (Version 0.1)")
print("=" * 60)
input_ids = tokenizer.encode(prompt, return_tensors='pt')
start = time.time()
for _ in range(max_new_tokens):
with torch.no_grad():
outputs = model(input_ids) # 每次处理所有 tokens
next_token_id = torch.argmax(outputs.logits[0, -1, :]).item()
input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]])], dim=1)
time_without_cache = time.time() - start
print(f"Time: {time_without_cache*1000:.2f} ms\n")
print("=" * 60)
print("WITH KV Cache (Version 3)")
print("=" * 60)
input_ids = tokenizer.encode(prompt, return_tensors='pt')
start = time.time()
# Prefill
with torch.no_grad():
outputs = model(input_ids, use_cache=True)
past_key_values = outputs.past_key_values
next_token_id = torch.argmax(outputs.logits[0, -1, :]).item()
input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]])], dim=1)
# Decode with cache
for _ in range(max_new_tokens - 1):
last_token_id = torch.tensor([[input_ids[0, -1].item()]])
with torch.no_grad():
outputs = model(last_token_id, past_key_values=past_key_values, use_cache=True)
past_key_values = outputs.past_key_values
next_token_id = torch.argmax(outputs.logits[0, -1, :]).item()
input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]])], dim=1)
time_with_cache = time.time() - start
print(f"Time: {time_with_cache*1000:.2f} ms\n")
print("=" * 60)
print("SPEEDUP")
print("=" * 60)
print(f"Speedup: {time_without_cache / time_with_cache:.2f}x")
# 测试
benchmark_with_and_without_cache("Once upon a time", max_new_tokens=50)典型结果:
===========================================================
WITHOUT KV Cache (Version 0.1)
===========================================================
Time: 9934.40 ms
===========================================================
WITH KV Cache (Version 3)
===========================================================
Time: 2376.81 ms
===========================================================
SPEEDUP
===========================================================
Speedup: 4.18x
惊人的 4.18x 加速!
# 查看 KV Cache 的结构
input_ids = tokenizer.encode("Hello world", return_tensors='pt')
outputs = model(input_ids, use_cache=True)
past_key_values = outputs.past_key_values
print(f"Type: {type(past_key_values)}")
# tuple
print(f"Number of layers: {len(past_key_values)}")
# 12 (GPT-2 有 12 层)
print(f"Each layer: {type(past_key_values[0])}")
# tuple (key, value)
print(f"Key shape: {past_key_values[0][0].shape}")
# [batch, num_heads, seq_len, head_dim]
# [1, 12, 2, 64] for "Hello world"
print(f"Value shape: {past_key_values[0][1].shape}")
# [1, 12, 2, 64]
# 完整结构
# past_key_values = (
# (layer_0_key, layer_0_value), # [1, 12, 2, 64] each
# (layer_1_key, layer_1_value),
# ...
# (layer_11_key, layer_11_value),
# )为了更好理解原理,我们手动实现一个简化版:
import torch
import torch.nn as nn
class SimplifiedAttentionWithCache(nn.Module):
"""简化的 Attention(支持 KV Cache)"""
def __init__(self, hidden_size=64, num_heads=4):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.q_proj = nn.Linear(hidden_size, hidden_size)
self.k_proj = nn.Linear(hidden_size, hidden_size)
self.v_proj = nn.Linear(hidden_size, hidden_size)
self.o_proj = nn.Linear(hidden_size, hidden_size)
def forward(self, x, past_kv=None):
"""
x: [batch, seq_len, hidden_size]
past_kv: (past_key, past_value) 或 None
返回: (output, (new_key, new_value))
"""
batch, seq_len, _ = x.shape
# 计算 Q, K, V
q = self.q_proj(x) # [batch, seq_len, hidden_size]
k = self.k_proj(x)
v = self.v_proj(x)
# Reshape 成多头
q = q.view(batch, seq_len, self.num_heads, self.head_dim)
k = k.view(batch, seq_len, self.num_heads, self.head_dim)
v = v.view(batch, seq_len, self.num_heads, self.head_dim)
# 如果有缓存,拼接
if past_kv is not None:
past_key, past_value = past_kv
k = torch.cat([past_key, k], dim=1) # 拼接历史
v = torch.cat([past_value, v], dim=1)
# Permute: [B, S, H, D] → [B, H, S, D]
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
# Attention
scores = q @ k.transpose(-2, -1) / (self.head_dim ** 0.5)
attn_weights = torch.softmax(scores, dim=-1)
output = attn_weights @ v
# Permute back
output = output.permute(0, 2, 1, 3).contiguous()
output = output.view(batch, seq_len, self.hidden_size)
# 输出投影
output = self.o_proj(output)
# 返回输出和新的 KV(用于下次缓存)
return output, (k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3))
# 测试
attn = SimplifiedAttentionWithCache()
# Prefill
x_prefill = torch.randn(1, 5, 64) # 5 个 tokens
output, past_kv = attn(x_prefill, past_kv=None)
print(f"Prefill output: {output.shape}")
print(f"Cached K: {past_kv[0].shape}, V: {past_kv[1].shape}")
# Decode (只输入 1 个新 token)
x_decode = torch.randn(1, 1, 64) # 1 个 token
output, past_kv = attn(x_decode, past_kv=past_kv)
print(f"Decode output: {output.shape}")
print(f"Updated K: {past_kv[0].shape}, V: {past_kv[1].shape}")
# K 和 V 的 seq_len 从 5 增加到 6输出:
Prefill output: torch.Size([1, 5, 64])
Cached K: torch.Size([1, 5, 4, 16]), V: torch.Size([1, 5, 4, 16])
Decode output: torch.Size([1, 1, 64])
Updated K: torch.Size([1, 6, 4, 16]), V: torch.Size([1, 6, 4, 16])
Version 3 (with KV Cache) 实现了:
- ✅ 消除重复计算
- ✅ 10-20x 加速
- ✅ 每个 Decode step 只处理 1 个 token
下一步: 支持批处理,同时处理多个请求
# 场景:3 个用户同时发送请求
user1: "Hello"
user2: "What is AI"
user3: "Once upon a time"
# Version 3 的处理方式(串行)
generate(user1) # 5ms
generate(user2) # 5ms
generate(user3) # 5ms
# 总时间: 15ms
# 使用 Batching(并行)
generate_batch([user1, user2, user3]) # 6ms
# 总时间: 6ms
# 加速: 2.5x# 问题:用户输入长度不同
user1: [1, 2, 3] # 3 tokens
user2: [4, 5, 6, 7] # 4 tokens
user3: [8, 9] # 2 tokens
# 如何组成一个 batch?
# PyTorch 的 tensor 必须是矩形的(所有行相同长度)解决方案:Padding(填充)
# 填充到相同长度(最长的)
max_len = 4
user1_padded: [1, 2, 3, 0] # 填充 1 个 0
user2_padded: [4, 5, 6, 7] # 不需要填充
user3_padded: [8, 9, 0, 0] # 填充 2 个 0
# 组成 batch
batch = torch.tensor([
[1, 2, 3, 0],
[4, 5, 6, 7],
[8, 9, 0, 0]
])
# shape: [3, 4]import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval()
# 设置 padding token
tokenizer.pad_token = tokenizer.eos_token
def generate_batch(prompts, max_new_tokens=10):
"""
Version 4: 批量生成
参数:
prompts: list of str
"""
# Tokenize 并 padding
inputs = tokenizer(
prompts,
return_tensors='pt',
padding=True, # 自动 padding
truncation=True
)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
print(f"Batch size: {input_ids.shape[0]}")
print(f"Max length: {input_ids.shape[1]}")
print(f"Input IDs:\n{input_ids}")
print(f"Attention mask:\n{attention_mask}")
# Prefill
with torch.no_grad():
outputs = model(
input_ids,
attention_mask=attention_mask,
use_cache=True
)
past_key_values = outputs.past_key_values
# 第一个新 token
next_token_ids = torch.argmax(outputs.logits[:, -1, :], dim=-1)
input_ids = torch.cat([input_ids, next_token_ids.unsqueeze(1)], dim=1)
# 更新 attention mask
attention_mask = torch.cat([
attention_mask,
torch.ones(attention_mask.shape[0], 1, dtype=torch.long)
], dim=1)
# Decode
for i in range(max_new_tokens - 1):
# 只输入最后一个 token
last_tokens = input_ids[:, -1:]
with torch.no_grad():
outputs = model(
last_tokens,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True
)
past_key_values = outputs.past_key_values
# 采样
next_token_ids = torch.argmax(outputs.logits[:, -1, :], dim=-1)
input_ids = torch.cat([input_ids, next_token_ids.unsqueeze(1)], dim=1)
# 更新 mask
attention_mask = torch.cat([
attention_mask,
torch.ones(attention_mask.shape[0], 1, dtype=torch.long)
], dim=1)
# 解码
results = []
for i, ids in enumerate(input_ids):
text = tokenizer.decode(ids, skip_special_tokens=True)
results.append(text)
return results
# 测试
prompts = [
"Hello, I am",
"The secret of life is",
"Once upon a time"
]
results = generate_batch(prompts, max_new_tokens=10)
print("\n" + "=" * 60)
print("RESULTS")
print("=" * 60)
for i, (prompt, result) in enumerate(zip(prompts, results)):
print(f"Prompt {i+1}: {prompt}")
print(f"Generated: {result}")
print()输出:
Batch size: 3
Max length: 5
Input IDs:
tensor([[15496, 11, 314, 716, 50256],
[ 464, 3200, 286, 1204, 318],
[ 7454, 2402, 257, 640, 50256]])
Attention mask:
tensor([[1, 1, 1, 1, 0],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 0]])
============================================================
RESULTS
============================================================
Prompt 1: Hello, I am
Generated: Hello, I am a very good friend of mine.
Prompt 2: The secret of life is
Generated: The secret of life is to be able to do
Prompt 3: Once upon a time
Generated: Once upon a time, the world was a place
# Attention Mask 的作用:告诉模型哪些位置是 padding
input_ids:
[[1, 2, 3, 0], # 0 是 padding
[4, 5, 6, 7], # 无 padding
[8, 9, 0, 0]] # 两个 padding
attention_mask:
[[1, 1, 1, 0], # 最后一个位置是 padding,mask 掉
[1, 1, 1, 1], # 全部有效
[1, 1, 0, 0]] # 最后两个是 padding
# 在 Attention 计算时:
# scores = Q @ K.T
# scores = scores.masked_fill(attention_mask == 0, -inf)
# 这样 padding 位置的 attention 权重会是 0import time
def benchmark_batch_sizes():
"""测试不同 batch size 的性能"""
prompt = "Hello, I am"
max_new_tokens = 20
for batch_size in [1, 2, 4, 8]:
prompts = [prompt] * batch_size
start = time.time()
results = generate_batch(prompts, max_new_tokens=max_new_tokens)
elapsed = time.time() - start
throughput = batch_size * max_new_tokens / elapsed
print(f"Batch size {batch_size}: {elapsed*1000:.2f} ms, "
f"Throughput: {throughput:.2f} tokens/s")
benchmark_batch_sizes()典型结果(GPU):
Batch size 1: 1165.84 ms, Throughput: 17.16 tokens/s
Batch size 2: 1915.17 ms, Throughput: 20.89 tokens/s
Batch size 4: 2063.64 ms, Throughput: 38.77 tokens/s
Batch size 8: 2736.99 ms, Throughput: 58.46 tokens/s
观察:
- Batch size 增大,吞吐量提升
- 但单个请求的延迟也增加了
Version 4 (with Batching) 实现了:
- ✅ 同时处理多个请求
- ✅ 提升吞吐量 2-3x
- ✅ 处理不同长度的输入(padding)
存在的问题:
- ❌ 所有序列必须同时开始和结束
- ❌ 短序列完成后仍要等长序列
- ❌ 浪费计算资源
下一步: 实现 Scheduler,支持动态批处理
# 场景:4 个请求
batch = [
seq1, # 需要生成 10 tokens
seq2, # 需要生成 50 tokens
seq3, # 需要生成 20 tokens
seq4, # 需要生成 100 tokens
]
# Version 4 的处理:
# - 全部一起开始
# - seq1 在 step 10 完成,但要等到 step 100 才能返回
# - seq1-seq3 浪费了 90, 50, 80 步的计算
# GPU 利用率:
# Step 1-10: 4 个序列 (100%)
# Step 11-20: 3 个序列 (75%)
# Step 21-50: 2 个序列 (50%)
# Step 51-100: 1 个序列 (25%)
# 平均: ~62%解决方案:Continuous Batching
# 动态批处理:
# - 序列完成后立即移除
# - 用新请求填补空位
# - GPU 始终保持满载from collections import deque
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class Sequence:
"""表示一个生成序列"""
seq_id: int # 序列 ID
prompt: str # 原始 prompt
token_ids: List[int] # 当前的 tokens
max_tokens: int = 50 # 最多生成多少 tokens
is_finished: bool = False # 是否完成
def __len__(self):
return len(self.token_ids)
class SimpleScheduler:
"""简单的调度器"""
def __init__(self, max_batch_size=4):
self.max_batch_size = max_batch_size
self.waiting = deque() # 等待队列
self.running = deque() # 运行队列
self.finished = [] # 完成的序列
self.next_seq_id = 0
def add_request(self, prompt: str, max_tokens: int = 50):
"""添加新请求"""
token_ids = tokenizer.encode(prompt)
seq = Sequence(
seq_id=self.next_seq_id,
prompt=prompt,
token_ids=token_ids,
max_tokens=max_tokens
)
self.next_seq_id += 1
self.waiting.append(seq)
print(f"Added seq {seq.seq_id}: {prompt[:30]}...")
def schedule(self) -> List[Sequence]:
"""
调度:选择要处理的序列
返回:要处理的序列列表
"""
# 移除已完成的序列
self.running = deque([seq for seq in self.running if not seq.is_finished])
# 从 waiting 补充到 running
while len(self.running) < self.max_batch_size and self.waiting:
seq = self.waiting.popleft()
self.running.append(seq)
print(f"Started seq {seq.seq_id}")
# 返回当前 batch
return list(self.running)
def mark_finished(self, seq: Sequence):
"""标记序列为完成"""
seq.is_finished = True
self.finished.append(seq)
print(f"Finished seq {seq.seq_id}")
def is_empty(self):
"""是否所有请求都处理完了"""
return len(self.waiting) == 0 and len(self.running) == 0
# 测试调度器
scheduler = SimpleScheduler(max_batch_size=2)
# 添加 5 个请求
scheduler.add_request("Hello, I am", max_tokens=5)
scheduler.add_request("The meaning of life is", max_tokens=10)
scheduler.add_request("Once upon a time", max_tokens=3)
scheduler.add_request("What is AI", max_tokens=7)
scheduler.add_request("Tell me a story", max_tokens=8)
# 模拟调度过程
print("\n" + "=" * 60)
print("Scheduling simulation")
print("=" * 60)
step = 0
while not scheduler.is_empty():
# 调度
batch = scheduler.schedule()
print(f"\nStep {step}: Processing {len(batch)} sequences")
for seq in batch:
# 模拟生成一个 token
seq.token_ids.append(0) # 假设生成 token 0
# 检查是否完成
if len(seq.token_ids) - len(tokenizer.encode(seq.prompt)) >= seq.max_tokens:
scheduler.mark_finished(seq)
step += 1
if step > 20: # 防止无限循环
break
print(f"\nTotal steps: {step}")
print(f"Finished sequences: {len(scheduler.finished)}")输出:
Added seq 0: Hello, I am...
Added seq 1: The meaning of life is...
Added seq 2: Once upon a time...
Added seq 3: What is AI...
Added seq 4: Tell me a story...
============================================================
Scheduling simulation
============================================================
Started seq 0
Started seq 1
Step 0: Processing 2 sequences
Step 1: Processing 2 sequences
Step 2: Processing 2 sequences
Step 3: Processing 2 sequences
Step 4: Processing 2 sequences
Finished seq 0
Started seq 2
Step 5: Processing 2 sequences
Step 6: Processing 2 sequences
Finished seq 2
Started seq 3
Step 7: Processing 2 sequences
Step 8: Processing 2 sequences
Step 9: Processing 2 sequences
Step 10: Processing 2 sequences
Step 11: Processing 2 sequences
Step 12: Processing 2 sequences
Finished seq 3
Started seq 4
Finished seq 1
Step 13: Processing 1 sequences
Step 14: Processing 1 sequences
Step 15: Processing 1 sequences
Step 16: Processing 1 sequences
Step 17: Processing 1 sequences
Step 18: Processing 1 sequences
Step 19: Processing 1 sequences
Step 20: Processing 1 sequences
Finished seq 4
Total steps: 21
Finished sequences: 5
def generate_with_scheduler(scheduler: SimpleScheduler):
"""
Version 5: 使用 Scheduler 的生成
"""
all_results = {}
while not scheduler.is_empty():
# 1. 调度
batch = scheduler.schedule()
if not batch:
break
# 2. 准备输入
prompts = [seq.prompt for seq in batch]
inputs = tokenizer(
prompts,
return_tensors='pt',
padding=True
)
# 3. 生成一个 token
with torch.no_grad():
outputs = model(**inputs)
next_token_ids = torch.argmax(outputs.logits[:, -1, :], dim=-1)
# 4. 更新序列
for seq, next_token_id in zip(batch, next_token_ids):
seq.token_ids.append(next_token_id.item())
# 检查是否完成
num_generated = len(seq.token_ids) - len(tokenizer.encode(seq.prompt))
if (next_token_id == tokenizer.eos_token_id or
num_generated >= seq.max_tokens):
# 完成
scheduler.mark_finished(seq)
text = tokenizer.decode(seq.token_ids, skip_special_tokens=True)
all_results[seq.seq_id] = text
return all_results
# 测试
scheduler = SimpleScheduler(max_batch_size=2)
scheduler.add_request("Hello, I am", max_tokens=10)
scheduler.add_request("The secret is", max_tokens=10)
scheduler.add_request("Once upon a time", max_tokens=10)
results = generate_with_scheduler(scheduler)
print("\n" + "=" * 60)
print("Results")
print("=" * 60)
for seq_id, text in sorted(results.items()):
print(f"Seq {seq_id}: {text}")Version 5 (with Scheduler) 实现了:
- ✅ 动态批处理(Continuous Batching)
- ✅ 序列完成后立即移除
- ✅ 自动补充新请求
- ✅ 提升 GPU 利用率
存在的问题:
- ❌ KV Cache 占用大量显存
- ❌ 不同长度的序列浪费 KV Cache 空间
- ❌ 无法高效管理内存
下一步: PagedAttention,优化 KV Cache 内存管理
# 场景:batch_size=4, max_seq_len=1024
# 预分配 KV Cache
kv_cache_size_per_seq = max_seq_len * hidden_size * num_layers * 2
total_kv_cache = kv_cache_size_per_seq * batch_size
# 实际使用情况:
# seq1: 50 tokens (使用 50/1024 = 4.9%)
# seq2: 200 tokens (使用 200/1024 = 19.5%)
# seq3: 100 tokens (使用 100/1024 = 9.8%)
# seq4: 500 tokens (使用 500/1024 = 48.8%)
# 平均利用率: ~20%
# 浪费: 80%PagedAttention 的解决方案:
像操作系统的虚拟内存一样,将 KV Cache 分成固定大小的"块"(blocks)。
# 不再为每个序列预分配整块内存
# 而是按需分配固定大小的 blocks
block_size = 16 # 每个 block 存储 16 个 tokens
# seq1 (50 tokens): 需要 4 个 blocks (50/16 ≈ 4)
# seq2 (200 tokens): 需要 13 个 blocks
# seq3 (100 tokens): 需要 7 个 blocks
# seq4 (500 tokens): 需要 32 个 blocks
# 总共: 4 + 13 + 7 + 32 = 56 blocks
# 实际存储: 56 * 16 = 896 tokens
# vs 预分配: 4 * 1024 = 4096 tokens
# 节省: (4096 - 896) / 4096 = 78%from typing import List, Dict
class Block:
"""表示一个 KV Cache block"""
def __init__(self, block_id: int, block_size: int = 16):
self.block_id = block_id
self.block_size = block_size
self.ref_count = 0 # 引用计数
def __repr__(self):
return f"Block({self.block_id})"
class BlockManager:
"""管理 KV Cache blocks"""
def __init__(self, num_blocks: int = 100, block_size: int = 16):
self.block_size = block_size
# 创建所有 blocks
self.blocks = [Block(i, block_size) for i in range(num_blocks)]
# 空闲 blocks
self.free_blocks = set(range(num_blocks))
# 每个序列占用的 blocks
self.seq_to_blocks: Dict[int, List[int]] = {}
def can_allocate(self, num_blocks: int) -> bool:
"""检查是否有足够的空闲 blocks"""
return len(self.free_blocks) >= num_blocks
def allocate(self, seq_id: int, num_blocks: int) -> List[int]:
"""
为序列分配 blocks
返回:分配的 block IDs
"""
if not self.can_allocate(num_blocks):
raise RuntimeError(f"Not enough blocks! Need {num_blocks}, "
f"have {len(self.free_blocks)}")
# 分配 blocks
allocated = []
for _ in range(num_blocks):
block_id = self.free_blocks.pop()
self.blocks[block_id].ref_count = 1
allocated.append(block_id)
self.seq_to_blocks[seq_id] = allocated
print(f"Allocated {num_blocks} blocks for seq {seq_id}: {allocated}")
return allocated
def free(self, seq_id: int):
"""释放序列的 blocks"""
if seq_id not in self.seq_to_blocks:
return
blocks = self.seq_to_blocks[seq_id]
for block_id in blocks:
self.blocks[block_id].ref_count = 0
self.free_blocks.add(block_id)
print(f"Freed {len(blocks)} blocks from seq {seq_id}")
del self.seq_to_blocks[seq_id]
def append_slot(self, seq_id: int) -> bool:
"""
为序列追加一个 slot(生成新 token 时)
返回:是否需要新 block
"""
blocks = self.seq_to_blocks[seq_id]
current_tokens = len(blocks) * self.block_size
# 检查最后一个 block 是否已满
# 这里简化处理,假设每个 token 都用一个 slot
# 实际需要跟踪每个 block 的使用情况
return False # 简化版本
def get_stats(self):
"""获取统计信息"""
total = len(self.blocks)
used = total - len(self.free_blocks)
return {
'total_blocks': total,
'used_blocks': used,
'free_blocks': len(self.free_blocks),
'utilization': used / total * 100
}
# 测试 BlockManager
block_manager = BlockManager(num_blocks=100, block_size=16)
# 分配给不同序列
block_manager.allocate(seq_id=0, num_blocks=4) # seq0: 50 tokens
block_manager.allocate(seq_id=1, num_blocks=13) # seq1: 200 tokens
block_manager.allocate(seq_id=2, num_blocks=7) # seq2: 100 tokens
# 查看状态
stats = block_manager.get_stats()
print(f"\nStats: {stats}")
# 释放一个序列
block_manager.free(seq_id=0)
# 再次查看
stats = block_manager.get_stats()
print(f"Stats after free: {stats}")输出:
Allocated 4 blocks for seq 0: [99, 98, 97, 96]
Allocated 13 blocks for seq 1: [95, 94, 93, 92, 91, 90, 89, 88, 87, 86, 85, 84, 83]
Allocated 7 blocks for seq 2: [82, 81, 80, 79, 78, 77, 76]
Stats: {'total_blocks': 100, 'used_blocks': 24, 'free_blocks': 76, 'utilization': 24.0}
Freed 4 blocks from seq 0
Stats after free: {'total_blocks': 100, 'used_blocks': 20, 'free_blocks': 80, 'utilization': 20.0}
# 每个序列维护一个 block_table
# block_table[i] = 物理 block ID
seq = {
'seq_id': 0,
'tokens': [1, 2, 3, ..., 50], # 50 个 tokens
'block_table': [5, 12, 8, 15] # 4 个 blocks
}
# 逻辑地址 → 物理地址
def get_physical_address(seq, token_index, block_size=16):
# token_index: 0-49
block_index = token_index // block_size # 在哪个 block
block_offset = token_index % block_size # block 内偏移
physical_block_id = seq['block_table'][block_index]
physical_address = physical_block_id * block_size + block_offset
return physical_address
# 例如:访问 seq 的第 25 个 token
token_index = 25
block_index = 25 // 16 = 1 # 第 1 个 block
block_offset = 25 % 16 = 9 # block 内第 9 个位置
physical_block_id = seq['block_table'][1] = 12
physical_address = 12 * 16 + 9 = 201
# KV Cache 的实际存储位置就是 physical_addressclass SequenceWithBlocks(Sequence):
"""带 block 管理的序列"""
def __init__(self, seq_id, prompt, max_tokens, block_size=16):
super().__init__(seq_id, prompt, [], max_tokens)
self.block_size = block_size
self.block_table = [] # 分配的 blocks
@property
def num_blocks_needed(self):
"""当前需要多少个 blocks"""
return (len(self.token_ids) + self.block_size - 1) // self.block_size
class SchedulerWithBlockManager:
"""带 BlockManager 的调度器"""
def __init__(self, max_batch_size=4, num_blocks=100, block_size=16):
self.max_batch_size = max_batch_size
self.block_manager = BlockManager(num_blocks, block_size)
self.waiting = deque()
self.running = deque()
self.finished = []
self.next_seq_id = 0
def add_request(self, prompt, max_tokens=50):
"""添加请求"""
seq = SequenceWithBlocks(
seq_id=self.next_seq_id,
prompt=prompt,
max_tokens=max_tokens
)
self.next_seq_id += 1
self.waiting.append(seq)
print(f"Added seq {seq.seq_id}")
def schedule(self):
"""调度(考虑 block 可用性)"""
# 移除完成的序列并释放 blocks
for seq in list(self.running):
if seq.is_finished:
self.block_manager.free(seq.seq_id)
self.running.remove(seq)
# 尝试从 waiting 添加新序列
while len(self.running) < self.max_batch_size and self.waiting:
seq = self.waiting[0]
# 检查是否有足够的 blocks
if self.block_manager.can_allocate(seq.num_blocks_needed):
# 分配 blocks
blocks = self.block_manager.allocate(
seq.seq_id,
seq.num_blocks_needed
)
seq.block_table = blocks
# 移到 running
self.waiting.popleft()
self.running.append(seq)
else:
# 内存不足,暂时不能调度
print(f"Cannot schedule seq {seq.seq_id}: not enough blocks")
break
return list(self.running)
def mark_finished(self, seq):
"""标记完成"""
seq.is_finished = True
self.finished.append(seq)
# 测试
scheduler = SchedulerWithBlockManager(max_batch_size=2, num_blocks=50)
# 添加请求
scheduler.add_request("Hello", max_tokens=30)
scheduler.add_request("World", max_tokens=40)
scheduler.add_request("Test", max_tokens=20)
# 调度
batch = scheduler.schedule()
print(f"\nScheduled: {[seq.seq_id for seq in batch]}")
# 查看 block 使用情况
stats = scheduler.block_manager.get_stats()
print(f"Block stats: {stats}")# 对比
# 假设:4 个序列,max_len=1024, block_size=16
# 传统 KV Cache(预分配):
traditional_memory = 4 * 1024 * hidden_size
# = 4 * 1024 * 768 (GPT-2) = 3,145,728 元素
# PagedAttention(按需分配):
# seq1: 50 tokens → 4 blocks
# seq2: 200 tokens → 13 blocks
# seq3: 100 tokens → 7 blocks
# seq4: 500 tokens → 32 blocks
total_blocks = 56
paged_memory = 56 * 16 * hidden_size
# = 56 * 16 * 768 = 688,128 元素
# 节省
savings = (traditional_memory - paged_memory) / traditional_memory
# = 78%Version 6 (with PagedAttention) 实现了:
- ✅ 分块 KV Cache 管理
- ✅ 按需分配内存
- ✅ 节省 70-80% 显存
- ✅ 支持更大的 batch size
下一步: 集成高级优化技术(Flash Attention, CUDA Graph 等)
标准的 Attention 计算有一个致命的瓶颈:需要存储完整的 attention matrix。
import torch
import math
def standard_attention(Q, K, V):
"""
标准 Attention 实现
Q, K, V: [batch, num_heads, seq_len, head_dim]
问题:
1. 需要先计算完整的 attention matrix: [seq_len, seq_len]
2. 显存占用: O(seq_len²)
3. 对于长序列 (seq_len=2048),需要 2048*2048 = 4M 个元素(每个 head!)
"""
batch, num_heads, seq_len, head_dim = Q.shape
# Step 1: 计算 scores [B, H, S, S]
scores = Q @ K.transpose(-2, -1) / math.sqrt(head_dim)
# 显存占用: batch * num_heads * seq_len * seq_len * 4 bytes
# 例如: 4 * 12 * 1024 * 1024 * 4 = 192 MB (只是一个中间结果!)
# Step 2: Softmax
attn_weights = torch.softmax(scores, dim=-1)
# 仍然占用: [B, H, S, S]
# Step 3: 与 V 相乘
output = attn_weights @ V # [B, H, S, D]
return output
# 测试显存占用
batch = 4
num_heads = 12
seq_len = 1024
head_dim = 64
Q = torch.randn(batch, num_heads, seq_len, head_dim, device='cuda')
K = torch.randn(batch, num_heads, seq_len, head_dim, device='cuda')
V = torch.randn(batch, num_heads, seq_len, head_dim, device='cuda')
# 查看显存占用
torch.cuda.reset_peak_memory_stats()
output = standard_attention(Q, K, V)
peak_memory = torch.cuda.max_memory_allocated() / 1024**2
print(f"Peak memory: {peak_memory:.2f} MB")
# 典型输出:
# Peak memory: 342.15 MB
#
# 其中 attention matrix 占用:
# 4 * 12 * 1024 * 1024 * 4 = 192 MB问题分析:
# 对于 GPT-2 (seq_len=1024, 12 heads)
# Attention matrix 每个 head: 1024 * 1024 = 1,048,576 个元素
# 12 个 heads: 12 * 1,048,576 = 12,582,912 个元素
# 占用显存: 12,582,912 * 4 bytes = 48 MB (每个样本!)
# 对于更长的序列 (seq_len=2048)
# Attention matrix: 2048 * 2048 * 12 = 50,331,648 个元素
# 占用显存: 50,331,648 * 4 bytes = 192 MB (每个样本!)
# 这还只是 forward pass,backward pass 还要存梯度!Flash Attention 通过分块计算和在线 Softmax避免存储完整的 attention matrix。
关键观察:
- 我们不需要同时存储所有的 attention scores
- 可以分块计算,每次只处理一小块
- 使用在线算法更新 softmax 的统计量
def flash_attention_concept(Q, K, V, block_size=64):
"""
Flash Attention 的核心思想 (简化版伪代码)
不存储完整的 [seq_len, seq_len] attention matrix
而是分块计算并立即使用
显存占用: O(seq_len) 而非 O(seq_len²)
"""
batch, num_heads, seq_len, head_dim = Q.shape
output = torch.zeros_like(Q)
# 外层循环: 遍历 Q 的 blocks
for q_start in range(0, seq_len, block_size):
q_end = min(q_start + block_size, seq_len)
Q_block = Q[:, :, q_start:q_end, :] # [B, H, block_size, D]
# 初始化这个 Q block 的输出
block_output = torch.zeros_like(Q_block)
# 维护 softmax 的统计量
block_max = torch.full(
(batch, num_heads, q_end - q_start, 1),
float('-inf'),
device=Q.device
)
block_sum = torch.zeros(
(batch, num_heads, q_end - q_start, 1),
device=Q.device
)
# 内层循环: 遍历 K, V 的 blocks
for kv_start in range(0, seq_len, block_size):
kv_end = min(kv_start + block_size, seq_len)
K_block = K[:, :, kv_start:kv_end, :]
V_block = V[:, :, kv_start:kv_end, :]
# 计算这个 block 的 scores
# [B, H, Q_block, KV_block]
scores = Q_block @ K_block.transpose(-2, -1) / math.sqrt(head_dim)
# 在线 Softmax (关键优化!)
# 不需要存储所有 scores,只需要更新统计量
# 计算当前 block 的最大值
block_scores_max = scores.max(dim=-1, keepdim=True)[0]
new_max = torch.maximum(block_max, block_scores_max)
# 重新缩放之前的结果
old_scale = torch.exp(block_max - new_max)
block_output = block_output * old_scale
block_sum = block_sum * old_scale
# 计算当前 block 的贡献
exp_scores = torch.exp(scores - new_max)
block_output += exp_scores @ V_block
block_sum += exp_scores.sum(dim=-1, keepdim=True)
# 更新最大值
block_max = new_max
# 归一化
output[:, :, q_start:q_end, :] = block_output / block_sum
return output
# 对比显存占用
print("="*60)
print("Memory Comparison")
print("="*60)
# 标准 Attention
torch.cuda.reset_peak_memory_stats()
output1 = standard_attention(Q, K, V)
standard_memory = torch.cuda.max_memory_allocated() / 1024**2
# Flash Attention (概念版)
torch.cuda.reset_peak_memory_stats()
output2 = flash_attention_concept(Q, K, V, block_size=64)
flash_memory = torch.cuda.max_memory_allocated() / 1024**2
print(f"Standard Attention: {standard_memory:.2f} MB")
print(f"Flash Attention: {flash_memory:.2f} MB")
print(f"Memory saved: {(1 - flash_memory/standard_memory)*100:.1f}%")
# 验证结果一致
print(f"Results match: {torch.allclose(output1, output2, atol=1e-4)}")Flash Attention 的核心是在线 Softmax算法,允许我们不存储所有 scores 就能计算正确的 softmax。
def online_softmax_explained():
"""
详解在线 Softmax 算法
问题: 如何在分块处理时正确计算 softmax?
标准 Softmax:
softmax(x_i) = exp(x_i) / Σ exp(x_j)
挑战: 分块时不知道全局的 Σ exp(x_j)
解决方案: 维护两个统计量
- max: 当前见过的最大值
- sum: 当前的 exp 和
"""
# 例子: 计算 [1, 2, 3, 4, 5] 的 softmax
scores = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
print("="*60)
print("在线 Softmax 示例")
print("="*60)
print(f"Input scores: {scores}")
# 标准方法 (一次性计算)
standard_softmax = torch.softmax(scores, dim=0)
print(f"\n标准 Softmax: {standard_softmax}")
# 在线方法 (分两块处理: [1,2,3] 和 [4,5])
print("\n在线方法 (block_size=3):")
# Block 1: [1, 2, 3]
block1 = scores[:3]
print(f"\nBlock 1: {block1}")
max1 = block1.max()
exp1 = torch.exp(block1 - max1)
sum1 = exp1.sum()
output1 = exp1 / sum1
print(f" Max: {max1:.4f}")
print(f" Sum: {sum1:.4f}")
print(f" Softmax: {output1}")
# Block 2: [4, 5]
block2 = scores[3:]
print(f"\nBlock 2: {block2}")
max2 = block2.max()
exp2 = torch.exp(block2 - max2)
sum2 = exp2.sum()
print(f" Max: {max2:.4f}")
print(f" Sum: {sum2:.4f}")
# 合并: 需要重新缩放!
print(f"\n合并两个 blocks:")
# 全局最大值
global_max = max(max1, max2)
print(f" Global max: {global_max:.4f}")
# 重新缩放 block1
scale1 = torch.exp(max1 - global_max)
rescaled_sum1 = sum1 * scale1
rescaled_output1 = output1 * scale1
print(f" Block1 rescale factor: {scale1:.4f}")
print(f" Block1 rescaled sum: {rescaled_sum1:.4f}")
# 重新缩放 block2
scale2 = torch.exp(max2 - global_max)
rescaled_sum2 = sum2 * scale2
rescaled_exp2 = exp2 * scale2
print(f" Block2 rescale factor: {scale2:.4f}")
print(f" Block2 rescaled sum: {rescaled_sum2:.4f}")
# 全局 sum
global_sum = rescaled_sum1 + rescaled_sum2
print(f" Global sum: {global_sum:.4f}")
# 最终 softmax
online_softmax = torch.cat([
rescaled_output1 / (global_sum / scale1),
rescaled_exp2 / global_sum
])
print(f"\n在线 Softmax: {online_softmax}")
print(f"标准 Softmax: {standard_softmax}")
print(f"误差: {torch.abs(online_softmax - standard_softmax).max():.2e}")
online_softmax_explained()实际使用中,我们用优化过的 CUDA kernel 实现:
# 方法 1: 使用 flash-attn 库
try:
from flash_attn import flash_attn_func
def attention_with_flash(Q, K, V):
"""
使用 Flash Attention
注意: Flash Attention 需要输入格式为 [B, S, H, D]
标准 PyTorch 是 [B, H, S, D]
"""
# 转换格式: [B, H, S, D] → [B, S, H, D]
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# Flash Attention
output = flash_attn_func(
Q, K, V,
dropout_p=0.0,
softmax_scale=1.0 / math.sqrt(Q.shape[-1]),
causal=True # 自回归掩码
)
# 转换回来: [B, S, H, D] → [B, H, S, D]
output = output.transpose(1, 2)
return output
print("Flash Attention available!")
except ImportError:
print("Flash Attention not installed")
print("Install with: pip install flash-attn --no-build-isolation")
# 方法 2: 使用 Hugging Face 内置支持
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"gpt2",
torch_dtype=torch.float16, # Flash Attention 需要 FP16
attn_implementation="flash_attention_2",
device_map="cuda"
)
print("Model loaded with Flash Attention!")import time
def benchmark_attention_implementations():
"""对比不同 Attention 实现的性能"""
# 测试配置
configs = [
(4, 12, 512, 64),
(4, 12, 1024, 64),
(4, 12, 2048, 64),
(4, 12, 4096, 64),
]
print("="*80)
print(f"{'Seq Len':<10} {'Standard (ms)':<15} {'Flash (ms)':<15} {'Speedup':<10} {'Memory Saved':<15}")
print("="*80)
for batch, num_heads, seq_len, head_dim in configs:
Q = torch.randn(batch, num_heads, seq_len, head_dim,
dtype=torch.float16, device='cuda')
K = torch.randn(batch, num_heads, seq_len, head_dim,
dtype=torch.float16, device='cuda')
V = torch.randn(batch, num_heads, seq_len, head_dim,
dtype=torch.float16, device='cuda')
# 标准 Attention
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
_ = standard_attention(Q, K, V)
torch.cuda.synchronize()
standard_time = (time.time() - start) * 10 # ms per iteration
torch.cuda.reset_peak_memory_stats()
_ = standard_attention(Q, K, V)
standard_memory = torch.cuda.max_memory_allocated() / 1024**2
# Flash Attention
try:
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
_ = attention_with_flash(Q, K, V)
torch.cuda.synchronize()
flash_time = (time.time() - start) * 10
torch.cuda.reset_peak_memory_stats()
_ = attention_with_flash(Q, K, V)
flash_memory = torch.cuda.max_memory_allocated() / 1024**2
speedup = standard_time / flash_time
memory_saved = (1 - flash_memory/standard_memory) * 100
print(f"{seq_len:<10} {standard_time:<15.2f} {flash_time:<15.2f} "
f"{speedup:<10.2f}x {memory_saved:<15.1f}%")
except:
print(f"{seq_len:<10} {standard_time:<15.2f} {'N/A':<15} {'N/A':<10} {'N/A':<15}")
# 典型输出:
# ================================================================================
# Seq Len Standard (ms) Flash (ms) Speedup Memory Saved
# ================================================================================
# 512 12.34 4.56 2.71x 45.2%
# 1024 45.67 12.34 3.70x 52.8%
# 2048 178.23 38.91 4.58x 61.3%
# 4096 698.45 125.67 5.56x 68.7%Flash Attention 的优势:
- ✅ 显存占用从 O(N²) 降到 O(N)
- ✅ 速度提升 2-5x (序列越长提升越大)
- ✅ 支持更长的序列 (可以处理 8k+ tokens)
- ✅ 数值稳定性更好
- ✅ 不改变算法逻辑,结果完全一致
适用场景:
- 长序列推理 (>1024 tokens)
- 显存受限的情况
- 需要高吞吐量的场景
每次调用 CUDA kernel 都有 CPU-GPU 通信开销:
import torch
import time
def measure_kernel_launch_overhead():
"""测量 kernel launch 的开销"""
x = torch.randn(1000, 1000, device='cuda')
# 小的计算 (kernel 执行时间很短)
def small_compute():
return x + 1
# 测量 1000 次 kernel launch 的时间
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
_ = small_compute()
torch.cuda.synchronize()
total_time = time.time() - start
print(f"1000 次 kernel launch: {total_time*1000:.2f} ms")
print(f"平均每次: {total_time*1000/1000:.2f} ms")
print(f"每次 kernel launch 开销: ~5-50 µs")
# 对于 Decode 阶段,每生成一个 token 都要:
# 1. forward pass (多个 kernels)
# 2. sample (argmax/multinomial)
# 3. append to sequence
#
# 假设每步 50 个 kernel calls
# 生成 100 tokens = 5000 kernel calls
# 开销: 5000 * 20µs = 100ms
#
# 这是纯开销,不包含实际计算!
measure_kernel_launch_overhead()问题分析:
# 标准 Decode 循环
for step in range(num_tokens):
# CPU 准备输入
input_tensor = prepare_input() # CPU
# CPU → GPU: 发送 kernel launch 命令
# Kernel 1: Embedding
embedded = model.embed(input_tensor) # CPU overhead ~20µs
# Kernel 2-N: Transformer layers
hidden = model.layers(embedded) # 每层 ~20µs overhead
# Kernel N+1: Output projection
logits = model.lm_head(hidden) # ~20µs overhead
# Kernel N+2: Sampling
next_token = sample(logits) # ~20µs overhead
# GPU → CPU: 取回结果
next_token_cpu = next_token.cpu() # 同步点
# 每步累积的 overhead: 50+ kernels * 20µs = 1ms+
# 生成 100 tokens: 100ms+ 纯开销
# 这还不包括 Python 解释器的开销!CUDA Graph 将整个计算图"录制"下来,后续直接重放:
# CUDA Graph 的工作流程:
# 1. 录制阶段 (Record)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
# 这里的所有操作都会被记录
output = model(input)
# 2. 重放阶段 (Replay)
# 只需要一次 CPU → GPU 命令
graph.replay()
# 优势:
# - 所有 kernels 的拓扑关系已知
# - 可以提前优化调度
# - 只需一次 CPU-GPU 通信
# - kernel launch overhead 从 N 次降到 1 次import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Dict, Tuple
class CUDAGraphRunner:
"""
使用 CUDA Graph 加速 Decode 阶段
核心思想:
- Decode 阶段的计算模式固定 (输入形状 [batch, 1])
- 可以录制成 graph 并重复使用
- 大幅减少 kernel launch overhead
"""
def __init__(self, model, max_batch_sizes=[1, 2, 4, 8]):
self.model = model
self.device = next(model.parameters()).device
self.max_batch_sizes = max_batch_sizes
# 为不同的 batch size 缓存 graph
self.graphs: Dict[int, torch.cuda.CUDAGraph] = {}
self.static_inputs: Dict[int, torch.Tensor] = {}
self.static_outputs: Dict[int, any] = {}
print(f"Initialized CUDA Graph Runner for batch sizes: {max_batch_sizes}")
def capture_graph(self, batch_size: int, past_key_values):
"""
录制 CUDA Graph
步骤:
1. Warmup: 运行几次稳定化
2. Capture: 录制操作序列
3. Save: 保存 graph 和静态输入/输出
"""
if batch_size in self.graphs:
return # 已经录制过了
print(f"Capturing CUDA Graph for batch_size={batch_size}...")
# 创建静态输入 (地址固定)
static_input_ids = torch.zeros(
(batch_size, 1),
dtype=torch.long,
device=self.device
)
# Warmup: CUDA Graph 需要"热身"
# 原因: 第一次运行可能触发内存分配等操作
print(f" Warmup...")
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
with torch.no_grad():
_ = self.model(
static_input_ids,
past_key_values=past_key_values,
use_cache=True
)
torch.cuda.current_stream().wait_stream(s)
# 录制 Graph
print(f" Recording...")
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
with torch.no_grad():
static_outputs = self.model(
static_input_ids,
past_key_values=past_key_values,
use_cache=True
)
# 保存
self.graphs[batch_size] = graph
self.static_inputs[batch_size] = static_input_ids
self.static_outputs[batch_size] = static_outputs
print(f" Captured successfully!")
def run_graph(self, input_ids: torch.Tensor, past_key_values):
"""
运行 CUDA Graph
关键: 必须 in-place 更新静态输入!
"""
batch_size = input_ids.shape[0]
# 首次运行: 录制 graph
if batch_size not in self.graphs:
self.capture_graph(batch_size, past_key_values)
# 更新静态输入 (in-place!)
self.static_inputs[batch_size].copy_(input_ids)
# 重放 graph
self.graphs[batch_size].replay()
# 返回静态输出
return self.static_outputs[batch_size]
def get_stats(self):
"""获取统计信息"""
return {
'num_graphs': len(self.graphs),
'batch_sizes': list(self.graphs.keys())
}
def generate_with_cuda_graph(prompt: str, max_new_tokens: int = 50):
"""
Version 7a: 使用 CUDA Graph 的生成
"""
# 初始化
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model = model.cuda().half() # GPU + FP16
model.eval()
# 创建 CUDA Graph Runner
graph_runner = CUDAGraphRunner(model, max_batch_sizes=[1])
input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda()
print(f"Input: {prompt}")
print(f"Input tokens: {input_ids.shape[1]}")
# Prefill (不使用 graph,因为形状不固定)
print("\nPrefill stage (no graph)...")
with torch.no_grad():
outputs = model(input_ids, use_cache=True)
past_key_values = outputs.past_key_values
next_token_id = torch.argmax(outputs.logits[0, -1, :]).item()
generated = [next_token_id]
# Decode (使用 CUDA Graph)
print("Decode stage (with CUDA graph)...")
for i in range(max_new_tokens - 1):
last_token = torch.tensor([[next_token_id]], device='cuda')
# 使用 CUDA Graph
outputs = graph_runner.run_graph(last_token, past_key_values)
past_key_values = outputs.past_key_values
next_token_id = torch.argmax(outputs.logits[0, -1, :]).item()
generated.append(next_token_id)
if i < 5 or i % 10 == 0:
text = tokenizer.decode(
tokenizer.encode(prompt) + generated,
skip_special_tokens=True
)
print(f" Step {i+1}: {text[:60]}...")
# 解码
full_ids = tokenizer.encode(prompt) + generated
final_text = tokenizer.decode(full_ids, skip_special_tokens=True)
print(f"\nGraph stats: {graph_runner.get_stats()}")
print(f"\nGenerated text:\n{final_text}")
return final_text
# 测试
if torch.cuda.is_available():
generate_with_cuda_graph("The future of AI is", max_new_tokens=30)
else:
print("CUDA not available, skipping CUDA Graph demo")import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def benchmark_cuda_graph_speedup():
"""
对比有无 CUDA Graph 的性能
"""
if not torch.cuda.is_available():
print("CUDA not available")
return
# 准备
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
model = model.cuda().half()
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 准备输入
prompt = "Hello, I am"
input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda()
# Prefill
with torch.no_grad():
outputs = model(input_ids, use_cache=True)
past_kv = outputs.past_key_values
num_tokens = 100
print("="*70)
print("CUDA Graph Performance Comparison")
print("="*70)
# 方法 1: 不使用 CUDA Graph
print("\n1. WITHOUT CUDA Graph:")
torch.cuda.synchronize()
start = time.time()
temp_past_kv = past_kv
for _ in range(num_tokens):
last_token = torch.tensor([[0]], device='cuda')
with torch.no_grad():
outputs = model(last_token, past_key_values=temp_past_kv, use_cache=True)
temp_past_kv = outputs.past_key_values
torch.cuda.synchronize()
no_graph_time = time.time() - start
print(f" Time: {no_graph_time*1000:.2f} ms")
print(f" Tokens/sec: {num_tokens/no_graph_time:.2f}")
# 方法 2: 使用 CUDA Graph
print("\n2. WITH CUDA Graph:")
graph_runner = CUDAGraphRunner(model, max_batch_sizes=[1])
torch.cuda.synchronize()
start = time.time()
temp_past_kv = past_kv
for _ in range(num_tokens):
last_token = torch.tensor([[0]], device='cuda')
outputs = graph_runner.run_graph(last_token, temp_past_kv)
temp_past_kv = outputs.past_key_values
torch.cuda.synchronize()
graph_time = time.time() - start
print(f" Time: {graph_time*1000:.2f} ms")
print(f" Tokens/sec: {num_tokens/graph_time:.2f}")
# 对比
print("\n" + "="*70)
print("RESULTS")
print("="*70)
speedup = no_graph_time / graph_time
overhead_saved = (no_graph_time - graph_time) * 1000
print(f"Speedup: {speedup:.2f}x")
print(f"Overhead saved: {overhead_saved:.2f} ms")
print(f"Per-token overhead reduction: {overhead_saved/num_tokens:.3f} ms")
# 典型输出:
# ======================================================================
# CUDA Graph Performance Comparison
# ======================================================================
#
# 1. WITHOUT CUDA Graph:
# Time: 285.34 ms
# Tokens/sec: 350.45
#
# 2. WITH CUDA Graph:
# Time: 198.42 ms
# Tokens/sec: 503.99
#
# ======================================================================
# RESULTS
# ======================================================================
# Speedup: 1.44x
# Overhead saved: 86.92 ms
# Per-token overhead reduction: 0.869 ms"""
CUDA Graph 的限制:
1. 输入形状必须固定
- Prefill 阶段: 不同 prompt 长度不同 → 不能用 graph
- Decode 阶段: 总是 [batch, 1] → 可以用 graph
2. 不能包含 CPU-GPU 数据传输
- 不能在 graph 内部调用 .cpu()
- 不能在 graph 内部打印 tensor 值
3. 不能包含动态控制流
- 不能有 if/else 依赖于 tensor 值
- 不能有动态的 for 循环次数
4. 需要 warmup
- 第一次运行可能触发内存分配
- 需要运行几次稳定后再录制
"""
# 最佳实践示例
class ProductionCUDAGraphRunner:
"""生产级 CUDA Graph Runner"""
def __init__(self, model):
self.model = model
self.device = next(model.parameters()).device
self.graphs = {}
self.static_io = {}
# 配置
self.max_batch_sizes = [1, 2, 4, 8, 16, 32]
self.warmup_iters = 3
def should_use_graph(self, batch_size: int, seq_len: int) -> bool:
"""判断是否应该使用 CUDA Graph"""
# 只在 Decode 阶段使用 (seq_len=1)
if seq_len != 1:
return False
# 只支持预定义的 batch sizes
if batch_size not in self.max_batch_sizes:
return False
return True
def get_or_create_graph(self, batch_size: int, **model_kwargs):
"""获取或创建 graph"""
if batch_size not in self.graphs:
self._capture_graph(batch_size, **model_kwargs)
return self.graphs[batch_size]
def _capture_graph(self, batch_size: int, **model_kwargs):
"""录制 graph (内部方法)"""
# 实现细节见上面的 CUDAGraphRunner
pass
def forward(self, input_ids: torch.Tensor, **model_kwargs):
"""
智能 forward:
- 可以用 graph 时用 graph
- 不能用时回退到标准 forward
"""
batch_size, seq_len = input_ids.shape
if self.should_use_graph(batch_size, seq_len):
# 使用 graph
graph = self.get_or_create_graph(batch_size, **model_kwargs)
return self._run_with_graph(graph, input_ids, **model_kwargs)
else:
# 标准 forward
with torch.no_grad():
return self.model(input_ids, **model_kwargs)
def _run_with_graph(self, graph, input_ids, **model_kwargs):
"""使用 graph 运行"""
# 实现细节
pass
# 使用示例
runner = ProductionCUDAGraphRunner(model)
# Prefill: 自动使用标准 forward
prefill_out = runner.forward(prompt_ids) # shape: [1, 100]
# Decode: 自动使用 CUDA graph
for _ in range(num_tokens):
decode_out = runner.forward(last_token_id) # shape: [1, 1]
# graph 加速!CUDA Graph 的优势:
- ✅ 减少 kernel launch overhead (1.3-1.5x 加速)
- ✅ 降低延迟 (每个 token 快 0.5-1ms)
- ✅ 更好的 GPU 利用率
- ✅ 确定性执行 (调试更容易)
适用场景:
- Decode 阶段 (形状固定)
- 低延迟要求 (在线服务)
- 小 batch size (overhead 占比大)
不适用:
- Prefill 阶段 (形状变化)
- 需要动态控制流
- CPU-GPU 交互频繁
PyTorch 2.0 引入了 torch.compile,可以将 PyTorch 代码编译成优化的内核。
关键优化:
# 1. Kernel Fusion (算子融合)
#
# 未优化:
x = input + 1 # Kernel 1
y = x * 2 # Kernel 2
z = torch.relu(y) # Kernel 3
# 3 个 kernels, 2 次内存读写
#
# Torch Compile 优化后:
z = torch.relu((input + 1) * 2) # 1 个融合 kernel
# 1 个 kernel, 0 次中间内存读写
# 2. Memory Planning (内存规划)
#
# 提前分配所有需要的内存
# 重用临时 buffer
# 减少内存碎片
# 3. Operator Specialization (算子特化)
#
# 针对特定输入形状生成专门的代码
# 消除动态分支
# 更好的向量化import torch
import time
# 定义一个简单的模型
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(1000, 1000)
self.fc2 = torch.nn.Linear(1000, 1000)
self.fc3 = torch.nn.Linear(1000, 1000)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
# 创建模型
model = SimpleModel().cuda()
model.eval()
# 输入
x = torch.randn(32, 1000, device='cuda')
print("="*70)
print("Torch Compile Comparison")
print("="*70)
# 1. 标准执行
with torch.no_grad():
# Warmup
for _ in range(10):
_ = model(x)
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
output = model(x)
torch.cuda.synchronize()
eager_time = time.time() - start
print(f"\n1. Eager Mode (standard):")
print(f" Time: {eager_time*1000:.2f} ms")
print(f" Per iteration: {eager_time*10:.2f} ms")
# 2. Torch Compile
compiled_model = torch.compile(model, mode="reduce-overhead")
with torch.no_grad():
# Warmup (第一次会触发编译,很慢)
print(f"\n2. Compiling... (this will take a few seconds)")
for _ in range(10):
_ = compiled_model(x)
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
output = compiled_model(x)
torch.cuda.synchronize()
compiled_time = time.time() - start
print(f"\n Compiled Mode:")
print(f" Time: {compiled_time*1000:.2f} ms")
print(f" Per iteration: {compiled_time*10:.2f} ms")
# 对比
print(f"\n" + "="*70)
print("RESULTS")
print("="*70)
speedup = eager_time / compiled_time
print(f"Speedup: {speedup:.2f}x")
print(f"Time saved: {(eager_time - compiled_time)*1000:.2f} ms")
# 典型输出:
# ======================================================================
# Torch Compile Comparison
# ======================================================================
#
# 1. Eager Mode (standard):
# Time: 34.32 ms
# Per iteration: 0.34 ms
#
# 2. Compiling... (this will take a few seconds)
#
# Compiled Mode:
# Time: 25.65 ms
# Per iteration: 0.26 ms
#
# ======================================================================
# RESULTS
# ======================================================================
# Speedup: 1.34x
# Time saved: 8.67 ms"""
torch.compile 支持多种编译模式:
1. default: 平衡编译时间和运行性能
2. reduce-overhead: 减少 Python 开销 (推荐推理)
3. max-autotune: 最大化性能 (编译时间长)
"""
def compare_compile_modes():
"""对比不同编译模式"""
model = SimpleModel().cuda().eval()
x = torch.randn(32, 1000, device='cuda')
modes = ['default', 'reduce-overhead', 'max-autotune']
results = {}
print("="*70)
print("Compile Mode Comparison")
print("="*70)
for mode in modes:
print(f"\nCompiling with mode='{mode}'...")
# 编译
compiled_model = torch.compile(model, mode=mode)
# 测量编译时间
compile_start = time.time()
with torch.no_grad():
_ = compiled_model(x) # 触发编译
torch.cuda.synchronize()
compile_time = time.time() - compile_start
# 测量运行时间
with torch.no_grad():
# Warmup
for _ in range(10):
_ = compiled_model(x)
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
_ = compiled_model(x)
torch.cuda.synchronize()
run_time = time.time() - start
results[mode] = {
'compile_time': compile_time,
'run_time': run_time
}
print(f" Compile time: {compile_time:.2f}s")
print(f" Run time: {run_time*1000:.2f} ms")
print(f" Per iteration: {run_time*10:.2f} ms")
# 总结
print(f"\n" + "="*70)
print("SUMMARY")
print("="*70)
print(f"{'Mode':<20} {'Compile (s)':<15} {'Run (ms)':<15} {'Speed':<10}")
print("-"*70)
baseline = results['default']['run_time']
for mode, data in results.items():
speedup = baseline / data['run_time']
print(f"{mode:<20} {data['compile_time']:<15.2f} "
f"{data['run_time']*1000:<15.2f} {speedup:<10.2f}x")
# 典型输出:
# ======================================================================
# SUMMARY
# ======================================================================
# Mode Compile (s) Run (ms) Speed
# ----------------------------------------------------------------------
# default 0.68 54.77 1.00 x
# reduce-overhead 0.80 26.56 2.06 x
# max-autotune 0.75 27.46 1.99 xdef generate_with_torch_compile(prompt: str, max_new_tokens: int = 50):
"""
Version 7b: 使用 Torch Compile 的生成
"""
# 初始化
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model = model.cuda()
model.eval()
print("Compiling model...")
# 编译模型 (推荐 reduce-overhead 模式)
model = torch.compile(model, mode="reduce-overhead")
input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda()
print(f"\nInput: {prompt}")
print("First forward pass (will trigger compilation)...")
# 第一次运行会触发编译
with torch.no_grad():
outputs = model(input_ids, use_cache=True)
past_kv = outputs.past_key_values
next_token_id = torch.argmax(outputs.logits[0, -1, :]).item()
generated = [next_token_id]
print("Subsequent passes will be fast...")
# 后续运行使用编译后的版本
for i in range(max_new_tokens - 1):
last_token = input_ids[:, -1:]
with torch.no_grad():
outputs = model(
last_token,
past_key_values=past_kv,
use_cache=True
)
past_kv = outputs.past_key_values
next_token_id = torch.argmax(outputs.logits[0, -1, :]).item()
generated.append(next_token_id)
# 解码
full_ids = tokenizer.encode(prompt) + generated
final_text = tokenizer.decode(full_ids, skip_special_tokens=True)
print(f"\nGenerated:\n{final_text}")
return final_textTorch Compile 的优势:
- ✅ 自动优化,无需手动修改代码
- ✅ 1.3-2x 加速 (模型越复杂提升越大)
- ✅ 减少 Python overhead
- ✅ 更好的 kernel fusion
注意事项:
- 首次运行慢 (需要编译)
- 编译时间 (几秒到几分钟)
- 可能增加显存占用
量化通过使用低精度数值表示来减少显存占用和加速计算。
import torch
def compare_precisions():
"""对比不同精度的特性"""
# 创建一个测试 tensor
x_fp32 = torch.randn(1000, 1000)
print("="*70)
print("Precision Comparison")
print("="*70)
precisions = {
'FP32': (torch.float32, x_fp32),
'FP16': (torch.float16, x_fp32.half()),
'BF16': (torch.bfloat16, x_fp32.bfloat16()),
'INT8': (torch.int8, None), # 需要特殊处理
}
print(f"\n{'Type':<10} {'Bytes':<10} {'Range':<25} {'Precision':<15}")
print("-"*70)
# FP32
print(f"{'FP32':<10} {4:<10} {'±3.4e38':<25} {'~7 digits':<15}")
print(f"{' ':<10} (标准精度,默认)")
# FP16
print(f"\n{'FP16':<10} {2:<10} {'±65504':<25} {'~3 digits':<15}")
print(f"{' ':<10} (半精度,常用于训练)")
print(f"{' ':<10} (容易 overflow/underflow)")
# BF16
print(f"\n{'BF16':<10} {2:<10} {'±3.4e38':<25} {'~2 digits':<15}")
print(f"{' ':<10} (Brain Float,与 FP32 范围相同)")
print(f"{' ':<10} (更stable,但精度略低)")
# INT8
print(f"\n{'INT8':<10} {1:<10} {'-128 to 127':<25} {'整数':<15}")
print(f"{' ':<10} (8位整数,需要量化)")
print(f"{' ':<10} (显存占用最小)")
# 显存对比
print(f"\n" + "="*70)
print("Memory Usage (for 124M params like GPT-2)")
print("="*70)
params = 124_000_000
for name, bytes_per_param in [('FP32', 4), ('FP16', 2), ('INT8', 1)]:
memory_mb = params * bytes_per_param / 1024**2
saving = (1 - bytes_per_param/4) * 100
print(f"{name}: {memory_mb:.1f} MB ({saving:.0f}% saving)")
compare_precisions()
# 输出:
# ======================================================================
# Memory Usage (for 124M params like GPT-2)
# ======================================================================
# FP32: 473.0 MB (0% saving)
# FP16: 236.5 MB (50% saving)
# INT8: 118.3 MB (75% saving)最简单的量化方式:
def use_fp16():
"""使用 FP16 加速"""
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
# 方法 1: 手动转换
model_fp16 = model.half().cuda()
# 方法 2: 加载时指定
model_fp16 = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="cuda"
)
print("Model loaded in FP16!")
print(f"Memory footprint: {model_fp16.get_memory_footprint() / 1024**2:.2f} MB")
# 使用 (完全相同的 API)
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_ids = tokenizer.encode("Hello", return_tensors='pt').cuda()
with torch.no_grad():
outputs = model_fp16(input_ids)
print("Forward pass successful!")更激进的量化 (INT8/INT4):
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch
def use_8bit_quantization():
"""使用 INT8 量化"""
model_name = "gpt2"
# 配置 INT8 量化
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0, # 超过此阈值的异常值用 FP16
)
# 加载量化模型
model_int8 = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto" # 自动分配到 GPU
)
memory_mb = model_int8.get_memory_footprint() / 1024**2
print(f"INT8 Model memory: {memory_mb:.2f} MB")
# 使用 (API 完全相同)
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_ids = tokenizer.encode("Hello, I am", return_tensors='pt').to(model_int8.device)
with torch.no_grad():
outputs = model_int8(input_ids)
next_token = torch.argmax(outputs.logits[0, -1, :])
print(f"Generated token: {tokenizer.decode(next_token)}")
def use_4bit_quantization():
"""使用 INT4 量化 (最激进)"""
model_name = "gpt2"
# 配置 INT4 量化
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16, # 计算时用 FP16
bnb_4bit_use_double_quant=True, # 双重量化
bnb_4bit_quant_type="nf4" # NormalFloat4
)
# 加载
model_int4 = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto"
)
memory_mb = model_int4.get_memory_footprint() / 1024**2
print(f"INT4 Model memory: {memory_mb:.2f} MB")
# 使用
tokenizer = AutoTokenizer.from_pretrained(model_name)
prompt = "The future of AI"
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(model_int4.device)
# 生成
with torch.no_grad():
generated_ids = model_int4.generate(
input_ids,
max_new_tokens=20,
do_sample=False
)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(f"Generated: {generated_text}")def benchmark_quantization():
"""
完整的量化对比测试
"""
import time
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 准备测试输入
prompt = "Hello, I am"
input_ids = tokenizer.encode(prompt, return_tensors='pt')
results = {}
print("="*80)
print("Quantization Comparison")
print("="*80)
# 1. FP32 (baseline)
print("\n1. Loading FP32 model...")
model_fp32 = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
memory_fp32 = model_fp32.get_memory_footprint() / 1024**2
input_ids_fp32 = input_ids.to(model_fp32.device)
# Warmup
with torch.no_grad():
for _ in range(5):
_ = model_fp32(input_ids_fp32)
# Benchmark
torch.cuda.synchronize()
start = time.time()
with torch.no_grad():
for _ in range(50):
_ = model_fp32(input_ids_fp32)
torch.cuda.synchronize()
time_fp32 = time.time() - start
results['FP32'] = {
'memory': memory_fp32,
'time': time_fp32
}
del model_fp32
torch.cuda.empty_cache()
# 2. FP16
print("\n2. Loading FP16 model...")
model_fp16 = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
memory_fp16 = model_fp16.get_memory_footprint() / 1024**2
input_ids_fp16 = input_ids.to(model_fp16.device)
with torch.no_grad():
for _ in range(5):
_ = model_fp16(input_ids_fp16)
torch.cuda.synchronize()
start = time.time()
with torch.no_grad():
for _ in range(50):
_ = model_fp16(input_ids_fp16)
torch.cuda.synchronize()
time_fp16 = time.time() - start
results['FP16'] = {
'memory': memory_fp16,
'time': time_fp16
}
del model_fp16
torch.cuda.empty_cache()
# 3. INT8
print("\n3. Loading INT8 model...")
quantization_config_int8 = BitsAndBytesConfig(load_in_8bit=True)
model_int8 = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config_int8,
device_map="auto"
)
memory_int8 = model_int8.get_memory_footprint() / 1024**2
input_ids_int8 = input_ids.to(model_int8.device)
with torch.no_grad():
for _ in range(5):
_ = model_int8(input_ids_int8)
torch.cuda.synchronize()
start = time.time()
with torch.no_grad():
for _ in range(50):
_ = model_int8(input_ids_int8)
torch.cuda.synchronize()
time_int8 = time.time() - start
results['INT8'] = {
'memory': memory_int8,
'time': time_int8
}
del model_int8
torch.cuda.empty_cache()
# 4. INT4
print("\n4. Loading INT4 model...")
quantization_config_int4 = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)
model_int4 = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config_int4,
device_map="auto"
)
memory_int4 = model_int4.get_memory_footprint() / 1024**2
input_ids_int4 = input_ids.to(model_int4.device)
with torch.no_grad():
for _ in range(5):
_ = model_int4(input_ids_int4)
torch.cuda.synchronize()
start = time.time()
with torch.no_grad():
for _ in range(50):
_ = model_int4(input_ids_int4)
torch.cuda.synchronize()
time_int4 = time.time() - start
results['INT4'] = {
'memory': memory_int4,
'time': time_int4
}
# 打印结果
print("\n" + "="*80)
print("RESULTS")
print("="*80)
print(f"{'Type':<8} {'Memory (MB)':<15} {'Time (ms)':<15} {'Speedup':<10} {'Memory Saved':<15}")
print("-"*80)
baseline_memory = results['FP32']['memory']
baseline_time = results['FP32']['time']
for dtype, data in results.items():
memory_saved = (1 - data['memory']/baseline_memory) * 100
speedup = baseline_time / data['time']
print(f"{dtype:<8} {data['memory']:<15.2f} {data['time']*1000:<15.2f} "
f"{speedup:<10.2f}x {memory_saved:<15.1f}%")
# 典型输出:
# ================================================================================
# RESULTS
# ================================================================================
# Type Memory (MB) Time (ms) Speedup Memory Saved
# --------------------------------------------------------------------------------
# FP32 548.00 678.23 1.00x 0.0%
# FP16 274.00 423.45 1.60x 50.0%
# INT8 145.50 512.34 1.32x 73.4%
# INT4 86.25 589.12 1.15x 84.3%"""
量化的权衡分析:
1. FP16:
优点:
- 50% 显存节省
- 1.5-2x 加速 (GPU 支持 Tensor Cores)
- 几乎无精度损失
缺点:
- 数值范围小 (容易 overflow)
适用: 几乎所有场景
2. INT8:
优点:
- 75% 显存节省
- 允许加载更大模型
- 精度损失小 (<1% perplexity 下降)
缺点:
- 速度提升有限 (1.2-1.4x)
- 量化/反量化有开销
适用: 显存受限,模型太大
3. INT4:
优点:
- 87% 显存节省
- 可以加载 4x 大的模型
缺点:
- 速度可能更慢 (量化开销大)
- 精度损失明显 (2-5% perplexity 下降)
适用: 极度显存受限
选择建议:
- 有足够显存 → FP16 (最佳性价比)
- 显存不够 → INT8
- 极度显存不足 → INT4
- 追求极致性能 → FP16 + Flash Attention + CUDA Graph
"""让我们组合所有优化技术:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import time
class OptimizedInferenceEngine:
"""
综合所有优化的推理引擎
优化技术:
- Flash Attention
- CUDA Graph
- Torch Compile
- FP16 精度
- KV Cache
"""
def __init__(
self,
model_name: str,
use_flash_attention: bool = True,
use_cuda_graph: bool = True,
use_torch_compile: bool = True,
use_fp16: bool = True,
):
self.model_name = model_name
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("="*70)
print("Initializing Optimized Inference Engine")
print("="*70)
# 1. 加载模型 (FP16 + Flash Attention)
print(f"\n1. Loading model with optimizations:")
print(f" - Flash Attention: {use_flash_attention}")
print(f" - FP16: {use_fp16}")
model_kwargs = {}
if use_fp16:
model_kwargs['torch_dtype'] = torch.float16
if use_flash_attention:
model_kwargs['attn_implementation'] = 'flash_attention_2'
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
**model_kwargs,
device_map=self.device
)
self.model.eval()
memory_mb = self.model.get_memory_footprint() / 1024**2
print(f" Model memory: {memory_mb:.2f} MB")
# 2. Torch Compile
if use_torch_compile and self.device == 'cuda':
print(f"\n2. Compiling model with torch.compile...")
self.model = torch.compile(self.model, mode="reduce-overhead")
print(f" Compiled!")
else:
print(f"\n2. Torch Compile: disabled")
# 3. CUDA Graph
self.use_cuda_graph = use_cuda_graph and self.device == 'cuda'
if self.use_cuda_graph:
print(f"\n3. CUDA Graph: enabled")
self.graph_runner = CUDAGraphRunner(self.model, max_batch_sizes=[1])
else:
print(f"\n3. CUDA Graph: disabled")
# Tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
print(f"\n{'='*70}")
print("Engine ready!")
print(f"{'='*70}\n")
def generate(
self,
prompt: str,
max_new_tokens: int = 50,
temperature: float = 1.0,
do_sample: bool = False,
) -> str:
"""生成文本"""
input_ids = self.tokenizer.encode(prompt, return_tensors='pt')
input_ids = input_ids.to(self.device)
# Prefill阶段
with torch.no_grad():
outputs = self.model(input_ids, use_cache=True)
past_kv = outputs.past_key_values
logits = outputs.logits[0, -1, :]
# 采样第一个 token
if do_sample:
probs = torch.softmax(logits / temperature, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1).item()
else:
next_token_id = torch.argmax(logits).item()
generated = [next_token_id]
# Decode 阶段
for _ in range(max_new_tokens - 1):
last_token = torch.tensor([[next_token_id]], device=self.device)
# 使用 CUDA Graph (如果启用)
if self.use_cuda_graph:
outputs = self.graph_runner.run_graph(last_token, past_kv)
else:
with torch.no_grad():
outputs = self.model(
last_token,
past_key_values=past_kv,
use_cache=True
)
past_kv = outputs.past_key_values
logits = outputs.logits[0, -1, :]
# 采样
if do_sample:
probs = torch.softmax(logits / temperature, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1).item()
else:
next_token_id = torch.argmax(logits).item()
generated.append(next_token_id)
# 检查 EOS
if next_token_id == self.tokenizer.eos_token_id:
break
# 解码
full_ids = self.tokenizer.encode(prompt) + generated
return self.tokenizer.decode(full_ids, skip_special_tokens=True)
# 使用示例
def demo_optimized_engine():
"""演示优化引擎"""
# 创建引擎
engine = OptimizedInferenceEngine(
model_name="gpt2",
use_flash_attention=False, # GPT-2 不支持
use_cuda_graph=True,
use_torch_compile=True,
use_fp16=True,
)
# 测试生成
prompts = [
"The future of AI is",
"Once upon a time",
"The secret of happiness is",
]
print("\nGenerating texts...\n")
for i, prompt in enumerate(prompts, 1):
print(f"{i}. Prompt: {prompt}")
start = time.time()
generated_text = engine.generate(prompt, max_new_tokens=30)
elapsed = time.time() - start
print(f" Generated: {generated_text}")
print(f" Time: {elapsed*1000:.2f} ms")
print()
if torch.cuda.is_available():
demo_optimized_engine()def final_benchmark():
"""
对比所有版本的性能
从 Version 0.1 (baseline) 到 Version 7 (all optimizations)
"""
if not torch.cuda.is_available():
print("CUDA not available")
return
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
prompt = "Hello, I am"
input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda()
max_new_tokens = 50
results = {}
print("="*80)
print("FINAL PERFORMANCE COMPARISON")
print("="*80)
print("\nBenchmarking all versions...\n")
# Version 0.1: Baseline (no optimizations)
print("1. Version 0.1 (Baseline - no optimizations)...")
model_v01 = AutoModelForCausalLM.from_pretrained(model_name).cuda()
model_v01.eval()
torch.cuda.synchronize()
start = time.time()
temp_ids = input_ids.clone()
with torch.no_grad():
for _ in range(max_new_tokens):
outputs = model_v01(temp_ids)
next_token = torch.argmax(outputs.logits[0, -1, :])
temp_ids = torch.cat([temp_ids, next_token.unsqueeze(0).unsqueeze(0)], dim=1)
torch.cuda.synchronize()
time_v01 = time.time() - start
results['v0.1'] = time_v01
del model_v01
torch.cuda.empty_cache()
# Version 3: + KV Cache
print("2. Version 3 (+ KV Cache)...")
model_v3 = AutoModelForCausalLM.from_pretrained(model_name).cuda()
model_v3.eval()
torch.cuda.synchronize()
start = time.time()
with torch.no_grad():
outputs = model_v3(input_ids, use_cache=True)
past_kv = outputs.past_key_values
for _ in range(max_new_tokens - 1):
last_token = torch.tensor([[0]], device='cuda')
outputs = model_v3(last_token, past_key_values=past_kv, use_cache=True)
past_kv = outputs.past_key_values
torch.cuda.synchronize()
time_v3 = time.time() - start
results['v3'] = time_v3
del model_v3
torch.cuda.empty_cache()
# Version 7: All optimizations
print("3. Version 7 (All optimizations)...")
model_v7 = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16
).cuda()
model_v7.eval()
# Torch compile 同时打开会有兼容性问题,还在解决
# model_v7 = torch.compile(model_v7, mode="reduce-overhead")
# CUDA Graph
graph_runner = CUDAGraphRunner(model_v7, max_batch_sizes=[1])
torch.cuda.synchronize()
start = time.time()
with torch.no_grad():
outputs = model_v7(input_ids, use_cache=True)
past_kv = outputs.past_key_values
for _ in range(max_new_tokens - 1):
last_token = torch.tensor([[0]], device='cuda')
outputs = graph_runner.run_graph(last_token, past_kv)
past_kv = outputs.past_key_values
torch.cuda.synchronize()
time_v7 = time.time() - start
results['v7'] = time_v7
# 打印结果
print("\n" + "="*80)
print("RESULTS")
print("="*80)
print(f"{'Version':<30} {'Time (ms)':<15} {'Tokens/s':<15} {'Speedup':<10}")
print("-"*80)
versions = {
'v0.1': 'Version 0.1 (Baseline)',
'v3': 'Version 3 (+ KV Cache)',
'v7': 'Version 7 (All optimizations)',
}
baseline_time = results['v0.1']
for key, name in versions.items():
time_ms = results[key] * 1000
tokens_per_sec = max_new_tokens / results[key]
speedup = baseline_time / results[key]
print(f"{name:<30} {time_ms:<15.2f} {tokens_per_sec:<15.2f} {speedup:<10.2f}x")
# 详细分析
print("\n" + "="*80)
print("OPTIMIZATION BREAKDOWN")
print("="*80)
print(f"\nTotal speedup: {baseline_time/results['v7']:.2f}x")
print(f"Time saved: {(baseline_time - results['v7'])*1000:.2f} ms")
print(f"\nOptimization contributions:")
print(f" KV Cache: {results['v0.1']/results['v3']:.2f}x")
print(f" FP16+Compile+Graph: {results['v3']/results['v7']:.2f}x")
# 典型输出:
# ================================================================================
# RESULTS
# ================================================================================
# Version Time (ms) Tokens/s Speedup
# --------------------------------------------------------------------------------
# Version 0.1 (Baseline) 778.98 64.19 1.00x
# Version 3 (+ KV Cache) 476.96 104.83 1.63x
# Version 7 (All optimizations) 647.28 77.25 1.2x
#
# ================================================================================
# OPTIMIZATION BREAKDOWN
# ================================================================================
#
# Total speedup: 1.2x
# Time saved: 131.7 ms
#
# Optimization contributions:
# KV Cache: 1.63x
# FP16+Compile+Graph: 0.74x详细讲解了 LLM 推理的四大高级优化技术:
- Flash Attention: 减少显存占用从 O(N²) 到 O(N), 2-5x 加速
- CUDA Graph: 消除 kernel launch overhead, 1.3-1.5x 加速
- Torch Compile: JIT 编译优化, 1.3-2x 加速
- 量化: FP16/INT8/INT4, 节省 50-87% 显存
- 第一优先级: KV Cache (必须, 10x+ 提升)
- 第二优先级: FP16 + Batching (简单有效, 3-5x 提升)
- 第三优先级: Flash Attention (长序列必备)
- 第四优先级: CUDA Graph + Torch Compile (锦上添花, 2x 提升)
- 显存受限时: INT8/INT4 量化
| 组件 | 我们的实现 | nano-vllm |
|---|---|---|
| KV Cache | Hugging Face 内置 | 手动实现 PagedAttention |
| Scheduler | 简单队列 | 完整的 Continuous Batching |
| BlockManager | 简化版本 | 支持 Prefix Caching |
| Attention | 标准 PyTorch | Flash Attention + Triton Kernel |
| Decode | 标准执行 | CUDA Graph 优化 |
| 并行 | 单卡 | Tensor Parallelism |
现在你已经理解了基础,可以深入阅读 nano-vllm:
-
从 LLMEngine 开始 (
llm_engine.py)- 看
generate()方法 - 理解整体流程
- 看
-
理解 Scheduler (
scheduler.py)schedule()如何选择序列postprocess()如何处理结果
-
深入 ModelRunner (
model_runner.py)prepare_prefill()和prepare_decode()run_model()的 CUDA Graph 逻辑
-
学习 BlockManager (
block_manager.py)allocate()的 prefix caching 实现compute_hash()如何检测前缀
-
研究 Attention (
attention.py)- Flash Attention 的集成
- Triton kernel 的 KV 存储
-
修改参数
# 尝试不同的配置 llm = LLM( model_path, max_num_seqs=256, # 调整并发数 kvcache_block_size=128, # 调整 block 大小 enforce_eager=True # 禁用 CUDA Graph )
-
添加日志
# 在关键位置添加打印 def schedule(self): print(f"Waiting: {len(self.waiting)}, Running: {len(self.running)}") # ...
-
性能测试
# 测试不同优化的效果 # - 有无 KV Cache # - 有无 CUDA Graph # - 不同 batch size
-
PyTorch 基础 (第零部分)
- Tensor 操作
- 矩阵乘法
- nn.Module
-
基础推理 (第一部分)
- 10 行代码生成文本
- 理解推理循环
-
Prefill & Decode (第二部分)
- 两个阶段的区别
- 为什么分开处理
-
KV Cache (第三部分)
- 消除重复计算
- 10x+ 加速
-
Batching (第四部分)
- 并行处理多个请求
- 2-3x 吞吐量提升
-
Scheduler (第五部分)
- Continuous Batching
- 动态资源管理
-
PagedAttention (第六部分)
- 分块 KV Cache
- 节省 70-80% 显存
-
高级优化 (第七部分)
- Flash Attention
- CUDA Graph
- Torch Compile
- 量化
Version 0.1: 最简单的生成 (~30 行)
↓
Version 2: 区分 Prefill/Decode (~50 行)
↓
Version 3: 添加 KV Cache (~50 行)
↓
Version 4: 支持 Batching (~80 行)
↓
Version 5: 添加 Scheduler (~150 行)
↓
Version 6: PagedAttention (~250 行)
Baseline (Version 0.1):
+ KV Cache (Version 3):
+ Batching (Version 4):
+ Scheduler (Version 5):
+ PagedAttention (Version 6):
+ Flash Attention:
+ CUDA Graph:
+ Tensor Parallelism:
- 运行所有代码示例
- 修改参数,观察效果
- 阅读 nano-vllm 源码
- 实现自己的优化
- 贡献到开源项目
祝学习愉快! 🚀