Add support for AMX instructions#5818
Conversation
This will only be included if LLVM >= 12 is used to build Halide
|
(Synced to head to fix some irrelevant LLVM build issues) |
That is a good excuse for me to setup LLVM 13 locally. I will try to see why it is not building with it |
Recent changes in LLVM trunk made the previous calling convention deprecated (and thus compiling with warning/error)
|
The OSX failure is unrelated (will be fixed by #5841), should be good to land |
|
You should sync this to master to force the bots to retry. |
I'm not sure if the buildbot is still running since there is a "cancelled" message there. |
|
Please try syncing to master once again; hopefully the buildbots will finally be clean. |
Thanks, I will do that, hopefully it will all be green now. |
|
Failures are the unrelated cuda-hang failure that we still haven't diagnosed; ok to land |
src/CodeGen_X86.cpp
Outdated
| {"dpwssdx16", Int(32, 16), "dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_SapphireRapids}, | ||
| {"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids}, | ||
| {"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids}, | ||
| {"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids}, |
There was a problem hiding this comment.
nit: irrelevant whitespace?
| if (Halide_LLVM_VERSION VERSION_GREATER_EQUAL 12.0) | ||
| # AMX instructions require LLVM 12 or newer | ||
| list(APPEND RUNTIME_LL x86_amx) | ||
| endif () | ||
|
|
There was a problem hiding this comment.
Does including this fail at build time or at runtime only?
There was a problem hiding this comment.
this fails at build time with the following message
[1/7] Generating initmod.x86_amx.bc
FAILED: src/runtime/initmod.x86_amx.bc /home/frederik/projects/halide/build-11/src/runtime/initmod.x86_amx.bc
cd /home/frederik/projects/halide/build-11/src/runtime && /usr/lib/llvm-11/bin/llvm-as /home/frederik/projects/halide/src/runtime/x86_amx.ll -o initmod.x86_amx.bc
/usr/lib/llvm-11/bin/llvm-as: /home/frederik/projects/halide/src/runtime/x86_amx.ll:3:18: error: expected type
%2 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %rows, i16 %colbytes, i8* %1, i64 %stride) nounwind readonly
^
|
Just letting you know we haven't lost track of this and the TensorCore PRs. We had some different priorities and annual leave. I look forward to getting this merged soon. |
aac8f78 to
f0f9f3e
Compare
|
I don't think the test failures are related to anything in this PR |
steven-johnson
left a comment
There was a problem hiding this comment.
LGTM so far -- aside from style nits, I think it would be good to split the new test into correctness and performance tests, as Halide does for virtually all other features.
src/ExtractTileOperations.cpp
Outdated
| @@ -0,0 +1,414 @@ | |||
| #include "ExtractTileOperations.h" | |||
|
|
|||
| #include "IRMatch.h" // expr_match | |||
There was a problem hiding this comment.
Nit: We don't usually add comments explaining why each header is included.
There was a problem hiding this comment.
Speaking of which, it might be a good idea to run IWYU on our codebase...
src/ExtractTileOperations.cpp
Outdated
|
|
||
| enum class AMXOpType { | ||
| Int8, | ||
| Bf16, |
There was a problem hiding this comment.
Nit: I assume this is bfloat16? If so, spelling it out (eg Bfloat16) would be preferable.
| case AMXOpType::Bf16: | ||
| return Float(32, 256); | ||
| default: | ||
| return Type(); |
There was a problem hiding this comment.
I assume this is a should-never-happen case, so doing something like internal_error << "Unexpected"; would be appropriate.
src/ExtractTileOperations.cpp
Outdated
| const auto wild_i32 = Variable::make(Int(32), "*"); | ||
| const auto wild_i32x = Variable::make(Int(32, 0), "*"); | ||
|
|
||
| Tile<2> is_2d_tile_index(const Expr &e) { |
There was a problem hiding this comment.
Nit: I'd expect a function named "is_whatever" to return bool, but this returns a struct. Something like get_2d_tile_index would be better.
src/ExtractTileOperations.cpp
Outdated
| return {}; | ||
| } | ||
|
|
||
| Tile<3> is_3d_tile_index(const Expr &e) { |
src/ExtractTileOperations.cpp
Outdated
| // 4 bytes for i32, f32 | ||
| auto colbytes = tile_y * 4; | ||
| auto matmul = | ||
| Call::make(res_type, "tile_matmul", {tile_x, colbytes, tile_r, out, lhs, rhs}, Call::Intrinsic); |
There was a problem hiding this comment.
No need to split this into two lines
src/ExtractTileOperations.cpp
Outdated
| op_type = AMXOpType::Bf16; | ||
| } | ||
|
|
||
| user_assert(!in_allocate) << "Found two possible tile allocations for AMX allocation"; |
There was a problem hiding this comment.
Would it be helpful to append amx_name or tile_name to the error message, for debugging purposes?
src/ExtractTileOperations.cpp
Outdated
| } | ||
|
|
||
| auto alloc_type = amx_op_type_result_type(op_type); | ||
|
|
There was a problem hiding this comment.
No need for this blank line
src/ExtractTileOperations.cpp
Outdated
| } | ||
|
|
||
| auto body = mutate(op->body); | ||
| return ProducerConsumer::make(amx_name, op->is_producer, body); |
There was a problem hiding this comment.
Nit: std::move(body) ?
test/performance/tiled_matmul.cpp
Outdated
| .vectorize(mmyi); | ||
|
|
||
| Func result = mm.in(); | ||
| //result.print_loop_nest(); |
There was a problem hiding this comment.
Don't check in commented-out code (unless there is a comment explaining why, as is done elsewhere in this file)
|
When converting to correctness tests there's a bit of a change in the pattern for the rhs load when using |
When using `Buffer` instead of `ImageParam` the `Ramp` expression generated is 1D instead of 2D, therefore we recognize this with a special case. The lanes are still matched against the dimensions of the LHS 3d tile lanes.
|
I think the recent commits addressed all comments, is there anything else that needs to be addressed? |
|
lgtm. The pattern matching seems to be pretty ad-hoc and possibly brittle, but that can always be improved later. The checks for LLVM 12 will be removed pretty soon too. |
|
This is failing for LLVM11 for Makefile-based builds. I'll see if I can prep a patch. |
This pull request continues the work started by @jwlawson in #5780 with the objective of adding initial support for AMX instructions in Halide.
The main addition here is the fix for building Halide using LLVM 11. Support for AMX instructions requires LLVM 12 or newer so when building with LLVM 11 the unsupported instructions are not included.
A new LLVM module was created (x86_amx.ll) to contain all the required intrinsics to enable support for tile operations, this module is only included when LLVM >= 12 is present.