Skip to content

[KDA] Add Grouped Value Attention (GVA) support#833

Merged
yzhangcs merged 8 commits intomainfrom
feat/kda-gva
Apr 16, 2026
Merged

[KDA] Add Grouped Value Attention (GVA) support#833
yzhangcs merged 8 commits intomainfrom
feat/kda-gva

Conversation

@yzhangcs
Copy link
Copy Markdown
Member

@yzhangcs yzhangcs commented Apr 15, 2026

Summary

  • Support HV > H (num_v_heads > num_qk_heads) in KDA ops and layer, following the gated_delta_rule GVA pattern
  • All Triton kernels use i_h = i_hv // (HV // H) for qk-head mapping; backward compatible when HV == H
  • Layer projections (f_proj, b_proj, A_log, dt_bias) output at value-head dimension directly, no repeat needed
  • GVA test cases added inline to existing test_naive_chunk, test_fused_recurrent, test_chunk

Changed files

File Change
fla/ops/kda/chunk_intra.py inter_solve_fused, sub_chunk, bwd_intra kernels: add HV, grid B*HV
fla/ops/kda/chunk_intra_token_parallel.py Token-parallel kernel: add HV, manual q/k indexing
fla/ops/kda/wy_fast.py Forward/backward WY kernels: add HV, ptr += offset style
fla/ops/kda/chunk_bwd.py dAv, wy_dqkg_fused kernels: add HV; dq/dk reduce from HV→H
fla/ops/kda/chunk_fwd.py Pass through to GVA-aware output kernel
fla/ops/kda/chunk.py API validation for GVA shapes
fla/ops/kda/naive.py Reference impl with GVA docstrings
fla/ops/gla/chunk.py chunk_gla_fwd_o_gk kernel: native GVA via HV from v.shape
fla/layers/kda.py Projections at HV dim, no repeat
tests/ops/test_kda.py GVA cases in existing tests
pyproject.toml isort config (line_length=127, hanging indent)

Test plan

  • pytest tests/ops/test_kda.py::test_gva_naive_chunk — naive recurrent vs chunk (HV > H)
  • pytest tests/ops/test_kda.py::test_fused_recurrent — includes HV > H cases
  • pytest tests/ops/test_kda.py::test_chunk — includes HV > H cases with backward + use_gate_in_kernel
  • pytest tests/ops/test_kda.py — full suite, verify no regressions

Summary by CodeRabbit

  • New Features

    • Added broad support for grouped value attention (expanded value-head / HV/GVA mode) across KDA paths.
  • Refactor

    • Unified head/value-dimension handling, tensor layouts, and kernel indexing to operate in HV mode for correctness and performance.
    • Removed redundant runtime duplication logic for grouped value attention.
  • Documentation

    • Expanded and clarified KDA docstrings and usage examples for HV/GVA modes.
  • Tests

    • Extended tests for HV/GVA configurations and added conditional skips for certain CUDA E2E tests.

Support HV > H (num_v_heads > num_qk_heads) in KDA ops and layer,
following the same pattern as gated_delta_rule.

Ops changes:
- All Triton kernels (chunk_intra, token_parallel, wy_fast, chunk_bwd,
  chunk_gla_fwd_o_gk) now accept HV parameter and use i_h = i_hv // (HV // H)
  for qk-head mapping. Backward compatible when HV == H.
- g shape: [B, T, HV, K], beta: [B, T, HV], state: [N, HV, K, V]
- Backward pass reduces dq/dk from HV back to H via group sum.

Layer changes:
- f_proj outputs gate_dim = HV * K directly (no repeat needed)
- b_proj outputs num_v_heads directly
- A_log shape [HV], dt_bias shape [HV * K]

Also adds GVA test cases to existing test_naive_chunk, test_fused_recurrent,
and test_chunk, and configures isort in pyproject.toml.
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 15, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

Walkthrough

Introduces Grouped Value Attention (GVA) by adding a value-head dimension HV (with HV % H == 0) and propagating HV semantics across layer code, Triton kernels, Python wrappers, reference implementations, and tests; removes runtime Q/K duplication and makes pointer/index/stride math HV-aware.

Changes

Cohort / File(s) Summary
Layer definition
fla/layers/kda.py
Gate projection resized to gate_dim = num_v_heads * head_k_dim; A_log/dt_bias re-parameterized to per-num_v_heads/gate_dim; reshapes updated and runtime repeat() duplication removed.
GLA / Chunk forward
fla/ops/gla/chunk.py, fla/ops/kda/chunk_fwd.py, fla/ops/kda/chunk.py
Added HV constexpr to autotune keys/kernels; grid and program-id mapping changed from B*HB*HV; pointer arithmetic and block-pointer extents made HV-aware; Python launchers now infer and pass HV. Minor import/comment tidy.
Chunk backward
fla/ops/kda/chunk_bwd.py
Backward Triton kernels gained HV constexpr; autotune keys and program-id mapping updated; pointer/stride math moved to HV; Python wrappers infer HV, adjust grid/allocations, and reduce dq/dk across grouped value heads when HV>H.
Intra-chunk kernels
fla/ops/kda/chunk_intra.py, fla/ops/kda/chunk_intra_token_parallel.py
Kernels/autotune keys extended with HV; program-id/head mapping updated to map HVH (G = HV//H); block-pointer strides, masked/mapped loads, and outputs updated to HV semantics; wrappers derive/pass HV.
WY fast paths
fla/ops/kda/wy_fast.py
Kernels gained HV constexpr; head indexing/pointers refactored to use HV-aware base pointers and strides; host launches derive HV, use grid (NT, B*HV), and allocate w/qg with HV shapes. gk made required in host API.
Naive / reference
fla/ops/kda/naive.py
Docstrings added; derive HV from v.shape[2]; q/k expanded to value-head-aligned shapes via repeat_interleave; state S and intermediates switched to HV layout; inner accumulation rewritten to use Aqk.
Tests & infra
tests/ops/test_kda.py, tests/ops/test_intracard_cache.py
test_kda.py updated parametrizations and tensors to include HV; tests construct/compare HV-shaped states. Two CUDA E2E tests are conditionally skipped when FLA_DISABLE_BACKEND_DISPATCH=1.

Sequence Diagram(s)

sequenceDiagram
    participant Layer as Model Layer
    participant Host as Python wrappers
    participant Kernel as Triton Kernel
    participant GPU as GPU Memory

    rect rgba(100,150,240,0.5)
    Layer->>Host: prepare q, k, g, v, beta, A_log, dt_bias with HV
    end

    rect rgba(120,200,80,0.5)
    Host->>GPU: upload HV-aware tensors, launch kernel (grid B*HV)
    GPU->>Kernel: kernel invoked (maps program_id -> B, i_hv)
    Kernel->>GPU: load q/k/g/v using HV-strides
    Kernel->>GPU: compute forward/backward (Aqk, o, dA, dq/dk/dv)
    Kernel-->>Host: write HV-shaped outputs
    end

    rect rgba(200,120,160,0.5)
    Host->>Host: if HV>H, reduce dq/dk across groups (G=HV//H)
    Host-->>Layer: return outputs, final_state, gradients
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested reviewers

  • sustcsonglin
  • Nathancgy

"🐇 I hopped through heads and gates with glee,
HV multiplied the value-tree,
Kernels learned new strides to race,
Pointers aligned — a tidy place,
Hooray, attention grows in me!"

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 24.24% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely summarizes the main change: adding Grouped Value Attention (GVA) support to the KDA implementation. It is specific, accurate, and directly reflects the primary objective of the changeset.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/kda-gva

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Grouped Value Attention (GVA) to the KDA and GLA modules, enabling configurations where the number of value heads exceeds query/key heads. The update includes modifications to the Triton kernels to handle the expanded value-head dimension and updates to the naive implementations and test suites. Review feedback highlighted critical bugs in the wy_fast.py kernel related to the indexing and allocation of the qg and kg tensors. Specifically, these tensors must be managed at the value-head resolution to avoid race conditions and memory corruption during GVA operations.

Comment thread fla/ops/kda/wy_fast.py Outdated
Comment thread fla/ops/kda/wy_fast.py Outdated
Comment thread fla/ops/kda/wy_fast.py
Comment thread fla/ops/kda/wy_fast.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
fla/ops/kda/chunk.py (2)

263-298: ⚠️ Potential issue | 🟡 Minor

The examples still construct h0 with a rejected dtype.

Line 336 requires initial_state.dtype == torch.float32, but both new examples build h0 as torch.bfloat16. Copy-pasting either snippet now fails immediately.

📝 Suggested doc fix
-        >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
+        >>> h0 = torch.randn(B, H, K, V, dtype=torch.float32, device='cuda')
...
-        >>> h0 = torch.randn(B, HV, K, V, dtype=torch.bfloat16, device='cuda')
+        >>> h0 = torch.randn(B, HV, K, V, dtype=torch.float32, device='cuda')
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/kda/chunk.py` around lines 263 - 298, The doc examples construct
initial_state `h0` with dtype=torch.bfloat16 but `chunk_kda` requires
`initial_state.dtype == torch.float32`; update the two example snippets to
create `h0` with dtype=torch.float32 (e.g., torch.randn(...,
dtype=torch.float32, device='cuda')) so the dtype check in chunk_kda passes
while leaving other tensors (q, k, v, g, beta) as-is.

338-358: ⚠️ Potential issue | 🟠 Major

Validate the rest of the HV-shaped inputs before launching kernels.

This block now checks g and beta, but v, initial_state, A_log, and dt_bias can still arrive in stale pre-GVA shapes. A mismatched [N, H, K, V] state or [H] / [H*K] gate params will only fail later inside the Triton path, with much worse diagnostics.

🔍 Suggested guardrail
-    B, T, H, K, HV = *q.shape, v.shape[2]
+    B, T, H, K = q.shape
+    assert v.ndim == 4 and v.shape[:2] == (B, T), (
+        f"v must have shape [B, T, HV, V] with the same batch/seq dims as q, got {list(v.shape)}"
+    )
+    HV, V = v.shape[2], v.shape[3]
     assert q.shape == k.shape, f"q and k must have the same shape, got q={q.shape} vs k={k.shape}"
     assert K <= 256, f"Currently we only support key headdim <=256 for KDA, got {K}."
     assert HV % H == 0, (
         f"For GVA, num_v_heads (HV={HV}) must be evenly divisible by num_qk_heads (H={H}), "
         f"but got HV % H = {HV % H}"
     )
     assert g.shape == (B, T, HV, K), f"g must have shape [B, T, HV, K]={[B, T, HV, K]}, got {list(g.shape)}"
     assert beta.shape == (B, T, HV), f"beta must have shape [B, T, HV]={[B, T, HV]}, got {list(beta.shape)}"
+    if initial_state is not None:
+        assert initial_state.shape[1:] == (HV, K, V), (
+            f"initial_state must have shape [N, HV, K, V], got {list(initial_state.shape)}"
+        )
+    if use_gate_in_kernel:
+        assert A_log.shape == (HV,), f"A_log must have shape [HV], got {list(A_log.shape)}"
+        if dt_bias is not None:
+            assert dt_bias.shape == (HV * K,), f"dt_bias must have shape [HV * K], got {list(dt_bias.shape)}"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/kda/chunk.py` around lines 338 - 358, The code currently validates g
and beta shapes but can still accept stale shapes for v, initial_state, A_log
and dt_bias which will only error later; add upfront shape guards: verify
v.ndim==4 and v.shape[0:2]==(B,T) and v.shape[2]==HV (and capture V=v.shape[3]);
if initial_state is provided assert its leading batch dim matches B and that its
remaining dims include H and K in the expected order (e.g. initial_state.shape
== (B, H, K, V) or initial_state.shape[:3] == (B, H, K) if full V omitted); when
use_gate_in_kernel is true assert A_log.shape is either (H,) or (H*K,) and
dt_bias (if not None) has a matching shape; ensure these checks reference the
existing symbols v, initial_state, A_log, dt_bias, H, K, HV so mismatches fail
early with clear messages.
🧹 Nitpick comments (1)
tests/ops/test_kda.py (1)

19-35: Add at least one HV > H varlen case.

The new GVA coverage is all fixed-length. Most of the risky pointer/stride changes in this PR also touch the IS_VARLEN branches, so I’d add one HV > H case to test_chunk_varlen / test_chunk_varlen_prefill as well.

Also applies to: 84-98, 346-381

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/test_kda.py` around lines 19 - 35, Add at least one parameter tuple
where HV > H to the pytest.parametrize lists used by test_chunk_varlen and
test_chunk_varlen_prefill so a varlen case is exercised (e.g., change one tuple
in the list of tests from (HV==H) to something like (H=1, HV=2, ...) or append a
new pytest.param with HV > H); update the param blocks that define
("B","T","H","HV","D","scale","gate_logit_normalizer","dtype") (the same blocks
around the existing tuples) so the test suite covers a variable-length case and
exercises the IS_VARLEN branches.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/ops/kda/chunk_bwd.py`:
- Around line 442-443: The code computes G = HV // H using q.shape and v.shape
(variables H, HV) without validating assumptions; add a guard in chunk_bwd.py
before computing G and before the HV execution path (the branch that remaps
heads and does view(..., H, G, ...)) to verify HV >= H and HV % H == 0, and if
the check fails raise a clear exception (e.g., ValueError) or take a safe
fallback path; apply the same validation at the second occurrence around the
code at the other HV path (the lines corresponding to the block at ~575-578) so
invalid ratios cannot produce misroutes, divide-by-zero or invalid reshapes.

In `@fla/ops/kda/wy_fast.py`:
- Around line 267-271: recompute_w_u_fwd_kda_kernel is writing per-grouped-head
results into H-headed tensors qg and kg causing races when HV > H; allocate qg
and kg to be HV-headed (matching grouped heads) or change the kernel write index
to use i_hv (not i_h) so each (chunk, hv) writes to a unique slot, and update
downstream consumers (chunk_kda_bwd and the chunk_gated_delta_rule_* functions)
to read the HV-headed qg/kg (or to reduce from HV->H deterministically) so the
per-value-head gate information is preserved and deterministic.
- Line 267: The code unconditionally calls torch.empty_like(gk) while the
signature still allows gk: torch.Tensor | None; update the function to either
require a non-optional gk or add an explicit early check that raises a clear
ValueError if gk is None; locate the site using the symbols gk and w (the line
with w = torch.empty_like(gk)) and modify the function signature or insert a
guard like `if gk is None: raise ValueError("gk must be provided")` before
allocating w so callers get a clear contract violation instead of an opaque
allocation error.

---

Outside diff comments:
In `@fla/ops/kda/chunk.py`:
- Around line 263-298: The doc examples construct initial_state `h0` with
dtype=torch.bfloat16 but `chunk_kda` requires `initial_state.dtype ==
torch.float32`; update the two example snippets to create `h0` with
dtype=torch.float32 (e.g., torch.randn(..., dtype=torch.float32, device='cuda'))
so the dtype check in chunk_kda passes while leaving other tensors (q, k, v, g,
beta) as-is.
- Around line 338-358: The code currently validates g and beta shapes but can
still accept stale shapes for v, initial_state, A_log and dt_bias which will
only error later; add upfront shape guards: verify v.ndim==4 and
v.shape[0:2]==(B,T) and v.shape[2]==HV (and capture V=v.shape[3]); if
initial_state is provided assert its leading batch dim matches B and that its
remaining dims include H and K in the expected order (e.g. initial_state.shape
== (B, H, K, V) or initial_state.shape[:3] == (B, H, K) if full V omitted); when
use_gate_in_kernel is true assert A_log.shape is either (H,) or (H*K,) and
dt_bias (if not None) has a matching shape; ensure these checks reference the
existing symbols v, initial_state, A_log, dt_bias, H, K, HV so mismatches fail
early with clear messages.

---

Nitpick comments:
In `@tests/ops/test_kda.py`:
- Around line 19-35: Add at least one parameter tuple where HV > H to the
pytest.parametrize lists used by test_chunk_varlen and test_chunk_varlen_prefill
so a varlen case is exercised (e.g., change one tuple in the list of tests from
(HV==H) to something like (H=1, HV=2, ...) or append a new pytest.param with HV
> H); update the param blocks that define
("B","T","H","HV","D","scale","gate_logit_normalizer","dtype") (the same blocks
around the existing tuples) so the test suite covers a variable-length case and
exercises the IS_VARLEN branches.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 8104be1b-8415-4078-bedb-387adb34381a

📥 Commits

Reviewing files that changed from the base of the PR and between f6476ec and de949db.

📒 Files selected for processing (11)
  • fla/layers/kda.py
  • fla/ops/gla/chunk.py
  • fla/ops/kda/chunk.py
  • fla/ops/kda/chunk_bwd.py
  • fla/ops/kda/chunk_fwd.py
  • fla/ops/kda/chunk_intra.py
  • fla/ops/kda/chunk_intra_token_parallel.py
  • fla/ops/kda/naive.py
  • fla/ops/kda/wy_fast.py
  • pyproject.toml
  • tests/ops/test_kda.py

Comment thread fla/ops/kda/chunk_bwd.py
Comment thread fla/ops/kda/wy_fast.py Outdated
Comment thread fla/ops/kda/wy_fast.py Outdated
- Guard q/qg/kg pointer mutations with STORE_QG/STORE_KG to avoid
  Triton pointer arithmetic on None
- Fix qg/kg strides from H*K to HV*K in make_block_ptr
- Fix qg/kg pointer offsets from (bos*H+i_h)*K to (bos*HV+i_hv)*K
- Fix qg/kg allocation shape from empty_like(q) [B,T,H,K] to [B,T,HV,K]
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
fla/ops/kda/wy_fast.py (1)

267-271: ⚠️ Potential issue | 🟡 Minor

gk is dereferenced unconditionally despite optional signature.

Line 267 calls torch.empty_like(gk) but the signature at line 253 declares gk: torch.Tensor | None = None. If gk is None, this raises an AttributeError instead of a clear contract violation.

Suggested fix
 def recompute_w_u_fwd(
     k: torch.Tensor,
     v: torch.Tensor,
     beta: torch.Tensor,
     A: torch.Tensor,
     q: torch.Tensor | None = None,
-    gk: torch.Tensor | None = None,
+    gk: torch.Tensor,
     cu_seqlens: torch.LongTensor | None = None,
     chunk_indices: torch.LongTensor | None = None,
 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:

Or add an explicit guard:

 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
+    if gk is None:
+        raise ValueError("recompute_w_u_fwd requires `gk` to be provided")
     B, T, H, K, V = *k.shape, v.shape[-1]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/kda/wy_fast.py` around lines 267 - 271, The code unconditionally
calls torch.empty_like(gk) assigning to w even though gk: torch.Tensor | None;
change this to allocate w conditionally (e.g. w = torch.empty_like(gk) if gk is
not None else torch.empty(B, T, HV, K, device=k.device, dtype=k.dtype)) so you
don't dereference None, and similarly ensure any other allocations that relied
on optional tensors (like kg) follow the same pattern before calling
recompute_w_u_fwd_kda_kernel; keep the kernel call signature unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@fla/ops/kda/wy_fast.py`:
- Around line 267-271: The code unconditionally calls torch.empty_like(gk)
assigning to w even though gk: torch.Tensor | None; change this to allocate w
conditionally (e.g. w = torch.empty_like(gk) if gk is not None else
torch.empty(B, T, HV, K, device=k.device, dtype=k.dtype)) so you don't
dereference None, and similarly ensure any other allocations that relied on
optional tensors (like kg) follow the same pattern before calling
recompute_w_u_fwd_kda_kernel; keep the kernel call signature unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 30779959-c8f3-4118-a44f-782dcb8cc6b2

📥 Commits

Reviewing files that changed from the base of the PR and between de949db and da02235.

📒 Files selected for processing (1)
  • fla/ops/kda/wy_fast.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
fla/ops/kda/wy_fast.py (1)

257-258: Validation for HV % H == 0 exists at the public API level and is not required in recompute_w_u_fwd.

The constraint is already validated in chunk_kda() in chunk.py (line ~353) with a clear error message. Since all calls to recompute_w_u_fwd go through this validated entry point, adding validation here would be redundant defensive programming rather than a requirement. This suggestion can be skipped unless wy_fast.py functions become public APIs.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/kda/wy_fast.py` around lines 257 - 258, Remove the redundant HV % H
== 0 validation inside recompute_w_u_fwd in wy_fast.py since chunk_kda (in
chunk.py) already enforces this invariant for all call sites; locate the HV and
H usage in the recompute_w_u_fwd function (where B, T, H, K, V = *k.shape,
v.shape[-1] and HV = v.shape[2]) and delete the defensive check and associated
error handling so the function relies on the public API validation in chunk_kda.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@fla/ops/kda/wy_fast.py`:
- Around line 257-258: Remove the redundant HV % H == 0 validation inside
recompute_w_u_fwd in wy_fast.py since chunk_kda (in chunk.py) already enforces
this invariant for all call sites; locate the HV and H usage in the
recompute_w_u_fwd function (where B, T, H, K, V = *k.shape, v.shape[-1] and HV =
v.shape[2]) and delete the defensive check and associated error handling so the
function relies on the public API validation in chunk_kda.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 227a3346-1e33-4f68-9218-47de4c7ba540

📥 Commits

Reviewing files that changed from the base of the PR and between 77ebfd0 and 5e9578d.

📒 Files selected for processing (1)
  • fla/ops/kda/wy_fast.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
fla/ops/kda/wy_fast.py (1)

256-270: ⚠️ Potential issue | 🟡 Minor

Return type annotation is inconsistent with actual return values.

kg is always allocated (line 270) and never None, but the return type annotation on line 256 specifies torch.Tensor | None for the fourth element.

🔧 Suggested fix
-) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor]:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/kda/wy_fast.py` around lines 256 - 270, The function return type
annotation declares the fourth tuple element as "torch.Tensor | None" but the
code always allocates kg (variable name kg) so the annotation is inaccurate;
update the function signature return annotation to have the fourth element be
torch.Tensor (remove the | None) or change the allocation of kg to be
conditional (e.g., set kg = None when the case that should produce None occurs)
so the declared tuple type matches the actual returned value — locate the
function signature with the "-> tuple[torch.Tensor, torch.Tensor, torch.Tensor |
None, torch.Tensor | None]" annotation and the kg allocation "kg =
torch.empty(...)" to implement the fix.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@fla/ops/kda/wy_fast.py`:
- Around line 256-270: The function return type annotation declares the fourth
tuple element as "torch.Tensor | None" but the code always allocates kg
(variable name kg) so the annotation is inaccurate; update the function
signature return annotation to have the fourth element be torch.Tensor (remove
the | None) or change the allocation of kg to be conditional (e.g., set kg =
None when the case that should produce None occurs) so the declared tuple type
matches the actual returned value — locate the function signature with the "->
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]"
annotation and the kg allocation "kg = torch.empty(...)" to implement the fix.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: bf7924de-4e2c-43c3-9547-1b0ac7dd45da

📥 Commits

Reviewing files that changed from the base of the PR and between 5e9578d and 296d52f.

📒 Files selected for processing (1)
  • fla/ops/kda/wy_fast.py

@github-actions
Copy link
Copy Markdown

⚠️ Benchmark Results (NVIDIA-H100-PT2-7)

Status: 2 regression(s) detected

GPU NVIDIA H200
CUDA 12.8
PyTorch 2.7.1+cu128
Base 323bbf21e8
Head f92b40bb0c
Threshold 5.0%
Op Mode B T H D Base (ms) Head (ms) Speedup Change
chunk_gla fwd 1 8192 96 128 2.253 2.249 1.00x -0.2%
chunk_gla fwd 2 16384 16 128 1.834 1.834 1.00x -0.0%
chunk_gla fwd 4 2048 16 128 0.427 0.427 1.00x -0.1%
chunk_gla fwd 4 4096 64 128 3.015 3.018 1.00x +0.1%
chunk_gla fwd 8 1024 8 64 0.267 0.281 0.95x +5.1% 🔴
chunk_gla fwd 8 2048 32 256 3.694 3.698 1.00x +0.1%
chunk_kda fwd 1 8192 96 128 2.725 2.718 1.00x -0.2%
chunk_kda fwd 2 16384 16 128 1.932 1.923 1.00x -0.5%
chunk_kda fwd 4 2048 16 128 0.593 0.610 0.97x +3.0%
chunk_kda fwd 4 4096 64 128 3.479 3.472 1.00x -0.2%
chunk_kda fwd 8 1024 8 64 1.260 0.611 2.06x -51.5% 🟢
chunk_kda fwd 8 2048 32 256 4.413 4.124 1.07x -6.6% 🟢
chunk_gla fwdbwd 1 8192 96 128 10.944 10.943 1.00x -0.0%
chunk_gla fwdbwd 2 16384 16 128 8.487 8.508 1.00x +0.2%
chunk_gla fwdbwd 4 2048 16 128 2.025 2.028 1.00x +0.1%
chunk_gla fwdbwd 4 4096 64 128 15.186 15.232 1.00x +0.3%
chunk_gla fwdbwd 8 1024 8 64 1.049 1.332 0.79x +26.9% 🔴
chunk_gla fwdbwd 8 2048 32 256 18.982 18.977 1.00x -0.0%
chunk_kda fwdbwd 1 8192 96 128 12.437 12.570 0.99x +1.1%
chunk_kda fwdbwd 2 16384 16 128 8.384 8.502 0.99x +1.4%
chunk_kda fwdbwd 4 2048 16 128 2.192 2.213 0.99x +1.0%
chunk_kda fwdbwd 4 4096 64 128 15.991 16.290 0.98x +1.9%
chunk_kda fwdbwd 8 1024 8 64 1.981 1.881 1.05x -5.1% 🟢
chunk_kda fwdbwd 8 2048 32 256 18.604 18.598 1.00x -0.0%

This comment is automatically updated with the latest benchmark results.

@yzhangcs yzhangcs merged commit 14bef3d into main Apr 16, 2026
6 checks passed
@yzhangcs yzhangcs deleted the feat/kda-gva branch April 16, 2026 19:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants