Skip to content

feat(cpu): add inplace rmsnorm implementations for fp32 and fp16#483

Merged
chenghuaWang merged 2 commits intoUbiquitousLearning:v2from
chenghuaWang:v2
Oct 18, 2025
Merged

feat(cpu): add inplace rmsnorm implementations for fp32 and fp16#483
chenghuaWang merged 2 commits intoUbiquitousLearning:v2from
chenghuaWang:v2

Conversation

@chenghuaWang
Copy link
Copy Markdown
Collaborator

@chenghuaWang chenghuaWang commented Oct 18, 2025

  • Added rmsnorm_fp32_inplace and rmsnorm_fp16_inplace functions in ARM kernels
  • Updated RMSNormOp to support inplace operations using the new kernel functions
  • Modified LinearOp and related classes to support tensor redirection
  • Enhanced FlashAttention2Op with updated kernel includes and input handling
  • Added new test cases for FlashAttention2 with improved accuracy checks
  • Fixed contiguous tensor assertions in RMSNorm and RoPE operations
  • Extended Layer macros to support redirect attribute for ops
  • Updated StaticCache with new methods for KV cache management
  • Improved FA2 kernel tests with radix attention support and better validation

Summary by CodeRabbit

  • New Features

    • Added Qwen3 model with integrated Flash Attention 2 support for efficient generation.
    • Introduced in-place operations for RMSNorm and rotary position embeddings to reduce memory overhead.
    • Added ARM NEON-optimized Flash Attention 2 kernels for improved performance on ARM processors.
  • Infrastructure

    • Enhanced KV cache management for optimized inference workflows.
    • Expanded test coverage for attention mechanisms and kernel operations.

chenghuaWang and others added 2 commits October 18, 2025 20:19
- Added `rmsnorm_fp32_inplace` and `rmsnorm_fp16_inplace` functions in ARM kernels
- Updated RMSNormOp to support inplace operations using the new kernel functions
- Modified LinearOp and related classes to support tensor redirection
- Enhanced FlashAttention2Op with updated kernel includes and input handling
- Added new test cases for FlashAttention2 with improved accuracy checks
- Fixed contiguous tensor assertions in RMSNorm and RoPE operations
- Extended Layer macros to support redirect attribute for ops
- Updated StaticCache with new methods for KV cache management
- Improved FA2 kernel tests with radix attention support and better validation
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Oct 18, 2025

Walkthrough

This pull request introduces Flash Attention 2 (FA2) optimizations with architecture-specific implementations, adds in-place operation support to RMSNorm and RoPE kernels, implements redirect mode for Linear operations, adds the Qwen3 FA2 model, and refactors cache management to support efficient KV storage. Multiple operation APIs are updated to allow mutable option access.

Changes

Cohort / File(s) Summary
In-place RMSNorm kernels
mllm/backends/cpu/kernels/arm/rmsnorm.cpp, mllm/backends/cpu/kernels/arm/rmsnorm.hpp
Added rmsnorm_fp32_inplace and rmsnorm_fp16_inplace functions with 3-pass RMS normalization workflow (compute RMS, scale input, apply weight with optional offset). Public declarations and implementations for ARM architecture.
Flash Attention 2 architecture abstraction
mllm/backends/cpu/kernels/common/fa2_1/arch.hpp, impl-any.hpp, impl-arm.hpp, fwd_bshd.hpp
Introduced architecture tag system (__AnyArchTag, __ArmArchTag, __X86ArchTag) with template specializations for VectorDotProduct, MulFromConst, FMAConstArray, FilledWithConst. ARM NEON-optimized implementations for FP32. Generic fallback for any architecture. New fwd_bhsd template function for forward pass computation with batch, head, and sequence dimension handling.
Flash Attention 2 operation dispatch
mllm/backends/cpu/ops/FlashAttention2Op.cpp
Replaced FP16-centric path with FP32-focused implementation. Updated to use new fwd_bhsd with architecture selection (__ArmArchTag for ARM64/ARM, __X86ArchTag for x86). Removed K/V contiguity requirements and added runtime shape validations.
In-place RMSNorm operation
mllm/backends/cpu/ops/RMSNormOp.cpp, mllm/core/aops/RMSNormOp.cpp, mllm/core/aops/RMSNormOp.hpp
Added in-place execution path guarded by options_.isInplace() with architecture-specific kernel dispatch. Updated RMSNormOp::reshape() to return input tensor directly when in-place. Changed options() accessor to return non-const reference.
Linear operation redirect mode
mllm/backends/cpu/ops/LinearOp.cpp, mllm/core/aops/LinearOp.cpp, mllm/core/aops/LinearOp.hpp, mllm/nn/layers/Linear.hpp
Added early-return redirect path in reshape() and setup() when options_.isRedirect() is true. Changed options() accessor to non-const. Added MLLM_LAYER_ENABLE_REDIRECT_ATTRIBUTE(Linear) macro.
RoPE in-place and API updates
mllm/core/aops/RoPEOp.cpp, mllm/core/aops/RoPEOp.hpp, mllm/nn/layers/RoPE.hpp
Added conditional logic in reshape() and setup() for in-place operation. Changed options() to non-const accessor. Added MLLM_LAYER_ENABLE_INPLACE_ATTRIBUTE(RoPE) macro.
Base operation infrastructure
mllm/core/BaseOp.hpp, mllm/nn/Layer.hpp, mllm/core/aops/SiLUOp.hpp
Added setRedirect(bool) and isRedirect() methods to BaseOpOptions. Introduced MLLM_LAYER_ENABLE_REDIRECT_ATTRIBUTE macro for redirect propagation. Added options() non-const accessor to SiLUOp.
Static KV cache management
mllm/nn/lmcache/StaticCache.hpp, mllm/nn/lmcache/StaticCache.cpp
Replaced getKCache()/getVCache() with new getKVCache() returning array of 2 tensors. Added preGetKVWriteLocation() for FA2 mode KV reservation. Added buffer accessors getKCacheBuffer() and getVCacheBuffer().
Qwen3 FA2 model implementation
mllm/models/qwen3/modeling_qwen3_fa2.hpp
Complete Qwen3 model with FA2 support: makeRoPEInvFreq(), makeRotaryPosEmbedding() utilities; Qwen3MLP, Qwen3Attention, Qwen3Decoder, Qwen3Text, Qwen3ForCausalLM classes with generation flow, KV cache integration, and position-id management.
Utility and test updates
mllm/core/aops/ViewOp.cpp
Comment update (TODO → FIXME).
Flash Attention kernel tests
tests/cpu/FlashAttentionKernelTest.hpp, tests/cpu/KernelTest.cpp
Renamed FlashAttention2KernelTest to FlashAttn2KernelTest. Added FlashAttn2Module, FA2EagerModule test classes. Introduced testRadixAttn() and testRadixAttnOnce() test methods. Added ARM-specific FA2 test block.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant RMSNormOp
    participant BaseOp
    participant Kernel[ARM/x86 Kernel]
    
    Caller->>RMSNormOp: forward()
    
    alt In-place Mode (isInplace == true)
        RMSNormOp->>Kernel: rmsnorm_fp32_inplace(X, W, Y, ...)
        Kernel->>Kernel: 1. Compute RMS from X
        Kernel->>Kernel: 2. Scale X by RMS
        Kernel->>Kernel: 3. Apply weight W + offset
        Kernel-->>RMSNormOp: Result in Y
    else Non-in-place Mode
        RMSNormOp->>BaseOp: setup(inputs, outputs)
        RMSNormOp->>Kernel: rmsnorm_fp32(X, Y, W, ...)
        Kernel-->>RMSNormOp: Result
    end
    
    RMSNormOp-->>Caller: Output tensor
Loading
sequenceDiagram
    participant Caller
    participant FlashAttn2Op
    participant ArchDispatch
    participant ArmKernel[ARM NEON]
    participant AnyKernel[Generic]
    
    Caller->>FlashAttn2Op: forward(Q, K, V)
    FlashAttn2Op->>ArchDispatch: Detect architecture
    
    alt ARM64/ARM Architecture
        ArchDispatch->>ArmKernel: fwd_bhsd<__ArmArchTag>
        ArmKernel->>ArmKernel: VectorDotProduct (NEON 16-wide)
        ArmKernel->>ArmKernel: MulFromConst (NEON)
        ArmKernel->>ArmKernel: FMAConstArray (NEON)
        ArmKernel-->>FlashAttn2Op: Attention output
    else X86/Generic
        ArchDispatch->>AnyKernel: fwd_bhsd<__X86ArchTag>
        AnyKernel->>AnyKernel: VectorDotProduct (scalar loop)
        AnyKernel->>AnyKernel: MulFromConst (scalar loop)
        AnyKernel-->>FlashAttn2Op: Attention output
    end
    
    FlashAttn2Op-->>Caller: Output
Loading
sequenceDiagram
    participant Caller
    participant LinearOp
    participant Options
    
    Caller->>LinearOp: reshape()
    LinearOp->>Options: isRedirect()
    
    alt Redirect Mode
        Options-->>LinearOp: true
        LinearOp->>LinearOp: outputs.emplace_back(inputs[1])
        LinearOp-->>Caller: Return (bypass standard reshaping)
    else Normal Mode
        Options-->>LinearOp: false
        LinearOp->>LinearOp: Standard reshape logic
        LinearOp-->>Caller: Reshaped output
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Rationale: This PR encompasses substantial changes across multiple systems: (1) new Flash Attention 2 architecture-agnostic template framework with ARM NEON specializations requiring careful verification of SIMD correctness; (2) comprehensive in-place operation support touching RMSNorm, RoPE, and related infrastructure; (3) significant Qwen3 model implementation (~400+ lines of new model code); (4) API changes affecting option accessors across multiple operation classes; (5) cache management refactoring with dual-path logic (FA2 vs. eager). While many changes follow consistent patterns, the heterogeneity of concerns—kernel optimization, operation dispatch, model architecture, and cache management—plus the density of SIMD/template logic necessitates careful, multi-faceted review.

Poem

🐰 Flutter-flutter, what's this brew?
New kernels dance in NEON hue,
FA2 flows with redirection's glow,
RMSNorm in-place steals the show!
Qwen3 rises, cache songs play,
A rabbit's joy—optimization day! 🚀

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Title Check ⚠️ Warning The pull request title accurately describes in-place RMSNorm implementations for fp32 and fp16, which are indeed added to the codebase through new kernel functions in rmsnorm.cpp/rmsnorm.hpp and integrated into RMSNormOp. However, the PR contains substantial additional changes that are equally or more significant, including: new FlashAttention2 architecture-specific kernels with template infrastructure (arch.hpp, impl-any.hpp, impl-arm.hpp, fwd_bshd.hpp), a complete Qwen3 model implementation, LinearOp redirect support, RoPE in-place capabilities, and StaticCache enhancements. The title captures only one aspect of the broader changeset without conveying the scope of architectural enhancements and new model implementations. Consider revising the PR title to reflect the broader scope, such as "feat(cpu): optimize attention and normalization with inplace kernels and Qwen3 model support" or breaking this into multiple PRs: one for RMSNorm in-place implementations and another for FlashAttention2 infrastructure and Qwen3 model. This would better align the title with the changeset's actual scope and make review and navigation easier.
Docstring Coverage ⚠️ Warning Docstring coverage is 2.56% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Oct 18, 2025

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@chenghuaWang
Copy link
Copy Markdown
Collaborator Author

@coderabbitai review

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Oct 18, 2025

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 25

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
mllm/core/aops/ViewOp.cpp (2)

55-56: Avoid overflow when computing new_numel.

std::accumulate starts with an int seed (1), risking overflow on large shapes. Accumulate into int64.

Apply:

-  const int64_t new_numel = std::accumulate(actual_shape.begin(), actual_shape.end(), 1, std::multiplies<>());
+  const int64_t new_numel = std::accumulate(
+      actual_shape.begin(), actual_shape.end(), int64_t{1}, std::multiplies<int64_t>());

49-53: Handle zero‑element tensors and guard inferred dim overflow.

Current logic rejects inference when product==0 even if input_numel==0, and it doesn’t ensure the inferred size fits in int32.

-  if (infer_dim != -1) {
-    MLLM_RT_ASSERT(product != 0 && "cannot infer dimension for a shape with zero product");
-    MLLM_RT_ASSERT(input_numel % product == 0 && "input tensor size does not match inferred shape");
-    actual_shape[infer_dim] = static_cast<int32_t>(input_numel / product);
-  }
+  if (infer_dim != -1) {
+    if (product == 0) {
+      MLLM_RT_ASSERT(input_numel == 0 && "cannot infer -1 when other dims multiply to 0 unless numel is 0");
+      actual_shape[infer_dim] = 0;
+    } else {
+      MLLM_RT_ASSERT(input_numel % product == 0 && "input tensor size does not match inferred shape");
+      const int64_t inferred = input_numel / product;
+      MLLM_RT_ASSERT(inferred <= std::numeric_limits<int32_t>::max() &&
+                     "inferred dimension exceeds int32 range");
+      actual_shape[infer_dim] = static_cast<int32_t>(inferred);
+    }
+  }

Note: requires #include <limits>.

🧹 Nitpick comments (17)
mllm/core/aops/ViewOp.cpp (2)

71-80: Treat empty tensors as contiguous; prefer central contiguity API.

The manual check can misclassify edge cases; for numel()==0, mark contiguous up‑front, or defer to a canonical isContiguous() if available.

-  bool is_contiguous = true;
+  bool is_contiguous = true;
+  if (it.numel() == 0) {
+    is_contiguous = true;
+  } else
   int64_t current_stride = 1;

If TensorImpl has isContiguous(), use that instead of reimplementing the check.


91-93: Either implement view‑stride computation or improve the failure path.

Right now non‑contiguous inputs always hard‑fail. Two options:

  • Implement a compute_stride_for_view (PyTorch‑style) to allow more reshape patterns without copy.
  • If deferring, enrich the assert to aid debugging (print shape/stride), or fall back to a safe copy/redirect path when policy permits.
-    // FIXME: more stride logic such as `compute_stride_for_view` in PyTorch
-    MLLM_ASSERT_EXIT(is_contiguous, "ViewOp::reshape is only supported for contiguous tensors in this implementation");
+    // TODO: support non-contiguous view by computing new strides when feasible
+    MLLM_ASSERT_EXIT(
+        is_contiguous,
+        fmt::format("ViewOp::reshape requires contiguous tensor. orig_shape={}, orig_stride={}, requested_shape={}",
+                    fmt::join(orig_shape, ","), fmt::join(orig_stride, ","), fmt::join(actual_shape, ",")));

If a fallback copy is acceptable, consider redirecting to a contiguous materialization before reshaping.

mllm/nn/lmcache/StaticCache.hpp (1)

48-55: Harden KV accessors: add bounds checks and [[nodiscard]].

Prevent OOB on invalid layer_idx and encourage use of return values.

Apply:

-  std::array<Tensor, 2> getKVCache(int32_t layer_idx);
+  [[nodiscard]] std::array<Tensor, 2> getKVCache(int32_t layer_idx);

-  std::array<Tensor, 2> preGetKVWriteLocation(int32_t layer_idx, int32_t s);
+  [[nodiscard]] std::array<Tensor, 2> preGetKVWriteLocation(int32_t layer_idx, int32_t s);

-  [[nodiscard]] inline Tensor getKCacheBuffer(int32_t layer_idx) const { return k_cache_[layer_idx]; };
+  [[nodiscard]] inline Tensor getKCacheBuffer(int32_t layer_idx) const {
+    MLLM_RT_ASSERT(layer_idx >= 0 && layer_idx < static_cast<int32_t>(k_cache_.size()));
+    return k_cache_[layer_idx];
+  };

-  [[nodiscard]] inline Tensor getVCacheBuffer(int32_t layer_idx) const { return v_cache_[layer_idx]; };
+  [[nodiscard]] inline Tensor getVCacheBuffer(int32_t layer_idx) const {
+    MLLM_RT_ASSERT(layer_idx >= 0 && layer_idx < static_cast<int32_t>(v_cache_.size()));
+    return v_cache_[layer_idx];
+  };

Additionally, ensure <array> is explicitly included at the top of this header to avoid relying on transitive includes:

#include <array>
mllm/core/BaseOp.hpp (1)

70-77: Define redirect vs inplace semantics (preferably mutually exclusive).

If both are set, behavior is ambiguous. Either enforce mutual exclusion in setters or document precedence (e.g., redirect wins during reshape).

Example guard:

inline void setInplace(bool v) {
  // if both are true, decide policy or assert
  // MLLM_RT_ASSERT(!(v && redirect_), "inplace and redirect cannot both be true");
  inplace_ = v;
}
inline void setRedirect(bool v) {
  // MLLM_RT_ASSERT(!(v && inplace_), "redirect and inplace cannot both be true");
  redirect_ = v;
}
mllm/backends/cpu/ops/RMSNormOp.cpp (1)

31-41: Avoid nested parallelism (outer loop + threaded kernels).

You already parallelize across other_dim. Passing options_.getThreads() into per-chunk kernels can oversubscribe threads.

Set inner-kernel threads to 1 (or teach kernels to be thread-agnostic when called inside a parallel region).

Also applies to: 45-54

mllm/nn/lmcache/StaticCache.cpp (1)

145-159: Make getKVCache const for safer API.

This method does not mutate state; mark it const (and update the declaration in StaticCache.hpp) to prevent accidental writes via non-const access.

-std::array<Tensor, 2> StaticCache::getKVCache(int32_t layer_idx) {
+std::array<Tensor, 2> StaticCache::getKVCache(int32_t layer_idx) const {
mllm/models/qwen3/modeling_qwen3_fa2.hpp (2)

4-12: Include for std::pow/sin/cos.

Avoid relying on transitive includes.

 #include "mllm/models/ARGeneration.hpp"
+#include <cmath>

112-121: Silence unused member or use it.

num_key_value_groups_ is computed but unused; mark [[maybe_unused]] or use it.

-  int num_key_value_groups_;
+  [[maybe_unused]] int num_key_value_groups_;
mllm/core/aops/RMSNormOp.cpp (1)

42-47: Clear outputs before emplacing to avoid accumulation on repeated reshape calls.

Some executors call reshape multiple times; ensure outputs has exactly one tensor.

 void RMSNormOp::reshape(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) {
-  if (options_.isInplace()) {
-    outputs.emplace_back(inputs[0]);
+  outputs.clear();
+  if (options_.isInplace()) {
+    outputs.emplace_back(inputs[0]);
   } else {
     const auto& i = inputs[0];
     outputs.emplace_back(Tensor::empty(i.shape(), i.dtype(), i.device()));
   }
 }

Confirm your executor never relies on outputs' previous capacity.

mllm/core/aops/SiLUOp.hpp (1)

27-28: Provide a const accessor too; document mutation semantics.

Keep mutable access but add a const overload to avoid unnecessary escapes of internal state.

   inline SiLUOpOptions& options() { return options_; }
+  inline const SiLUOpOptions& options() const { return options_; }

Please confirm other AOPs expose both overloads for consistency.

mllm/nn/Layer.hpp (1)

101-108: LGTM! Redirect attribute macro follows established pattern.

The redirect macro correctly mirrors the inplace macro's structure, propagating the redirect flag to both layer and op option levels. This ensures consistent state across the layer abstraction boundary.

The static analysis suggestion to use constexpr template functions instead of macros is worth considering for improved type safety and debugging, though it would require broader refactoring of the layer attribute system.

mllm/backends/cpu/kernels/common/fa2_1/fwd_bshd.hpp (2)

66-82: K/V/Q pointer qualifiers and type names after rename.

Adjust calls to detail helpers to the new type names and const-correctness; no behavioral change.

-          __AccDType acc_s;
-          details::VectorDotProduct<__ArchTag, __QDType, __KDType, __AccDType>::run(q_token, k_token, &acc_s, D);
+          AccDType acc_s;
+          details::VectorDotProduct<ArchTag, QDType, KDType, AccDType>::run(q_token, k_token, &acc_s, D);
...
-          details::MulFromConst<__ArchTag, __AccDType, __AccDType>::run(acc_o, scores_scale, D);
+          details::MulFromConst<ArchTag, AccDType, AccDType>::run(acc_o, scores_scale, D);
...
-          details::FMAConstArray<__ArchTag, __AccDType, __AccDType, __AccDType>::run(acc_o, acc_s, v_token, D);
+          details::FMAConstArray<ArchTag, AccDType, AccDType, AccDType>::run(acc_o, acc_s, v_token, D);

14-18: Include path for generic SIMD looks correct; confirm file name.

Ensure impl-any-simd.hpp exists (not impl-any-simd.h/pp). If not, fix the include.

mllm/backends/cpu/kernels/common/fa2_1/impl-arm.hpp (3)

33-37: Dot-product ILP: use multiple accumulators.

Current single-accumulator FMA chain is dependency-heavy. Four independent accumulators improve ILP and throughput on many cores.

Example:

float32x4_t s0=vdupq_n_f32(0), s1=s0, s2=s0, s3=s0;
// inside 16-wide loop:
s0 = MLLM_NEON_FMA(s0, lhs_vec0, rhs_vec0);
s1 = MLLM_NEON_FMA(s1, lhs_vec1, rhs_vec1);
s2 = MLLM_NEON_FMA(s2, lhs_vec2, rhs_vec2);
s3 = MLLM_NEON_FMA(s3, lhs_vec3, rhs_vec3);
// after loops:
float result = MLLM_NEON_HADD(vaddq_f32(vaddq_f32(s0,s1), vaddq_f32(s2,s3)));

68-73: “FIXME” is misleading; MUL is appropriate here.

For pure scaling, vmulq_f32 is correct. Switching to FMA would add an unnecessary zero addend and may slow down on some µarches. Reword or remove the comment.

-      // FIXME: FMA may be faster than MUL
+      // NOTE: MUL is the appropriate choice for pure scaling by const_v.

14-15: Adjacent same-typed pointers are easy to swap in call sites.

(__lhs, __rhs, __out, len) are similar pointer types; easy to misuse at call sites.

Options:

  • Keep signature (to match base template) but add doxygen-style param docs and NOLINT for the static analyzer.
  • Alternatively, wrap inputs as struct { const T* lhs; const T* rhs; } to force named construction at call sites.

If desired, I can add parameter docs and a local // NOLINT(bugprone-easily-swappable-parameters) to silence the warning where needed.

mllm/backends/cpu/kernels/common/fa2_1/impl-any.hpp (1)

12-16: Same-type pointer adjacency; consider making swaps harder.

As in the ARM impl, (__lhs, __rhs, __out, len) invites accidental swaps at call sites.

  • Add param docs and a local NOLINT for the analyzer; or
  • Consider std::span<const T> for lhs/rhs in the generic path to increase type safety (if acceptable in this layer).
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 72f292e and 68f516e.

📒 Files selected for processing (26)
  • mllm/backends/cpu/kernels/arm/rmsnorm.cpp (2 hunks)
  • mllm/backends/cpu/kernels/arm/rmsnorm.hpp (1 hunks)
  • mllm/backends/cpu/kernels/common/fa2_1/arch.hpp (1 hunks)
  • mllm/backends/cpu/kernels/common/fa2_1/fwd_bshd.hpp (1 hunks)
  • mllm/backends/cpu/kernels/common/fa2_1/impl-any.hpp (1 hunks)
  • mllm/backends/cpu/kernels/common/fa2_1/impl-arm.hpp (1 hunks)
  • mllm/backends/cpu/ops/FlashAttention2Op.cpp (2 hunks)
  • mllm/backends/cpu/ops/LinearOp.cpp (1 hunks)
  • mllm/backends/cpu/ops/RMSNormOp.cpp (1 hunks)
  • mllm/core/BaseOp.hpp (1 hunks)
  • mllm/core/aops/LinearOp.cpp (2 hunks)
  • mllm/core/aops/LinearOp.hpp (1 hunks)
  • mllm/core/aops/RMSNormOp.cpp (1 hunks)
  • mllm/core/aops/RMSNormOp.hpp (1 hunks)
  • mllm/core/aops/RoPEOp.cpp (1 hunks)
  • mllm/core/aops/RoPEOp.hpp (1 hunks)
  • mllm/core/aops/SiLUOp.hpp (1 hunks)
  • mllm/core/aops/ViewOp.cpp (1 hunks)
  • mllm/models/qwen3/modeling_qwen3_fa2.hpp (1 hunks)
  • mllm/nn/Layer.hpp (1 hunks)
  • mllm/nn/layers/Linear.hpp (1 hunks)
  • mllm/nn/layers/RoPE.hpp (1 hunks)
  • mllm/nn/lmcache/StaticCache.cpp (1 hunks)
  • mllm/nn/lmcache/StaticCache.hpp (2 hunks)
  • tests/cpu/FlashAttentionKernelTest.hpp (1 hunks)
  • tests/cpu/KernelTest.cpp (1 hunks)
🧰 Additional context used
🪛 Clang (14.0.6)
mllm/core/aops/SiLUOp.hpp

[error] 30-30: member variable 'options_' has protected visibility

(cppcoreguidelines-non-private-member-variables-in-classes,-warnings-as-errors)

mllm/backends/cpu/kernels/common/fa2_1/fwd_bshd.hpp

[error] 6-6: 'cmath' file not found

(clang-diagnostic-error)


[error] 25-25: declaration uses identifier '__ArchTag', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 25-25: declaration uses identifier '__QDType', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 25-25: declaration uses identifier '__KDType', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 25-25: declaration uses identifier '__VDType', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 25-25: declaration uses identifier '__ODType', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 25-25: declaration uses identifier '__AccDType', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 27-27: 6 adjacent parameters of 'fwd_bhsd' of similar type ('int') are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 27-27: parameter name 'B' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 27-27: parameter name 'D' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 27-27: declaration uses identifier '__q', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 28-28: declaration uses identifier '__k', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 28-28: declaration uses identifier '__v', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 28-28: declaration uses identifier '__out', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 29-29: variable 'head_repeat_times' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)

mllm/backends/cpu/kernels/common/fa2_1/arch.hpp

[error] 5-5: 'cassert' file not found

(clang-diagnostic-error)


[error] 10-10: declaration uses identifier '__AnyArchTag', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 10-10: invalid case style for class '__AnyArchTag'

(readability-identifier-naming,-warnings-as-errors)


[error] 13-13: declaration uses identifier '__X86ArchTag', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 13-13: invalid case style for class '__X86ArchTag'

(readability-identifier-naming,-warnings-as-errors)


[error] 16-16: declaration uses identifier '__ArmArchTag', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 16-16: invalid case style for class '__ArmArchTag'

(readability-identifier-naming,-warnings-as-errors)


[error] 19-19: declaration uses identifier '__ArgTag', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 19-19: declaration uses identifier '__LhsDataType', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 19-19: declaration uses identifier '__RhsDataType', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 19-19: declaration uses identifier '__DstDataType', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 25-25: declaration uses identifier '__ArgTag', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 25-25: declaration uses identifier '__FromDataType', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 25-25: declaration uses identifier '__constDataType', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)

mllm/backends/cpu/kernels/common/fa2_1/impl-arm.hpp

[error] 6-6: 'mllm/core/DataTypes.hpp' file not found

(clang-diagnostic-error)


[error] 14-14: 2 adjacent parameters of 'run' of similar type ('const int *__restrict') are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 16-16: variable 'sum_vec' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 18-18: variable 'i' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 18-18: variable name 'i' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 19-19: variable 'block_size' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 20-20: variable 'len_aligned' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 45-45: variable 'result' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 56-56: variable 'const_vec' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 58-58: variable 'i' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 58-58: variable name 'i' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 59-59: variable 'block_size' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 60-60: variable 'len_aligned' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 94-94: variable 'acc_vec' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 96-96: variable 'i' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 96-96: variable name 'i' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 97-97: variable 'block_size' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 98-98: variable 'len_aligned' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 136-136: variable 'const_vec' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 138-138: variable 'i' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 138-138: variable name 'i' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 139-139: variable 'block_size' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 140-140: variable 'len_aligned' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)

mllm/core/aops/RMSNormOp.hpp

[error] 39-39: member variable 'weight_' has protected visibility

(cppcoreguidelines-non-private-member-variables-in-classes,-warnings-as-errors)


[error] 40-40: member variable 'options_' has protected visibility

(cppcoreguidelines-non-private-member-variables-in-classes,-warnings-as-errors)

mllm/nn/Layer.hpp

[error] 92-92: function-like macro 'MLLM_LAYER_ENABLE_INPLACE_ATTRIBUTE' used; consider a 'constexpr' template function

(cppcoreguidelines-macro-usage,-warnings-as-errors)


[error] 101-101: function-like macro 'MLLM_LAYER_ENABLE_REDIRECT_ATTRIBUTE' used; consider a 'constexpr' template function

(cppcoreguidelines-macro-usage,-warnings-as-errors)

mllm/backends/cpu/kernels/common/fa2_1/impl-any.hpp

[error] 6-6: 'mllm/core/DataTypes.hpp' file not found

(clang-diagnostic-error)


[error] 12-12: 2 adjacent parameters of 'run' of similar type ('const int *__restrict') are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 14-14: variable 'ret' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)

mllm/models/qwen3/modeling_qwen3_fa2.hpp

[error] 4-4: 'mllm/mllm.hpp' file not found

(clang-diagnostic-error)


[error] 15-15: 2 adjacent parameters of 'makeRoPEInvFreq' of convertible types are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 22-22: 2 adjacent parameters of 'makeRotaryPosEmbedding' of convertible types are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 71-71: constructor does not initialize these fields: gate_proj_, up_proj_, down_proj_, silu_

(cppcoreguidelines-pro-type-member-init,-warnings-as-errors)


[error] 79-79: constructor does not initialize these fields: gate_proj_, up_proj_, down_proj_, silu_

(cppcoreguidelines-pro-type-member-init,-warnings-as-errors)


[error] 79-79: 2 adjacent parameters of 'Qwen3MLP' of similar type ('const int &') are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 86-86: 2 adjacent parameters of 'forward' of similar type ('const int &') are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 87-87: variable name 'x' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 89-89: variable name 'y' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 96-96: constructor does not initialize these fields: q_proj_, k_proj_, v_proj_, o_proj_, rms_norm_q_, rms_norm_k_, q_rope_, k_rope_, hidden_size_, head_dim_, num_attention_heads_, num_key_value_heads_, num_key_value_groups_, layer_idx_

(cppcoreguidelines-pro-type-member-init,-warnings-as-errors)


[error] 115-115: constructor does not initialize these fields: q_proj_, k_proj_, v_proj_, o_proj_, rms_norm_q_, rms_norm_k_, q_rope_, k_rope_, hidden_size_, head_dim_, num_attention_heads_, num_key_value_heads_, layer_idx_

(cppcoreguidelines-pro-type-member-init,-warnings-as-errors)


[error] 115-115: 2 adjacent parameters of 'Qwen3Attention' of similar type ('const int &') are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 120-120: 'num_key_value_groups_' should be initialized in a member initializer of the constructor

(cppcoreguidelines-prefer-member-initializer,-warnings-as-errors)


[error] 138-138: 2 adjacent parameters of 'forward' of similar type ('const int &') are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 139-139: variable name 'x' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 144-144: variable 'B' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 144-144: variable name 'B' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 144-144: invalid case style for variable 'B'

(readability-identifier-naming,-warnings-as-errors)


[error] 145-145: variable 'S' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 145-145: variable name 'S' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 145-145: invalid case style for variable 'S'

(readability-identifier-naming,-warnings-as-errors)


[error] 177-177: member variable 'layer_idx_' has public visibility

(cppcoreguidelines-non-private-member-variables-in-classes,-warnings-as-errors)


[error] 180-180: constructor does not initialize these fields: input_layer_norm_, post_attention_layer_norm_

(cppcoreguidelines-pro-type-member-init,-warnings-as-errors)


[error] 189-189: constructor does not initialize these fields: input_layer_norm_, post_attention_layer_norm_

(cppcoreguidelines-pro-type-member-init,-warnings-as-errors)


[error] 189-189: 2 adjacent parameters of 'Qwen3Decoder' of similar type ('const int &') are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 196-196: 2 adjacent parameters of 'forward' of similar type ('const int &') are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 201-201: variable name 'x' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 211-211: constructor does not initialize these fields: decode_blocks_, norm_, embedding_

(cppcoreguidelines-pro-type-member-init,-warnings-as-errors)


[error] 219-219: constructor does not initialize these fields: decode_blocks_, norm_, embedding_

(cppcoreguidelines-pro-type-member-init,-warnings-as-errors)


[error] 219-219: 2 adjacent parameters of 'Qwen3Text' of similar type ('const int &') are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 226-226: 2 adjacent parameters of 'forward' of similar type ('const int &') are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 230-230: variable name 'x' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 246-246: constructor does not initialize these fields: lm_head_, tie_word_embeddings_, kv_cache_

(cppcoreguidelines-pro-type-member-init,-warnings-as-errors)


[error] 273-273: 2 adjacent parameters of 'forward' of similar type ('const int &') are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 280-280: variable 'position_ids' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 307-307: variable name 'S' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 307-307: invalid case style for variable 'S'

(readability-identifier-naming,-warnings-as-errors)

tests/cpu/FlashAttentionKernelTest.hpp

[error] 4-4: 'limits' file not found

(clang-diagnostic-error)


[error] 20-20: 2 adjacent parameters of 'FlashAttn2Module' of similar type ('int') are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 22-22: 2 adjacent parameters of 'forward' of similar type ('const int &') are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 30-30: use '= default' to define a trivial default constructor

(modernize-use-equals-default,-warnings-as-errors)


[error] 32-32: 2 adjacent parameters of 'forward' of similar type ('const int &') are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 35-35: variable name 'Q' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 35-35: invalid case style for variable 'Q'

(readability-identifier-naming,-warnings-as-errors)


[error] 36-36: variable name 'K' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 36-36: invalid case style for variable 'K'

(readability-identifier-naming,-warnings-as-errors)


[error] 37-37: variable name 'V' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 37-37: invalid case style for variable 'V'

(readability-identifier-naming,-warnings-as-errors)


[error] 53-53: invalid case style for variable 'S_Q'

(readability-identifier-naming,-warnings-as-errors)


[error] 54-54: invalid case style for variable 'S_KV'

(readability-identifier-naming,-warnings-as-errors)


[error] 58-58: declaration uses identifier '__delta', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 58-58: variable '__delta' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 58-58: invalid case style for variable '__delta'

(readability-identifier-naming,-warnings-as-errors)


[error] 77-77: class 'FlashAttn2KernelTest' defines a default destructor but does not define a copy constructor, a copy assignment operator, a move constructor or a move assignment operator

(cppcoreguidelines-special-member-functions,-warnings-as-errors)


[error] 82-82: method 'testRadixAttnOnce' can be made static

(readability-convert-member-functions-to-static,-warnings-as-errors)


[error] 83-83: variable name 'B' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 83-83: invalid case style for variable 'B'

(readability-identifier-naming,-warnings-as-errors)


[error] 84-84: variable 'H_Q' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 84-84: invalid case style for variable 'H_Q'

(readability-identifier-naming,-warnings-as-errors)


[error] 85-85: variable 'H_KV' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 85-85: invalid case style for variable 'H_KV'

(readability-identifier-naming,-warnings-as-errors)


[error] 86-86: variable 'S_Q' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 86-86: invalid case style for variable 'S_Q'

(readability-identifier-naming,-warnings-as-errors)


[error] 87-87: variable 'S_KV' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 87-87: invalid case style for variable 'S_KV'

(readability-identifier-naming,-warnings-as-errors)


[error] 88-88: variable 'D' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 88-88: variable name 'D' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 88-88: invalid case style for variable 'D'

(readability-identifier-naming,-warnings-as-errors)


[error] 94-94: variable name 'Q' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 94-94: invalid case style for variable 'Q'

(readability-identifier-naming,-warnings-as-errors)


[error] 95-95: variable name 'K' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 95-95: invalid case style for variable 'K'

(readability-identifier-naming,-warnings-as-errors)


[error] 96-96: variable name 'V' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 96-96: invalid case style for variable 'V'

(readability-identifier-naming,-warnings-as-errors)


[error] 99-99: variable 'gt' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 99-99: variable name 'gt' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 100-100: variable 'predict' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 115-115: method 'testRadixAttn' can be made static

(readability-convert-member-functions-to-static,-warnings-as-errors)


[error] 115-115: all parameters should be named in a function

(readability-named-parameter,-warnings-as-errors)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-android
🔇 Additional comments (13)
mllm/nn/layers/RoPE.hpp (1)

19-19: In-place attribute enabled — confirm alias safety and coverage.

Looks good. Please verify RoPE forward/op paths are alias-safe when inplace=true and add tests for non-contiguous-last-dim and mixed dtypes.

mllm/nn/lmcache/StaticCache.hpp (1)

86-88: Validate sub-slice bounds for K/V buffers.

Ensure start_idx_ and len_ lie within the sequence dimension of k_cache_[i]/v_cache_[i] to avoid runtime slicing errors.

mllm/backends/cpu/ops/RMSNormOp.cpp (1)

28-59: Incorrect review comment - x86 path is actually safe for in-place operations.

The x86::rmsnorm_fp32 implementation safely handles alias conditions. The loop increments by hn::Lanes(d) (full vector width), not by 1, ensuring non-overlapping memory access even when input and output pointers are identical. This contrasts with the ARM non-inplace variant, which has overlapping reads due to unit-stride loop increments and therefore requires a separate inplace implementation. The x86 algorithm structure already prevents the aliasing hazard, making an explicit _inplace variant unnecessary.

Likely an incorrect or invalid review comment.

mllm/core/aops/RMSNormOp.cpp (1)

50-52: Setup skip is fine; ensure kernel preconditions are validated elsewhere.

Since BaseOp::setup is skipped for inplace, make sure contiguity/dtype checks happen in the backend kernel path.

mllm/nn/layers/Linear.hpp (1)

25-25: LGTM! Redirect attribute enabled for Linear layer.

The macro invocation correctly enables redirect functionality for the Linear layer, consistent with the broader redirect infrastructure introduced in this PR.

mllm/core/aops/LinearOp.cpp (1)

107-110: LGTM! Redirect path in setup is correctly implemented.

The early return when redirect is enabled is appropriate and consistent with the pattern used in other operations.

mllm/core/aops/RMSNormOp.hpp (1)

36-36: LGTM! Accessor change enables dynamic option configuration.

The change from a const accessor to a non-const accessor is intentional and necessary to support dynamic configuration of inplace/redirect behavior at runtime, consistent with similar changes across other operation classes in this PR.

mllm/core/aops/RoPEOp.hpp (1)

36-36: LGTM! Consistent accessor change for dynamic configuration.

The accessor change is consistent with similar updates across other operation classes and enables dynamic configuration of inplace behavior, as evidenced by the corresponding implementation changes in RoPEOp.cpp.

mllm/core/aops/LinearOp.hpp (1)

144-144: LGTM! Accessor change aligns with redirect implementation.

The change to a non-const accessor is necessary to support the redirect functionality implemented in LinearOp.cpp, where options are checked and potentially modified at runtime.

mllm/core/aops/RoPEOp.cpp (1)

43-45: LGTM! Setup correctly handles inplace mode.

The conditional setup logic is appropriate and consistent with the pattern used in other operations.

mllm/nn/Layer.hpp (1)

92-99: LGTM! Inplace attribute macro correctly propagates options.

The macro properly sets the inplace flag in both the layer's public options and the underlying op's options, ensuring consistency. The const_cast is necessary here to enable post-construction modification of options.

mllm/backends/cpu/kernels/arm/rmsnorm.hpp (1)

18-20: Fix __restrict aliasing; macro suggestion is incorrect and should not be applied.

The __restrict concern on both X and Y parameters is valid—declaring both as restrict violates the contract if Y==X (the in-place case). Remove __restrict from X (or Y, but X is preferred).

However, the macro "fix" is incorrect. The review comment suggests changing to MLLM_HOST_ARCH, which is never defined in mllm/utils/CPUArchHelper.hpp. The canonical ARM macros are MLLM_HOST_ARCH_ARM64 and MLLM_HOST_ARCH_ARM. The current rmsnorm.hpp pattern (MLLM_HOST_ARCH_ARM64 || MLLM_HOST_ARCH_ARM) is correct; other files like FlashAttention2Op.cpp (line 39) and RadixAttnOp.cpp (line 31) are the inconsistent ones—they incorrectly use the undefined MLLM_HOST_ARCH. Applying the suggested diff would introduce dead code and break ARM support.

Apply only the restrict fix:

-void rmsnorm_fp32_inplace(const mllm_fp32_t* __restrict X, const mllm_fp32_t* __restrict W, mllm_fp32_t* __restrict Y, int D,
+void rmsnorm_fp32_inplace(const mllm_fp32_t* X, const mllm_fp32_t* __restrict W, mllm_fp32_t* Y, int D,
                           float epsilon, bool add_unit_offset, int thread_count);

-void rmsnorm_fp16_inplace(const mllm_fp16_t* __restrict X, const mllm_fp16_t* __restrict W, mllm_fp16_t* __restrict Y, int D,
+void rmsnorm_fp16_inplace(const mllm_fp16_t* X, const mllm_fp16_t* __restrict W, mllm_fp16_t* Y, int D,
                           float epsilon, bool add_unit_offset, int thread_count);

Do not change the macro at line 8; instead, fix FlashAttention2Op.cpp and RadixAttnOp.cpp to match the correct pattern.

Likely an incorrect or invalid review comment.

mllm/backends/cpu/kernels/common/fa2_1/impl-arm.hpp (1)

92-100: Remove this review comment—the aliasing concern is invalid.

The analysis shows acc_o (from __out buffer) and v_token (from __v buffer) are entirely distinct tensors. The FA2 algorithm never performs in-place updates; v_token is a read-only input from the KV cache while acc_o is the output accumulator. Both __restrict__ annotations are appropriate and safe given this guarantee from the caller. Removing either would be a performance regression with no correctness benefit.

Likely an incorrect or invalid review comment.

Comment thread mllm/backends/cpu/kernels/arm/rmsnorm.cpp
Comment thread mllm/backends/cpu/kernels/arm/rmsnorm.cpp
Comment on lines +5 to +7
#include <cassert>
#include "mllm/utils/Common.hpp"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix reserved identifiers and add safe fallback; adjust includes.

  • Types and template params starting with “__” are reserved; rename.
  • Empty primary templates make calls no-op if specialization not included—dangerous.
  • Include /<type_traits>; drop unused .

Apply:

-#include <cassert>
+#include <cstddef>
+#include <type_traits>
 #include "mllm/utils/Common.hpp"

 namespace mllm::cpu::flash_attn2::details {

-struct __AnyArchTag {};
-using any_arch_tag = __AnyArchTag;
+struct AnyArchTag {};
+using any_arch_tag = AnyArchTag;

-struct __X86ArchTag {};
-using x86_arch_tag = __X86ArchTag;
+struct X86ArchTag {};
+using x86_arch_tag = X86ArchTag;

-struct __ArmArchTag {};
-using arm_arch_tag = __ArmArchTag;
+struct ArmArchTag {};
+using arm_arch_tag = ArmArchTag;

-template<typename __ArgTag, typename __LhsDataType, typename __RhsDataType, typename __DstDataType>
+template<typename ArgTag, typename LhsDataType, typename RhsDataType, typename DstDataType>
 struct VectorDotProduct {
-  static MLLM_FORCE_INLINE void run(const __LhsDataType* __restrict__ __lhs, const __RhsDataType* __restrict__ __rhs,
-                                    __DstDataType* __out, size_t len) {}
+  template <typename T> struct dependent_false : std::false_type {};
+  static MLLM_FORCE_INLINE void run(const LhsDataType* MLLM_RESTRICT lhs,
+                                    const RhsDataType* MLLM_RESTRICT rhs,
+                                    DstDataType* MLLM_RESTRICT out, std::size_t len) {
+    static_assert(dependent_false<ArgTag>::value, "VectorDotProduct: missing specialization for this arch/data type.");
+  }
 };

-template<typename __ArgTag, typename __FromDataType, typename __constDataType>
+template<typename ArgTag, typename FromDataType, typename ConstDataType>
 struct MulFromConst {
-  static MLLM_FORCE_INLINE void run(__FromDataType* __restrict__ __from, const __constDataType const_v, size_t len) {}
+  template <typename T> struct dependent_false : std::false_type {};
+  static MLLM_FORCE_INLINE void run(FromDataType* MLLM_RESTRICT from, const ConstDataType const_v, std::size_t len) {
+    static_assert(dependent_false<ArgTag>::value, "MulFromConst: missing specialization for this arch/data type.");
+  }
 };

-template<typename ArchTag, typename T, typename U, typename V>
+template<typename ArchTag, typename T, typename U, typename V>
 struct FMAConstArray {
-  static MLLM_FORCE_INLINE void run(T* __restrict__ acc_o, const U acc_s, const V* __restrict__ v_token, size_t len) {}
+  template <typename X> struct dependent_false : std::false_type {};
+  static MLLM_FORCE_INLINE void run(T* MLLM_RESTRICT acc_o, const U acc_s, const V* MLLM_RESTRICT v_token, std::size_t len) {
+    static_assert(dependent_false<ArchTag>::value, "FMAConstArray: missing specialization for this arch/data type.");
+  }
 };

-template<typename ArchTag, typename T>
+template<typename ArchTag, typename T>
 struct FilledWithConst {
-  static MLLM_FORCE_INLINE void run(T* __restrict__ a, const T v, size_t len) {}
+  template <typename X> struct dependent_false : std::false_type {};
+  static MLLM_FORCE_INLINE void run(T* MLLM_RESTRICT a, const T v, std::size_t len) {
+    static_assert(dependent_false<ArchTag>::value, "FilledWithConst: missing specialization for this arch/data type.");
+  }
 };

If you prefer a functional fallback over compile-time error, I can provide scalar implementations instead.

Also applies to: 10-18, 19-38

🧰 Tools
🪛 Clang (14.0.6)

[error] 5-5: 'cassert' file not found

(clang-diagnostic-error)

🤖 Prompt for AI Agents
In mllm/backends/cpu/kernels/common/fa2_1/arch.hpp around lines 5-7, replace the
reserved identifier usage (any types, template parameters, or identifiers
beginning with "__") with non-reserved names, remove the unused #include
<cassert>, and add standard headers <cstddef> and <type_traits>; then replace
empty primary templates with either a safe compile-time trap
(static_assert(false, ...) guarded by dependent condition) or provide a
functional scalar fallback implementation so calls without a specialization
don't silently become no-ops, and ensure all renamed identifiers are
consistently adjusted throughout the file.

Comment on lines +25 to +29
template<typename __ArchTag, typename __QDType, typename __KDType, typename __VDType, typename __ODType, typename __AccDType,
bool high_precession_exp = true>
void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, int32_t D, const __QDType* __restrict__ __q,
__KDType* __k, __VDType* __v, __ODType* __restrict__ __out, int32_t thread_count) {
int32_t head_repeat_times = H_Q / H_KV;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Reserved identifiers and K/V should be const; also assert divisibility H_Q % H_KV == 0.

  • Names starting with double underscore are reserved; rename template/param identifiers.
  • K/V are not modified; make them const to match call sites and avoid casting away const.
  • Add an explicit divisibility assertion; code uses integer division to map heads.
-template<typename __ArchTag, typename __QDType, typename __KDType, typename __VDType, typename __ODType, typename __AccDType,
-         bool high_precession_exp = true>
-void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, int32_t D, const __QDType* __restrict__ __q,
-              __KDType* __k, __VDType* __v, __ODType* __restrict__ __out, int32_t thread_count) {
-  int32_t head_repeat_times = H_Q / H_KV;
+template<typename ArchTag, typename QDType, typename KDType, typename VDType, typename ODType, typename AccDType,
+         bool high_precision_exp = true>
+void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, int32_t D,
+              const QDType* __restrict q, const KDType* __restrict k, const VDType* __restrict v,
+              ODType* __restrict out, int32_t thread_count) {
+  MLLM_RT_ASSERT(H_KV > 0 && H_Q % H_KV == 0);
+  int32_t head_repeat_times = H_Q / H_KV;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
template<typename __ArchTag, typename __QDType, typename __KDType, typename __VDType, typename __ODType, typename __AccDType,
bool high_precession_exp = true>
void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, int32_t D, const __QDType* __restrict__ __q,
__KDType* __k, __VDType* __v, __ODType* __restrict__ __out, int32_t thread_count) {
int32_t head_repeat_times = H_Q / H_KV;
template<typename ArchTag, typename QDType, typename KDType, typename VDType, typename ODType, typename AccDType,
bool high_precision_exp = true>
void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, int32_t D,
const QDType* __restrict q, const KDType* __restrict k, const VDType* __restrict v,
ODType* __restrict out, int32_t thread_count) {
MLLM_RT_ASSERT(H_KV > 0 && H_Q % H_KV == 0);
int32_t head_repeat_times = H_Q / H_KV;
🧰 Tools
🪛 Clang (14.0.6)

[error] 25-25: declaration uses identifier '__ArchTag', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 25-25: declaration uses identifier '__QDType', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 25-25: declaration uses identifier '__KDType', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 25-25: declaration uses identifier '__VDType', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 25-25: declaration uses identifier '__ODType', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 25-25: declaration uses identifier '__AccDType', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 27-27: 6 adjacent parameters of 'fwd_bhsd' of similar type ('int') are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 27-27: parameter name 'B' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 27-27: parameter name 'D' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 27-27: declaration uses identifier '__q', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 28-28: declaration uses identifier '__k', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 28-28: declaration uses identifier '__v', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 28-28: declaration uses identifier '__out', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 29-29: variable 'head_repeat_times' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)

🤖 Prompt for AI Agents
In mllm/backends/cpu/kernels/common/fa2_1/fwd_bshd.hpp around lines 25-29,
rename any template type/parameter identifiers that begin with double
underscores to non-reserved names (e.g., __ArchTag -> ArchTag, __QDType ->
QDType, etc.), change the k and v pointer parameters to const (__KDType const*
k, __VDType const* v) to reflect they are not modified, and add an explicit
assertion (e.g., assert(H_Q % H_KV == 0)) before computing head_repeat_times to
ensure H_Q is divisible by H_KV. Ensure call sites and forward declarations are
updated to match the new parameter names and const correctness.

__KDType* __k, __VDType* __v, __ODType* __restrict__ __out, int32_t thread_count) {
int32_t head_repeat_times = H_Q / H_KV;

__AccDType scale = scale = std::sqrt(1.0 / D) * (__AccDType)std::numbers::log2e;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Compile-time bug: double assignment “scale = scale = …”.

Fix and keep types consistent.

-  __AccDType scale = scale = std::sqrt(1.0 / D) * (__AccDType)std::numbers::log2e;
+  AccDType scale = static_cast<AccDType>(std::sqrt(1.0 / static_cast<double>(D))) *
+                   static_cast<AccDType>(std::numbers::log2e);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
__AccDType scale = scale = std::sqrt(1.0 / D) * (__AccDType)std::numbers::log2e;
__AccDType scale = std::sqrt(1.0 / D) * (__AccDType)std::numbers::log2e;
🤖 Prompt for AI Agents
In mllm/backends/cpu/kernels/common/fa2_1/fwd_bshd.hpp around line 31, remove
the accidental double assignment ("scale = scale = ...") and make the arithmetic
use consistent types: compute the sqrt on an __AccDType value and cast
std::numbers::log2e to __AccDType before multiplying so the expression is a
single initialization of scale with matching types (e.g., cast the 1.0/D or the
result of std::numbers::log2e to __AccDType and avoid the duplicated "scale ="
token).

Comment on lines +20 to +21
FlashAttn2Module(int H_Q, int H_KV) : nn::Module() {}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Constructor params unused.

Either remove the (H_Q, H_KV) ctor or store/validate them. Keeping unused params in tests is harmless but noisy.

-  FlashAttn2Module(int H_Q, int H_KV) : nn::Module() {}
+  explicit FlashAttn2Module() = default;
🧰 Tools
🪛 Clang (14.0.6)

[error] 20-20: 2 adjacent parameters of 'FlashAttn2Module' of similar type ('int') are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)

🤖 Prompt for AI Agents
In tests/cpu/FlashAttentionKernelTest.hpp around lines 20-21, the
FlashAttn2Module constructor declares parameters (int H_Q, int H_KV) but does
not use or store them; remove the unused parameters from the ctor signature or
add member variables to store/validate H_Q and H_KV and initialize them in the
constructor, updating any test code that constructs FlashAttn2Module to match
the new signature if you remove params.

Comment on lines +45 to 47
K = K.transpose(1, 2).repeat(h_q / h_kv, 1);
V = V.transpose(1, 2).repeat(h_q / h_kv, 1);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Head expansion: verify repeat semantics target the head axis.

Ensure repeat(h_q / h_kv, 1) repeats along the head dimension, not batch. If API expects per-axis repeats, consider an explicit API:

If needed:

-    K = K.transpose(1, 2).repeat(h_q / h_kv, 1);
-    V = V.transpose(1, 2).repeat(h_q / h_kv, 1);
+    K = K.transpose(1, 2).repeat({1, h_q / h_kv, 1, 1});  // repeat on head axis
+    V = V.transpose(1, 2).repeat({1, h_q / h_kv, 1, 1});

🏁 Script executed:

#!/bin/bash
# Inspect repeat signatures and call sites
rg -nP 'repeat\s*\(' -C2 --type=cpp

Length of output: 3194


I need to verify the repeat implementation and confirm the axis semantics.

Now let me check the implementation to confirm the repeat semantics:

Let me inspect the actual repeat implementation to confirm the dimension semantics:

Based on my verification of the code and repeat semantics, I found a critical issue with the dimensions used in the repeat calls.

Incorrect repeat dimension: change dim=1 to dim=2 for head expansion.

The current code transposes K and V from [B, H, S, D] to [B, S, H, D], then calls .repeat(h_q / h_kv, 1). This repeats along dimension 1 (the sequence axis), not the head dimension. After transpose, the head dimension is at index 2, not 1.

The repeat API signature is repeat(multiplier, dim), not per-dimension repeats. The fix should be:

-    K = K.transpose(1, 2).repeat(h_q / h_kv, 1);
-    V = V.transpose(1, 2).repeat(h_q / h_kv, 1);
+    K = K.transpose(1, 2).repeat(h_q / h_kv, 2);
+    V = V.transpose(1, 2).repeat(h_q / h_kv, 2);

This bug also appears in tests/cpu/RadixAttnKernel.hpp (lines 47-48), mllm/nn/lmcache/DynamicCache.cpp (lines 23-24), and test files using the same pattern.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
K = K.transpose(1, 2).repeat(h_q / h_kv, 1);
V = V.transpose(1, 2).repeat(h_q / h_kv, 1);
K = K.transpose(1, 2).repeat(h_q / h_kv, 2);
V = V.transpose(1, 2).repeat(h_q / h_kv, 2);
🤖 Prompt for AI Agents
In tests/cpu/FlashAttentionKernelTest.hpp around lines 45-47, the repeat calls
after transposing K and V use dim=1 which repeats the sequence axis; change the
repeat axis to dim=2 so the head dimension is expanded (i.e., call repeat(h_q /
h_kv, 2) for both K and V). Apply the same correction in
tests/cpu/RadixAttnKernel.hpp (lines ~47-48), mllm/nn/lmcache/DynamicCache.cpp
(lines ~23-24), and any other tests/files using transpose(...).repeat(h_q /
h_kv, 1) to ensure the head dimension is the one being repeated.

Comment on lines +53 to +64
auto S_Q = Q.shape()[2];
auto S_KV = K.shape()[2];
auto mask = Tensor::zeros({1, 1, S_Q, S_KV});
{
auto weight = mllm::nn::functional::matmul(Q, K, false, true);
// TODO mask weight
weight = mllm::nn::functional::softmax(weight, -1);
ref_output = mllm::nn::functional::matmul(weight, V);
auto ptr = mask.ptr<float>();
int __delta = S_KV - S_Q;
for (int s_q_idx = 0; s_q_idx < S_Q; s_q_idx++) {
int S_KV_BOUND = std::min(__delta + s_q_idx + 1, S_KV);
for (int s_kv_idx = S_KV_BOUND; s_kv_idx < S_KV; s_kv_idx++) {
ptr[s_q_idx * S_KV + s_kv_idx] = -std::numeric_limits<float>::infinity();
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Causal mask indexing assumes B=H=1.

Safe for these tests, but brittle. Consider broadcasting to [B,H,S_Q,S_KV] or using a helper.

-    auto mask = Tensor::zeros({1, 1, S_Q, S_KV});
+    auto mask = Tensor::zeros({Q.shape()[0], Q.shape()[1], S_Q, S_KV}); // or a causalMask helper
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
auto S_Q = Q.shape()[2];
auto S_KV = K.shape()[2];
auto mask = Tensor::zeros({1, 1, S_Q, S_KV});
{
auto weight = mllm::nn::functional::matmul(Q, K, false, true);
// TODO mask weight
weight = mllm::nn::functional::softmax(weight, -1);
ref_output = mllm::nn::functional::matmul(weight, V);
auto ptr = mask.ptr<float>();
int __delta = S_KV - S_Q;
for (int s_q_idx = 0; s_q_idx < S_Q; s_q_idx++) {
int S_KV_BOUND = std::min(__delta + s_q_idx + 1, S_KV);
for (int s_kv_idx = S_KV_BOUND; s_kv_idx < S_KV; s_kv_idx++) {
ptr[s_q_idx * S_KV + s_kv_idx] = -std::numeric_limits<float>::infinity();
}
}
auto S_Q = Q.shape()[2];
auto S_KV = K.shape()[2];
auto mask = Tensor::zeros({Q.shape()[0], Q.shape()[1], S_Q, S_KV}); // or a causalMask helper
{
auto ptr = mask.ptr<float>();
int __delta = S_KV - S_Q;
for (int s_q_idx = 0; s_q_idx < S_Q; s_q_idx++) {
int S_KV_BOUND = std::min(__delta + s_q_idx + 1, S_KV);
for (int s_kv_idx = S_KV_BOUND; s_kv_idx < S_KV; s_kv_idx++) {
ptr[s_q_idx * S_KV + s_kv_idx] = -std::numeric_limits<float>::infinity();
}
}
🧰 Tools
🪛 Clang (14.0.6)

[error] 53-53: invalid case style for variable 'S_Q'

(readability-identifier-naming,-warnings-as-errors)


[error] 54-54: invalid case style for variable 'S_KV'

(readability-identifier-naming,-warnings-as-errors)


[error] 58-58: declaration uses identifier '__delta', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 58-58: variable '__delta' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 58-58: invalid case style for variable '__delta'

(readability-identifier-naming,-warnings-as-errors)

Comment on lines +102 to +111
// Compare
// rtol and atol set to 1e-2f is because:
// 1. The eager softmax is approximate, but radix is not.
auto result = test::allClose(gt, predict, 1e-2f, 1e-2f);
if (!result) {
print(result);
print("S_Q and S_KV is", S_Q, S_KV);
print(predict);
return false;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Tighten failure diagnostics.

Add max abs/rel error to ease debugging.

-    if (!result) {
-      print(result);
+    if (!result) {
+      auto diff = (gt - predict).abs();
+      print("max_abs_err:", diff.max().item<float>());
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Compare
// rtol and atol set to 1e-2f is because:
// 1. The eager softmax is approximate, but radix is not.
auto result = test::allClose(gt, predict, 1e-2f, 1e-2f);
if (!result) {
print(result);
print("S_Q and S_KV is", S_Q, S_KV);
print(predict);
return false;
}
// Compare
// rtol and atol set to 1e-2f is because:
// 1. The eager softmax is approximate, but radix is not.
auto result = test::allClose(gt, predict, 1e-2f, 1e-2f);
if (!result) {
auto diff = (gt - predict).abs();
print("max_abs_err:", diff.max().item<float>());
print("S_Q and S_KV is", S_Q, S_KV);
print(predict);
return false;
}
🤖 Prompt for AI Agents
In tests/cpu/FlashAttentionKernelTest.hpp around lines 102 to 111, the failure
branch only prints the boolean result, S_Q/S_KV and the predicted tensor;
enhance diagnostics by computing and printing the maximum absolute error and
maximum relative error between gt and predict before returning false. Implement
this by iterating over elements (or using existing tensor utilities) to compute
abs_diff = fabs(gt[i] - predict[i]) and rel_diff = abs_diff / (fabs(gt[i]) +
eps) with a small eps (e.g., 1e-8), track the maximum of each, and include those
max values in the print output alongside S_Q, S_KV and the tensors so failures
show max abs and max rel errors for easier debugging.

Comment thread tests/cpu/KernelTest.cpp
Comment on lines +810 to +839
//===----------------------------------------------------------------------===//
// Radix Attn
//===----------------------------------------------------------------------===//
#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM)
#include "FlashAttentionKernelTest.hpp"
TEST_F(FlashAttn2KernelTest, fwd_bshd) {
EXPECT_EQ(testRadixAttn({{
{"H_Q", 28},
{"H_KV", 2},
{"S_Q", 10},
{"S_KV", 10},
{"D", 128},
},
{
{"H_Q", 28},
{"H_KV", 2},
{"S_Q", 10},
{"S_KV", 20},
{"D", 128},
},
{
{"H_Q", 28},
{"H_KV", 2},
{"S_Q", 1},
{"S_KV", 20},
{"D", 128},
}}),
true);
}
#endif
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

🧩 Analysis chain

Verify test implementation - potential copy-paste issue.

The test block is named FlashAttn2KernelTest but:

  1. The comment still says "Radix Attn" instead of "FlashAttn2"
  2. It calls testRadixAttn rather than a FlashAttn2-specific test function

If testRadixAttn is intentionally reused to test FlashAttn2 functionality, this should be clarified with a comment explaining why. Otherwise, this appears to be a copy-paste error and should call an appropriate FlashAttn2 test function.

Run the following script to check if there's a FlashAttn2-specific test function available:


🏁 Script executed:

#!/bin/bash
# Search for FlashAttn2 or FA2 test functions in the test header
rg -n "test.*[Ff]lash.*[Aa]ttn2|test.*FA2" tests/cpu/FlashAttentionKernelTest.hpp

Length of output: 0


I need to broaden the investigation since the first search didn't produce output. Let me verify the file structure and look for the test functions more comprehensively.


🏁 Script executed:

#!/bin/bash
# Check if the header file exists and list its structure
echo "=== Checking file existence ==="
ls -la tests/cpu/FlashAttentionKernelTest.hpp 2>&1

echo ""
echo "=== Searching for all test function definitions ==="
rg -n "^[a-zA-Z_].*test.*\(" tests/cpu/FlashAttentionKernelTest.hpp 2>&1

echo ""
echo "=== Searching for testRadixAttn definition ==="
rg -n "testRadixAttn" tests/cpu/ 2>&1

echo ""
echo "=== Checking for FlashAttn2 references in the header ==="
rg -n "FlashAttn2\|Attn2" tests/cpu/FlashAttentionKernelTest.hpp 2>&1

Length of output: 1069


Good findings! The file exists and testRadixAttn is defined in the header. Let me examine the actual implementation to understand whether this is intentional reuse or a naming issue.


🏁 Script executed:

#!/bin/bash
# Examine testRadixAttn implementation in FlashAttentionKernelTest.hpp
echo "=== testRadixAttn implementation in FlashAttentionKernelTest.hpp ==="
rg -n -A 20 "bool testRadixAttn\(" tests/cpu/FlashAttentionKernelTest.hpp

echo ""
echo "=== Check for any comments or context around testRadixAttn ==="
rg -n -B 5 -A 15 "bool testRadixAttn\(" tests/cpu/FlashAttentionKernelTest.hpp

echo ""
echo "=== Check class definition and inheritance in FlashAttentionKernelTest.hpp ==="
rg -n -A 5 "class.*KernelTest\|class.*FlashAttn" tests/cpu/FlashAttentionKernelTest.hpp

echo ""
echo "=== List all test functions (TEST_F) in KernelTest.cpp around line 810 ==="
rg -n -B 2 -A 10 "TEST_F\(FlashAttn2KernelTest" tests/cpu/KernelTest.cpp

Length of output: 1410


I need to find where FlashAttn2KernelTest is actually defined to understand if it's intentionally reusing RadixAttn test functions.


🏁 Script executed:

#!/bin/bash
# Search for FlashAttn2KernelTest class definition
echo "=== Search for FlashAttn2KernelTest class definition ==="
rg -n "class FlashAttn2KernelTest" tests/cpu/

echo ""
echo "=== Search for any class that includes testRadixAttn ==="
rg -n -B 3 "testRadixAttn" tests/cpu/FlashAttentionKernelTest.hpp | head -30

echo ""
echo "=== Look at the entire FlashAttentionKernelTest.hpp to understand structure ==="
cat tests/cpu/FlashAttentionKernelTest.hpp

Length of output: 4120


Update comment to reflect correct functionality; clarify misleading naming.

The code is functionally correct—the test functions inside FlashAttn2KernelTest do properly test FlashAttn2 (see implementation at line 82-111 in FlashAttentionKernelTest.hpp, which creates FlashAttn2Module and compares against eager attention). However, the naming is misleading:

  • The comment "// Radix Attn" at line 810 should be "// FlashAttn2"
  • The function names testRadixAttn and testRadixAttnOnce are confusing since they test FlashAttn2, not Radix attention

Consider renaming these functions to testFlashAttn2 or adding a clarifying comment explaining why RadixAttn-named functions test FlashAttn2.

🤖 Prompt for AI Agents
In tests/cpu/KernelTest.cpp around lines 810 to 839, update the misleading
comment and naming: change the comment "// Radix Attn" to "// FlashAttn2" and
either rename the helper test functions testRadixAttn/testRadixAttnOnce to
testFlashAttn2/testFlashAttn2Once (and update all call sites) or add a short
clarifying comment above the helpers explaining that these Radix-named helpers
actually exercise FlashAttn2 (pointing to FlashAttentionKernelTest.hpp
implementation), so test names/comments accurately reflect that they verify
FlashAttn2 behavior.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant