-
Notifications
You must be signed in to change notification settings - Fork 749
Metal backend: SDPA metal implementation #16086
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Stack from ghstack (oldest at bottom): |
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16086
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 2 Unrelated FailuresAs of commit a9108f8 with merge base c00d726 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request replaces the MPSGraph-based implementation of Scaled Dot Product Attention (SDPA) with a custom Metal kernel implementation, ported from PyTorch and influenced by MLX.
Key Changes
- Custom Metal kernel: Implements a one-pass SDPA algorithm embedded as a 200+ line inline shader with template instantiations for float, half, and bfloat types across head dimensions of 64, 96, and 128
- Enhanced Metal API: Adds new
setArgoverloads for uint32_t, float, bool, and uint3 types, plus a newdispatchThreadgroupsmethod for explicit threadgroup dispatch - Stride-aware computation: The new kernel handles transposed tensor layouts by decomposing batch and head indices and using explicit stride information
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 14 comments.
| File | Description |
|---|---|
| backends/apple/metal/runtime/shims/et_metal_ops.mm | Replaces ~400 lines of MPSGraph code with inline Metal shader source and direct kernel dispatch; adds shader library caching |
| backends/apple/metal/runtime/shims/et_metal.mm | Implements new setArg overloads for scalar types and uint3 structs; adds dispatchThreadgroups for explicit threadgroup control |
| backends/apple/metal/runtime/shims/et_metal.h | Declares new Metal kernel function methods for argument setting and threadgroup dispatch |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| allocate_mtl_buffer(&out_contents_ptr, out_size_bytes); | ||
|
|
||
| void* attn_contents_ptr = nullptr; | ||
| allocate_mtl_buffer(&attn_contents_ptr, attn_size_bytes); | ||
|
|
||
| // Use MLX-style Metal kernels instead of MPSGraph | ||
| ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Implementing using MLX Metal kernels"); | ||
|
|
||
| // Get shader library | ||
| ETMetalShaderLibrary* library = get_sdpa_shader_library(); | ||
| if (!library) { | ||
| ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to get shader library"); | ||
| throw std::runtime_error("Failed to get SDPA shader library"); | ||
| } | ||
|
|
||
| // Determine kernel name based on dtype and head_dim (PyTorch format) | ||
| std::string type_name; | ||
| if (dtype == static_cast<int32_t>(SupportedDTypes::FLOAT32)) { | ||
| type_name = "float"; | ||
| } else if (dtype == static_cast<int32_t>(SupportedDTypes::BFLOAT16)) { |
Copilot
AI
Dec 4, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Memory leak on error paths. If allocate_mtl_buffer succeeds for attn_contents_ptr (line 1290) but then an exception is thrown before the tensor handle is created (e.g., at line 1299-1300, 1309-1310, 1316-1317, or 1326-1327), the allocated memory is never freed.
Additionally, if attn_contents_ptr allocation succeeds but there's an error, both out_contents_ptr and attn_contents_ptr need to be freed. Consider using RAII or ensuring all error paths properly clean up both allocations.
| allocate_mtl_buffer(&out_contents_ptr, out_size_bytes); | |
| void* attn_contents_ptr = nullptr; | |
| allocate_mtl_buffer(&attn_contents_ptr, attn_size_bytes); | |
| // Use MLX-style Metal kernels instead of MPSGraph | |
| ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Implementing using MLX Metal kernels"); | |
| // Get shader library | |
| ETMetalShaderLibrary* library = get_sdpa_shader_library(); | |
| if (!library) { | |
| ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to get shader library"); | |
| throw std::runtime_error("Failed to get SDPA shader library"); | |
| } | |
| // Determine kernel name based on dtype and head_dim (PyTorch format) | |
| std::string type_name; | |
| if (dtype == static_cast<int32_t>(SupportedDTypes::FLOAT32)) { | |
| type_name = "float"; | |
| } else if (dtype == static_cast<int32_t>(SupportedDTypes::BFLOAT16)) { | |
| void* attn_contents_ptr = nullptr; | |
| try { | |
| allocate_mtl_buffer(&out_contents_ptr, out_size_bytes); | |
| allocate_mtl_buffer(&attn_contents_ptr, attn_size_bytes); | |
| // Use MLX-style Metal kernels instead of MPSGraph | |
| ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Implementing using MLX Metal kernels"); | |
| // Get shader library | |
| ETMetalShaderLibrary* library = get_sdpa_shader_library(); | |
| if (!library) { | |
| ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to get shader library"); | |
| throw std::runtime_error("Failed to get SDPA shader library"); | |
| } | |
| // Determine kernel name based on dtype and head_dim (PyTorch format) | |
| std::string type_name; | |
| if (dtype == static_cast<int32_t>(SupportedDTypes::FLOAT32)) { | |
| type_name = "float"; | |
| } else if (dtype == static_cast<int32_t>(SupportedDTypes::BFLOAT16)) { | |
| type_name = "bfloat16"; | |
| } | |
| // ... rest of the code ... | |
| } catch (...) { | |
| if (out_contents_ptr) { | |
| free_mtl_buffer(out_contents_ptr); | |
| } | |
| if (attn_contents_ptr) { | |
| free_mtl_buffer(attn_contents_ptr); | |
| } | |
| throw; | |
| } |
| if (dtype == static_cast<int32_t>(SupportedDTypes::FLOAT32)) { | ||
| type_name = "float"; | ||
| } else if (dtype == static_cast<int32_t>(SupportedDTypes::BFLOAT16)) { | ||
| type_name = "bfloat"; | ||
| } else { | ||
| ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Unsupported dtype for Metal kernel"); | ||
| throw std::runtime_error("Unsupported dtype for Metal SDPA kernel"); | ||
| } |
Copilot
AI
Dec 4, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Missing support for half/float16 dtype. The shader instantiates kernels for half type (line 337), but the host code only handles FLOAT32 and BFLOAT16 dtypes (lines 1304-1310). If a float16/half tensor is passed, it will throw an "Unsupported dtype" error even though the kernel exists.
Either add support for half/float16 in the host code (checking for the appropriate SupportedDTypes enum value and using type_name = "half"), or remove the half instantiation from the shader if it's not intended to be supported.
| // Dispatch using threadgroups (PyTorch uses grid: [batch*heads, qSize, 1], group: [1024, 1, 1]) | ||
| // Note: We need to use dispatchThreadgroups, not dispatchThreads | ||
| // Each threadgroup processes one query token across all key-value tokens | ||
| kernel_func->dispatchThreadgroups( | ||
| batchSize * num_heads, // gridX | ||
| qSize, // gridY | ||
| 1, // gridZ | ||
| 1024, // threadsX | ||
| 1, // threadsY | ||
| 1); // threadsZ |
Copilot
AI
Dec 4, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing validation: The dispatch uses a hardcoded threadgroup size of 1024 threads (line 1484), but doesn't verify that the Metal device supports this. Different Metal devices have different maximum threadgroup sizes (typically 512-1024). The kernel should either:
- Query the device's maxTotalThreadsPerThreadgroup and use min(1024, maxThreads)
- Document that this implementation requires devices with >= 1024 threads per threadgroup
- Add a runtime check and throw an error if the device doesn't support 1024 threads
The kernel assumes BN=32 (line 165) which means 32 simdgroups with 32 threads each = 1024 total. If the device doesn't support this, the kernel will fail or produce incorrect results.
| return R"( | ||
| // Ported from PyTorch's Attention.metal | ||
| // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/mps/kernels/Attention.metal | ||
| // Largely influeneced by |
Copilot
AI
Dec 4, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spelling error: "influeneced" should be "influenced"
| // Largely influeneced by | |
| // Largely influenced by |
|
|
||
| auto* out_tensor = reinterpret_cast<Tensor*>(out_tensor_handle); | ||
|
|
||
| // Prepare kernel arguments (PyTorch format) |
Copilot
AI
Dec 4, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Potential division by zero if key_tensor->sizes()[1] is 0. The gqa_factor calculation num_heads / key_tensor->sizes()[1] will cause a division by zero error. This should be validated before the division, especially since the code validates other dimensions like headSize.
| // Prepare kernel arguments (PyTorch format) | |
| // Prepare kernel arguments (PyTorch format) | |
| if (key_tensor->sizes()[1] == 0) { | |
| ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: key_tensor->sizes()[1] is zero, division by zero error"); | |
| aoti_torch_mps_free(out_contents_ptr); | |
| aoti_torch_mps_free(attn_contents_ptr); | |
| throw std::runtime_error("Division by zero: key_tensor->sizes()[1] is zero"); | |
| } |
| uint mask_kv_seq_stride = 0; | ||
| uint mask_q_seq_stride = 0; | ||
| if (has_mask_val) { | ||
| auto* mask_tensor = reinterpret_cast<Tensor*>(*attn_mask); |
Copilot
AI
Dec 4, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent indentation: Line 1419 has extra indentation (appears to use spaces instead of the surrounding code's indentation style). This should match the indentation of the surrounding code for consistency.
| auto* mask_tensor = reinterpret_cast<Tensor*>(*attn_mask); | |
| auto* mask_tensor = reinterpret_cast<Tensor*>(*attn_mask); |
| static std::string get_sdpa_metal_source() { | ||
| return R"( | ||
| // Ported from PyTorch's Attention.metal | ||
| // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/mps/kernels/Attention.metal | ||
| // Largely influeneced by | ||
| // https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/scaled_dot_product_attention.metal | ||
| // Modified to support floating point masks and transposed middle dimensions (dims 1 & 2) | ||
| #include <metal_stdlib> | ||
| #include <metal_simdgroup> | ||
| #include <metal_math> | ||
| using namespace metal; | ||
| typedef half float16_t; | ||
| typedef bfloat bfloat16_t; | ||
| // PyTorch's sdpa_vector kernel (one-pass variant) | ||
| template <typename T, int D, int V = D> | ||
| [[kernel]] void sdpa_vector( | ||
| const device T* queries [[buffer(0)]], | ||
| const device T* keys [[buffer(1)]], | ||
| const device T* values [[buffer(2)]], | ||
| device T* out [[buffer(3)]], | ||
| constant uint& gqa_factor [[buffer(4)]], | ||
| constant uint& N [[buffer(5)]], | ||
| constant uint3& qkv_head_strides [[buffer(6)]], | ||
| constant uint3& qkv_seq_strides [[buffer(7)]], | ||
| constant float& scale [[buffer(8)]], | ||
| const device T* mask [[buffer(9)]], // Changed from bool* to T* for floating point masks | ||
| constant uint3& mask_strides [[buffer(10)]], | ||
| constant bool& has_mask [[buffer(11)]], | ||
| constant uint3& qkv_batch_strides [[buffer(12)]], // NEW: batch strides for Q, K, V | ||
| constant uint& num_q_heads [[buffer(13)]], // NEW: number of query heads | ||
| uint3 tid [[threadgroup_position_in_grid]], | ||
| uint3 tpg [[threadgroups_per_grid]], | ||
| uint simd_gid [[simdgroup_index_in_threadgroup]], | ||
| uint simd_lid [[thread_index_in_simdgroup]]) { | ||
| constexpr uint BN = 32; | ||
| constexpr uint BD = 32; | ||
| constexpr uint qk_per_thread = D / BD; | ||
| constexpr uint v_per_thread = V / BD; | ||
| const uint q_head_stride = qkv_head_strides.x; | ||
| const uint q_seq_stride = qkv_seq_strides.x; | ||
| const uint q_batch_stride = qkv_batch_strides.x; | ||
| const uint k_head_stride = qkv_head_strides.y; | ||
| const uint k_seq_stride = qkv_seq_strides.y; | ||
| const uint k_batch_stride = qkv_batch_strides.y; | ||
| const uint v_head_stride = qkv_head_strides.z; | ||
| const uint v_seq_stride = qkv_seq_strides.z; | ||
| const uint v_batch_stride = qkv_batch_strides.z; | ||
| const uint mask_head_stride = mask_strides.x; | ||
| const uint mask_kv_seq_stride = mask_strides.y; | ||
| const uint mask_q_seq_stride = mask_strides.z; | ||
| uint inner_k_stride = BN * int(k_seq_stride); | ||
| uint inner_v_stride = BN * int(v_seq_stride); | ||
| typedef float U; | ||
| thread U q[qk_per_thread]; | ||
| thread U k[qk_per_thread]; | ||
| thread U o[v_per_thread]; | ||
| threadgroup U outputs[BN * BD]; | ||
| threadgroup U max_scores[BN]; | ||
| threadgroup U sum_exp_scores[BN]; | ||
| // Adjust positions | ||
| const int head_idx = tid.x; // Flattened batch*heads index | ||
| const int q_seq_idx = tid.y; | ||
| // Decompose flattened head_idx into batch and head indices | ||
| const int batch_idx = head_idx / num_q_heads; | ||
| const int head_in_batch = head_idx % num_q_heads; | ||
| const int kv_head_idx = head_in_batch / gqa_factor; | ||
| const int Q = tpg.y; | ||
| const int group_offset = head_idx * Q + q_seq_idx; | ||
| const int o_offset = group_offset; | ||
| // Use decomposed indices with separate batch and head strides | ||
| queries += batch_idx * q_batch_stride + head_in_batch * q_head_stride + q_seq_idx * q_seq_stride + | ||
| simd_lid * qk_per_thread; | ||
| keys += batch_idx * k_batch_stride + kv_head_idx * k_head_stride + simd_gid * k_seq_stride + | ||
| simd_lid * qk_per_thread; | ||
| values += batch_idx * v_batch_stride + kv_head_idx * v_head_stride + simd_gid * v_seq_stride + | ||
| simd_lid * v_per_thread; | ||
| if (has_mask) { | ||
| mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + | ||
| q_seq_idx * mask_q_seq_stride; | ||
| } | ||
| out += o_offset * V + simd_gid * v_per_thread; | ||
| // Read the query and 0 the output accumulator | ||
| for (uint i = 0; i < qk_per_thread; i++) { | ||
| q[i] = scale * static_cast<U>(queries[i]); | ||
| } | ||
| for (uint i = 0; i < v_per_thread; i++) { | ||
| o[i] = 0; | ||
| } | ||
| U max_score = -INFINITY; | ||
| U sum_exp_score = 0; | ||
| // For each key | ||
| for (uint i = simd_gid; i < N; i += BN) { | ||
| // Check mask: for floating point masks, values > -1e9 are considered valid (not masked) | ||
| // Masked positions typically have -inf or very negative values | ||
| const bool is_valid = !has_mask || (static_cast<U>(mask[0]) > -1e9f); | ||
| if (is_valid) { | ||
| // Read the key | ||
| for (uint j = 0; j < qk_per_thread; j++) { | ||
| k[j] = static_cast<U>(keys[j]); | ||
| } | ||
| // Compute the i-th score | ||
| U score = 0; | ||
| for (uint j = 0; j < qk_per_thread; j++) { | ||
| score += q[j] * k[j]; | ||
| } | ||
| score = simd_sum(score); | ||
| // Add mask value to score if mask is present | ||
| if (has_mask) { | ||
| score += static_cast<U>(mask[0]); | ||
| } | ||
| // Update the accumulators | ||
| U new_max = max(max_score, score); | ||
| U factor = metal::fast::exp(max_score - new_max); | ||
| U exp_score = metal::fast::exp(score - new_max); | ||
| max_score = new_max; | ||
| sum_exp_score = sum_exp_score * factor + exp_score; | ||
| // Update the output accumulator | ||
| for (uint j = 0; j < v_per_thread; j++) { | ||
| o[j] = o[j] * factor + exp_score * static_cast<U>(values[j]); | ||
| } | ||
| } | ||
| // Move the pointers to the next kv | ||
| keys += inner_k_stride; | ||
| values += inner_v_stride; | ||
| if (has_mask) { | ||
| mask += BN * mask_kv_seq_stride; | ||
| } | ||
| } | ||
| // Each thread has a partial part of the output so we need to combine them. | ||
| // First let's communicate the max and sum_exp | ||
| if (simd_lid == 0) { | ||
| max_scores[simd_gid] = max_score; | ||
| sum_exp_scores[simd_gid] = sum_exp_score; | ||
| } | ||
| threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| max_score = max_scores[simd_lid]; | ||
| U new_max = simd_max(max_score); | ||
| U factor = metal::fast::exp(max_score - new_max); | ||
| sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); | ||
| // Now we need to aggregate all the outputs | ||
| for (uint i = 0; i < v_per_thread; i++) { | ||
| outputs[simd_lid * BD + simd_gid] = o[i]; | ||
| threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| const U safe_sum = (sum_exp_score == 0 ? 1e-6f : sum_exp_score); | ||
| o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / safe_sum; | ||
| threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| } | ||
| // And write the output | ||
| if (simd_lid == 0) { | ||
| for (uint i = 0; i < v_per_thread; i++) { | ||
| out[i] = static_cast<T>(o[i]); | ||
| } | ||
| } | ||
| } | ||
| #define INSTANTIATE_SDPA_VECTOR(DTYPE, QK_DIM, VALUE_DIM) \ | ||
| template [[host_name("sdpa_vector_" #DTYPE "_" #QK_DIM \ | ||
| "_" #VALUE_DIM)]] kernel void \ | ||
| sdpa_vector<DTYPE, QK_DIM, VALUE_DIM>( \ | ||
| const device DTYPE* queries [[buffer(0)]], \ | ||
| const device DTYPE* keys [[buffer(1)]], \ | ||
| const device DTYPE* values [[buffer(2)]], \ | ||
| device DTYPE* out [[buffer(3)]], \ | ||
| constant uint& gqa_factor [[buffer(4)]], \ | ||
| constant uint& N [[buffer(5)]], \ | ||
| constant uint3& qkv_head_strides [[buffer(6)]], \ | ||
| constant uint3& qkv_seq_strides [[buffer(7)]], \ | ||
| constant float& scale [[buffer(8)]], \ | ||
| const device DTYPE* mask [[buffer(9)]], \ | ||
| constant uint3& mask_strides [[buffer(10)]], \ | ||
| constant bool& has_mask [[buffer(11)]], \ | ||
| constant uint3& qkv_batch_strides [[buffer(12)]], \ | ||
| constant uint& num_q_heads [[buffer(13)]], \ | ||
| uint3 tid [[threadgroup_position_in_grid]], \ | ||
| uint3 tpg [[threadgroups_per_grid]], \ | ||
| uint simd_gid [[simdgroup_index_in_threadgroup]], \ | ||
| uint simd_lid [[thread_index_in_simdgroup]]); | ||
| #define INSTANTIATE_SDPA_VECTOR_HEADS(DTYPE) \ | ||
| INSTANTIATE_SDPA_VECTOR(DTYPE, 64, 64); \ | ||
| INSTANTIATE_SDPA_VECTOR(DTYPE, 96, 96); \ | ||
| INSTANTIATE_SDPA_VECTOR(DTYPE, 128, 128); | ||
| INSTANTIATE_SDPA_VECTOR_HEADS(float); | ||
| INSTANTIATE_SDPA_VECTOR_HEADS(half); | ||
| INSTANTIATE_SDPA_VECTOR_HEADS(bfloat); | ||
| )"; | ||
| } |
Copilot
AI
Dec 4, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maintainability concern: The Metal shader source code is embedded as a 200+ line raw string literal within the C++ source file. This makes the shader code harder to maintain because:
- No syntax highlighting for Metal shading language
- Difficult to debug shader compilation errors (line numbers won't match)
- Increases compilation time for the C++ file
- Makes version control diffs harder to read
Consider moving the shader source to a separate .metal file and either:
- Using a build-time tool to embed it as a string constant
- Loading it at runtime from a resource file
- Using Metal's offline shader compilation if available
This is especially important for a 200+ line shader that may need future modifications.
| // Check mask: for floating point masks, values > -1e9 are considered valid (not masked) | ||
| // Masked positions typically have -inf or very negative values | ||
| const bool is_valid = !has_mask || (static_cast<U>(mask[0]) > -1e9f); | ||
| if (is_valid) { | ||
| // Read the key | ||
| for (uint j = 0; j < qk_per_thread; j++) { | ||
| k[j] = static_cast<U>(keys[j]); | ||
| } | ||
| // Compute the i-th score | ||
| U score = 0; | ||
| for (uint j = 0; j < qk_per_thread; j++) { | ||
| score += q[j] * k[j]; | ||
| } | ||
| score = simd_sum(score); | ||
| // Add mask value to score if mask is present | ||
| if (has_mask) { | ||
| score += static_cast<U>(mask[0]); | ||
| } |
Copilot
AI
Dec 4, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing documentation: The mask handling logic is subtle and could benefit from better documentation. The code:
- Checks if mask[0] > -1e9f to determine validity (line 236)
- If valid, adds mask[0] to the score (line 253)
This implements an optimization where extremely negative mask values (like -inf) skip processing entirely, while other mask values (including moderately negative ones) are added to scores. The comment should clarify:
- Why -1e9f is the threshold (vs -inf)
- That this is an optimization to skip computation for masked-out positions
- How this differs from standard additive masking where all mask values are added
This will help future maintainers understand the intended behavior and avoid "fixing" this seemingly redundant check.
| typedef half float16_t; | ||
| typedef bfloat bfloat16_t; |
Copilot
AI
Dec 4, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing documentation: The type aliases float16_t and bfloat16_t are defined but only bfloat is actually used in the instantiations (line 338). The half type is used directly rather than float16_t. Consider either:
- Using the aliases consistently (use
float16_tinstead ofhalf) - Removing the unused
float16_ttypedef - Adding a comment explaining why
float16_tis defined but not used
This inconsistency could cause confusion for future maintainers.
| typedef half float16_t; | |
| typedef bfloat bfloat16_t; |
|
|
||
| void* out_contents_ptr = nullptr; | ||
| allocate_mtl_buffer(&out_contents_ptr, out_size_bytes); | ||
|
|
||
| void* attn_contents_ptr = nullptr; | ||
| allocate_mtl_buffer(&attn_contents_ptr, attn_size_bytes); |
Copilot
AI
Dec 4, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Memory leak on error paths. If allocate_mtl_buffer succeeds for out_contents_ptr (line 1287) but then an exception is thrown before the tensor handle is created (e.g., at line 1299-1300, 1309-1310, 1316-1317, or 1326-1327), the allocated memory is never freed.
The code should free out_contents_ptr if any error occurs after its allocation. Consider adding proper cleanup using RAII (e.g., unique_ptr with custom deleter) or ensure all error paths after line 1287 call aoti_torch_mps_free(out_contents_ptr).
| void* out_contents_ptr = nullptr; | |
| allocate_mtl_buffer(&out_contents_ptr, out_size_bytes); | |
| void* attn_contents_ptr = nullptr; | |
| allocate_mtl_buffer(&attn_contents_ptr, attn_size_bytes); | |
| // Use RAII to manage allocated memory and avoid leaks on error paths | |
| struct MTLBufferDeleter { | |
| void operator()(void* ptr) const { | |
| if (ptr) { | |
| aoti_torch_mps_free(ptr); | |
| } | |
| } | |
| }; | |
| std::unique_ptr<void, MTLBufferDeleter> out_contents_ptr_raii; | |
| void* out_contents_ptr = nullptr; | |
| allocate_mtl_buffer(&out_contents_ptr, out_size_bytes); | |
| out_contents_ptr_raii.reset(out_contents_ptr); | |
| std::unique_ptr<void, MTLBufferDeleter> attn_contents_ptr_raii; | |
| void* attn_contents_ptr = nullptr; | |
| allocate_mtl_buffer(&attn_contents_ptr, attn_size_bytes); | |
| attn_contents_ptr_raii.reset(attn_contents_ptr); |
mergennachin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
Are you ignoring "is_causal" altogether?
int32_t is_causal, -
Can we compile the metal shader at build time? Isn't it jit compiling?
| if (head_dim != 64 && head_dim != 96 && head_dim != 128) { | ||
| ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Unsupported head_dim %lld (must be 64, 96, or 128)", head_dim); | ||
| throw std::runtime_error("Unsupported head_dim for Metal SDPA kernel - must be exactly 64, 96, or 128"); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the main reason for only limiting to these head sizes?
| U factor = metal::fast::exp(max_score - new_max); | ||
| U exp_score = metal::fast::exp(score - new_max); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick/micro optimizatoin: Can you save on exponentation if max_score == score?
Replaces SDPA MPSGraph's implementation with Metal implementation (adapted from MLX implementation, with several modifications, to support transposed middle dimensions, and floating point attention masks).
Speeds up voxtral/whisper by 2-3x
Fixes BFloat16 issue on macOS 26.1