feat(cpu, smollm3-tokenizer): add KAI SGEMM NEON implementation for ARM#503
feat(cpu, smollm3-tokenizer): add KAI SGEMM NEON implementation for ARM#503chenghuaWang merged 5 commits intoUbiquitousLearning:v2from
Conversation
- Introduce `KaiLinear_fp32_fp32_fp32p_mxk_kxn` kernel for fp32 GEMM on ARM NEON - Add new linear implementation types: `kMllmBlas_KAI_SGEMM_NT_NT_NEON` and `kMllmBlas_KAI_SGEMM_NT_T_SME` - Update CMake options with Android performance hints and profiling components - Enhance ParameterFile loading with optional mmap support - Refactor matmul tests to include manual reference computation - Add Android performance hint headers for future optimizations This commit enables high-performance fp32 linear operations on ARM CPUs using KAI kernels, provides better control over memory mapping during model loading, and improves test coverage for BLAS-like operations.
…support - Add CMakeLists.txt for smollm3 example executable - Implement main.cpp with SmolLM3Tokenizer usage - Include tokenization logic with thinking/non-thinking templates - Support dynamic date insertion in chat templates - Enable BPE-based encoding/decoding workflows
|
Important Review skippedAuto reviews are disabled on base/target branches other than the default branch. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the WalkthroughThis PR adds SmolLM3 model support with a dedicated tokenizer implementation, introduces fp32 linear kernel implementations for ARM64 via Kai, enables memory-mapped (mmap) parameter file loading, integrates new Kai SGEMM NEON/SME backend implementations, and adds Android burst performance hints configuration across the build system and CPU backend infrastructure. Changes
Sequence Diagram(s)sequenceDiagram
participant User as Application
participant Loader as ParameterFile::load()
participant Reader as ParameterFileIOImpl::read()
participant File as Disk/Memory
participant Tensor as TensorStorage
User->>Loader: load(file_name, version, device, mmap=true)
Loader->>Reader: read(file_path, mmap=true)
alt mmap enabled
Reader->>File: mmap(file_path)
Reader->>File: validate header
Reader->>File: parse descriptors from mapped region
Reader->>Tensor: create tensor with MMAP memory type
Tensor-->>Reader: tensor view into mapped data
else mmap disabled
Reader->>File: open binary file
Reader->>File: read header
Reader->>File: read/allocate descriptors
Reader->>File: seek to data offset
Reader->>Tensor: allocate storage
Reader->>File: read data into storage
end
Reader-->>Loader: ParameterFile with tensors
Loader-->>User: ParameterFile::ptr_t
sequenceDiagram
participant App as Application
participant LinearOp as LinearOp
participant Loader as Load Phase
participant Forward as Forward Phase
participant Kai as KaiLinear_fp32
App->>LinearOp: load(impl_type=KAI_SGEMM_NT_NT_NEON)
LinearOp->>Loader: identify impl_type
Loader->>Loader: select BLAS or default backend
Loader->>Loader: pretranspose weight if needed
Loader->>Kai: quant_pack_rhs_offline(weight, bias)
Kai-->>Loader: packed_weight
Loader->>LinearOp: store packed_weight
App->>LinearOp: forward(lhs, weight, bias)
LinearOp->>Forward: route to KAI_SGEMM_NT_NT_NEON case
Forward->>Kai: matmul(dst, lhs, packed_weight, workspace, M, K, N, threads)
Kai->>Kai: tile matmul with M/N steps
Kai->>Kai: invoke ukernel for each tile
Kai-->>Forward: compute result into dst
Forward-->>LinearOp: dst
LinearOp-->>App: output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Areas requiring extra attention:
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches✅ Passed checks (2 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Change memory type assignment from kGlobal to kParamsNormal in ParameterFile.cpp to correctly handle parameter allocation.
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 9
🧹 Nitpick comments (6)
tests/cpu/MllmBlasArmSgemmKernelTest.hpp (3)
27-44: Optional: address static analysis style hints.The static analyzer flags several style issues in this block:
- Variable naming convention:
DST,M,K,N(expected snake_case or camelCase per project style)- Variable name length:
M,K,Nare flagged as too short- Float literal suffix:
0.0fshould be0.0F(uppercase)These are not functional issues but could improve consistency if the project enforces these rules.
Example diff (if adopting lowercase naming and uppercase suffix):
- auto DST = mllm::Tensor::emptyLike(RefDST).alloc(); + auto dst = mllm::Tensor::emptyLike(RefDST).alloc(); // Calculate DST. { - auto dst_ptr = DST.ptr<float>(); + auto dst_ptr = dst.ptr<float>(); auto a_ptr = A.ptr<float>(); auto b_ptr = B.ptr<float>(); - const int M = S_Q; - const int K = S_KV; - const int N = D; - for (int i = 0; i < M; ++i) { - for (int j = 0; j < N; ++j) { - float sum = 0.0f; - for (int k = 0; k < K; ++k) { sum += a_ptr[i * K + k] * b_ptr[k * N + j]; } - dst_ptr[i * N + j] = sum; + const int num_rows = S_Q; + const int inner_dim = S_KV; + const int num_cols = D; + for (int i = 0; i < num_rows; ++i) { + for (int j = 0; j < num_cols; ++j) { + float sum = 0.0F; + for (int k = 0; k < inner_dim; ++k) { sum += a_ptr[i * inner_dim + k] * b_ptr[k * num_cols + j]; } + dst_ptr[i * num_cols + j] = sum; } } } - auto result = mllm::test::allClose(DST, RefDST); + auto result = mllm::test::allClose(dst, RefDST);
71-88: Optional: address static analysis style hints.Similar to the first test, the static analyzer flags style issues:
- Variable naming:
DST,M,K,N- Variable name length:
M,K,N- Float literal suffix:
0.0fvs0.0FThese are the same concerns as in the previous function. If you decide to address them, apply consistent changes across both test functions.
29-44: Consider extracting common matmul logic to reduce duplication.Both test functions contain nearly identical nested-loop matrix multiplication logic, differing only in the inner loop's B matrix access pattern (transposed vs non-transposed). Consider extracting a helper function to reduce duplication:
Example helper approach:
// Helper to compute C = A * B (or A * B^T if transpose_b is true) static void compute_matmul_cpu( float* c_ptr, const float* a_ptr, const float* b_ptr, int M, int K, int N, bool transpose_b) { for (int i = 0; i < M; ++i) { for (int j = 0; j < N; ++j) { float sum = 0.0f; for (int k = 0; k < K; ++k) { const float b_val = transpose_b ? b_ptr[j * K + k] : b_ptr[k * N + j]; sum += a_ptr[i * K + k] * b_val; } c_ptr[i * N + j] = sum; } } }Then call it from both test functions, reducing maintenance burden and improving readability.
Also applies to: 73-88
CMakeLists.txt (1)
45-46: Consider wrapping the long option description for better readability.The new
MLLM_ANDROID_BURST_PERFORMANCE_HINTSoption is a good addition for Android performance optimization. However, the description string on Line 46 is quite long (over 120 characters). Consider wrapping it across multiple lines for better maintainability.Apply this diff to improve readability:
-option(MLLM_ANDROID_BURST_PERFORMANCE_HINTS "If MLLM need use APerformanceHintManager to tell android we need best performance" OFF) +option(MLLM_ANDROID_BURST_PERFORMANCE_HINTS + "If MLLM need use APerformanceHintManager to tell android we need best performance" + OFF)Alternatively, shorten the description:
-option(MLLM_ANDROID_BURST_PERFORMANCE_HINTS "If MLLM need use APerformanceHintManager to tell android we need best performance" OFF) +option(MLLM_ANDROID_BURST_PERFORMANCE_HINTS + "Enable Android APerformanceHintManager for burst performance hints" + OFF)mllm/backends/cpu/ops/MatMulOp.cpp (1)
53-54: TODO comment indicates known kGGUF bug.The comment on Line 53 flags that the kGGUF matmul type is still buggy, yet Line 54 still selects it under certain conditions. This could lead to incorrect results in production.
Do you want me to:
- Generate a verification script to check if there's an existing issue tracking this bug?
- Open a new issue to track this bug with details about when kGGUF is selected and what the expected fix timeline is?
- Suggest adding a runtime warning when kGGUF is selected to alert users of potential issues?
examples/smollm3/main.cpp (1)
13-16: Consider adding error handling for invalid tokenizer path.The example creates a
SmolLM3Tokenizerwithout verifying that the provided path is valid. If the tokenizer file is missing or corrupted, this will likely throw an exception or crash. For a user-facing example, consider adding error handling to provide a helpful message.{ + try { auto tokenizer = mllm::models::smollm3::SmolLM3Tokenizer(tokenizer_path.get()); auto ids = tokenizer.encode(tokenizer.applyChatTemplate("Bonjour 😈", false)); mllm::print(ids); mllm::print(tokenizer.decode(ids)); + } catch (const std::exception& e) { + fmt::print(stderr, "Error loading tokenizer from '{}': {}\n", tokenizer_path.get(), e.what()); + return 1; + } }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (17)
CMakeLists.txt(1 hunks)examples/CMakeLists.txt(1 hunks)examples/smollm3/CMakeLists.txt(1 hunks)examples/smollm3/main.cpp(1 hunks)mllm/backends/cpu/kernels/arm/linear/kai.cpp(2 hunks)mllm/backends/cpu/kernels/arm/linear/kai.hpp(2 hunks)mllm/backends/cpu/ops/LinearOp.cpp(4 hunks)mllm/backends/cpu/ops/MatMulOp.cpp(1 hunks)mllm/core/ParameterFile.cpp(2 hunks)mllm/core/ParameterFile.hpp(1 hunks)mllm/core/aops/LinearOp.hpp(1 hunks)mllm/engine/hints/Android.hpp(1 hunks)mllm/mllm.cpp(1 hunks)mllm/mllm.hpp(1 hunks)mllm/models/smollm3_3B/tokenization_smollm3.hpp(1 hunks)tests/cpu/MllmBlasArmSgemmKernelTest.hpp(2 hunks)tests/cpu/MllmBlasArmSgemvKernelTest.hpp(2 hunks)
🧰 Additional context used
🪛 Clang (14.0.6)
tests/cpu/MllmBlasArmSgemvKernelTest.hpp
[error] 22-22: variable name 'A' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 22-22: invalid case style for variable 'A'
(readability-identifier-naming,-warnings-as-errors)
[error] 23-23: variable name 'B' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 23-23: invalid case style for variable 'B'
(readability-identifier-naming,-warnings-as-errors)
[error] 24-24: variable name 'C' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 24-24: invalid case style for variable 'C'
(readability-identifier-naming,-warnings-as-errors)
[error] 25-25: invalid case style for variable 'DST'
(readability-identifier-naming,-warnings-as-errors)
[error] 35-35: invalid case style for variable 'DSTP'
(readability-identifier-naming,-warnings-as-errors)
mllm/engine/hints/Android.hpp
[error] 5-5: 'android/performance_hint.h' file not found
(clang-diagnostic-error)
tests/cpu/MllmBlasArmSgemmKernelTest.hpp
[error] 27-27: invalid case style for variable 'DST'
(readability-identifier-naming,-warnings-as-errors)
[error] 34-34: variable name 'M' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 34-34: invalid case style for variable 'M'
(readability-identifier-naming,-warnings-as-errors)
[error] 35-35: variable name 'K' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 35-35: invalid case style for variable 'K'
(readability-identifier-naming,-warnings-as-errors)
[error] 36-36: variable name 'N' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 36-36: invalid case style for variable 'N'
(readability-identifier-naming,-warnings-as-errors)
[error] 39-39: floating point literal has suffix 'f', which is not uppercase
(readability-uppercase-literal-suffix,-warnings-as-errors)
[error] 71-71: invalid case style for variable 'DST'
(readability-identifier-naming,-warnings-as-errors)
[error] 78-78: variable name 'M' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 78-78: invalid case style for variable 'M'
(readability-identifier-naming,-warnings-as-errors)
[error] 79-79: variable name 'K' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 79-79: invalid case style for variable 'K'
(readability-identifier-naming,-warnings-as-errors)
[error] 80-80: variable name 'N' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 80-80: invalid case style for variable 'N'
(readability-identifier-naming,-warnings-as-errors)
[error] 83-83: floating point literal has suffix 'f', which is not uppercase
(readability-uppercase-literal-suffix,-warnings-as-errors)
mllm/backends/cpu/kernels/arm/linear/kai.hpp
[error] 56-56: invalid case style for class 'KaiLinear_fp32_fp32_fp32p_mxk_kxn'
(readability-identifier-naming,-warnings-as-errors)
[error] 57-57: method 'need_pack_lhs' can be made static
(readability-convert-member-functions-to-static,-warnings-as-errors)
[error] 59-59: method 'need_pack_rhs' can be made static
(readability-convert-member-functions-to-static,-warnings-as-errors)
[error] 66-66: parameter name 'K' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 66-66: parameter name 'N' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 69-69: parameter name 'M' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 69-69: parameter name 'K' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 69-69: parameter name 'N' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 72-72: variable 'ukernel_' is non-const and globally accessible, consider making it const
(cppcoreguidelines-avoid-non-const-global-variables,-warnings-as-errors)
[error] 72-72: invalid case style for variable 'ukernel_'
(readability-identifier-naming,-warnings-as-errors)
mllm/models/smollm3_3B/tokenization_smollm3.hpp
[error] 5-5: 'string' file not found
(clang-diagnostic-error)
[error] 21-21: do not declare C-style arrays, use std::array<> instead
(cppcoreguidelines-avoid-c-arrays,-warnings-as-errors)
[error] 25-25: constructor does not initialize these fields: prompt
(cppcoreguidelines-pro-type-member-init,-warnings-as-errors)
[error] 29-29: variable 'no_think_template_str' is non-const and globally accessible, consider making it const
(cppcoreguidelines-avoid-non-const-global-variables,-warnings-as-errors)
[error] 35-35: variable 'think_template_str' is non-const and globally accessible, consider making it const
(cppcoreguidelines-avoid-non-const-global-variables,-warnings-as-errors)
[error] 53-53: constructor does not initialize these fields: bpe_
(cppcoreguidelines-pro-type-member-init,-warnings-as-errors)
[error] 69-69: method 'replaceAll' can be made static
(readability-convert-member-functions-to-static,-warnings-as-errors)
[error] 69-69: parameter name 's' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 69-69: 2 adjacent parameters of 'replaceAll' of similar type ('const int &') are easily swapped by mistake
(bugprone-easily-swappable-parameters,-warnings-as-errors)
[error] 69-69: parameter name 'to' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 71-71: variable 'pos' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 79-79: variable 't' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 79-79: variable name 't' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 80-80: variable 'tm_' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 80-80: invalid case style for variable 'tm_'
(readability-identifier-naming,-warnings-as-errors)
[error] 82-82: variable 'oss' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 84-84: variable 'date_in_number' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 86-86: do not declare C-style arrays, use std::array<> instead
(cppcoreguidelines-avoid-c-arrays,-warnings-as-errors)
[error] 88-88: variable 'month' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 89-89: variable 'year' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 91-91: variable 'tpl' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 144-144: variable 'ret' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
examples/smollm3/main.cpp
[error] 1-1: 'fmt/core.h' file not found
(clang-diagnostic-error)
[error] 8-8: variable 'MLLM_MAIN' is non-const and globally accessible, consider making it const
(cppcoreguidelines-avoid-non-const-global-variables,-warnings-as-errors)
[error] 8-8: invalid case style for variable 'MLLM_MAIN'
(readability-identifier-naming,-warnings-as-errors)
mllm/mllm.cpp
[error] 95-95: 3 adjacent parameters of 'load' of similar type are easily swapped by mistake
(bugprone-easily-swappable-parameters,-warnings-as-errors)
[error] 96-96: repeated branch in conditional chain
(bugprone-branch-clone,-warnings-as-errors)
mllm/backends/cpu/kernels/arm/linear/kai.cpp
[error] 4-4: 'limits' file not found
(clang-diagnostic-error)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-android
- GitHub Check: build-macos
🔇 Additional comments (7)
tests/cpu/MllmBlasArmSgemvKernelTest.hpp (2)
22-25: LGTM! Improved test coverage and corrected output shape.The changes improve the test in two ways:
- Random tensor initialization with range [-1, 1] provides better coverage than fixed values
- The DST shape correction from
{1, D}to{1, S}properly reflects the matrix multiplication result:{1, D} @ {D, S} = {1, S}
35-35: LGTM! Consistent shape correction.The DSTP shape now correctly matches DST's
{1, S}shape, ensuring both the baseline and optimized implementations are tested with the proper output dimensions.tests/cpu/MllmBlasArmSgemmKernelTest.hpp (2)
27-44: Good approach: manual CPU reference for BLAS testing.Replacing the BLAS matmul call with an explicit nested-loop implementation provides a clear, simple reference to validate the BLAS kernels against. The logic correctly computes the matrix product for row-major layout.
71-88: Correctly implements transposed matrix multiply.The manual computation properly handles the transposed B matrix by accessing
b_ptr[j * K + k]instead ofb_ptr[k * N + j], which correctly represents B^T[k,j] in row-major layout.mllm/models/smollm3_3B/tokenization_smollm3.hpp (1)
127-128: Clarify the comment about regex processing.Line 128 states "No need to Regex:" and then creates a single-element initializer list. This suggests regex processing was intentionally skipped or is a placeholder. Consider clarifying why regex processing is not needed here, or if this is temporary code that should be implemented.
Based on the GPT2_EXPR defined at Line 21 and the
byteLevelPreTokenizerusingunicode_regex_splitat Line 164, it appears regex processing should occur. Please verify if this is intentional or if regex-based splitting should be applied here before byte-level pre-tokenization.examples/CMakeLists.txt (1)
10-10: LGTM!The addition of the
smollm3subdirectory properly integrates the new SmolLM3 example into the build system.mllm/mllm.cpp (1)
95-100: LGTM! The mmap parameter is properly propagated.The addition of the
mmapparameter to theloadfunction signature and its propagation to the underlyingParameterFileIOImpl::readcalls is clean and consistent. This enables optional memory-mapped file loading for better performance.Note: The static analysis warning about "easily swappable parameters" is a minor concern. If you want to make the API more robust against parameter swapping, consider using an options struct or strong types in the future, but this is not critical for the current change.
When bias is null, allocate and explicitly zero-initialize the bias array to ensure correct behavior during offline packing. fix(cpu): use transposed weight dimensions for packing calculations Corrected the dimension parameters passed to quant_pack_rhs_size and quant_pack_rhs_offline to use transposed weight tensor sizes. feat(core): add new MllmBlas KAI SGEMM implementation types Registered new linear implementation types for KAI-based SGEMM with NEON and SME backends in both NT/NT and NT/T configurations. feat(engine): conditionally include Android performance hints Wrapped Android-specific headers in preprocessor guards to avoid build errors on non-Android platforms. refactor(models): mark template strings as const in SmolLM3 tokenizer Changed static inline string variables to be explicitly const to enforce immutability and improve code clarity.
feat(cpu): add KAI SGEMM NEON implementation for ARM
KaiLinear_fp32_fp32_fp32p_mxk_kxnkernel for fp32 GEMM on ARM NEONkMllmBlas_KAI_SGEMM_NT_NT_NEONandkMllmBlas_KAI_SGEMM_NT_T_SMEThis commit enables high-performance fp32 linear operations on ARM CPUs using KAI kernels,
provides better control over memory mapping during model loading, and improves test coverage
for BLAS-like operations.
feat(examples): add smollm3 example with tokenizer and chat template
Summary by CodeRabbit
Release Notes
New Features
Performance
Configuration