Skip to content

Conversation

@manuelcandales
Copy link
Contributor

@manuelcandales manuelcandales commented Dec 4, 2025

Replaces SDPA MPSGraph's implementation with Metal implementation (adapted from MLX implementation, with several modifications, to support transposed middle dimensions, and floating point attention masks).

Speeds up voxtral/whisper by 2-3x

Fixes BFloat16 issue on macOS 26.1

[ghstack-poisoned]
@manuelcandales
Copy link
Contributor Author

manuelcandales commented Dec 4, 2025

Stack from ghstack (oldest at bottom):

Copilot AI review requested due to automatic review settings December 4, 2025 21:11
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16086

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 2 Unrelated Failures

As of commit a9108f8 with merge base c00d726 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

manuelcandales added a commit that referenced this pull request Dec 4, 2025
ghstack-source-id: aa77e4b
ghstack-comment-id: 3614336034
Pull-Request: #16086
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 4, 2025
@github-actions
Copy link

github-actions bot commented Dec 4, 2025

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copilot finished reviewing on behalf of manuelcandales December 4, 2025 21:17
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request replaces the MPSGraph-based implementation of Scaled Dot Product Attention (SDPA) with a custom Metal kernel implementation, ported from PyTorch and influenced by MLX.

Key Changes

  • Custom Metal kernel: Implements a one-pass SDPA algorithm embedded as a 200+ line inline shader with template instantiations for float, half, and bfloat types across head dimensions of 64, 96, and 128
  • Enhanced Metal API: Adds new setArg overloads for uint32_t, float, bool, and uint3 types, plus a new dispatchThreadgroups method for explicit threadgroup dispatch
  • Stride-aware computation: The new kernel handles transposed tensor layouts by decomposing batch and head indices and using explicit stride information

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 14 comments.

File Description
backends/apple/metal/runtime/shims/et_metal_ops.mm Replaces ~400 lines of MPSGraph code with inline Metal shader source and direct kernel dispatch; adds shader library caching
backends/apple/metal/runtime/shims/et_metal.mm Implements new setArg overloads for scalar types and uint3 structs; adds dispatchThreadgroups for explicit threadgroup control
backends/apple/metal/runtime/shims/et_metal.h Declares new Metal kernel function methods for argument setting and threadgroup dispatch

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1287 to +1306
allocate_mtl_buffer(&out_contents_ptr, out_size_bytes);

void* attn_contents_ptr = nullptr;
allocate_mtl_buffer(&attn_contents_ptr, attn_size_bytes);

// Use MLX-style Metal kernels instead of MPSGraph
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Implementing using MLX Metal kernels");

// Get shader library
ETMetalShaderLibrary* library = get_sdpa_shader_library();
if (!library) {
ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to get shader library");
throw std::runtime_error("Failed to get SDPA shader library");
}

// Determine kernel name based on dtype and head_dim (PyTorch format)
std::string type_name;
if (dtype == static_cast<int32_t>(SupportedDTypes::FLOAT32)) {
type_name = "float";
} else if (dtype == static_cast<int32_t>(SupportedDTypes::BFLOAT16)) {
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Memory leak on error paths. If allocate_mtl_buffer succeeds for attn_contents_ptr (line 1290) but then an exception is thrown before the tensor handle is created (e.g., at line 1299-1300, 1309-1310, 1316-1317, or 1326-1327), the allocated memory is never freed.

Additionally, if attn_contents_ptr allocation succeeds but there's an error, both out_contents_ptr and attn_contents_ptr need to be freed. Consider using RAII or ensuring all error paths properly clean up both allocations.

Suggested change
allocate_mtl_buffer(&out_contents_ptr, out_size_bytes);
void* attn_contents_ptr = nullptr;
allocate_mtl_buffer(&attn_contents_ptr, attn_size_bytes);
// Use MLX-style Metal kernels instead of MPSGraph
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Implementing using MLX Metal kernels");
// Get shader library
ETMetalShaderLibrary* library = get_sdpa_shader_library();
if (!library) {
ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to get shader library");
throw std::runtime_error("Failed to get SDPA shader library");
}
// Determine kernel name based on dtype and head_dim (PyTorch format)
std::string type_name;
if (dtype == static_cast<int32_t>(SupportedDTypes::FLOAT32)) {
type_name = "float";
} else if (dtype == static_cast<int32_t>(SupportedDTypes::BFLOAT16)) {
void* attn_contents_ptr = nullptr;
try {
allocate_mtl_buffer(&out_contents_ptr, out_size_bytes);
allocate_mtl_buffer(&attn_contents_ptr, attn_size_bytes);
// Use MLX-style Metal kernels instead of MPSGraph
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Implementing using MLX Metal kernels");
// Get shader library
ETMetalShaderLibrary* library = get_sdpa_shader_library();
if (!library) {
ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to get shader library");
throw std::runtime_error("Failed to get SDPA shader library");
}
// Determine kernel name based on dtype and head_dim (PyTorch format)
std::string type_name;
if (dtype == static_cast<int32_t>(SupportedDTypes::FLOAT32)) {
type_name = "float";
} else if (dtype == static_cast<int32_t>(SupportedDTypes::BFLOAT16)) {
type_name = "bfloat16";
}
// ... rest of the code ...
} catch (...) {
if (out_contents_ptr) {
free_mtl_buffer(out_contents_ptr);
}
if (attn_contents_ptr) {
free_mtl_buffer(attn_contents_ptr);
}
throw;
}

Copilot uses AI. Check for mistakes.
Comment on lines +1304 to +1311
if (dtype == static_cast<int32_t>(SupportedDTypes::FLOAT32)) {
type_name = "float";
} else if (dtype == static_cast<int32_t>(SupportedDTypes::BFLOAT16)) {
type_name = "bfloat";
} else {
ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Unsupported dtype for Metal kernel");
throw std::runtime_error("Unsupported dtype for Metal SDPA kernel");
}
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Missing support for half/float16 dtype. The shader instantiates kernels for half type (line 337), but the host code only handles FLOAT32 and BFLOAT16 dtypes (lines 1304-1310). If a float16/half tensor is passed, it will throw an "Unsupported dtype" error even though the kernel exists.

Either add support for half/float16 in the host code (checking for the appropriate SupportedDTypes enum value and using type_name = "half"), or remove the half instantiation from the shader if it's not intended to be supported.

Copilot uses AI. Check for mistakes.
Comment on lines +1477 to +1486
// Dispatch using threadgroups (PyTorch uses grid: [batch*heads, qSize, 1], group: [1024, 1, 1])
// Note: We need to use dispatchThreadgroups, not dispatchThreads
// Each threadgroup processes one query token across all key-value tokens
kernel_func->dispatchThreadgroups(
batchSize * num_heads, // gridX
qSize, // gridY
1, // gridZ
1024, // threadsX
1, // threadsY
1); // threadsZ
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing validation: The dispatch uses a hardcoded threadgroup size of 1024 threads (line 1484), but doesn't verify that the Metal device supports this. Different Metal devices have different maximum threadgroup sizes (typically 512-1024). The kernel should either:

  1. Query the device's maxTotalThreadsPerThreadgroup and use min(1024, maxThreads)
  2. Document that this implementation requires devices with >= 1024 threads per threadgroup
  3. Add a runtime check and throw an error if the device doesn't support 1024 threads

The kernel assumes BN=32 (line 165) which means 32 simdgroups with 32 threads each = 1024 total. If the device doesn't support this, the kernel will fail or produce incorrect results.

Copilot uses AI. Check for mistakes.
return R"(
// Ported from PyTorch's Attention.metal
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/mps/kernels/Attention.metal
// Largely influeneced by
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spelling error: "influeneced" should be "influenced"

Suggested change
// Largely influeneced by
// Largely influenced by

Copilot uses AI. Check for mistakes.

auto* out_tensor = reinterpret_cast<Tensor*>(out_tensor_handle);

// Prepare kernel arguments (PyTorch format)
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Potential division by zero if key_tensor->sizes()[1] is 0. The gqa_factor calculation num_heads / key_tensor->sizes()[1] will cause a division by zero error. This should be validated before the division, especially since the code validates other dimensions like headSize.

Suggested change
// Prepare kernel arguments (PyTorch format)
// Prepare kernel arguments (PyTorch format)
if (key_tensor->sizes()[1] == 0) {
ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: key_tensor->sizes()[1] is zero, division by zero error");
aoti_torch_mps_free(out_contents_ptr);
aoti_torch_mps_free(attn_contents_ptr);
throw std::runtime_error("Division by zero: key_tensor->sizes()[1] is zero");
}

Copilot uses AI. Check for mistakes.
uint mask_kv_seq_stride = 0;
uint mask_q_seq_stride = 0;
if (has_mask_val) {
auto* mask_tensor = reinterpret_cast<Tensor*>(*attn_mask);
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent indentation: Line 1419 has extra indentation (appears to use spaces instead of the surrounding code's indentation style). This should match the indentation of the surrounding code for consistency.

Suggested change
auto* mask_tensor = reinterpret_cast<Tensor*>(*attn_mask);
auto* mask_tensor = reinterpret_cast<Tensor*>(*attn_mask);

Copilot uses AI. Check for mistakes.
Comment on lines +127 to +340
static std::string get_sdpa_metal_source() {
return R"(
// Ported from PyTorch's Attention.metal
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/mps/kernels/Attention.metal
// Largely influeneced by
// https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/scaled_dot_product_attention.metal
// Modified to support floating point masks and transposed middle dimensions (dims 1 & 2)
#include <metal_stdlib>
#include <metal_simdgroup>
#include <metal_math>
using namespace metal;
typedef half float16_t;
typedef bfloat bfloat16_t;
// PyTorch's sdpa_vector kernel (one-pass variant)
template <typename T, int D, int V = D>
[[kernel]] void sdpa_vector(
const device T* queries [[buffer(0)]],
const device T* keys [[buffer(1)]],
const device T* values [[buffer(2)]],
device T* out [[buffer(3)]],
constant uint& gqa_factor [[buffer(4)]],
constant uint& N [[buffer(5)]],
constant uint3& qkv_head_strides [[buffer(6)]],
constant uint3& qkv_seq_strides [[buffer(7)]],
constant float& scale [[buffer(8)]],
const device T* mask [[buffer(9)]], // Changed from bool* to T* for floating point masks
constant uint3& mask_strides [[buffer(10)]],
constant bool& has_mask [[buffer(11)]],
constant uint3& qkv_batch_strides [[buffer(12)]], // NEW: batch strides for Q, K, V
constant uint& num_q_heads [[buffer(13)]], // NEW: number of query heads
uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr uint BN = 32;
constexpr uint BD = 32;
constexpr uint qk_per_thread = D / BD;
constexpr uint v_per_thread = V / BD;
const uint q_head_stride = qkv_head_strides.x;
const uint q_seq_stride = qkv_seq_strides.x;
const uint q_batch_stride = qkv_batch_strides.x;
const uint k_head_stride = qkv_head_strides.y;
const uint k_seq_stride = qkv_seq_strides.y;
const uint k_batch_stride = qkv_batch_strides.y;
const uint v_head_stride = qkv_head_strides.z;
const uint v_seq_stride = qkv_seq_strides.z;
const uint v_batch_stride = qkv_batch_strides.z;
const uint mask_head_stride = mask_strides.x;
const uint mask_kv_seq_stride = mask_strides.y;
const uint mask_q_seq_stride = mask_strides.z;
uint inner_k_stride = BN * int(k_seq_stride);
uint inner_v_stride = BN * int(v_seq_stride);
typedef float U;
thread U q[qk_per_thread];
thread U k[qk_per_thread];
thread U o[v_per_thread];
threadgroup U outputs[BN * BD];
threadgroup U max_scores[BN];
threadgroup U sum_exp_scores[BN];
// Adjust positions
const int head_idx = tid.x; // Flattened batch*heads index
const int q_seq_idx = tid.y;
// Decompose flattened head_idx into batch and head indices
const int batch_idx = head_idx / num_q_heads;
const int head_in_batch = head_idx % num_q_heads;
const int kv_head_idx = head_in_batch / gqa_factor;
const int Q = tpg.y;
const int group_offset = head_idx * Q + q_seq_idx;
const int o_offset = group_offset;
// Use decomposed indices with separate batch and head strides
queries += batch_idx * q_batch_stride + head_in_batch * q_head_stride + q_seq_idx * q_seq_stride +
simd_lid * qk_per_thread;
keys += batch_idx * k_batch_stride + kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
simd_lid * qk_per_thread;
values += batch_idx * v_batch_stride + kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
simd_lid * v_per_thread;
if (has_mask) {
mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
out += o_offset * V + simd_gid * v_per_thread;
// Read the query and 0 the output accumulator
for (uint i = 0; i < qk_per_thread; i++) {
q[i] = scale * static_cast<U>(queries[i]);
}
for (uint i = 0; i < v_per_thread; i++) {
o[i] = 0;
}
U max_score = -INFINITY;
U sum_exp_score = 0;
// For each key
for (uint i = simd_gid; i < N; i += BN) {
// Check mask: for floating point masks, values > -1e9 are considered valid (not masked)
// Masked positions typically have -inf or very negative values
const bool is_valid = !has_mask || (static_cast<U>(mask[0]) > -1e9f);
if (is_valid) {
// Read the key
for (uint j = 0; j < qk_per_thread; j++) {
k[j] = static_cast<U>(keys[j]);
}
// Compute the i-th score
U score = 0;
for (uint j = 0; j < qk_per_thread; j++) {
score += q[j] * k[j];
}
score = simd_sum(score);
// Add mask value to score if mask is present
if (has_mask) {
score += static_cast<U>(mask[0]);
}
// Update the accumulators
U new_max = max(max_score, score);
U factor = metal::fast::exp(max_score - new_max);
U exp_score = metal::fast::exp(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
for (uint j = 0; j < v_per_thread; j++) {
o[j] = o[j] * factor + exp_score * static_cast<U>(values[j]);
}
}
// Move the pointers to the next kv
keys += inner_k_stride;
values += inner_v_stride;
if (has_mask) {
mask += BN * mask_kv_seq_stride;
}
}
// Each thread has a partial part of the output so we need to combine them.
// First let's communicate the max and sum_exp
if (simd_lid == 0) {
max_scores[simd_gid] = max_score;
sum_exp_scores[simd_gid] = sum_exp_score;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_score = max_scores[simd_lid];
U new_max = simd_max(max_score);
U factor = metal::fast::exp(max_score - new_max);
sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
// Now we need to aggregate all the outputs
for (uint i = 0; i < v_per_thread; i++) {
outputs[simd_lid * BD + simd_gid] = o[i];
threadgroup_barrier(mem_flags::mem_threadgroup);
const U safe_sum = (sum_exp_score == 0 ? 1e-6f : sum_exp_score);
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / safe_sum;
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// And write the output
if (simd_lid == 0) {
for (uint i = 0; i < v_per_thread; i++) {
out[i] = static_cast<T>(o[i]);
}
}
}
#define INSTANTIATE_SDPA_VECTOR(DTYPE, QK_DIM, VALUE_DIM) \
template [[host_name("sdpa_vector_" #DTYPE "_" #QK_DIM \
"_" #VALUE_DIM)]] kernel void \
sdpa_vector<DTYPE, QK_DIM, VALUE_DIM>( \
const device DTYPE* queries [[buffer(0)]], \
const device DTYPE* keys [[buffer(1)]], \
const device DTYPE* values [[buffer(2)]], \
device DTYPE* out [[buffer(3)]], \
constant uint& gqa_factor [[buffer(4)]], \
constant uint& N [[buffer(5)]], \
constant uint3& qkv_head_strides [[buffer(6)]], \
constant uint3& qkv_seq_strides [[buffer(7)]], \
constant float& scale [[buffer(8)]], \
const device DTYPE* mask [[buffer(9)]], \
constant uint3& mask_strides [[buffer(10)]], \
constant bool& has_mask [[buffer(11)]], \
constant uint3& qkv_batch_strides [[buffer(12)]], \
constant uint& num_q_heads [[buffer(13)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 tpg [[threadgroups_per_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define INSTANTIATE_SDPA_VECTOR_HEADS(DTYPE) \
INSTANTIATE_SDPA_VECTOR(DTYPE, 64, 64); \
INSTANTIATE_SDPA_VECTOR(DTYPE, 96, 96); \
INSTANTIATE_SDPA_VECTOR(DTYPE, 128, 128);
INSTANTIATE_SDPA_VECTOR_HEADS(float);
INSTANTIATE_SDPA_VECTOR_HEADS(half);
INSTANTIATE_SDPA_VECTOR_HEADS(bfloat);
)";
}
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maintainability concern: The Metal shader source code is embedded as a 200+ line raw string literal within the C++ source file. This makes the shader code harder to maintain because:

  1. No syntax highlighting for Metal shading language
  2. Difficult to debug shader compilation errors (line numbers won't match)
  3. Increases compilation time for the C++ file
  4. Makes version control diffs harder to read

Consider moving the shader source to a separate .metal file and either:

  • Using a build-time tool to embed it as a string constant
  • Loading it at runtime from a resource file
  • Using Metal's offline shader compilation if available

This is especially important for a 200+ line shader that may need future modifications.

Copilot uses AI. Check for mistakes.
Comment on lines +234 to +254
// Check mask: for floating point masks, values > -1e9 are considered valid (not masked)
// Masked positions typically have -inf or very negative values
const bool is_valid = !has_mask || (static_cast<U>(mask[0]) > -1e9f);
if (is_valid) {
// Read the key
for (uint j = 0; j < qk_per_thread; j++) {
k[j] = static_cast<U>(keys[j]);
}
// Compute the i-th score
U score = 0;
for (uint j = 0; j < qk_per_thread; j++) {
score += q[j] * k[j];
}
score = simd_sum(score);
// Add mask value to score if mask is present
if (has_mask) {
score += static_cast<U>(mask[0]);
}
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing documentation: The mask handling logic is subtle and could benefit from better documentation. The code:

  1. Checks if mask[0] > -1e9f to determine validity (line 236)
  2. If valid, adds mask[0] to the score (line 253)

This implements an optimization where extremely negative mask values (like -inf) skip processing entirely, while other mask values (including moderately negative ones) are added to scores. The comment should clarify:

  • Why -1e9f is the threshold (vs -inf)
  • That this is an optimization to skip computation for masked-out positions
  • How this differs from standard additive masking where all mask values are added

This will help future maintainers understand the intended behavior and avoid "fixing" this seemingly redundant check.

Copilot uses AI. Check for mistakes.
Comment on lines +141 to +142
typedef half float16_t;
typedef bfloat bfloat16_t;
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing documentation: The type aliases float16_t and bfloat16_t are defined but only bfloat is actually used in the instantiations (line 338). The half type is used directly rather than float16_t. Consider either:

  1. Using the aliases consistently (use float16_t instead of half)
  2. Removing the unused float16_t typedef
  3. Adding a comment explaining why float16_t is defined but not used

This inconsistency could cause confusion for future maintainers.

Suggested change
typedef half float16_t;
typedef bfloat bfloat16_t;

Copilot uses AI. Check for mistakes.
Comment on lines +1285 to +1290

void* out_contents_ptr = nullptr;
allocate_mtl_buffer(&out_contents_ptr, out_size_bytes);

void* attn_contents_ptr = nullptr;
allocate_mtl_buffer(&attn_contents_ptr, attn_size_bytes);
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Memory leak on error paths. If allocate_mtl_buffer succeeds for out_contents_ptr (line 1287) but then an exception is thrown before the tensor handle is created (e.g., at line 1299-1300, 1309-1310, 1316-1317, or 1326-1327), the allocated memory is never freed.

The code should free out_contents_ptr if any error occurs after its allocation. Consider adding proper cleanup using RAII (e.g., unique_ptr with custom deleter) or ensure all error paths after line 1287 call aoti_torch_mps_free(out_contents_ptr).

Suggested change
void* out_contents_ptr = nullptr;
allocate_mtl_buffer(&out_contents_ptr, out_size_bytes);
void* attn_contents_ptr = nullptr;
allocate_mtl_buffer(&attn_contents_ptr, attn_size_bytes);
// Use RAII to manage allocated memory and avoid leaks on error paths
struct MTLBufferDeleter {
void operator()(void* ptr) const {
if (ptr) {
aoti_torch_mps_free(ptr);
}
}
};
std::unique_ptr<void, MTLBufferDeleter> out_contents_ptr_raii;
void* out_contents_ptr = nullptr;
allocate_mtl_buffer(&out_contents_ptr, out_size_bytes);
out_contents_ptr_raii.reset(out_contents_ptr);
std::unique_ptr<void, MTLBufferDeleter> attn_contents_ptr_raii;
void* attn_contents_ptr = nullptr;
allocate_mtl_buffer(&attn_contents_ptr, attn_size_bytes);
attn_contents_ptr_raii.reset(attn_contents_ptr);

Copilot uses AI. Check for mistakes.
Copy link
Contributor

@mergennachin mergennachin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +1315 to +1318
if (head_dim != 64 && head_dim != 96 && head_dim != 128) {
ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Unsupported head_dim %lld (must be 64, 96, or 128)", head_dim);
throw std::runtime_error("Unsupported head_dim for Metal SDPA kernel - must be exactly 64, 96, or 128");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the main reason for only limiting to these head sizes?

Comment on lines +258 to +259
U factor = metal::fast::exp(max_score - new_max);
U exp_score = metal::fast::exp(score - new_max);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick/micro optimizatoin: Can you save on exponentation if max_score == score?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants