Metal backend: Implement the AOTI MPS shim#15022
Conversation
|
Stack from ghstack (oldest at bottom): |
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/15022
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 Cancelled JobAs of commit 6f6fd58 with merge base 6e0c9f6 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| auto src_mtl_buffer = (id<MTLBuffer>)src_buffer; | ||
| auto dst_mtl_buffer = (id<MTLBuffer>)dst_buffer; | ||
|
|
||
| uint8_t* src_contents = static_cast<uint8_t*>([src_mtl_buffer contents]); | ||
| uint8_t* dst_contents = static_cast<uint8_t*>([dst_mtl_buffer contents]); | ||
|
|
||
| if (!src_contents || !dst_contents) { | ||
| ET_LOG(Error, "aoti_torch_mps_copy_buffer: Failed to get buffer contents"); | ||
| return Error::Internal; | ||
| } | ||
|
|
||
| memcpy(dst_contents + dst_offset, src_contents + src_offset, data_size); |
There was a problem hiding this comment.
aoti_torch_mps_free and aoti_torch_mps_memcpy expect contents ptr, but aoti_torch_mps_copy_buffer expects MTLBuffer objects. Is this intentional?
Isn't it something like this?
auto src_it = ptr_to_mtl_buffer.find(src_buffer);
auto dst_it = ptr_to_mtl_buffer.find(dst_buffer);
if (src_it == ptr_to_mtl_buffer.end()) {
ET_LOG(Error, "aoti_torch_mps_copy_buffer: src_buffer %p not found", src_buffer);
return Error::InvalidArgument;
}
if (dst_it == ptr_to_mtl_buffer.end()) {
ET_LOG(Error, "aoti_torch_mps_copy_buffer: dst_buffer %p not found", dst_buffer);
return Error::InvalidArgument;
}
id<MTLBuffer> src_mtl_buffer = src_it->second;
id<MTLBuffer> dst_mtl_buffer = dst_it->second;
ETMetalStream* stream = getCurrentMetalStream();
stream->copy(src_mtl_buffer, dst_mtl_buffer, data_size, src_offset, dst_offset, SyncType::NONE);
There was a problem hiding this comment.
It is intentional to make things work in this first landing. It is a workaround, because right now I can't create an ET tensor with a Metal buffer, I need the contents pointer. I think to avoid this workaround I need to make changes in executorch::extension::from_blob. I want to look into this later.
| id<MTLBuffer> subBuffer = [device newBufferWithBytesNoCopy:buffer_pointer + constant_offset | ||
| length:data_size | ||
| options:MTLResourceCPUCacheModeWriteCombined | MTLResourceStorageModeShared | ||
| deallocator:nil]; | ||
|
|
||
| if (constant_offset != 0) { | ||
| ptr_to_mtl_buffer[buffer_pointer + constant_offset] = subBuffer; // Map contents to buffer | ||
| } |
There was a problem hiding this comment.
Why do you need this at all? subBuffer doesn't seem to be used anywhere
There was a problem hiding this comment.
This is another workaround, because right now AOTI MPS creates a single buffer with all the constants (data tensors). This is one key difference between AOTI MPS vs AOTI CUDA.
However, when I am calling ops implemented with MPSGraph, which take in data tensors (right now convolution and sdpa), I need to call initWithMTLBuffer to create an MPSGraphTensor from the buffer. But initWithMTLBuffer doesn't let me pass an offset. So, I need to have the data tensor in its own buffer.
| } | ||
| } | ||
|
|
||
| AOTITorchError aoti_torch_mps_get_kernel_function( |
| } | ||
| } | ||
|
|
||
| AOTITorchError aoti_torch_mps_start_encoding( |
| // Pure C dispatch functions - array versions | ||
| AOTITorchError aoti_torch_mps_dispatch_array( | ||
| AOTIMetalKernelFunctionHandle func, | ||
| const uint64_t* length, |
There was a problem hiding this comment.
Validate length != nullptr
Includes: - Shader library management - Kernel function handling - Command buffer execution - Metal memory operations ghstack-source-id: cba3048 ghstack-comment-id: 3392300374 Pull-Request: pytorch#15022
Includes: - Shader library management - Kernel function handling - Command buffer execution - Metal memory operations
Includes: