Skip to content

Remove unnecessary concatenation using the zipping approach.  #1768

@wujingyue

Description

@wujingyue

A spin-off from #1502 (comment). Created for tracking progress.

Problem

Below is a common pattern in nanoGPT's backprop.

dQ, dK, dV = scaled_dot_product_attention_backprop(...)  # bf16[16,12,128,64], bf16[16,12,128,64], bf16[16,12,128,64]
dQ = transpose(dQ, [0, 2, 1, 3])  # [16, 128, 12, 64]
dQ = reshape(dQ, [16, 128, 768])  # [16, 128, 768]
dK = the product of the same transpose and reshape on dK
dV = the product of the same transpose and reshape on dV

concatenated = cat([dQ, dK, dV], axis=-1)

dQKV_sum = sum(concatenated, ...)  # omitting a round trip to float
dQKV_view = reshape(concatenated, [B*S, H*D*3])
dQKV_permute = transpose(dQKV_view, [1, 0])

return dQKV_sum, dQKV_view, dQKV_permute

Because nvFuser doesn't take sdpa_backward and therefore sees three unconnected input tensors (dQ, dK, and dV), it has to materialize dQKV_view and dQKV_permute.

Solution

TL;DR: change Thunder's cudnnex to feed nvFuser a concatenated tensor that contains dQ, dK and dV, so nvFuser realizes that the existing cat is unnecessary and removes it.

  1. cudnnex will convert the SDPA backward op into a cudnn spda_backward kernel (which outputs one dQKV tensor) followed by a split.
  2. cudnnex will give that split to nvFuser, so nvFuser will see the following pattern:
    dQKV = fd.ops.define_tensor([B, S, H, D*3])
    dQ = fd.ops.slice(dQKV, ...)
    dQ = fd.ops.view(dQ, ...)
    dQ = fd.ops.permute(dQ, ...)
    dK = ...the same slice-view-permute pattern
    dV = ...the same slice-view-permute pattern
    
    concatenated = fd.ops.cat([dQ, dK, dV], axis=-1)
    
    dQKV_sum = fd.ops.sum(concatenated, ...)  # omitting a round trip to float
    dQKV_view = fd.ops.view(concatenated, [B*S, H*D*3])
    dQKV_permute = fd.ops.permute(dQKV_view, [1, 0])
    
  3. nvFuser will cancel the slices and the cat and merge all view and permute between them, so the above will become:
    dQKV = fd.ops.define_tensor([B, S, H, D*3])
    concatenated = fd.ops.permute(fd.ops.view(dQKV, ...), ...)
    dQKV_sum = fd.ops.sum(concatenated, ...)  # omitting a round trip to float
    dQKV_view = fd.ops.view(concatenated, [B*S, H*D*3])
    dQKV_permute = fd.ops.permute(dQKV_view, [1, 0])
    
    As a result, dQKV_view and dQKV_permute will become aliases of dQKV. The fusion will boil down to a ReduceSum kernel that sums [B,S,H,D*3] to [H*D*3].

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions