feat(cpu): add inplace rmsnorm implementations for fp32 and fp16#483
feat(cpu): add inplace rmsnorm implementations for fp32 and fp16#483chenghuaWang merged 2 commits intoUbiquitousLearning:v2from
Conversation
- Added `rmsnorm_fp32_inplace` and `rmsnorm_fp16_inplace` functions in ARM kernels - Updated RMSNormOp to support inplace operations using the new kernel functions - Modified LinearOp and related classes to support tensor redirection - Enhanced FlashAttention2Op with updated kernel includes and input handling - Added new test cases for FlashAttention2 with improved accuracy checks - Fixed contiguous tensor assertions in RMSNorm and RoPE operations - Extended Layer macros to support redirect attribute for ops - Updated StaticCache with new methods for KV cache management - Improved FA2 kernel tests with radix attention support and better validation
WalkthroughThis pull request introduces Flash Attention 2 (FA2) optimizations with architecture-specific implementations, adds in-place operation support to RMSNorm and RoPE kernels, implements redirect mode for Linear operations, adds the Qwen3 FA2 model, and refactors cache management to support efficient KV storage. Multiple operation APIs are updated to allow mutable option access. Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant RMSNormOp
participant BaseOp
participant Kernel[ARM/x86 Kernel]
Caller->>RMSNormOp: forward()
alt In-place Mode (isInplace == true)
RMSNormOp->>Kernel: rmsnorm_fp32_inplace(X, W, Y, ...)
Kernel->>Kernel: 1. Compute RMS from X
Kernel->>Kernel: 2. Scale X by RMS
Kernel->>Kernel: 3. Apply weight W + offset
Kernel-->>RMSNormOp: Result in Y
else Non-in-place Mode
RMSNormOp->>BaseOp: setup(inputs, outputs)
RMSNormOp->>Kernel: rmsnorm_fp32(X, Y, W, ...)
Kernel-->>RMSNormOp: Result
end
RMSNormOp-->>Caller: Output tensor
sequenceDiagram
participant Caller
participant FlashAttn2Op
participant ArchDispatch
participant ArmKernel[ARM NEON]
participant AnyKernel[Generic]
Caller->>FlashAttn2Op: forward(Q, K, V)
FlashAttn2Op->>ArchDispatch: Detect architecture
alt ARM64/ARM Architecture
ArchDispatch->>ArmKernel: fwd_bhsd<__ArmArchTag>
ArmKernel->>ArmKernel: VectorDotProduct (NEON 16-wide)
ArmKernel->>ArmKernel: MulFromConst (NEON)
ArmKernel->>ArmKernel: FMAConstArray (NEON)
ArmKernel-->>FlashAttn2Op: Attention output
else X86/Generic
ArchDispatch->>AnyKernel: fwd_bhsd<__X86ArchTag>
AnyKernel->>AnyKernel: VectorDotProduct (scalar loop)
AnyKernel->>AnyKernel: MulFromConst (scalar loop)
AnyKernel-->>FlashAttn2Op: Attention output
end
FlashAttn2Op-->>Caller: Output
sequenceDiagram
participant Caller
participant LinearOp
participant Options
Caller->>LinearOp: reshape()
LinearOp->>Options: isRedirect()
alt Redirect Mode
Options-->>LinearOp: true
LinearOp->>LinearOp: outputs.emplace_back(inputs[1])
LinearOp-->>Caller: Return (bypass standard reshaping)
else Normal Mode
Options-->>LinearOp: false
LinearOp->>LinearOp: Standard reshape logic
LinearOp-->>Caller: Reshaped output
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Rationale: This PR encompasses substantial changes across multiple systems: (1) new Flash Attention 2 architecture-agnostic template framework with ARM NEON specializations requiring careful verification of SIMD correctness; (2) comprehensive in-place operation support touching RMSNorm, RoPE, and related infrastructure; (3) significant Qwen3 model implementation (~400+ lines of new model code); (4) API changes affecting option accessors across multiple operation classes; (5) cache management refactoring with dual-path logic (FA2 vs. eager). While many changes follow consistent patterns, the heterogeneity of concerns—kernel optimization, operation dispatch, model architecture, and cache management—plus the density of SIMD/template logic necessitates careful, multi-faceted review. Poem
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
✅ Actions performedReview triggered.
|
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 25
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
mllm/core/aops/ViewOp.cpp (2)
55-56: Avoid overflow when computing new_numel.
std::accumulatestarts with anintseed (1), risking overflow on large shapes. Accumulate into int64.Apply:
- const int64_t new_numel = std::accumulate(actual_shape.begin(), actual_shape.end(), 1, std::multiplies<>()); + const int64_t new_numel = std::accumulate( + actual_shape.begin(), actual_shape.end(), int64_t{1}, std::multiplies<int64_t>());
49-53: Handle zero‑element tensors and guard inferred dim overflow.Current logic rejects inference when product==0 even if
input_numel==0, and it doesn’t ensure the inferred size fits in int32.- if (infer_dim != -1) { - MLLM_RT_ASSERT(product != 0 && "cannot infer dimension for a shape with zero product"); - MLLM_RT_ASSERT(input_numel % product == 0 && "input tensor size does not match inferred shape"); - actual_shape[infer_dim] = static_cast<int32_t>(input_numel / product); - } + if (infer_dim != -1) { + if (product == 0) { + MLLM_RT_ASSERT(input_numel == 0 && "cannot infer -1 when other dims multiply to 0 unless numel is 0"); + actual_shape[infer_dim] = 0; + } else { + MLLM_RT_ASSERT(input_numel % product == 0 && "input tensor size does not match inferred shape"); + const int64_t inferred = input_numel / product; + MLLM_RT_ASSERT(inferred <= std::numeric_limits<int32_t>::max() && + "inferred dimension exceeds int32 range"); + actual_shape[infer_dim] = static_cast<int32_t>(inferred); + } + }Note: requires
#include <limits>.
🧹 Nitpick comments (17)
mllm/core/aops/ViewOp.cpp (2)
71-80: Treat empty tensors as contiguous; prefer central contiguity API.The manual check can misclassify edge cases; for
numel()==0, mark contiguous up‑front, or defer to a canonicalisContiguous()if available.- bool is_contiguous = true; + bool is_contiguous = true; + if (it.numel() == 0) { + is_contiguous = true; + } else int64_t current_stride = 1;If
TensorImplhasisContiguous(), use that instead of reimplementing the check.
91-93: Either implement view‑stride computation or improve the failure path.Right now non‑contiguous inputs always hard‑fail. Two options:
- Implement a
compute_stride_for_view(PyTorch‑style) to allow more reshape patterns without copy.- If deferring, enrich the assert to aid debugging (print shape/stride), or fall back to a safe copy/redirect path when policy permits.
- // FIXME: more stride logic such as `compute_stride_for_view` in PyTorch - MLLM_ASSERT_EXIT(is_contiguous, "ViewOp::reshape is only supported for contiguous tensors in this implementation"); + // TODO: support non-contiguous view by computing new strides when feasible + MLLM_ASSERT_EXIT( + is_contiguous, + fmt::format("ViewOp::reshape requires contiguous tensor. orig_shape={}, orig_stride={}, requested_shape={}", + fmt::join(orig_shape, ","), fmt::join(orig_stride, ","), fmt::join(actual_shape, ",")));If a fallback copy is acceptable, consider redirecting to a contiguous materialization before reshaping.
mllm/nn/lmcache/StaticCache.hpp (1)
48-55: Harden KV accessors: add bounds checks and [[nodiscard]].Prevent OOB on invalid
layer_idxand encourage use of return values.Apply:
- std::array<Tensor, 2> getKVCache(int32_t layer_idx); + [[nodiscard]] std::array<Tensor, 2> getKVCache(int32_t layer_idx); - std::array<Tensor, 2> preGetKVWriteLocation(int32_t layer_idx, int32_t s); + [[nodiscard]] std::array<Tensor, 2> preGetKVWriteLocation(int32_t layer_idx, int32_t s); - [[nodiscard]] inline Tensor getKCacheBuffer(int32_t layer_idx) const { return k_cache_[layer_idx]; }; + [[nodiscard]] inline Tensor getKCacheBuffer(int32_t layer_idx) const { + MLLM_RT_ASSERT(layer_idx >= 0 && layer_idx < static_cast<int32_t>(k_cache_.size())); + return k_cache_[layer_idx]; + }; - [[nodiscard]] inline Tensor getVCacheBuffer(int32_t layer_idx) const { return v_cache_[layer_idx]; }; + [[nodiscard]] inline Tensor getVCacheBuffer(int32_t layer_idx) const { + MLLM_RT_ASSERT(layer_idx >= 0 && layer_idx < static_cast<int32_t>(v_cache_.size())); + return v_cache_[layer_idx]; + };Additionally, ensure
<array>is explicitly included at the top of this header to avoid relying on transitive includes:#include <array>mllm/core/BaseOp.hpp (1)
70-77: Defineredirectvsinplacesemantics (preferably mutually exclusive).If both are set, behavior is ambiguous. Either enforce mutual exclusion in setters or document precedence (e.g.,
redirectwins during reshape).Example guard:
inline void setInplace(bool v) { // if both are true, decide policy or assert // MLLM_RT_ASSERT(!(v && redirect_), "inplace and redirect cannot both be true"); inplace_ = v; } inline void setRedirect(bool v) { // MLLM_RT_ASSERT(!(v && inplace_), "redirect and inplace cannot both be true"); redirect_ = v; }mllm/backends/cpu/ops/RMSNormOp.cpp (1)
31-41: Avoid nested parallelism (outer loop + threaded kernels).You already parallelize across
other_dim. Passingoptions_.getThreads()into per-chunk kernels can oversubscribe threads.Set inner-kernel threads to 1 (or teach kernels to be thread-agnostic when called inside a parallel region).
Also applies to: 45-54
mllm/nn/lmcache/StaticCache.cpp (1)
145-159: Make getKVCache const for safer API.This method does not mutate state; mark it const (and update the declaration in StaticCache.hpp) to prevent accidental writes via non-const access.
-std::array<Tensor, 2> StaticCache::getKVCache(int32_t layer_idx) { +std::array<Tensor, 2> StaticCache::getKVCache(int32_t layer_idx) const {mllm/models/qwen3/modeling_qwen3_fa2.hpp (2)
4-12: Include for std::pow/sin/cos.Avoid relying on transitive includes.
#include "mllm/models/ARGeneration.hpp" +#include <cmath>
112-121: Silence unused member or use it.num_key_value_groups_ is computed but unused; mark [[maybe_unused]] or use it.
- int num_key_value_groups_; + [[maybe_unused]] int num_key_value_groups_;mllm/core/aops/RMSNormOp.cpp (1)
42-47: Clear outputs before emplacing to avoid accumulation on repeated reshape calls.Some executors call reshape multiple times; ensure outputs has exactly one tensor.
void RMSNormOp::reshape(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) { - if (options_.isInplace()) { - outputs.emplace_back(inputs[0]); + outputs.clear(); + if (options_.isInplace()) { + outputs.emplace_back(inputs[0]); } else { const auto& i = inputs[0]; outputs.emplace_back(Tensor::empty(i.shape(), i.dtype(), i.device())); } }Confirm your executor never relies on outputs' previous capacity.
mllm/core/aops/SiLUOp.hpp (1)
27-28: Provide a const accessor too; document mutation semantics.Keep mutable access but add a const overload to avoid unnecessary escapes of internal state.
inline SiLUOpOptions& options() { return options_; } + inline const SiLUOpOptions& options() const { return options_; }Please confirm other AOPs expose both overloads for consistency.
mllm/nn/Layer.hpp (1)
101-108: LGTM! Redirect attribute macro follows established pattern.The redirect macro correctly mirrors the inplace macro's structure, propagating the redirect flag to both layer and op option levels. This ensures consistent state across the layer abstraction boundary.
The static analysis suggestion to use constexpr template functions instead of macros is worth considering for improved type safety and debugging, though it would require broader refactoring of the layer attribute system.
mllm/backends/cpu/kernels/common/fa2_1/fwd_bshd.hpp (2)
66-82: K/V/Q pointer qualifiers and type names after rename.Adjust calls to detail helpers to the new type names and const-correctness; no behavioral change.
- __AccDType acc_s; - details::VectorDotProduct<__ArchTag, __QDType, __KDType, __AccDType>::run(q_token, k_token, &acc_s, D); + AccDType acc_s; + details::VectorDotProduct<ArchTag, QDType, KDType, AccDType>::run(q_token, k_token, &acc_s, D); ... - details::MulFromConst<__ArchTag, __AccDType, __AccDType>::run(acc_o, scores_scale, D); + details::MulFromConst<ArchTag, AccDType, AccDType>::run(acc_o, scores_scale, D); ... - details::FMAConstArray<__ArchTag, __AccDType, __AccDType, __AccDType>::run(acc_o, acc_s, v_token, D); + details::FMAConstArray<ArchTag, AccDType, AccDType, AccDType>::run(acc_o, acc_s, v_token, D);
14-18: Include path for generic SIMD looks correct; confirm file name.Ensure impl-any-simd.hpp exists (not impl-any-simd.h/pp). If not, fix the include.
mllm/backends/cpu/kernels/common/fa2_1/impl-arm.hpp (3)
33-37: Dot-product ILP: use multiple accumulators.Current single-accumulator FMA chain is dependency-heavy. Four independent accumulators improve ILP and throughput on many cores.
Example:
float32x4_t s0=vdupq_n_f32(0), s1=s0, s2=s0, s3=s0; // inside 16-wide loop: s0 = MLLM_NEON_FMA(s0, lhs_vec0, rhs_vec0); s1 = MLLM_NEON_FMA(s1, lhs_vec1, rhs_vec1); s2 = MLLM_NEON_FMA(s2, lhs_vec2, rhs_vec2); s3 = MLLM_NEON_FMA(s3, lhs_vec3, rhs_vec3); // after loops: float result = MLLM_NEON_HADD(vaddq_f32(vaddq_f32(s0,s1), vaddq_f32(s2,s3)));
68-73: “FIXME” is misleading; MUL is appropriate here.For pure scaling,
vmulq_f32is correct. Switching to FMA would add an unnecessary zero addend and may slow down on some µarches. Reword or remove the comment.- // FIXME: FMA may be faster than MUL + // NOTE: MUL is the appropriate choice for pure scaling by const_v.
14-15: Adjacent same-typed pointers are easy to swap in call sites.
(__lhs, __rhs, __out, len)are similar pointer types; easy to misuse at call sites.Options:
- Keep signature (to match base template) but add doxygen-style param docs and NOLINT for the static analyzer.
- Alternatively, wrap inputs as
struct { const T* lhs; const T* rhs; }to force named construction at call sites.If desired, I can add parameter docs and a local
// NOLINT(bugprone-easily-swappable-parameters)to silence the warning where needed.mllm/backends/cpu/kernels/common/fa2_1/impl-any.hpp (1)
12-16: Same-type pointer adjacency; consider making swaps harder.As in the ARM impl,
(__lhs, __rhs, __out, len)invites accidental swaps at call sites.
- Add param docs and a local NOLINT for the analyzer; or
- Consider
std::span<const T>forlhs/rhsin the generic path to increase type safety (if acceptable in this layer).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (26)
mllm/backends/cpu/kernels/arm/rmsnorm.cpp(2 hunks)mllm/backends/cpu/kernels/arm/rmsnorm.hpp(1 hunks)mllm/backends/cpu/kernels/common/fa2_1/arch.hpp(1 hunks)mllm/backends/cpu/kernels/common/fa2_1/fwd_bshd.hpp(1 hunks)mllm/backends/cpu/kernels/common/fa2_1/impl-any.hpp(1 hunks)mllm/backends/cpu/kernels/common/fa2_1/impl-arm.hpp(1 hunks)mllm/backends/cpu/ops/FlashAttention2Op.cpp(2 hunks)mllm/backends/cpu/ops/LinearOp.cpp(1 hunks)mllm/backends/cpu/ops/RMSNormOp.cpp(1 hunks)mllm/core/BaseOp.hpp(1 hunks)mllm/core/aops/LinearOp.cpp(2 hunks)mllm/core/aops/LinearOp.hpp(1 hunks)mllm/core/aops/RMSNormOp.cpp(1 hunks)mllm/core/aops/RMSNormOp.hpp(1 hunks)mllm/core/aops/RoPEOp.cpp(1 hunks)mllm/core/aops/RoPEOp.hpp(1 hunks)mllm/core/aops/SiLUOp.hpp(1 hunks)mllm/core/aops/ViewOp.cpp(1 hunks)mllm/models/qwen3/modeling_qwen3_fa2.hpp(1 hunks)mllm/nn/Layer.hpp(1 hunks)mllm/nn/layers/Linear.hpp(1 hunks)mllm/nn/layers/RoPE.hpp(1 hunks)mllm/nn/lmcache/StaticCache.cpp(1 hunks)mllm/nn/lmcache/StaticCache.hpp(2 hunks)tests/cpu/FlashAttentionKernelTest.hpp(1 hunks)tests/cpu/KernelTest.cpp(1 hunks)
🧰 Additional context used
🪛 Clang (14.0.6)
mllm/core/aops/SiLUOp.hpp
[error] 30-30: member variable 'options_' has protected visibility
(cppcoreguidelines-non-private-member-variables-in-classes,-warnings-as-errors)
mllm/backends/cpu/kernels/common/fa2_1/fwd_bshd.hpp
[error] 6-6: 'cmath' file not found
(clang-diagnostic-error)
[error] 25-25: declaration uses identifier '__ArchTag', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 25-25: declaration uses identifier '__QDType', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 25-25: declaration uses identifier '__KDType', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 25-25: declaration uses identifier '__VDType', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 25-25: declaration uses identifier '__ODType', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 25-25: declaration uses identifier '__AccDType', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 27-27: 6 adjacent parameters of 'fwd_bhsd' of similar type ('int') are easily swapped by mistake
(bugprone-easily-swappable-parameters,-warnings-as-errors)
[error] 27-27: parameter name 'B' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 27-27: parameter name 'D' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 27-27: declaration uses identifier '__q', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 28-28: declaration uses identifier '__k', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 28-28: declaration uses identifier '__v', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 28-28: declaration uses identifier '__out', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 29-29: variable 'head_repeat_times' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
mllm/backends/cpu/kernels/common/fa2_1/arch.hpp
[error] 5-5: 'cassert' file not found
(clang-diagnostic-error)
[error] 10-10: declaration uses identifier '__AnyArchTag', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 10-10: invalid case style for class '__AnyArchTag'
(readability-identifier-naming,-warnings-as-errors)
[error] 13-13: declaration uses identifier '__X86ArchTag', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 13-13: invalid case style for class '__X86ArchTag'
(readability-identifier-naming,-warnings-as-errors)
[error] 16-16: declaration uses identifier '__ArmArchTag', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 16-16: invalid case style for class '__ArmArchTag'
(readability-identifier-naming,-warnings-as-errors)
[error] 19-19: declaration uses identifier '__ArgTag', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 19-19: declaration uses identifier '__LhsDataType', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 19-19: declaration uses identifier '__RhsDataType', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 19-19: declaration uses identifier '__DstDataType', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 25-25: declaration uses identifier '__ArgTag', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 25-25: declaration uses identifier '__FromDataType', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 25-25: declaration uses identifier '__constDataType', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
mllm/backends/cpu/kernels/common/fa2_1/impl-arm.hpp
[error] 6-6: 'mllm/core/DataTypes.hpp' file not found
(clang-diagnostic-error)
[error] 14-14: 2 adjacent parameters of 'run' of similar type ('const int *__restrict') are easily swapped by mistake
(bugprone-easily-swappable-parameters,-warnings-as-errors)
[error] 16-16: variable 'sum_vec' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 18-18: variable 'i' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 18-18: variable name 'i' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 19-19: variable 'block_size' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 20-20: variable 'len_aligned' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 45-45: variable 'result' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 56-56: variable 'const_vec' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 58-58: variable 'i' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 58-58: variable name 'i' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 59-59: variable 'block_size' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 60-60: variable 'len_aligned' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 94-94: variable 'acc_vec' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 96-96: variable 'i' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 96-96: variable name 'i' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 97-97: variable 'block_size' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 98-98: variable 'len_aligned' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 136-136: variable 'const_vec' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 138-138: variable 'i' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 138-138: variable name 'i' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 139-139: variable 'block_size' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 140-140: variable 'len_aligned' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
mllm/core/aops/RMSNormOp.hpp
[error] 39-39: member variable 'weight_' has protected visibility
(cppcoreguidelines-non-private-member-variables-in-classes,-warnings-as-errors)
[error] 40-40: member variable 'options_' has protected visibility
(cppcoreguidelines-non-private-member-variables-in-classes,-warnings-as-errors)
mllm/nn/Layer.hpp
[error] 92-92: function-like macro 'MLLM_LAYER_ENABLE_INPLACE_ATTRIBUTE' used; consider a 'constexpr' template function
(cppcoreguidelines-macro-usage,-warnings-as-errors)
[error] 101-101: function-like macro 'MLLM_LAYER_ENABLE_REDIRECT_ATTRIBUTE' used; consider a 'constexpr' template function
(cppcoreguidelines-macro-usage,-warnings-as-errors)
mllm/backends/cpu/kernels/common/fa2_1/impl-any.hpp
[error] 6-6: 'mllm/core/DataTypes.hpp' file not found
(clang-diagnostic-error)
[error] 12-12: 2 adjacent parameters of 'run' of similar type ('const int *__restrict') are easily swapped by mistake
(bugprone-easily-swappable-parameters,-warnings-as-errors)
[error] 14-14: variable 'ret' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
mllm/models/qwen3/modeling_qwen3_fa2.hpp
[error] 4-4: 'mllm/mllm.hpp' file not found
(clang-diagnostic-error)
[error] 15-15: 2 adjacent parameters of 'makeRoPEInvFreq' of convertible types are easily swapped by mistake
(bugprone-easily-swappable-parameters,-warnings-as-errors)
[error] 22-22: 2 adjacent parameters of 'makeRotaryPosEmbedding' of convertible types are easily swapped by mistake
(bugprone-easily-swappable-parameters,-warnings-as-errors)
[error] 71-71: constructor does not initialize these fields: gate_proj_, up_proj_, down_proj_, silu_
(cppcoreguidelines-pro-type-member-init,-warnings-as-errors)
[error] 79-79: constructor does not initialize these fields: gate_proj_, up_proj_, down_proj_, silu_
(cppcoreguidelines-pro-type-member-init,-warnings-as-errors)
[error] 79-79: 2 adjacent parameters of 'Qwen3MLP' of similar type ('const int &') are easily swapped by mistake
(bugprone-easily-swappable-parameters,-warnings-as-errors)
[error] 86-86: 2 adjacent parameters of 'forward' of similar type ('const int &') are easily swapped by mistake
(bugprone-easily-swappable-parameters,-warnings-as-errors)
[error] 87-87: variable name 'x' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 89-89: variable name 'y' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 96-96: constructor does not initialize these fields: q_proj_, k_proj_, v_proj_, o_proj_, rms_norm_q_, rms_norm_k_, q_rope_, k_rope_, hidden_size_, head_dim_, num_attention_heads_, num_key_value_heads_, num_key_value_groups_, layer_idx_
(cppcoreguidelines-pro-type-member-init,-warnings-as-errors)
[error] 115-115: constructor does not initialize these fields: q_proj_, k_proj_, v_proj_, o_proj_, rms_norm_q_, rms_norm_k_, q_rope_, k_rope_, hidden_size_, head_dim_, num_attention_heads_, num_key_value_heads_, layer_idx_
(cppcoreguidelines-pro-type-member-init,-warnings-as-errors)
[error] 115-115: 2 adjacent parameters of 'Qwen3Attention' of similar type ('const int &') are easily swapped by mistake
(bugprone-easily-swappable-parameters,-warnings-as-errors)
[error] 120-120: 'num_key_value_groups_' should be initialized in a member initializer of the constructor
(cppcoreguidelines-prefer-member-initializer,-warnings-as-errors)
[error] 138-138: 2 adjacent parameters of 'forward' of similar type ('const int &') are easily swapped by mistake
(bugprone-easily-swappable-parameters,-warnings-as-errors)
[error] 139-139: variable name 'x' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 144-144: variable 'B' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 144-144: variable name 'B' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 144-144: invalid case style for variable 'B'
(readability-identifier-naming,-warnings-as-errors)
[error] 145-145: variable 'S' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 145-145: variable name 'S' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 145-145: invalid case style for variable 'S'
(readability-identifier-naming,-warnings-as-errors)
[error] 177-177: member variable 'layer_idx_' has public visibility
(cppcoreguidelines-non-private-member-variables-in-classes,-warnings-as-errors)
[error] 180-180: constructor does not initialize these fields: input_layer_norm_, post_attention_layer_norm_
(cppcoreguidelines-pro-type-member-init,-warnings-as-errors)
[error] 189-189: constructor does not initialize these fields: input_layer_norm_, post_attention_layer_norm_
(cppcoreguidelines-pro-type-member-init,-warnings-as-errors)
[error] 189-189: 2 adjacent parameters of 'Qwen3Decoder' of similar type ('const int &') are easily swapped by mistake
(bugprone-easily-swappable-parameters,-warnings-as-errors)
[error] 196-196: 2 adjacent parameters of 'forward' of similar type ('const int &') are easily swapped by mistake
(bugprone-easily-swappable-parameters,-warnings-as-errors)
[error] 201-201: variable name 'x' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 211-211: constructor does not initialize these fields: decode_blocks_, norm_, embedding_
(cppcoreguidelines-pro-type-member-init,-warnings-as-errors)
[error] 219-219: constructor does not initialize these fields: decode_blocks_, norm_, embedding_
(cppcoreguidelines-pro-type-member-init,-warnings-as-errors)
[error] 219-219: 2 adjacent parameters of 'Qwen3Text' of similar type ('const int &') are easily swapped by mistake
(bugprone-easily-swappable-parameters,-warnings-as-errors)
[error] 226-226: 2 adjacent parameters of 'forward' of similar type ('const int &') are easily swapped by mistake
(bugprone-easily-swappable-parameters,-warnings-as-errors)
[error] 230-230: variable name 'x' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 246-246: constructor does not initialize these fields: lm_head_, tie_word_embeddings_, kv_cache_
(cppcoreguidelines-pro-type-member-init,-warnings-as-errors)
[error] 273-273: 2 adjacent parameters of 'forward' of similar type ('const int &') are easily swapped by mistake
(bugprone-easily-swappable-parameters,-warnings-as-errors)
[error] 280-280: variable 'position_ids' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 307-307: variable name 'S' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 307-307: invalid case style for variable 'S'
(readability-identifier-naming,-warnings-as-errors)
tests/cpu/FlashAttentionKernelTest.hpp
[error] 4-4: 'limits' file not found
(clang-diagnostic-error)
[error] 20-20: 2 adjacent parameters of 'FlashAttn2Module' of similar type ('int') are easily swapped by mistake
(bugprone-easily-swappable-parameters,-warnings-as-errors)
[error] 22-22: 2 adjacent parameters of 'forward' of similar type ('const int &') are easily swapped by mistake
(bugprone-easily-swappable-parameters,-warnings-as-errors)
[error] 30-30: use '= default' to define a trivial default constructor
(modernize-use-equals-default,-warnings-as-errors)
[error] 32-32: 2 adjacent parameters of 'forward' of similar type ('const int &') are easily swapped by mistake
(bugprone-easily-swappable-parameters,-warnings-as-errors)
[error] 35-35: variable name 'Q' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 35-35: invalid case style for variable 'Q'
(readability-identifier-naming,-warnings-as-errors)
[error] 36-36: variable name 'K' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 36-36: invalid case style for variable 'K'
(readability-identifier-naming,-warnings-as-errors)
[error] 37-37: variable name 'V' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 37-37: invalid case style for variable 'V'
(readability-identifier-naming,-warnings-as-errors)
[error] 53-53: invalid case style for variable 'S_Q'
(readability-identifier-naming,-warnings-as-errors)
[error] 54-54: invalid case style for variable 'S_KV'
(readability-identifier-naming,-warnings-as-errors)
[error] 58-58: declaration uses identifier '__delta', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 58-58: variable '__delta' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 58-58: invalid case style for variable '__delta'
(readability-identifier-naming,-warnings-as-errors)
[error] 77-77: class 'FlashAttn2KernelTest' defines a default destructor but does not define a copy constructor, a copy assignment operator, a move constructor or a move assignment operator
(cppcoreguidelines-special-member-functions,-warnings-as-errors)
[error] 82-82: method 'testRadixAttnOnce' can be made static
(readability-convert-member-functions-to-static,-warnings-as-errors)
[error] 83-83: variable name 'B' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 83-83: invalid case style for variable 'B'
(readability-identifier-naming,-warnings-as-errors)
[error] 84-84: variable 'H_Q' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 84-84: invalid case style for variable 'H_Q'
(readability-identifier-naming,-warnings-as-errors)
[error] 85-85: variable 'H_KV' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 85-85: invalid case style for variable 'H_KV'
(readability-identifier-naming,-warnings-as-errors)
[error] 86-86: variable 'S_Q' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 86-86: invalid case style for variable 'S_Q'
(readability-identifier-naming,-warnings-as-errors)
[error] 87-87: variable 'S_KV' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 87-87: invalid case style for variable 'S_KV'
(readability-identifier-naming,-warnings-as-errors)
[error] 88-88: variable 'D' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 88-88: variable name 'D' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 88-88: invalid case style for variable 'D'
(readability-identifier-naming,-warnings-as-errors)
[error] 94-94: variable name 'Q' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 94-94: invalid case style for variable 'Q'
(readability-identifier-naming,-warnings-as-errors)
[error] 95-95: variable name 'K' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 95-95: invalid case style for variable 'K'
(readability-identifier-naming,-warnings-as-errors)
[error] 96-96: variable name 'V' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 96-96: invalid case style for variable 'V'
(readability-identifier-naming,-warnings-as-errors)
[error] 99-99: variable 'gt' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 99-99: variable name 'gt' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 100-100: variable 'predict' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 115-115: method 'testRadixAttn' can be made static
(readability-convert-member-functions-to-static,-warnings-as-errors)
[error] 115-115: all parameters should be named in a function
(readability-named-parameter,-warnings-as-errors)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-android
🔇 Additional comments (13)
mllm/nn/layers/RoPE.hpp (1)
19-19: In-place attribute enabled — confirm alias safety and coverage.Looks good. Please verify RoPE forward/op paths are alias-safe when
inplace=trueand add tests for non-contiguous-last-dim and mixed dtypes.mllm/nn/lmcache/StaticCache.hpp (1)
86-88: Validate sub-slice bounds for K/V buffers.Ensure
start_idx_andlen_lie within the sequence dimension ofk_cache_[i]/v_cache_[i]to avoid runtime slicing errors.mllm/backends/cpu/ops/RMSNormOp.cpp (1)
28-59: Incorrect review comment - x86 path is actually safe for in-place operations.The x86::rmsnorm_fp32 implementation safely handles alias conditions. The loop increments by
hn::Lanes(d)(full vector width), not by 1, ensuring non-overlapping memory access even when input and output pointers are identical. This contrasts with the ARM non-inplace variant, which has overlapping reads due to unit-stride loop increments and therefore requires a separate inplace implementation. The x86 algorithm structure already prevents the aliasing hazard, making an explicit_inplacevariant unnecessary.Likely an incorrect or invalid review comment.
mllm/core/aops/RMSNormOp.cpp (1)
50-52: Setup skip is fine; ensure kernel preconditions are validated elsewhere.Since BaseOp::setup is skipped for inplace, make sure contiguity/dtype checks happen in the backend kernel path.
mllm/nn/layers/Linear.hpp (1)
25-25: LGTM! Redirect attribute enabled for Linear layer.The macro invocation correctly enables redirect functionality for the Linear layer, consistent with the broader redirect infrastructure introduced in this PR.
mllm/core/aops/LinearOp.cpp (1)
107-110: LGTM! Redirect path in setup is correctly implemented.The early return when redirect is enabled is appropriate and consistent with the pattern used in other operations.
mllm/core/aops/RMSNormOp.hpp (1)
36-36: LGTM! Accessor change enables dynamic option configuration.The change from a const accessor to a non-const accessor is intentional and necessary to support dynamic configuration of inplace/redirect behavior at runtime, consistent with similar changes across other operation classes in this PR.
mllm/core/aops/RoPEOp.hpp (1)
36-36: LGTM! Consistent accessor change for dynamic configuration.The accessor change is consistent with similar updates across other operation classes and enables dynamic configuration of inplace behavior, as evidenced by the corresponding implementation changes in RoPEOp.cpp.
mllm/core/aops/LinearOp.hpp (1)
144-144: LGTM! Accessor change aligns with redirect implementation.The change to a non-const accessor is necessary to support the redirect functionality implemented in LinearOp.cpp, where options are checked and potentially modified at runtime.
mllm/core/aops/RoPEOp.cpp (1)
43-45: LGTM! Setup correctly handles inplace mode.The conditional setup logic is appropriate and consistent with the pattern used in other operations.
mllm/nn/Layer.hpp (1)
92-99: LGTM! Inplace attribute macro correctly propagates options.The macro properly sets the inplace flag in both the layer's public options and the underlying op's options, ensuring consistency. The
const_castis necessary here to enable post-construction modification of options.mllm/backends/cpu/kernels/arm/rmsnorm.hpp (1)
18-20: Fix __restrict aliasing; macro suggestion is incorrect and should not be applied.The
__restrictconcern on both X and Y parameters is valid—declaring both as restrict violates the contract if Y==X (the in-place case). Remove__restrictfrom X (or Y, but X is preferred).However, the macro "fix" is incorrect. The review comment suggests changing to
MLLM_HOST_ARCH, which is never defined inmllm/utils/CPUArchHelper.hpp. The canonical ARM macros areMLLM_HOST_ARCH_ARM64andMLLM_HOST_ARCH_ARM. The current rmsnorm.hpp pattern (MLLM_HOST_ARCH_ARM64 || MLLM_HOST_ARCH_ARM) is correct; other files like FlashAttention2Op.cpp (line 39) and RadixAttnOp.cpp (line 31) are the inconsistent ones—they incorrectly use the undefinedMLLM_HOST_ARCH. Applying the suggested diff would introduce dead code and break ARM support.Apply only the restrict fix:
-void rmsnorm_fp32_inplace(const mllm_fp32_t* __restrict X, const mllm_fp32_t* __restrict W, mllm_fp32_t* __restrict Y, int D, +void rmsnorm_fp32_inplace(const mllm_fp32_t* X, const mllm_fp32_t* __restrict W, mllm_fp32_t* Y, int D, float epsilon, bool add_unit_offset, int thread_count); -void rmsnorm_fp16_inplace(const mllm_fp16_t* __restrict X, const mllm_fp16_t* __restrict W, mllm_fp16_t* __restrict Y, int D, +void rmsnorm_fp16_inplace(const mllm_fp16_t* X, const mllm_fp16_t* __restrict W, mllm_fp16_t* Y, int D, float epsilon, bool add_unit_offset, int thread_count);Do not change the macro at line 8; instead, fix FlashAttention2Op.cpp and RadixAttnOp.cpp to match the correct pattern.
Likely an incorrect or invalid review comment.
mllm/backends/cpu/kernels/common/fa2_1/impl-arm.hpp (1)
92-100: Remove this review comment—the aliasing concern is invalid.The analysis shows
acc_o(from__outbuffer) andv_token(from__vbuffer) are entirely distinct tensors. The FA2 algorithm never performs in-place updates;v_tokenis a read-only input from the KV cache whileacc_ois the output accumulator. Both__restrict__annotations are appropriate and safe given this guarantee from the caller. Removing either would be a performance regression with no correctness benefit.Likely an incorrect or invalid review comment.
| #include <cassert> | ||
| #include "mllm/utils/Common.hpp" | ||
|
|
There was a problem hiding this comment.
Fix reserved identifiers and add safe fallback; adjust includes.
- Types and template params starting with “__” are reserved; rename.
- Empty primary templates make calls no-op if specialization not included—dangerous.
- Include /<type_traits>; drop unused .
Apply:
-#include <cassert>
+#include <cstddef>
+#include <type_traits>
#include "mllm/utils/Common.hpp"
namespace mllm::cpu::flash_attn2::details {
-struct __AnyArchTag {};
-using any_arch_tag = __AnyArchTag;
+struct AnyArchTag {};
+using any_arch_tag = AnyArchTag;
-struct __X86ArchTag {};
-using x86_arch_tag = __X86ArchTag;
+struct X86ArchTag {};
+using x86_arch_tag = X86ArchTag;
-struct __ArmArchTag {};
-using arm_arch_tag = __ArmArchTag;
+struct ArmArchTag {};
+using arm_arch_tag = ArmArchTag;
-template<typename __ArgTag, typename __LhsDataType, typename __RhsDataType, typename __DstDataType>
+template<typename ArgTag, typename LhsDataType, typename RhsDataType, typename DstDataType>
struct VectorDotProduct {
- static MLLM_FORCE_INLINE void run(const __LhsDataType* __restrict__ __lhs, const __RhsDataType* __restrict__ __rhs,
- __DstDataType* __out, size_t len) {}
+ template <typename T> struct dependent_false : std::false_type {};
+ static MLLM_FORCE_INLINE void run(const LhsDataType* MLLM_RESTRICT lhs,
+ const RhsDataType* MLLM_RESTRICT rhs,
+ DstDataType* MLLM_RESTRICT out, std::size_t len) {
+ static_assert(dependent_false<ArgTag>::value, "VectorDotProduct: missing specialization for this arch/data type.");
+ }
};
-template<typename __ArgTag, typename __FromDataType, typename __constDataType>
+template<typename ArgTag, typename FromDataType, typename ConstDataType>
struct MulFromConst {
- static MLLM_FORCE_INLINE void run(__FromDataType* __restrict__ __from, const __constDataType const_v, size_t len) {}
+ template <typename T> struct dependent_false : std::false_type {};
+ static MLLM_FORCE_INLINE void run(FromDataType* MLLM_RESTRICT from, const ConstDataType const_v, std::size_t len) {
+ static_assert(dependent_false<ArgTag>::value, "MulFromConst: missing specialization for this arch/data type.");
+ }
};
-template<typename ArchTag, typename T, typename U, typename V>
+template<typename ArchTag, typename T, typename U, typename V>
struct FMAConstArray {
- static MLLM_FORCE_INLINE void run(T* __restrict__ acc_o, const U acc_s, const V* __restrict__ v_token, size_t len) {}
+ template <typename X> struct dependent_false : std::false_type {};
+ static MLLM_FORCE_INLINE void run(T* MLLM_RESTRICT acc_o, const U acc_s, const V* MLLM_RESTRICT v_token, std::size_t len) {
+ static_assert(dependent_false<ArchTag>::value, "FMAConstArray: missing specialization for this arch/data type.");
+ }
};
-template<typename ArchTag, typename T>
+template<typename ArchTag, typename T>
struct FilledWithConst {
- static MLLM_FORCE_INLINE void run(T* __restrict__ a, const T v, size_t len) {}
+ template <typename X> struct dependent_false : std::false_type {};
+ static MLLM_FORCE_INLINE void run(T* MLLM_RESTRICT a, const T v, std::size_t len) {
+ static_assert(dependent_false<ArchTag>::value, "FilledWithConst: missing specialization for this arch/data type.");
+ }
};If you prefer a functional fallback over compile-time error, I can provide scalar implementations instead.
Also applies to: 10-18, 19-38
🧰 Tools
🪛 Clang (14.0.6)
[error] 5-5: 'cassert' file not found
(clang-diagnostic-error)
🤖 Prompt for AI Agents
In mllm/backends/cpu/kernels/common/fa2_1/arch.hpp around lines 5-7, replace the
reserved identifier usage (any types, template parameters, or identifiers
beginning with "__") with non-reserved names, remove the unused #include
<cassert>, and add standard headers <cstddef> and <type_traits>; then replace
empty primary templates with either a safe compile-time trap
(static_assert(false, ...) guarded by dependent condition) or provide a
functional scalar fallback implementation so calls without a specialization
don't silently become no-ops, and ensure all renamed identifiers are
consistently adjusted throughout the file.
| template<typename __ArchTag, typename __QDType, typename __KDType, typename __VDType, typename __ODType, typename __AccDType, | ||
| bool high_precession_exp = true> | ||
| void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, int32_t D, const __QDType* __restrict__ __q, | ||
| __KDType* __k, __VDType* __v, __ODType* __restrict__ __out, int32_t thread_count) { | ||
| int32_t head_repeat_times = H_Q / H_KV; |
There was a problem hiding this comment.
Reserved identifiers and K/V should be const; also assert divisibility H_Q % H_KV == 0.
- Names starting with double underscore are reserved; rename template/param identifiers.
- K/V are not modified; make them const to match call sites and avoid casting away const.
- Add an explicit divisibility assertion; code uses integer division to map heads.
-template<typename __ArchTag, typename __QDType, typename __KDType, typename __VDType, typename __ODType, typename __AccDType,
- bool high_precession_exp = true>
-void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, int32_t D, const __QDType* __restrict__ __q,
- __KDType* __k, __VDType* __v, __ODType* __restrict__ __out, int32_t thread_count) {
- int32_t head_repeat_times = H_Q / H_KV;
+template<typename ArchTag, typename QDType, typename KDType, typename VDType, typename ODType, typename AccDType,
+ bool high_precision_exp = true>
+void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, int32_t D,
+ const QDType* __restrict q, const KDType* __restrict k, const VDType* __restrict v,
+ ODType* __restrict out, int32_t thread_count) {
+ MLLM_RT_ASSERT(H_KV > 0 && H_Q % H_KV == 0);
+ int32_t head_repeat_times = H_Q / H_KV;📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| template<typename __ArchTag, typename __QDType, typename __KDType, typename __VDType, typename __ODType, typename __AccDType, | |
| bool high_precession_exp = true> | |
| void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, int32_t D, const __QDType* __restrict__ __q, | |
| __KDType* __k, __VDType* __v, __ODType* __restrict__ __out, int32_t thread_count) { | |
| int32_t head_repeat_times = H_Q / H_KV; | |
| template<typename ArchTag, typename QDType, typename KDType, typename VDType, typename ODType, typename AccDType, | |
| bool high_precision_exp = true> | |
| void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, int32_t D, | |
| const QDType* __restrict q, const KDType* __restrict k, const VDType* __restrict v, | |
| ODType* __restrict out, int32_t thread_count) { | |
| MLLM_RT_ASSERT(H_KV > 0 && H_Q % H_KV == 0); | |
| int32_t head_repeat_times = H_Q / H_KV; |
🧰 Tools
🪛 Clang (14.0.6)
[error] 25-25: declaration uses identifier '__ArchTag', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 25-25: declaration uses identifier '__QDType', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 25-25: declaration uses identifier '__KDType', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 25-25: declaration uses identifier '__VDType', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 25-25: declaration uses identifier '__ODType', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 25-25: declaration uses identifier '__AccDType', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 27-27: 6 adjacent parameters of 'fwd_bhsd' of similar type ('int') are easily swapped by mistake
(bugprone-easily-swappable-parameters,-warnings-as-errors)
[error] 27-27: parameter name 'B' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 27-27: parameter name 'D' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 27-27: declaration uses identifier '__q', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 28-28: declaration uses identifier '__k', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 28-28: declaration uses identifier '__v', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 28-28: declaration uses identifier '__out', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 29-29: variable 'head_repeat_times' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
🤖 Prompt for AI Agents
In mllm/backends/cpu/kernels/common/fa2_1/fwd_bshd.hpp around lines 25-29,
rename any template type/parameter identifiers that begin with double
underscores to non-reserved names (e.g., __ArchTag -> ArchTag, __QDType ->
QDType, etc.), change the k and v pointer parameters to const (__KDType const*
k, __VDType const* v) to reflect they are not modified, and add an explicit
assertion (e.g., assert(H_Q % H_KV == 0)) before computing head_repeat_times to
ensure H_Q is divisible by H_KV. Ensure call sites and forward declarations are
updated to match the new parameter names and const correctness.
| __KDType* __k, __VDType* __v, __ODType* __restrict__ __out, int32_t thread_count) { | ||
| int32_t head_repeat_times = H_Q / H_KV; | ||
|
|
||
| __AccDType scale = scale = std::sqrt(1.0 / D) * (__AccDType)std::numbers::log2e; |
There was a problem hiding this comment.
Compile-time bug: double assignment “scale = scale = …”.
Fix and keep types consistent.
- __AccDType scale = scale = std::sqrt(1.0 / D) * (__AccDType)std::numbers::log2e;
+ AccDType scale = static_cast<AccDType>(std::sqrt(1.0 / static_cast<double>(D))) *
+ static_cast<AccDType>(std::numbers::log2e);📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| __AccDType scale = scale = std::sqrt(1.0 / D) * (__AccDType)std::numbers::log2e; | |
| __AccDType scale = std::sqrt(1.0 / D) * (__AccDType)std::numbers::log2e; |
🤖 Prompt for AI Agents
In mllm/backends/cpu/kernels/common/fa2_1/fwd_bshd.hpp around line 31, remove
the accidental double assignment ("scale = scale = ...") and make the arithmetic
use consistent types: compute the sqrt on an __AccDType value and cast
std::numbers::log2e to __AccDType before multiplying so the expression is a
single initialization of scale with matching types (e.g., cast the 1.0/D or the
result of std::numbers::log2e to __AccDType and avoid the duplicated "scale ="
token).
| FlashAttn2Module(int H_Q, int H_KV) : nn::Module() {} | ||
|
|
There was a problem hiding this comment.
Constructor params unused.
Either remove the (H_Q, H_KV) ctor or store/validate them. Keeping unused params in tests is harmless but noisy.
- FlashAttn2Module(int H_Q, int H_KV) : nn::Module() {}
+ explicit FlashAttn2Module() = default;🧰 Tools
🪛 Clang (14.0.6)
[error] 20-20: 2 adjacent parameters of 'FlashAttn2Module' of similar type ('int') are easily swapped by mistake
(bugprone-easily-swappable-parameters,-warnings-as-errors)
🤖 Prompt for AI Agents
In tests/cpu/FlashAttentionKernelTest.hpp around lines 20-21, the
FlashAttn2Module constructor declares parameters (int H_Q, int H_KV) but does
not use or store them; remove the unused parameters from the ctor signature or
add member variables to store/validate H_Q and H_KV and initialize them in the
constructor, updating any test code that constructs FlashAttn2Module to match
the new signature if you remove params.
| K = K.transpose(1, 2).repeat(h_q / h_kv, 1); | ||
| V = V.transpose(1, 2).repeat(h_q / h_kv, 1); | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
Head expansion: verify repeat semantics target the head axis.
Ensure repeat(h_q / h_kv, 1) repeats along the head dimension, not batch. If API expects per-axis repeats, consider an explicit API:
If needed:
- K = K.transpose(1, 2).repeat(h_q / h_kv, 1);
- V = V.transpose(1, 2).repeat(h_q / h_kv, 1);
+ K = K.transpose(1, 2).repeat({1, h_q / h_kv, 1, 1}); // repeat on head axis
+ V = V.transpose(1, 2).repeat({1, h_q / h_kv, 1, 1});🏁 Script executed:
#!/bin/bash
# Inspect repeat signatures and call sites
rg -nP 'repeat\s*\(' -C2 --type=cppLength of output: 3194
I need to verify the repeat implementation and confirm the axis semantics.
Now let me check the implementation to confirm the repeat semantics:
Let me inspect the actual repeat implementation to confirm the dimension semantics:
Based on my verification of the code and repeat semantics, I found a critical issue with the dimensions used in the repeat calls.
Incorrect repeat dimension: change dim=1 to dim=2 for head expansion.
The current code transposes K and V from [B, H, S, D] to [B, S, H, D], then calls .repeat(h_q / h_kv, 1). This repeats along dimension 1 (the sequence axis), not the head dimension. After transpose, the head dimension is at index 2, not 1.
The repeat API signature is repeat(multiplier, dim), not per-dimension repeats. The fix should be:
- K = K.transpose(1, 2).repeat(h_q / h_kv, 1);
- V = V.transpose(1, 2).repeat(h_q / h_kv, 1);
+ K = K.transpose(1, 2).repeat(h_q / h_kv, 2);
+ V = V.transpose(1, 2).repeat(h_q / h_kv, 2);This bug also appears in tests/cpu/RadixAttnKernel.hpp (lines 47-48), mllm/nn/lmcache/DynamicCache.cpp (lines 23-24), and test files using the same pattern.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| K = K.transpose(1, 2).repeat(h_q / h_kv, 1); | |
| V = V.transpose(1, 2).repeat(h_q / h_kv, 1); | |
| K = K.transpose(1, 2).repeat(h_q / h_kv, 2); | |
| V = V.transpose(1, 2).repeat(h_q / h_kv, 2); | |
🤖 Prompt for AI Agents
In tests/cpu/FlashAttentionKernelTest.hpp around lines 45-47, the repeat calls
after transposing K and V use dim=1 which repeats the sequence axis; change the
repeat axis to dim=2 so the head dimension is expanded (i.e., call repeat(h_q /
h_kv, 2) for both K and V). Apply the same correction in
tests/cpu/RadixAttnKernel.hpp (lines ~47-48), mllm/nn/lmcache/DynamicCache.cpp
(lines ~23-24), and any other tests/files using transpose(...).repeat(h_q /
h_kv, 1) to ensure the head dimension is the one being repeated.
| auto S_Q = Q.shape()[2]; | ||
| auto S_KV = K.shape()[2]; | ||
| auto mask = Tensor::zeros({1, 1, S_Q, S_KV}); | ||
| { | ||
| auto weight = mllm::nn::functional::matmul(Q, K, false, true); | ||
| // TODO mask weight | ||
| weight = mllm::nn::functional::softmax(weight, -1); | ||
| ref_output = mllm::nn::functional::matmul(weight, V); | ||
| auto ptr = mask.ptr<float>(); | ||
| int __delta = S_KV - S_Q; | ||
| for (int s_q_idx = 0; s_q_idx < S_Q; s_q_idx++) { | ||
| int S_KV_BOUND = std::min(__delta + s_q_idx + 1, S_KV); | ||
| for (int s_kv_idx = S_KV_BOUND; s_kv_idx < S_KV; s_kv_idx++) { | ||
| ptr[s_q_idx * S_KV + s_kv_idx] = -std::numeric_limits<float>::infinity(); | ||
| } | ||
| } |
There was a problem hiding this comment.
Causal mask indexing assumes B=H=1.
Safe for these tests, but brittle. Consider broadcasting to [B,H,S_Q,S_KV] or using a helper.
- auto mask = Tensor::zeros({1, 1, S_Q, S_KV});
+ auto mask = Tensor::zeros({Q.shape()[0], Q.shape()[1], S_Q, S_KV}); // or a causalMask helper📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| auto S_Q = Q.shape()[2]; | |
| auto S_KV = K.shape()[2]; | |
| auto mask = Tensor::zeros({1, 1, S_Q, S_KV}); | |
| { | |
| auto weight = mllm::nn::functional::matmul(Q, K, false, true); | |
| // TODO mask weight | |
| weight = mllm::nn::functional::softmax(weight, -1); | |
| ref_output = mllm::nn::functional::matmul(weight, V); | |
| auto ptr = mask.ptr<float>(); | |
| int __delta = S_KV - S_Q; | |
| for (int s_q_idx = 0; s_q_idx < S_Q; s_q_idx++) { | |
| int S_KV_BOUND = std::min(__delta + s_q_idx + 1, S_KV); | |
| for (int s_kv_idx = S_KV_BOUND; s_kv_idx < S_KV; s_kv_idx++) { | |
| ptr[s_q_idx * S_KV + s_kv_idx] = -std::numeric_limits<float>::infinity(); | |
| } | |
| } | |
| auto S_Q = Q.shape()[2]; | |
| auto S_KV = K.shape()[2]; | |
| auto mask = Tensor::zeros({Q.shape()[0], Q.shape()[1], S_Q, S_KV}); // or a causalMask helper | |
| { | |
| auto ptr = mask.ptr<float>(); | |
| int __delta = S_KV - S_Q; | |
| for (int s_q_idx = 0; s_q_idx < S_Q; s_q_idx++) { | |
| int S_KV_BOUND = std::min(__delta + s_q_idx + 1, S_KV); | |
| for (int s_kv_idx = S_KV_BOUND; s_kv_idx < S_KV; s_kv_idx++) { | |
| ptr[s_q_idx * S_KV + s_kv_idx] = -std::numeric_limits<float>::infinity(); | |
| } | |
| } |
🧰 Tools
🪛 Clang (14.0.6)
[error] 53-53: invalid case style for variable 'S_Q'
(readability-identifier-naming,-warnings-as-errors)
[error] 54-54: invalid case style for variable 'S_KV'
(readability-identifier-naming,-warnings-as-errors)
[error] 58-58: declaration uses identifier '__delta', which is a reserved identifier
(bugprone-reserved-identifier,-warnings-as-errors)
[error] 58-58: variable '__delta' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 58-58: invalid case style for variable '__delta'
(readability-identifier-naming,-warnings-as-errors)
| // Compare | ||
| // rtol and atol set to 1e-2f is because: | ||
| // 1. The eager softmax is approximate, but radix is not. | ||
| auto result = test::allClose(gt, predict, 1e-2f, 1e-2f); | ||
| if (!result) { | ||
| print(result); | ||
| print("S_Q and S_KV is", S_Q, S_KV); | ||
| print(predict); | ||
| return false; | ||
| } |
There was a problem hiding this comment.
Tighten failure diagnostics.
Add max abs/rel error to ease debugging.
- if (!result) {
- print(result);
+ if (!result) {
+ auto diff = (gt - predict).abs();
+ print("max_abs_err:", diff.max().item<float>());📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| // Compare | |
| // rtol and atol set to 1e-2f is because: | |
| // 1. The eager softmax is approximate, but radix is not. | |
| auto result = test::allClose(gt, predict, 1e-2f, 1e-2f); | |
| if (!result) { | |
| print(result); | |
| print("S_Q and S_KV is", S_Q, S_KV); | |
| print(predict); | |
| return false; | |
| } | |
| // Compare | |
| // rtol and atol set to 1e-2f is because: | |
| // 1. The eager softmax is approximate, but radix is not. | |
| auto result = test::allClose(gt, predict, 1e-2f, 1e-2f); | |
| if (!result) { | |
| auto diff = (gt - predict).abs(); | |
| print("max_abs_err:", diff.max().item<float>()); | |
| print("S_Q and S_KV is", S_Q, S_KV); | |
| print(predict); | |
| return false; | |
| } |
🤖 Prompt for AI Agents
In tests/cpu/FlashAttentionKernelTest.hpp around lines 102 to 111, the failure
branch only prints the boolean result, S_Q/S_KV and the predicted tensor;
enhance diagnostics by computing and printing the maximum absolute error and
maximum relative error between gt and predict before returning false. Implement
this by iterating over elements (or using existing tensor utilities) to compute
abs_diff = fabs(gt[i] - predict[i]) and rel_diff = abs_diff / (fabs(gt[i]) +
eps) with a small eps (e.g., 1e-8), track the maximum of each, and include those
max values in the print output alongside S_Q, S_KV and the tensors so failures
show max abs and max rel errors for easier debugging.
| //===----------------------------------------------------------------------===// | ||
| // Radix Attn | ||
| //===----------------------------------------------------------------------===// | ||
| #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) | ||
| #include "FlashAttentionKernelTest.hpp" | ||
| TEST_F(FlashAttn2KernelTest, fwd_bshd) { | ||
| EXPECT_EQ(testRadixAttn({{ | ||
| {"H_Q", 28}, | ||
| {"H_KV", 2}, | ||
| {"S_Q", 10}, | ||
| {"S_KV", 10}, | ||
| {"D", 128}, | ||
| }, | ||
| { | ||
| {"H_Q", 28}, | ||
| {"H_KV", 2}, | ||
| {"S_Q", 10}, | ||
| {"S_KV", 20}, | ||
| {"D", 128}, | ||
| }, | ||
| { | ||
| {"H_Q", 28}, | ||
| {"H_KV", 2}, | ||
| {"S_Q", 1}, | ||
| {"S_KV", 20}, | ||
| {"D", 128}, | ||
| }}), | ||
| true); | ||
| } | ||
| #endif |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Verify test implementation - potential copy-paste issue.
The test block is named FlashAttn2KernelTest but:
- The comment still says "Radix Attn" instead of "FlashAttn2"
- It calls
testRadixAttnrather than a FlashAttn2-specific test function
If testRadixAttn is intentionally reused to test FlashAttn2 functionality, this should be clarified with a comment explaining why. Otherwise, this appears to be a copy-paste error and should call an appropriate FlashAttn2 test function.
Run the following script to check if there's a FlashAttn2-specific test function available:
🏁 Script executed:
#!/bin/bash
# Search for FlashAttn2 or FA2 test functions in the test header
rg -n "test.*[Ff]lash.*[Aa]ttn2|test.*FA2" tests/cpu/FlashAttentionKernelTest.hppLength of output: 0
I need to broaden the investigation since the first search didn't produce output. Let me verify the file structure and look for the test functions more comprehensively.
🏁 Script executed:
#!/bin/bash
# Check if the header file exists and list its structure
echo "=== Checking file existence ==="
ls -la tests/cpu/FlashAttentionKernelTest.hpp 2>&1
echo ""
echo "=== Searching for all test function definitions ==="
rg -n "^[a-zA-Z_].*test.*\(" tests/cpu/FlashAttentionKernelTest.hpp 2>&1
echo ""
echo "=== Searching for testRadixAttn definition ==="
rg -n "testRadixAttn" tests/cpu/ 2>&1
echo ""
echo "=== Checking for FlashAttn2 references in the header ==="
rg -n "FlashAttn2\|Attn2" tests/cpu/FlashAttentionKernelTest.hpp 2>&1Length of output: 1069
Good findings! The file exists and testRadixAttn is defined in the header. Let me examine the actual implementation to understand whether this is intentional reuse or a naming issue.
🏁 Script executed:
#!/bin/bash
# Examine testRadixAttn implementation in FlashAttentionKernelTest.hpp
echo "=== testRadixAttn implementation in FlashAttentionKernelTest.hpp ==="
rg -n -A 20 "bool testRadixAttn\(" tests/cpu/FlashAttentionKernelTest.hpp
echo ""
echo "=== Check for any comments or context around testRadixAttn ==="
rg -n -B 5 -A 15 "bool testRadixAttn\(" tests/cpu/FlashAttentionKernelTest.hpp
echo ""
echo "=== Check class definition and inheritance in FlashAttentionKernelTest.hpp ==="
rg -n -A 5 "class.*KernelTest\|class.*FlashAttn" tests/cpu/FlashAttentionKernelTest.hpp
echo ""
echo "=== List all test functions (TEST_F) in KernelTest.cpp around line 810 ==="
rg -n -B 2 -A 10 "TEST_F\(FlashAttn2KernelTest" tests/cpu/KernelTest.cppLength of output: 1410
I need to find where FlashAttn2KernelTest is actually defined to understand if it's intentionally reusing RadixAttn test functions.
🏁 Script executed:
#!/bin/bash
# Search for FlashAttn2KernelTest class definition
echo "=== Search for FlashAttn2KernelTest class definition ==="
rg -n "class FlashAttn2KernelTest" tests/cpu/
echo ""
echo "=== Search for any class that includes testRadixAttn ==="
rg -n -B 3 "testRadixAttn" tests/cpu/FlashAttentionKernelTest.hpp | head -30
echo ""
echo "=== Look at the entire FlashAttentionKernelTest.hpp to understand structure ==="
cat tests/cpu/FlashAttentionKernelTest.hppLength of output: 4120
Update comment to reflect correct functionality; clarify misleading naming.
The code is functionally correct—the test functions inside FlashAttn2KernelTest do properly test FlashAttn2 (see implementation at line 82-111 in FlashAttentionKernelTest.hpp, which creates FlashAttn2Module and compares against eager attention). However, the naming is misleading:
- The comment "// Radix Attn" at line 810 should be "// FlashAttn2"
- The function names
testRadixAttnandtestRadixAttnOnceare confusing since they test FlashAttn2, not Radix attention
Consider renaming these functions to testFlashAttn2 or adding a clarifying comment explaining why RadixAttn-named functions test FlashAttn2.
🤖 Prompt for AI Agents
In tests/cpu/KernelTest.cpp around lines 810 to 839, update the misleading
comment and naming: change the comment "// Radix Attn" to "// FlashAttn2" and
either rename the helper test functions testRadixAttn/testRadixAttnOnce to
testFlashAttn2/testFlashAttn2Once (and update all call sites) or add a short
clarifying comment above the helpers explaining that these Radix-named helpers
actually exercise FlashAttn2 (pointing to FlashAttentionKernelTest.hpp
implementation), so test names/comments accurately reflect that they verify
FlashAttn2 behavior.
rmsnorm_fp32_inplaceandrmsnorm_fp16_inplacefunctions in ARM kernelsSummary by CodeRabbit
New Features
Infrastructure