[KDA] Add Grouped Value Attention (GVA) support#833
Conversation
Support HV > H (num_v_heads > num_qk_heads) in KDA ops and layer, following the same pattern as gated_delta_rule. Ops changes: - All Triton kernels (chunk_intra, token_parallel, wy_fast, chunk_bwd, chunk_gla_fwd_o_gk) now accept HV parameter and use i_h = i_hv // (HV // H) for qk-head mapping. Backward compatible when HV == H. - g shape: [B, T, HV, K], beta: [B, T, HV], state: [N, HV, K, V] - Backward pass reduces dq/dk from HV back to H via group sum. Layer changes: - f_proj outputs gate_dim = HV * K directly (no repeat needed) - b_proj outputs num_v_heads directly - A_log shape [HV], dt_bias shape [HV * K] Also adds GVA test cases to existing test_naive_chunk, test_fused_recurrent, and test_chunk, and configures isort in pyproject.toml.
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
WalkthroughIntroduces Grouped Value Attention (GVA) by adding a value-head dimension Changes
Sequence Diagram(s)sequenceDiagram
participant Layer as Model Layer
participant Host as Python wrappers
participant Kernel as Triton Kernel
participant GPU as GPU Memory
rect rgba(100,150,240,0.5)
Layer->>Host: prepare q, k, g, v, beta, A_log, dt_bias with HV
end
rect rgba(120,200,80,0.5)
Host->>GPU: upload HV-aware tensors, launch kernel (grid B*HV)
GPU->>Kernel: kernel invoked (maps program_id -> B, i_hv)
Kernel->>GPU: load q/k/g/v using HV-strides
Kernel->>GPU: compute forward/backward (Aqk, o, dA, dq/dk/dv)
Kernel-->>Host: write HV-shaped outputs
end
rect rgba(200,120,160,0.5)
Host->>Host: if HV>H, reduce dq/dk across groups (G=HV//H)
Host-->>Layer: return outputs, final_state, gradients
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request adds support for Grouped Value Attention (GVA) to the KDA and GLA modules, enabling configurations where the number of value heads exceeds query/key heads. The update includes modifications to the Triton kernels to handle the expanded value-head dimension and updates to the naive implementations and test suites. Review feedback highlighted critical bugs in the wy_fast.py kernel related to the indexing and allocation of the qg and kg tensors. Specifically, these tensors must be managed at the value-head resolution to avoid race conditions and memory corruption during GVA operations.
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
fla/ops/kda/chunk.py (2)
263-298:⚠️ Potential issue | 🟡 MinorThe examples still construct
h0with a rejected dtype.Line 336 requires
initial_state.dtype == torch.float32, but both new examples buildh0astorch.bfloat16. Copy-pasting either snippet now fails immediately.📝 Suggested doc fix
- >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') + >>> h0 = torch.randn(B, H, K, V, dtype=torch.float32, device='cuda') ... - >>> h0 = torch.randn(B, HV, K, V, dtype=torch.bfloat16, device='cuda') + >>> h0 = torch.randn(B, HV, K, V, dtype=torch.float32, device='cuda')🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/kda/chunk.py` around lines 263 - 298, The doc examples construct initial_state `h0` with dtype=torch.bfloat16 but `chunk_kda` requires `initial_state.dtype == torch.float32`; update the two example snippets to create `h0` with dtype=torch.float32 (e.g., torch.randn(..., dtype=torch.float32, device='cuda')) so the dtype check in chunk_kda passes while leaving other tensors (q, k, v, g, beta) as-is.
338-358:⚠️ Potential issue | 🟠 MajorValidate the rest of the HV-shaped inputs before launching kernels.
This block now checks
gandbeta, butv,initial_state,A_log, anddt_biascan still arrive in stale pre-GVA shapes. A mismatched[N, H, K, V]state or[H]/[H*K]gate params will only fail later inside the Triton path, with much worse diagnostics.🔍 Suggested guardrail
- B, T, H, K, HV = *q.shape, v.shape[2] + B, T, H, K = q.shape + assert v.ndim == 4 and v.shape[:2] == (B, T), ( + f"v must have shape [B, T, HV, V] with the same batch/seq dims as q, got {list(v.shape)}" + ) + HV, V = v.shape[2], v.shape[3] assert q.shape == k.shape, f"q and k must have the same shape, got q={q.shape} vs k={k.shape}" assert K <= 256, f"Currently we only support key headdim <=256 for KDA, got {K}." assert HV % H == 0, ( f"For GVA, num_v_heads (HV={HV}) must be evenly divisible by num_qk_heads (H={H}), " f"but got HV % H = {HV % H}" ) assert g.shape == (B, T, HV, K), f"g must have shape [B, T, HV, K]={[B, T, HV, K]}, got {list(g.shape)}" assert beta.shape == (B, T, HV), f"beta must have shape [B, T, HV]={[B, T, HV]}, got {list(beta.shape)}" + if initial_state is not None: + assert initial_state.shape[1:] == (HV, K, V), ( + f"initial_state must have shape [N, HV, K, V], got {list(initial_state.shape)}" + ) + if use_gate_in_kernel: + assert A_log.shape == (HV,), f"A_log must have shape [HV], got {list(A_log.shape)}" + if dt_bias is not None: + assert dt_bias.shape == (HV * K,), f"dt_bias must have shape [HV * K], got {list(dt_bias.shape)}"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/kda/chunk.py` around lines 338 - 358, The code currently validates g and beta shapes but can still accept stale shapes for v, initial_state, A_log and dt_bias which will only error later; add upfront shape guards: verify v.ndim==4 and v.shape[0:2]==(B,T) and v.shape[2]==HV (and capture V=v.shape[3]); if initial_state is provided assert its leading batch dim matches B and that its remaining dims include H and K in the expected order (e.g. initial_state.shape == (B, H, K, V) or initial_state.shape[:3] == (B, H, K) if full V omitted); when use_gate_in_kernel is true assert A_log.shape is either (H,) or (H*K,) and dt_bias (if not None) has a matching shape; ensure these checks reference the existing symbols v, initial_state, A_log, dt_bias, H, K, HV so mismatches fail early with clear messages.
🧹 Nitpick comments (1)
tests/ops/test_kda.py (1)
19-35: Add at least oneHV > Hvarlen case.The new GVA coverage is all fixed-length. Most of the risky pointer/stride changes in this PR also touch the
IS_VARLENbranches, so I’d add oneHV > Hcase totest_chunk_varlen/test_chunk_varlen_prefillas well.Also applies to: 84-98, 346-381
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/test_kda.py` around lines 19 - 35, Add at least one parameter tuple where HV > H to the pytest.parametrize lists used by test_chunk_varlen and test_chunk_varlen_prefill so a varlen case is exercised (e.g., change one tuple in the list of tests from (HV==H) to something like (H=1, HV=2, ...) or append a new pytest.param with HV > H); update the param blocks that define ("B","T","H","HV","D","scale","gate_logit_normalizer","dtype") (the same blocks around the existing tuples) so the test suite covers a variable-length case and exercises the IS_VARLEN branches.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/ops/kda/chunk_bwd.py`:
- Around line 442-443: The code computes G = HV // H using q.shape and v.shape
(variables H, HV) without validating assumptions; add a guard in chunk_bwd.py
before computing G and before the HV execution path (the branch that remaps
heads and does view(..., H, G, ...)) to verify HV >= H and HV % H == 0, and if
the check fails raise a clear exception (e.g., ValueError) or take a safe
fallback path; apply the same validation at the second occurrence around the
code at the other HV path (the lines corresponding to the block at ~575-578) so
invalid ratios cannot produce misroutes, divide-by-zero or invalid reshapes.
In `@fla/ops/kda/wy_fast.py`:
- Around line 267-271: recompute_w_u_fwd_kda_kernel is writing per-grouped-head
results into H-headed tensors qg and kg causing races when HV > H; allocate qg
and kg to be HV-headed (matching grouped heads) or change the kernel write index
to use i_hv (not i_h) so each (chunk, hv) writes to a unique slot, and update
downstream consumers (chunk_kda_bwd and the chunk_gated_delta_rule_* functions)
to read the HV-headed qg/kg (or to reduce from HV->H deterministically) so the
per-value-head gate information is preserved and deterministic.
- Line 267: The code unconditionally calls torch.empty_like(gk) while the
signature still allows gk: torch.Tensor | None; update the function to either
require a non-optional gk or add an explicit early check that raises a clear
ValueError if gk is None; locate the site using the symbols gk and w (the line
with w = torch.empty_like(gk)) and modify the function signature or insert a
guard like `if gk is None: raise ValueError("gk must be provided")` before
allocating w so callers get a clear contract violation instead of an opaque
allocation error.
---
Outside diff comments:
In `@fla/ops/kda/chunk.py`:
- Around line 263-298: The doc examples construct initial_state `h0` with
dtype=torch.bfloat16 but `chunk_kda` requires `initial_state.dtype ==
torch.float32`; update the two example snippets to create `h0` with
dtype=torch.float32 (e.g., torch.randn(..., dtype=torch.float32, device='cuda'))
so the dtype check in chunk_kda passes while leaving other tensors (q, k, v, g,
beta) as-is.
- Around line 338-358: The code currently validates g and beta shapes but can
still accept stale shapes for v, initial_state, A_log and dt_bias which will
only error later; add upfront shape guards: verify v.ndim==4 and
v.shape[0:2]==(B,T) and v.shape[2]==HV (and capture V=v.shape[3]); if
initial_state is provided assert its leading batch dim matches B and that its
remaining dims include H and K in the expected order (e.g. initial_state.shape
== (B, H, K, V) or initial_state.shape[:3] == (B, H, K) if full V omitted); when
use_gate_in_kernel is true assert A_log.shape is either (H,) or (H*K,) and
dt_bias (if not None) has a matching shape; ensure these checks reference the
existing symbols v, initial_state, A_log, dt_bias, H, K, HV so mismatches fail
early with clear messages.
---
Nitpick comments:
In `@tests/ops/test_kda.py`:
- Around line 19-35: Add at least one parameter tuple where HV > H to the
pytest.parametrize lists used by test_chunk_varlen and test_chunk_varlen_prefill
so a varlen case is exercised (e.g., change one tuple in the list of tests from
(HV==H) to something like (H=1, HV=2, ...) or append a new pytest.param with HV
> H); update the param blocks that define
("B","T","H","HV","D","scale","gate_logit_normalizer","dtype") (the same blocks
around the existing tuples) so the test suite covers a variable-length case and
exercises the IS_VARLEN branches.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 8104be1b-8415-4078-bedb-387adb34381a
📒 Files selected for processing (11)
fla/layers/kda.pyfla/ops/gla/chunk.pyfla/ops/kda/chunk.pyfla/ops/kda/chunk_bwd.pyfla/ops/kda/chunk_fwd.pyfla/ops/kda/chunk_intra.pyfla/ops/kda/chunk_intra_token_parallel.pyfla/ops/kda/naive.pyfla/ops/kda/wy_fast.pypyproject.tomltests/ops/test_kda.py
- Guard q/qg/kg pointer mutations with STORE_QG/STORE_KG to avoid Triton pointer arithmetic on None - Fix qg/kg strides from H*K to HV*K in make_block_ptr - Fix qg/kg pointer offsets from (bos*H+i_h)*K to (bos*HV+i_hv)*K - Fix qg/kg allocation shape from empty_like(q) [B,T,H,K] to [B,T,HV,K]
There was a problem hiding this comment.
♻️ Duplicate comments (1)
fla/ops/kda/wy_fast.py (1)
267-271:⚠️ Potential issue | 🟡 Minor
gkis dereferenced unconditionally despite optional signature.Line 267 calls
torch.empty_like(gk)but the signature at line 253 declaresgk: torch.Tensor | None = None. IfgkisNone, this raises anAttributeErrorinstead of a clear contract violation.Suggested fix
def recompute_w_u_fwd( k: torch.Tensor, v: torch.Tensor, beta: torch.Tensor, A: torch.Tensor, q: torch.Tensor | None = None, - gk: torch.Tensor | None = None, + gk: torch.Tensor, cu_seqlens: torch.LongTensor | None = None, chunk_indices: torch.LongTensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:Or add an explicit guard:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + if gk is None: + raise ValueError("recompute_w_u_fwd requires `gk` to be provided") B, T, H, K, V = *k.shape, v.shape[-1]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/kda/wy_fast.py` around lines 267 - 271, The code unconditionally calls torch.empty_like(gk) assigning to w even though gk: torch.Tensor | None; change this to allocate w conditionally (e.g. w = torch.empty_like(gk) if gk is not None else torch.empty(B, T, HV, K, device=k.device, dtype=k.dtype)) so you don't dereference None, and similarly ensure any other allocations that relied on optional tensors (like kg) follow the same pattern before calling recompute_w_u_fwd_kda_kernel; keep the kernel call signature unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@fla/ops/kda/wy_fast.py`:
- Around line 267-271: The code unconditionally calls torch.empty_like(gk)
assigning to w even though gk: torch.Tensor | None; change this to allocate w
conditionally (e.g. w = torch.empty_like(gk) if gk is not None else
torch.empty(B, T, HV, K, device=k.device, dtype=k.dtype)) so you don't
dereference None, and similarly ensure any other allocations that relied on
optional tensors (like kg) follow the same pattern before calling
recompute_w_u_fwd_kda_kernel; keep the kernel call signature unchanged.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 30779959-c8f3-4118-a44f-782dcb8cc6b2
📒 Files selected for processing (1)
fla/ops/kda/wy_fast.py
There was a problem hiding this comment.
🧹 Nitpick comments (1)
fla/ops/kda/wy_fast.py (1)
257-258: Validation forHV % H == 0exists at the public API level and is not required inrecompute_w_u_fwd.The constraint is already validated in
chunk_kda()inchunk.py(line ~353) with a clear error message. Since all calls torecompute_w_u_fwdgo through this validated entry point, adding validation here would be redundant defensive programming rather than a requirement. This suggestion can be skipped unlesswy_fast.pyfunctions become public APIs.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/kda/wy_fast.py` around lines 257 - 258, Remove the redundant HV % H == 0 validation inside recompute_w_u_fwd in wy_fast.py since chunk_kda (in chunk.py) already enforces this invariant for all call sites; locate the HV and H usage in the recompute_w_u_fwd function (where B, T, H, K, V = *k.shape, v.shape[-1] and HV = v.shape[2]) and delete the defensive check and associated error handling so the function relies on the public API validation in chunk_kda.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@fla/ops/kda/wy_fast.py`:
- Around line 257-258: Remove the redundant HV % H == 0 validation inside
recompute_w_u_fwd in wy_fast.py since chunk_kda (in chunk.py) already enforces
this invariant for all call sites; locate the HV and H usage in the
recompute_w_u_fwd function (where B, T, H, K, V = *k.shape, v.shape[-1] and HV =
v.shape[2]) and delete the defensive check and associated error handling so the
function relies on the public API validation in chunk_kda.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 227a3346-1e33-4f68-9218-47de4c7ba540
📒 Files selected for processing (1)
fla/ops/kda/wy_fast.py
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
fla/ops/kda/wy_fast.py (1)
256-270:⚠️ Potential issue | 🟡 MinorReturn type annotation is inconsistent with actual return values.
kgis always allocated (line 270) and neverNone, but the return type annotation on line 256 specifiestorch.Tensor | Nonefor the fourth element.🔧 Suggested fix
-) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor]:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/kda/wy_fast.py` around lines 256 - 270, The function return type annotation declares the fourth tuple element as "torch.Tensor | None" but the code always allocates kg (variable name kg) so the annotation is inaccurate; update the function signature return annotation to have the fourth element be torch.Tensor (remove the | None) or change the allocation of kg to be conditional (e.g., set kg = None when the case that should produce None occurs) so the declared tuple type matches the actual returned value — locate the function signature with the "-> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]" annotation and the kg allocation "kg = torch.empty(...)" to implement the fix.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@fla/ops/kda/wy_fast.py`:
- Around line 256-270: The function return type annotation declares the fourth
tuple element as "torch.Tensor | None" but the code always allocates kg
(variable name kg) so the annotation is inaccurate; update the function
signature return annotation to have the fourth element be torch.Tensor (remove
the | None) or change the allocation of kg to be conditional (e.g., set kg =
None when the case that should produce None occurs) so the declared tuple type
matches the actual returned value — locate the function signature with the "->
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]"
annotation and the kg allocation "kg = torch.empty(...)" to implement the fix.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: bf7924de-4e2c-43c3-9547-1b0ac7dd45da
📒 Files selected for processing (1)
fla/ops/kda/wy_fast.py
|
| GPU | NVIDIA H200 |
| CUDA | 12.8 |
| PyTorch | 2.7.1+cu128 |
| Base | 323bbf21e8 |
| Head | f92b40bb0c |
| Threshold | 5.0% |
| Op | Mode | B | T | H | D | Base (ms) | Head (ms) | Speedup | Change |
|---|---|---|---|---|---|---|---|---|---|
| chunk_gla | fwd | 1 | 8192 | 96 | 128 | 2.253 | 2.249 | 1.00x | -0.2% |
| chunk_gla | fwd | 2 | 16384 | 16 | 128 | 1.834 | 1.834 | 1.00x | -0.0% |
| chunk_gla | fwd | 4 | 2048 | 16 | 128 | 0.427 | 0.427 | 1.00x | -0.1% |
| chunk_gla | fwd | 4 | 4096 | 64 | 128 | 3.015 | 3.018 | 1.00x | +0.1% |
| chunk_gla | fwd | 8 | 1024 | 8 | 64 | 0.267 | 0.281 | 0.95x | +5.1% 🔴 |
| chunk_gla | fwd | 8 | 2048 | 32 | 256 | 3.694 | 3.698 | 1.00x | +0.1% |
| chunk_kda | fwd | 1 | 8192 | 96 | 128 | 2.725 | 2.718 | 1.00x | -0.2% |
| chunk_kda | fwd | 2 | 16384 | 16 | 128 | 1.932 | 1.923 | 1.00x | -0.5% |
| chunk_kda | fwd | 4 | 2048 | 16 | 128 | 0.593 | 0.610 | 0.97x | +3.0% |
| chunk_kda | fwd | 4 | 4096 | 64 | 128 | 3.479 | 3.472 | 1.00x | -0.2% |
| chunk_kda | fwd | 8 | 1024 | 8 | 64 | 1.260 | 0.611 | 2.06x | -51.5% 🟢 |
| chunk_kda | fwd | 8 | 2048 | 32 | 256 | 4.413 | 4.124 | 1.07x | -6.6% 🟢 |
| chunk_gla | fwdbwd | 1 | 8192 | 96 | 128 | 10.944 | 10.943 | 1.00x | -0.0% |
| chunk_gla | fwdbwd | 2 | 16384 | 16 | 128 | 8.487 | 8.508 | 1.00x | +0.2% |
| chunk_gla | fwdbwd | 4 | 2048 | 16 | 128 | 2.025 | 2.028 | 1.00x | +0.1% |
| chunk_gla | fwdbwd | 4 | 4096 | 64 | 128 | 15.186 | 15.232 | 1.00x | +0.3% |
| chunk_gla | fwdbwd | 8 | 1024 | 8 | 64 | 1.049 | 1.332 | 0.79x | +26.9% 🔴 |
| chunk_gla | fwdbwd | 8 | 2048 | 32 | 256 | 18.982 | 18.977 | 1.00x | -0.0% |
| chunk_kda | fwdbwd | 1 | 8192 | 96 | 128 | 12.437 | 12.570 | 0.99x | +1.1% |
| chunk_kda | fwdbwd | 2 | 16384 | 16 | 128 | 8.384 | 8.502 | 0.99x | +1.4% |
| chunk_kda | fwdbwd | 4 | 2048 | 16 | 128 | 2.192 | 2.213 | 0.99x | +1.0% |
| chunk_kda | fwdbwd | 4 | 4096 | 64 | 128 | 15.991 | 16.290 | 0.98x | +1.9% |
| chunk_kda | fwdbwd | 8 | 1024 | 8 | 64 | 1.981 | 1.881 | 1.05x | -5.1% 🟢 |
| chunk_kda | fwdbwd | 8 | 2048 | 32 | 256 | 18.604 | 18.598 | 1.00x | -0.0% |
This comment is automatically updated with the latest benchmark results.
Summary
HV > H(num_v_heads > num_qk_heads) in KDA ops and layer, following the gated_delta_rule GVA patterni_h = i_hv // (HV // H)for qk-head mapping; backward compatible whenHV == Hf_proj,b_proj,A_log,dt_bias) output at value-head dimension directly, no repeat neededtest_naive_chunk,test_fused_recurrent,test_chunkChanged files
fla/ops/kda/chunk_intra.pyinter_solve_fused,sub_chunk,bwd_intrakernels: add HV, grid B*HVfla/ops/kda/chunk_intra_token_parallel.pyfla/ops/kda/wy_fast.pyfla/ops/kda/chunk_bwd.pydAv,wy_dqkg_fusedkernels: add HV; dq/dk reduce from HV→Hfla/ops/kda/chunk_fwd.pyfla/ops/kda/chunk.pyfla/ops/kda/naive.pyfla/ops/gla/chunk.pychunk_gla_fwd_o_gkkernel: native GVA via HV from v.shapefla/layers/kda.pytests/ops/test_kda.pypyproject.tomlTest plan
pytest tests/ops/test_kda.py::test_gva_naive_chunk— naive recurrent vs chunk (HV > H)pytest tests/ops/test_kda.py::test_fused_recurrent— includes HV > H casespytest tests/ops/test_kda.py::test_chunk— includes HV > H cases with backward + use_gate_in_kernelpytest tests/ops/test_kda.py— full suite, verify no regressionsSummary by CodeRabbit
New Features
Refactor
Documentation
Tests