-
Notifications
You must be signed in to change notification settings - Fork 79
Closed
Description
A spin-off from #1502 (comment). Created for tracking progress.
Problem
Below is a common pattern in nanoGPT's backprop.
dQ, dK, dV = scaled_dot_product_attention_backprop(...) # bf16[16,12,128,64], bf16[16,12,128,64], bf16[16,12,128,64]
dQ = transpose(dQ, [0, 2, 1, 3]) # [16, 128, 12, 64]
dQ = reshape(dQ, [16, 128, 768]) # [16, 128, 768]
dK = the product of the same transpose and reshape on dK
dV = the product of the same transpose and reshape on dV
concatenated = cat([dQ, dK, dV], axis=-1)
dQKV_sum = sum(concatenated, ...) # omitting a round trip to float
dQKV_view = reshape(concatenated, [B*S, H*D*3])
dQKV_permute = transpose(dQKV_view, [1, 0])
return dQKV_sum, dQKV_view, dQKV_permute
Because nvFuser doesn't take sdpa_backward and therefore sees three unconnected input tensors (dQ, dK, and dV), it has to materialize dQKV_view and dQKV_permute.
Solution
TL;DR: change Thunder's cudnnex to feed nvFuser a concatenated tensor that contains dQ, dK and dV, so nvFuser realizes that the existing cat is unnecessary and removes it.
- cudnnex will convert the SDPA backward op into a cudnn
spda_backwardkernel (which outputs one dQKV tensor) followed by a split. - cudnnex will give that split to nvFuser, so nvFuser will see the following pattern:
dQKV = fd.ops.define_tensor([B, S, H, D*3]) dQ = fd.ops.slice(dQKV, ...) dQ = fd.ops.view(dQ, ...) dQ = fd.ops.permute(dQ, ...) dK = ...the same slice-view-permute pattern dV = ...the same slice-view-permute pattern concatenated = fd.ops.cat([dQ, dK, dV], axis=-1) dQKV_sum = fd.ops.sum(concatenated, ...) # omitting a round trip to float dQKV_view = fd.ops.view(concatenated, [B*S, H*D*3]) dQKV_permute = fd.ops.permute(dQKV_view, [1, 0]) - nvFuser will cancel the
slices and thecatand merge allviewandpermutebetween them, so the above will become:As a result,dQKV = fd.ops.define_tensor([B, S, H, D*3]) concatenated = fd.ops.permute(fd.ops.view(dQKV, ...), ...) dQKV_sum = fd.ops.sum(concatenated, ...) # omitting a round trip to float dQKV_view = fd.ops.view(concatenated, [B*S, H*D*3]) dQKV_permute = fd.ops.permute(dQKV_view, [1, 0])dQKV_viewanddQKV_permutewill become aliases ofdQKV. The fusion will boil down to a ReduceSum kernel that sums[B,S,H,D*3]to[H*D*3].
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels