From 05d2751580d467fc3c64e2b5e8e586f4f78b48e1 Mon Sep 17 00:00:00 2001 From: John Lawson Date: Wed, 17 Feb 2021 15:10:55 +0000 Subject: [PATCH 01/11] Add support for AMX tile instructions --- src/CMakeLists.txt | 2 + src/CodeGen_X86.cpp | 48 ++++- src/Expr.h | 4 + src/ExtractTileOperations.cpp | 293 ++++++++++++++++++++++++++++++ src/ExtractTileOperations.h | 20 ++ src/FuseGPUThreadLoops.cpp | 2 + src/IRPrinter.cpp | 3 + src/Lower.cpp | 10 + src/runtime/x86_avx512.ll | 38 ++++ test/performance/CMakeLists.txt | 1 + test/performance/tiled_matmul.cpp | 130 +++++++++++++ 11 files changed, 550 insertions(+), 1 deletion(-) create mode 100644 src/ExtractTileOperations.cpp create mode 100644 src/ExtractTileOperations.h create mode 100644 test/performance/tiled_matmul.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8bc68eeeb26c..84f71065f4ef 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -59,6 +59,7 @@ set(HEADER_FILES ExprUsesVar.h Extern.h ExternFuncArgument.h + ExtractTileOperations.h FastIntegerDivide.h FindCalls.h FindIntrinsics.h @@ -218,6 +219,7 @@ set(SOURCE_FILES EmulateFloat16Math.cpp Error.cpp Expr.cpp + ExtractTileOperations.cpp FastIntegerDivide.cpp FindCalls.cpp FindIntrinsics.cpp diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 2038dcce75c8..0aa64e089845 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -1,4 +1,5 @@ #include "CodeGen_Posix.h" + #include "ConciseCasts.h" #include "Debug.h" #include "IRMatch.h" @@ -79,15 +80,21 @@ class CodeGen_X86 : public CodeGen_Posix { void visit(const EQ *) override; void visit(const NE *) override; void visit(const Select *) override; + void visit(const Allocate *) override; + void visit(const Load *) override; + void visit(const Store *) override; void codegen_vector_reduce(const VectorReduce *, const Expr &init) override; // @} + +private: + Scope mem_type; }; CodeGen_X86::CodeGen_X86(Target t) : CodeGen_Posix(complete_x86_target(t)) { } -const int max_intrinsic_args = 4; +const int max_intrinsic_args = 6; struct x86Intrinsic { const char *intrin_name; @@ -184,6 +191,13 @@ const x86Intrinsic intrinsic_defs[] = { {"dpwssdx16", Int(32, 16), "dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_SapphireRapids}, {"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids}, {"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids}, + + {"tileloadd64_i8", Int(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids}, + {"tdpbssd", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), Int(8, 1024), Int(8, 1024)}, Target::AVX512_SapphireRapids}, + {"tilezero_i32", Int(32, 256), "tile_zero", {Int(16), Int(16)}, Target::AVX512_SapphireRapids}, + // CodeGen_LLVM cannot cope with returning Type() ie void*, and return type needs to be vector to trigger call_overloaded_intrin + {"tilestored64", Bool(2), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Int(32, 256)}, Target::AVX512_SapphireRapids}, + }; // clang-format on @@ -576,6 +590,38 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init CodeGen_Posix::codegen_vector_reduce(op, init); } +void CodeGen_X86::visit(const Allocate *op) { + ScopedBinding bind(mem_type, op->name, op->memory_type); + CodeGen_Posix::visit(op); +} + +void CodeGen_X86::visit(const Load *op) { + if (mem_type.contains(op->name) && mem_type.get(op->name) == MemoryType::AMXTile) { + const Ramp *ramp = op->index.as(); + internal_assert(ramp) << "Expected AMXTile to have index ramp\n"; + Value *ptr = codegen_buffer_pointer(op->name, op->type, ramp->base); + LoadInst *load = builder->CreateAlignedLoad(ptr, llvm::Align(op->type.bytes())); + add_tbaa_metadata(load, op->name, op->index); + value = load; + return; + } + CodeGen_Posix::visit(op); +} + +void CodeGen_X86::visit(const Store *op) { + if (mem_type.contains(op->name) && mem_type.get(op->name) == MemoryType::AMXTile) { + Value *val = codegen(op->value); + Halide::Type value_type = op->value.type(); + const Ramp *ramp = op->index.as(); + internal_assert(ramp) << "Expected AMXTile to have index ramp\n"; + Value *ptr = codegen_buffer_pointer(op->name, value_type, ramp->base); + StoreInst *store = builder->CreateAlignedStore(val, ptr, llvm::Align(value_type.bytes())); + add_tbaa_metadata(store, op->name, op->index); + return; + } + CodeGen_Posix::visit(op); +} + string CodeGen_X86::mcpu() const { if (target.has_feature(Target::AVX512_SapphireRapids)) { #if LLVM_VERSION >= 120 diff --git a/src/Expr.h b/src/Expr.h index c5472e766fa4..b70d608d290b 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -379,6 +379,10 @@ enum class MemoryType { * intermediate buffers. Necessary for vgather-vscatter instructions * on Hexagon */ VTCM, + + /** AMX Tile register for X86. Any data that would be used in an AMX matrix + * multiplication must first be loaded into an AMX tile register. */ + AMXTile, }; namespace Internal { diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp new file mode 100644 index 000000000000..e8edc3532775 --- /dev/null +++ b/src/ExtractTileOperations.cpp @@ -0,0 +1,293 @@ +#include "ExtractTileOperations.h" + +#include "IRMatch.h" // expr_match +#include "IRMutator.h" +#include "IROperator.h" // Expr + Expr +#include "Util.h" // ScopedValue + +namespace Halide { +namespace Internal { + +namespace { + +template +struct Tile { + bool result; + Expr base; + Expr stride[Dim]; + int extent[Dim]; +}; + +const auto wild_i32 = Variable::make(Int(32), "*"); +const auto wild_i32x = Variable::make(Int(32, 0), "*"); + +Tile<2> is_2d_tile_index(const Expr &e) { + // ramp(ramp(base, 1, 4), x4(stride), 4) + std::vector matches; + if (const auto *r1 = e.as()) { + if (const auto *r2 = r1->base.as()) { + auto ramp_2d_pattern = Ramp::make(Ramp::make(wild_i32, wild_i32, r2->lanes), Broadcast::make(wild_i32, r2->lanes), r1->lanes); + if (expr_match(ramp_2d_pattern, e, matches)) { + return {true, std::move(matches[0]), {std::move(matches[2]), std::move(matches[1])}, {r1->lanes, r2->lanes}}; + } + } + } + return {}; +} + +Tile<3> is_3d_tile_index(const Expr &e) { + std::vector matches; + auto add_sub_pattern = (wild_i32x + wild_i32x) - wild_i32x; + if (!expr_match(add_sub_pattern, e, matches)) { return {}; } + // ramp(x16(base), x16(stride), 4) + x16(ramp(idx, 1, 4)) y: 4, x: 4, r: 4 + // ramp(x10(base), x10(stride), 3) + x6(ramp(idx, 1, 5)) y: 2, x: 3, r: 5 + Expr first = std::move(matches[0]); + Expr second = std::move(matches[1]); + Expr adj = std::move(matches[2]); + const auto *r1 = first.as(); + const auto *b2 = second.as(); + if (!r1 && !b2) { + // Try switching the order + r1 = second.as(); + b2 = first.as(); + } + if (!r1 || !b2) { return {}; } + + const auto *b1 = r1->base.as(); + const auto *r2 = b2->value.as(); + + if (!b1 || !r2) { return {}; } + + int x_tile = r1->lanes; + int r_tile = r2->lanes; + int y_tile = b1->lanes / r_tile; + if (y_tile != b2->lanes / x_tile) { return {}; } + + auto pattern1 = Ramp::make(Broadcast::make(wild_i32, b1->lanes), Broadcast::make(wild_i32, b1->lanes), r1->lanes); + if (!expr_match(pattern1, first, matches)) { return {}; } + Expr base = std::move(matches[0]); + Expr x_stride = std::move(matches[1]); + + auto pattern2 = Broadcast::make(Ramp::make(wild_i32, wild_i32, r2->lanes), b2->lanes); + if (!expr_match(pattern2, second, matches)) { return {}; } + base += std::move(matches[0]); + Expr r_stride = std::move(matches[1]); + + auto pattern3 = Broadcast::make(wild_i32, b1->lanes * r1->lanes); + if (!expr_match(pattern3, adj, matches)) { return {}; } + base -= std::move(matches[0]); + + return {true, base, {x_stride, 0, r_stride}, {x_tile, y_tile, r_tile}}; +} + +struct NewMatmul { + bool result = false; + Stmt stmt; + int tile_x; + int tile_y; + int tile_r; +}; + +NewMatmul +convert_to_matmul(const Store *op, const std::string &new_name) { + // m[ramp(0, 1, S)] = VectorAdd(lhs[{XYR tile}] * xX(rhs[{YR tile}])) + m[ramp(0, 1, S)] + const auto wild_i8x = Variable::make(Int(8, 0), "*"); + const auto wild_i16x = Variable::make(Int(16, 0), "*"); + std::vector matches; + const auto pattern1 = wild_i32x + wild_i32x; + if (!expr_match(pattern1, op->value, matches)) { return {}; } + const auto *reduce = matches[0].as(); + const auto *load = matches[1].as(); + if (!reduce || reduce->op != VectorReduce::Add) { return {}; } + if (!load || load->name != op->name || !equal(load->index, op->index)) { return {}; } + + // FIXME: Add support for uint8 and bf16 for LLVM 13+ + auto pattern2 = cast(Int(32, 0), cast(Int(16, 0), wild_i8x) * wild_i16x); + if (!expr_match(pattern2, reduce->value, matches)) { return {}; } + const auto *lhs_load = matches[0].as(); + // FIXME: When tile_r is not 4 the broadcast is inside the index, not of the value + const auto *rhs_broadcast = matches[1].as(); + if (!lhs_load || !rhs_broadcast) { return {}; } + const auto *rhs_cast = rhs_broadcast->value.as(); + if (!rhs_cast || rhs_cast->value.type().element_of() != Int(8)) { return {}; } + const auto *rhs_load = rhs_cast->value.as(); + if (!rhs_load) { return {}; } + + const auto lhs_tile = is_3d_tile_index(lhs_load->index); + const auto rhs_tile = is_2d_tile_index(rhs_load->index); + // FIXME: When tile_r is not 4 the RHS load will be 4D (x, r/4, y, r%4) + if (!lhs_tile.result || !rhs_tile.result) { return {}; } + + const int tile_x = lhs_tile.extent[0]; + const int tile_y = lhs_tile.extent[1]; + const int tile_r = lhs_tile.extent[2]; + const int factor = reduce->value.type().lanes() / reduce->type.lanes(); + if (op->index.type().lanes() != tile_x * tile_y || + factor != tile_r || + tile_y != rhs_tile.extent[0] || + tile_r != rhs_tile.extent[1]) { + return {}; + } + + // {rows, colbytes, var, index} + auto lhs_var = Variable::make(Handle(), lhs_load->name); + auto lhs = Call::make(Int(8, 1024), "tile_load", {tile_x, tile_r, lhs_var, lhs_tile.base, lhs_tile.stride[0]}, Call::Intrinsic); + auto rhs_var = Variable::make(Handle(), rhs_load->name); + auto rhs = Call::make(Int(8, 1024), "tile_load", {1, tile_y * tile_r, rhs_var, rhs_tile.base, rhs_tile.stride[0]}, Call::Intrinsic); + + // {rows, colbytes, acc, out, lhs, rhs} + auto out = Load::make(Int(32, 256), new_name, Ramp::make(0, 1, 256), {}, {}, const_true(256), {}); + auto colbytes = tile_y * 32 / rhs_load->type.bits(); + auto matmul = Call::make(Int(32, 256), "tile_matmul", {tile_x, colbytes, tile_r, out, lhs, rhs}, Call::Intrinsic); + auto store = Store::make(new_name, matmul, Ramp::make(0, 1, 256), Parameter(), const_true(256), ModulusRemainder()); + return {true, std::move(store), tile_x, tile_y, tile_r}; +} + +Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const std::string &new_name) { + if (const auto *ramp = op->index.as()) { + if (const auto *bcast = op->value.as()) { + if (is_const_one(ramp->stride) && + is_const_zero(bcast->value) && + (bcast->lanes == tile_x * tile_y)) { + auto rows = Cast::make(Int(16), tile_x); + auto bytes = op->value.type().bytes(); + auto colbytes = Cast::make(Int(16), tile_y * bytes); + auto val = Call::make(Int(32, 256), "tile_zero", {rows, colbytes}, Call::Intrinsic); + auto store = Store::make(new_name, val, Ramp::make(0, 1, 256), Parameter(), const_true(256), ModulusRemainder()); + return store; + } + } + } + return {}; +} + +Stmt convert_to_tile_store(const Store *op, const std::string &amx_alloc, int tile_x, int tile_y) { + auto tile = is_2d_tile_index(op->index); + if (tile.result && tile.extent[0] == tile_x && tile.extent[1] == tile_y) { + auto out = Variable::make(Handle(), op->name); + auto tile_val = Load::make(Int(32, 256), amx_alloc, Ramp::make(0, 1, 256), {}, {}, const_true(256), {}); + auto bytes = op->value.type().bytes(); + internal_assert(bytes == 4) << "AMX store only supported for int32 and float32, not for " << op->value.type() << "\n"; + // {tile_x, tile_y, var, base, stride} + auto store = Call::make(Bool(2), "tile_store", {tile_x, tile_y * bytes, out, tile.base * bytes, tile.stride[0] * bytes, tile_val}, Call::Intrinsic); + return Evaluate::make(store); + } + return {}; +} + +class ExtractTileOperations : public IRMutator { + using IRMutator::visit; + + std::string tile_name; + std::string amx_alloc; + std::vector pending_stores; + bool is_valid = true; + bool in_allocate = false; + int found_tile_x = -1; + int found_tile_y = -1; + int found_tile_r = -1; + + Stmt visit(const Allocate *op) override { + if (op->type.is_int() && op->type.bits() == 32) { + if (in_allocate) { + // Found two possible tile allocations + // FIXME: Handle this better + is_valid = false; + return op; + } + amx_alloc = op->name + ".amx"; + tile_name = op->name; + ScopedValue old_in_alloc(in_allocate, true); + Stmt body = op->body; + + pending_stores.clear(); + body = mutate(body); + if (!is_valid) { + return op; + } + if (found_tile_x < 0 || found_tile_y < 0 || found_tile_r < 0) { + return op; + } + if (!pending_stores.empty()) { + // Really only need to go over the pending stores + body = mutate(body); + } + if (!is_valid) { + return op; + } + + return Allocate::make(amx_alloc, Int(32, 256), MemoryType::AMXTile, {1}, const_true(), body); + } + return IRMutator::visit(op); + } + + Stmt visit(const Free *op) override { + if (op->name != tile_name) { + return op; + } + return Free::make(amx_alloc); + } + + Stmt visit(const ProducerConsumer *op) override { + if (op->name != tile_name) { + return IRMutator::visit(op); + } + + auto body = mutate(op->body); + return ProducerConsumer::make(amx_alloc, op->is_producer, body); + } + + Stmt visit(const Store *op) override { + if (op->name != tile_name) { + const auto *load = op->value.as(); + if (!load || load->name != tile_name) { + return op; + } + auto store = convert_to_tile_store(op, amx_alloc, found_tile_x, found_tile_y); + if (store.defined()) { + return store; + } else { + // Found store of tile_name that is not a tile store. + is_valid = false; + return op; + } + } + + auto matmul = convert_to_matmul(op, amx_alloc); + if (matmul.result) { + if ((found_tile_x > 0 && matmul.tile_x != found_tile_x) || + (found_tile_r > 0 && matmul.tile_r != found_tile_r) || + (found_tile_y > 0 && matmul.tile_y != found_tile_y)) { + is_valid = false; + return op; + } + found_tile_x = matmul.tile_x; + found_tile_y = matmul.tile_y; + found_tile_r = matmul.tile_r; + return matmul.stmt; + } + + if (found_tile_x < 0 || found_tile_y < 0) { + pending_stores.emplace_back(op); + return op; + } + + auto zero = convert_to_zero(op, found_tile_x, found_tile_y, amx_alloc); + if (zero.defined()) { + return zero; + } + + // Otherwise there is some other operation using the allocation, so we cannot use the AMX instructions + is_valid = false; + return op; + } +}; + +} // namespace + +Stmt extract_tile_operations(const Stmt &s) { + return ExtractTileOperations().mutate(s); +} + +} // namespace Internal +} // namespace Halide diff --git a/src/ExtractTileOperations.h b/src/ExtractTileOperations.h new file mode 100644 index 000000000000..d246bddc5a04 --- /dev/null +++ b/src/ExtractTileOperations.h @@ -0,0 +1,20 @@ +#ifndef HALIDE_EXTRACT_TILE_OPERATIONS_H +#define HALIDE_EXTRACT_TILE_OPERATIONS_H + +/** \file + * Defines the lowering pass that injects calls to tile intrinsics that support + * AMX instructions. + */ + +#include "Expr.h" + +namespace Halide { +namespace Internal { + +/** TODO */ +Stmt extract_tile_operations(const Stmt &s); + +} // namespace Internal +} // namespace Halide + +#endif diff --git a/src/FuseGPUThreadLoops.cpp b/src/FuseGPUThreadLoops.cpp index 6b1798b25528..5d070a051c8a 100644 --- a/src/FuseGPUThreadLoops.cpp +++ b/src/FuseGPUThreadLoops.cpp @@ -1275,6 +1275,7 @@ class InjectThreadBarriers : public IRMutator { case MemoryType::Register: case MemoryType::LockedCache: case MemoryType::VTCM: + case MemoryType::AMXTile: break; } @@ -1299,6 +1300,7 @@ class InjectThreadBarriers : public IRMutator { case MemoryType::Register: case MemoryType::LockedCache: case MemoryType::VTCM: + case MemoryType::AMXTile: break; } diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index 62fdb705997b..38e2eeefd511 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -135,6 +135,9 @@ std::ostream &operator<<(std::ostream &out, const MemoryType &t) { case MemoryType::VTCM: out << "VTCM"; break; + case MemoryType::AMXTile: + out << "AMXTile"; + break; } return out; } diff --git a/src/Lower.cpp b/src/Lower.cpp index 07cb56f3556c..0dff89e3d38e 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -22,6 +22,7 @@ #include "DebugToFile.h" #include "Deinterleave.h" #include "EarlyFree.h" +#include "ExtractTileOperations.h" #include "FindCalls.h" #include "FindIntrinsics.h" #include "FlattenNestedRamps.h" @@ -412,6 +413,15 @@ Module lower(const vector &output_funcs, debug(2) << "Lowering after lowering unsafe promises:\n" << s << "\n\n"; +#if LLVM_VERSION >= 12 + if (t.has_feature(Target::AVX512_SapphireRapids)) { + debug(1) << "Extracting tile operations...\n"; + s = extract_tile_operations(s); + debug(2) << "Lowering after extracting tile operations:\n" + << s << "\n\n"; + } +#endif + debug(1) << "Flattening nested ramps...\n"; s = flatten_nested_ramps(s); debug(2) << "Lowering after flattening nested ramps:\n" diff --git a/src/runtime/x86_avx512.ll b/src/runtime/x86_avx512.ll index 904fabe9368e..5e9ace735bfd 100644 --- a/src/runtime/x86_avx512.ll +++ b/src/runtime/x86_avx512.ll @@ -90,3 +90,41 @@ define weak_odr <4 x i32> @dpwssdx4(<4 x i32> %init, <8 x i16> %a, <8 x i16> %b ret <4 x i32> %3 } declare <4 x i32> @llvm.x86.avx512.vpdpwssd.128(<4 x i32>, <4 x i32>, <4 x i32>) + +define weak_odr <1024 x i8> @tileloadd64_i8(i16 %rows, i16 %colbytes, i8* %ptr, i64 %off, i64 %stride) nounwind alwaysinline { + %1 = getelementptr i8, i8* %ptr, i64 %off + %2 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %rows, i16 %colbytes, i8* %1, i64 %stride) + %3 = bitcast x86_amx %2 to <1024 x i8> + ret <1024 x i8> %3 +} +declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) + +define weak_odr <256 x i32> @tdpbssd(i16 %rows, i16 %colbytes, i16 %acc, <256 x i32> %out, <1024 x i8> %lhs, <1024 x i8> %rhs) nounwind readnone alwaysinline { + %1 = bitcast <1024 x i8> %lhs to x86_amx + %2 = bitcast <1024 x i8> %rhs to x86_amx + %3 = bitcast <256 x i32> %out to x86_amx + %4 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %rows, i16 %colbytes, i16 %acc, x86_amx %3, x86_amx %1, x86_amx %2) nounwind readnone + %5 = bitcast x86_amx %4 to <256 x i32> + ret <256 x i32> %5 +} +declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) + +define weak_odr <2 x i1> @tilestored64(i16 %rows, i16 %cols, i8* %ptr, i64 %off, i64 %stride, <256 x i32> %val) nounwind alwaysinline { + %1 = getelementptr i8, i8* %ptr, i64 %off + %2 = bitcast <256 x i32> %val to x86_amx + tail call void @llvm.x86.tilestored64.internal(i16 %rows, i16 %cols, i8* %1, i64 %stride, x86_amx %2) + ret <2 x i1> zeroinitializer +} +declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) + +; NB: Even though this should be readnone, that will cause LLVM to try to +; generate a single zero tile, and copy it each time it is used. However the AMX +; registers cannot be copied, so this causes compilation failures: +; LLVM ERROR: Cannot emit physreg copy instruction +; renamable $tmm1 = COPY renamable $tmm0 +define weak_odr <256 x i32> @tilezero_i32(i16 %rows, i16 %colbytes) nounwind alwaysinline { + %1 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %rows, i16 %colbytes) + %2 = bitcast x86_amx %1 to <256 x i32> + ret <256 x i32> %2 +} +declare x86_amx @llvm.x86.tilezero.internal(i16, i16) diff --git a/test/performance/CMakeLists.txt b/test/performance/CMakeLists.txt index 65aa41da00f3..80f58a16afae 100644 --- a/test/performance/CMakeLists.txt +++ b/test/performance/CMakeLists.txt @@ -1,5 +1,6 @@ tests(GROUPS performance SOURCES + tiled_matmul.cpp async_gpu.cpp block_transpose.cpp boundary_conditions.cpp diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp new file mode 100644 index 000000000000..b3be78721f10 --- /dev/null +++ b/test/performance/tiled_matmul.cpp @@ -0,0 +1,130 @@ +#include "Halide.h" +#include "halide_benchmark.h" +#include "halide_test_dirs.h" +#include +#include + +using namespace Halide; + +#define FUSE 0 + +int main(int argc, char **argv) { + const int row = 16; + const int col = 16; + const int acc = 16; + + Var x("x"), y("y"); + ImageParam A(Int(8), 2, "lhs"); + ImageParam B(Int(8), 3, "rhs"); + + RDom r(0, acc); + + Func mm("matmul"); + mm(y, x) = cast(0); + mm(y, x) += cast(A(r.x, x)) * B(r.x % 4, y, r.x / 4); + + // Ensure all (x, y) tile sizes are the same so that loops are fused. + int tile_y = 8; + int tile_x = 6; + int tile_r = 4; + + // Schedule the reduction + Var rxi("rxi"), ryi("ryi"), rz("rz"); + RVar rri("rri"), rro("rro"); + mm.compute_at(mm.in(), y) + .update() + // Split into (x,y) tile + .tile(y, x, ryi, rxi, tile_y, tile_x, TailStrategy::GuardWithIf) + // Split reduction dim by tile_r + .split(r.x, rro, rri, tile_r) + // Reorder so that the (x,y) tile is inside the inner ro loop + .reorder({rri, ryi, rxi, rro, y, x}) + .atomic() + .vectorize(rri) + .vectorize(ryi) + .vectorize(rxi); + + // Schedule the initialization + Var ixi("ixi"), iyi("iyi"); + mm.compute_at(mm.in(), y) + .tile(y, x, iyi, ixi, tile_y, tile_x) + .vectorize(iyi) + .vectorize(ixi); + + // Schedule the consumer + Var mmxi("mmxi"), mmyi("mmyi"), mmz("mmz"); + mm.in() + .tile(y, x, mmyi, mmxi, tile_y, tile_x) + .vectorize(mmyi) + .vectorize(mmxi); + + int count = 1; + Buffer a_buf(acc, row); + for (int iy = 0; iy < row; iy++) { + for (int ix = 0; ix < acc; ix++) { + a_buf(ix, iy) = count++; //rand() % 256 - 128; + } + } + A.set(a_buf); + + Buffer b_buf(4, col, acc / 4); + count = 1; + for (int iy = 0; iy < acc / 4; iy++) { + for (int ix = 0; ix < col; ix++) { + for (int ik = 0; ik < 4; ++ik) { + b_buf(ik, ix, iy) = count++; //rand() % 256 - 128; + } + } + } + B.set(b_buf); + + Buffer out(col, row); + + Func result = mm.in(); + + // Uncomment to check the asm + Target target = get_jit_target_from_environment(); + result.compile_to_llvm_assembly("matmul.ll", {A, B}, target); + //result.compile_to_assembly("matmul.s", {A, B}, target); + + auto time = Tools::benchmark(20, 20, [&]() { + result.realize(out); + }); + std::cout << "Exec time: " << time << "\n"; + + for (int i = 0; i < row; ++i) { + for (int j = 0; j < acc; ++j) { + std::cout << std::setw(4) << (int)a_buf(j, i) << " "; + } + std::cout << "\n"; + } + std::cout << "\n\n*\n\n"; + for (int i = 0; i < acc; ++i) { + for (int j = 0; j < col; ++j) { + std::cout << std::setw(4) << (int)b_buf(i % 4, j, i / 4) << " "; + } + std::cout << "\n"; + } + std::cout << "\n\n=\n\n"; + for (int i = 0; i < row; ++i) { + for (int j = 0; j < col; ++j) { + std::cout << std::setw(6) << out(j, i) << " "; + } + std::cout << "\n"; + } + + for (int j = 0; j < row; ++j) { + for (int i = 0; i < col; ++i) { + int32_t val = 0; + for (int k = 0; k < acc; ++k) { + val += a_buf(k, j) * b_buf(k % 4, i, k / 4); + } + if (val != out(i, j)) { + std::cerr << "Invalid result at " << i << ", " << j << "\n" + << out(i, j) << " != " << val << "\n"; + return 1; + } + } + } + return 0; +} From 758ab0700105a24db741fba9bfd1fb8a1ce82513 Mon Sep 17 00:00:00 2001 From: John Lawson Date: Tue, 9 Mar 2021 15:42:28 +0000 Subject: [PATCH 02/11] Make AMX transform opt-in with memory type --- src/ExtractTileOperations.cpp | 4 +++- test/performance/tiled_matmul.cpp | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index e8edc3532775..5d2799b82655 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -188,7 +188,9 @@ class ExtractTileOperations : public IRMutator { int found_tile_r = -1; Stmt visit(const Allocate *op) override { - if (op->type.is_int() && op->type.bits() == 32) { + if (op->memory_type == MemoryType::AMXTile && + op->type.is_int() && + op->type.bits() == 32) { if (in_allocate) { // Found two possible tile allocations // FIXME: Handle this better diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index b3be78721f10..fe250857417e 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -32,6 +32,7 @@ int main(int argc, char **argv) { Var rxi("rxi"), ryi("ryi"), rz("rz"); RVar rri("rri"), rro("rro"); mm.compute_at(mm.in(), y) + .store_in(MemoryType::AMXTile) .update() // Split into (x,y) tile .tile(y, x, ryi, rxi, tile_y, tile_x, TailStrategy::GuardWithIf) From 0f30587b350058d77ee73f396c2766337f5660c8 Mon Sep 17 00:00:00 2001 From: John Lawson Date: Wed, 10 Mar 2021 14:50:59 +0000 Subject: [PATCH 03/11] Clean up tiled_matmul test --- test/performance/tiled_matmul.cpp | 45 ++++++++----------------------- 1 file changed, 11 insertions(+), 34 deletions(-) diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index fe250857417e..92d545ff486a 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -1,13 +1,9 @@ #include "Halide.h" #include "halide_benchmark.h" -#include "halide_test_dirs.h" #include -#include using namespace Halide; -#define FUSE 0 - int main(int argc, char **argv) { const int row = 16; const int col = 16; @@ -15,6 +11,10 @@ int main(int argc, char **argv) { Var x("x"), y("y"); ImageParam A(Int(8), 2, "lhs"); + // NB the RHS matrix in AMX instructions should be tiled in "VNNI format", + // where instead of being (cols, rows) where rows are adjacent in memory it + // should be (4, cols, rows / 4) for int8, or (2, cols, rows / 2) for bf16. + // This means that the rows must always be divisible by 4 (or 2 for bf16). ImageParam B(Int(8), 3, "rhs"); RDom r(0, acc); @@ -29,7 +29,7 @@ int main(int argc, char **argv) { int tile_r = 4; // Schedule the reduction - Var rxi("rxi"), ryi("ryi"), rz("rz"); + Var rxi("rxi"), ryi("ryi"); RVar rri("rri"), rro("rro"); mm.compute_at(mm.in(), y) .store_in(MemoryType::AMXTile) @@ -53,27 +53,25 @@ int main(int argc, char **argv) { .vectorize(ixi); // Schedule the consumer - Var mmxi("mmxi"), mmyi("mmyi"), mmz("mmz"); + Var mmxi("mmxi"), mmyi("mmyi"); mm.in() .tile(y, x, mmyi, mmxi, tile_y, tile_x) .vectorize(mmyi) .vectorize(mmxi); - int count = 1; Buffer a_buf(acc, row); for (int iy = 0; iy < row; iy++) { for (int ix = 0; ix < acc; ix++) { - a_buf(ix, iy) = count++; //rand() % 256 - 128; + a_buf(ix, iy) = rand() % 256 - 128; } } A.set(a_buf); Buffer b_buf(4, col, acc / 4); - count = 1; for (int iy = 0; iy < acc / 4; iy++) { for (int ix = 0; ix < col; ix++) { for (int ik = 0; ik < 4; ++ik) { - b_buf(ik, ix, iy) = count++; //rand() % 256 - 128; + b_buf(ik, ix, iy) = rand() % 256 - 128; } } } @@ -84,36 +82,15 @@ int main(int argc, char **argv) { Func result = mm.in(); // Uncomment to check the asm - Target target = get_jit_target_from_environment(); - result.compile_to_llvm_assembly("matmul.ll", {A, B}, target); - //result.compile_to_assembly("matmul.s", {A, B}, target); + //Target target = get_jit_target_from_environment(); + //result.compile_to_llvm_assembly("tiled_matmul.ll", {A, B}, target); + //result.compile_to_assembly("tiled_matmul.s", {A, B}, target); auto time = Tools::benchmark(20, 20, [&]() { result.realize(out); }); std::cout << "Exec time: " << time << "\n"; - for (int i = 0; i < row; ++i) { - for (int j = 0; j < acc; ++j) { - std::cout << std::setw(4) << (int)a_buf(j, i) << " "; - } - std::cout << "\n"; - } - std::cout << "\n\n*\n\n"; - for (int i = 0; i < acc; ++i) { - for (int j = 0; j < col; ++j) { - std::cout << std::setw(4) << (int)b_buf(i % 4, j, i / 4) << " "; - } - std::cout << "\n"; - } - std::cout << "\n\n=\n\n"; - for (int i = 0; i < row; ++i) { - for (int j = 0; j < col; ++j) { - std::cout << std::setw(6) << out(j, i) << " "; - } - std::cout << "\n"; - } - for (int j = 0; j < row; ++j) { for (int i = 0; i < col; ++i) { int32_t val = 0; From 89b5c9e8396eaaf398718205816b67a3f5c49909 Mon Sep 17 00:00:00 2001 From: John Lawson Date: Tue, 9 Mar 2021 16:27:01 +0000 Subject: [PATCH 04/11] Handle AMX intrinsic attributes better --- src/CodeGen_X86.cpp | 12 +++++++++--- src/runtime/x86_avx512.ll | 12 ++++++------ 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 0aa64e089845..e0e34d163f04 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -102,6 +102,10 @@ struct x86Intrinsic { const char *name; halide_type_t arg_types[max_intrinsic_args]; Target::Feature feature = Target::FeatureEnd; + uint32_t flags = 0; + enum Options { + AccessesMemory = 1 << 0, + }; }; // clang-format off @@ -192,11 +196,11 @@ const x86Intrinsic intrinsic_defs[] = { {"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids}, {"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids}, - {"tileloadd64_i8", Int(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids}, + {"tileloadd64_i8", Int(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, {"tdpbssd", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), Int(8, 1024), Int(8, 1024)}, Target::AVX512_SapphireRapids}, {"tilezero_i32", Int(32, 256), "tile_zero", {Int(16), Int(16)}, Target::AVX512_SapphireRapids}, // CodeGen_LLVM cannot cope with returning Type() ie void*, and return type needs to be vector to trigger call_overloaded_intrin - {"tilestored64", Bool(2), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Int(32, 256)}, Target::AVX512_SapphireRapids}, + {"tilestored64", Bool(2), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Int(32, 256)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, }; // clang-format on @@ -220,7 +224,9 @@ void CodeGen_X86::init_module() { } auto *fn = declare_intrin_overload(i.name, ret_type, i.intrin_name, std::move(arg_types)); - fn->addFnAttr(llvm::Attribute::ReadNone); + if((i.flags & x86Intrinsic::AccessesMemory) == 0) { + fn->addFnAttr(llvm::Attribute::ReadNone); + } fn->addFnAttr(llvm::Attribute::NoUnwind); } } diff --git a/src/runtime/x86_avx512.ll b/src/runtime/x86_avx512.ll index 5e9ace735bfd..7217caeb9f01 100644 --- a/src/runtime/x86_avx512.ll +++ b/src/runtime/x86_avx512.ll @@ -91,15 +91,15 @@ define weak_odr <4 x i32> @dpwssdx4(<4 x i32> %init, <8 x i16> %a, <8 x i16> %b } declare <4 x i32> @llvm.x86.avx512.vpdpwssd.128(<4 x i32>, <4 x i32>, <4 x i32>) -define weak_odr <1024 x i8> @tileloadd64_i8(i16 %rows, i16 %colbytes, i8* %ptr, i64 %off, i64 %stride) nounwind alwaysinline { +define weak_odr <1024 x i8> @tileloadd64_i8(i16 %rows, i16 %colbytes, i8* %ptr, i64 %off, i64 %stride) nounwind alwaysinline readonly { %1 = getelementptr i8, i8* %ptr, i64 %off - %2 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %rows, i16 %colbytes, i8* %1, i64 %stride) + %2 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %rows, i16 %colbytes, i8* %1, i64 %stride) nounwind readonly %3 = bitcast x86_amx %2 to <1024 x i8> ret <1024 x i8> %3 } declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) -define weak_odr <256 x i32> @tdpbssd(i16 %rows, i16 %colbytes, i16 %acc, <256 x i32> %out, <1024 x i8> %lhs, <1024 x i8> %rhs) nounwind readnone alwaysinline { +define weak_odr <256 x i32> @tdpbssd(i16 %rows, i16 %colbytes, i16 %acc, <256 x i32> %out, <1024 x i8> %lhs, <1024 x i8> %rhs) nounwind alwaysinline readnone { %1 = bitcast <1024 x i8> %lhs to x86_amx %2 = bitcast <1024 x i8> %rhs to x86_amx %3 = bitcast <256 x i32> %out to x86_amx @@ -109,10 +109,10 @@ define weak_odr <256 x i32> @tdpbssd(i16 %rows, i16 %colbytes, i16 %acc, <256 x } declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) -define weak_odr <2 x i1> @tilestored64(i16 %rows, i16 %cols, i8* %ptr, i64 %off, i64 %stride, <256 x i32> %val) nounwind alwaysinline { +define weak_odr <2 x i1> @tilestored64(i16 %rows, i16 %cols, i8* %ptr, i64 %off, i64 %stride, <256 x i32> %val) nounwind alwaysinline writeonly { %1 = getelementptr i8, i8* %ptr, i64 %off %2 = bitcast <256 x i32> %val to x86_amx - tail call void @llvm.x86.tilestored64.internal(i16 %rows, i16 %cols, i8* %1, i64 %stride, x86_amx %2) + tail call void @llvm.x86.tilestored64.internal(i16 %rows, i16 %cols, i8* %1, i64 %stride, x86_amx %2) nounwind writeonly ret <2 x i1> zeroinitializer } declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) @@ -123,7 +123,7 @@ declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) ; LLVM ERROR: Cannot emit physreg copy instruction ; renamable $tmm1 = COPY renamable $tmm0 define weak_odr <256 x i32> @tilezero_i32(i16 %rows, i16 %colbytes) nounwind alwaysinline { - %1 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %rows, i16 %colbytes) + %1 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %rows, i16 %colbytes) nounwind %2 = bitcast x86_amx %1 to <256 x i32> ret <256 x i32> %2 } From 1d6e94ea57509f3e8b6895921a7f936f395c1c22 Mon Sep 17 00:00:00 2001 From: John Lawson Date: Wed, 10 Mar 2021 15:16:24 +0000 Subject: [PATCH 05/11] Format --- src/CodeGen_X86.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index e0e34d163f04..a8c7cdb12f05 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -1,5 +1,4 @@ #include "CodeGen_Posix.h" - #include "ConciseCasts.h" #include "Debug.h" #include "IRMatch.h" @@ -224,8 +223,8 @@ void CodeGen_X86::init_module() { } auto *fn = declare_intrin_overload(i.name, ret_type, i.intrin_name, std::move(arg_types)); - if((i.flags & x86Intrinsic::AccessesMemory) == 0) { - fn->addFnAttr(llvm::Attribute::ReadNone); + if ((i.flags & x86Intrinsic::AccessesMemory) == 0) { + fn->addFnAttr(llvm::Attribute::ReadNone); } fn->addFnAttr(llvm::Attribute::NoUnwind); } From 427ff33bf26d992ec11987326a06697065065522 Mon Sep 17 00:00:00 2001 From: John Lawson Date: Thu, 11 Mar 2021 12:30:17 +0000 Subject: [PATCH 06/11] Fix test to behave like other tests --- test/performance/tiled_matmul.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index 92d545ff486a..7670ecb7d642 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -1,6 +1,9 @@ #include "Halide.h" #include "halide_benchmark.h" +#include "halide_test_dirs.h" + #include +#include using namespace Halide; @@ -83,8 +86,8 @@ int main(int argc, char **argv) { // Uncomment to check the asm //Target target = get_jit_target_from_environment(); - //result.compile_to_llvm_assembly("tiled_matmul.ll", {A, B}, target); - //result.compile_to_assembly("tiled_matmul.s", {A, B}, target); + //result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.ll", {A, B}, target); + //result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, target); auto time = Tools::benchmark(20, 20, [&]() { result.realize(out); @@ -104,5 +107,6 @@ int main(int argc, char **argv) { } } } + std::cout << "Success!\n"; return 0; } From 644af1029eb56034fc8c387942962f4680257b5c Mon Sep 17 00:00:00 2001 From: John Lawson Date: Thu, 11 Mar 2021 13:21:27 +0000 Subject: [PATCH 07/11] Add doc and missing load check --- src/ExtractTileOperations.cpp | 9 +++++++++ src/ExtractTileOperations.h | 3 ++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 5d2799b82655..641cfa3623ca 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -239,6 +239,15 @@ class ExtractTileOperations : public IRMutator { return ProducerConsumer::make(amx_alloc, op->is_producer, body); } + Expr visit(const Load* op) override { + if (op->name == tile_name) { + // Any tile load will be matched elsewhere, so a load here means that + // the AMX tile is used outside of a tile instruction. + is_valid = false; + } + return IRMutator::visit(op); + } + Stmt visit(const Store *op) override { if (op->name != tile_name) { const auto *load = op->value.as(); diff --git a/src/ExtractTileOperations.h b/src/ExtractTileOperations.h index d246bddc5a04..918e3b1b9940 100644 --- a/src/ExtractTileOperations.h +++ b/src/ExtractTileOperations.h @@ -11,7 +11,8 @@ namespace Halide { namespace Internal { -/** TODO */ +/** Rewrite any AMX tile operations that have been stored in the AMXTile memory + * type as intrinsic calls, to be used in the X86 backend. */ Stmt extract_tile_operations(const Stmt &s); } // namespace Internal From c1315864f8acd2489363885e58c0a3a5848bd421 Mon Sep 17 00:00:00 2001 From: John Lawson Date: Thu, 11 Mar 2021 14:50:58 +0000 Subject: [PATCH 08/11] Format --- src/ExtractTileOperations.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 641cfa3623ca..731f028c5835 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -239,11 +239,11 @@ class ExtractTileOperations : public IRMutator { return ProducerConsumer::make(amx_alloc, op->is_producer, body); } - Expr visit(const Load* op) override { + Expr visit(const Load *op) override { if (op->name == tile_name) { - // Any tile load will be matched elsewhere, so a load here means that - // the AMX tile is used outside of a tile instruction. - is_valid = false; + // Any tile load will be matched elsewhere, so a load here means that + // the AMX tile is used outside of a tile instruction. + is_valid = false; } return IRMutator::visit(op); } From d954996815d6c7631d2e86df3afaade898734299 Mon Sep 17 00:00:00 2001 From: John Lawson Date: Fri, 12 Mar 2021 12:41:34 +0000 Subject: [PATCH 09/11] Throw error if user requests AMX for invalid operation --- src/ExtractTileOperations.cpp | 84 +++++++++++++------------------ test/performance/tiled_matmul.cpp | 6 ++- 2 files changed, 39 insertions(+), 51 deletions(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 731f028c5835..62d877aac354 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -8,6 +8,9 @@ namespace Halide { namespace Internal { +using std::string; +using std::vector; + namespace { template @@ -23,7 +26,7 @@ const auto wild_i32x = Variable::make(Int(32, 0), "*"); Tile<2> is_2d_tile_index(const Expr &e) { // ramp(ramp(base, 1, 4), x4(stride), 4) - std::vector matches; + vector matches; if (const auto *r1 = e.as()) { if (const auto *r2 = r1->base.as()) { auto ramp_2d_pattern = Ramp::make(Ramp::make(wild_i32, wild_i32, r2->lanes), Broadcast::make(wild_i32, r2->lanes), r1->lanes); @@ -36,7 +39,7 @@ Tile<2> is_2d_tile_index(const Expr &e) { } Tile<3> is_3d_tile_index(const Expr &e) { - std::vector matches; + vector matches; auto add_sub_pattern = (wild_i32x + wild_i32x) - wild_i32x; if (!expr_match(add_sub_pattern, e, matches)) { return {}; } // ramp(x16(base), x16(stride), 4) + x16(ramp(idx, 1, 4)) y: 4, x: 4, r: 4 @@ -89,11 +92,11 @@ struct NewMatmul { }; NewMatmul -convert_to_matmul(const Store *op, const std::string &new_name) { +convert_to_matmul(const Store *op, const string &new_name) { // m[ramp(0, 1, S)] = VectorAdd(lhs[{XYR tile}] * xX(rhs[{YR tile}])) + m[ramp(0, 1, S)] const auto wild_i8x = Variable::make(Int(8, 0), "*"); const auto wild_i16x = Variable::make(Int(16, 0), "*"); - std::vector matches; + vector matches; const auto pattern1 = wild_i32x + wild_i32x; if (!expr_match(pattern1, op->value, matches)) { return {}; } const auto *reduce = matches[0].as(); @@ -143,7 +146,7 @@ convert_to_matmul(const Store *op, const std::string &new_name) { return {true, std::move(store), tile_x, tile_y, tile_r}; } -Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const std::string &new_name) { +Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const string &new_name) { if (const auto *ramp = op->index.as()) { if (const auto *bcast = op->value.as()) { if (is_const_one(ramp->stride) && @@ -161,11 +164,11 @@ Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const std::string return {}; } -Stmt convert_to_tile_store(const Store *op, const std::string &amx_alloc, int tile_x, int tile_y) { +Stmt convert_to_tile_store(const Store *op, const string &amx_name, int tile_x, int tile_y) { auto tile = is_2d_tile_index(op->index); if (tile.result && tile.extent[0] == tile_x && tile.extent[1] == tile_y) { auto out = Variable::make(Handle(), op->name); - auto tile_val = Load::make(Int(32, 256), amx_alloc, Ramp::make(0, 1, 256), {}, {}, const_true(256), {}); + auto tile_val = Load::make(Int(32, 256), amx_name, Ramp::make(0, 1, 256), {}, {}, const_true(256), {}); auto bytes = op->value.type().bytes(); internal_assert(bytes == 4) << "AMX store only supported for int32 and float32, not for " << op->value.type() << "\n"; // {tile_x, tile_y, var, base, stride} @@ -178,10 +181,9 @@ Stmt convert_to_tile_store(const Store *op, const std::string &amx_alloc, int ti class ExtractTileOperations : public IRMutator { using IRMutator::visit; - std::string tile_name; - std::string amx_alloc; - std::vector pending_stores; - bool is_valid = true; + string tile_name; + string amx_name; + vector pending_stores; bool in_allocate = false; int found_tile_x = -1; int found_tile_y = -1; @@ -191,22 +193,15 @@ class ExtractTileOperations : public IRMutator { if (op->memory_type == MemoryType::AMXTile && op->type.is_int() && op->type.bits() == 32) { - if (in_allocate) { - // Found two possible tile allocations - // FIXME: Handle this better - is_valid = false; - return op; - } - amx_alloc = op->name + ".amx"; - tile_name = op->name; + // FIXME: Handle nested allocations better + user_assert(!in_allocate) << "Found two possible tile allocations for AMX allocation"; + ScopedValue old_amx_name(amx_name, op->name + ".amx"); + ScopedValue old_tile_name(tile_name, op->name); ScopedValue old_in_alloc(in_allocate, true); Stmt body = op->body; pending_stores.clear(); body = mutate(body); - if (!is_valid) { - return op; - } if (found_tile_x < 0 || found_tile_y < 0 || found_tile_r < 0) { return op; } @@ -214,11 +209,8 @@ class ExtractTileOperations : public IRMutator { // Really only need to go over the pending stores body = mutate(body); } - if (!is_valid) { - return op; - } - return Allocate::make(amx_alloc, Int(32, 256), MemoryType::AMXTile, {1}, const_true(), body); + return Allocate::make(amx_name, Int(32, 256), MemoryType::AMXTile, {1}, const_true(), body); } return IRMutator::visit(op); } @@ -227,7 +219,7 @@ class ExtractTileOperations : public IRMutator { if (op->name != tile_name) { return op; } - return Free::make(amx_alloc); + return Free::make(amx_name); } Stmt visit(const ProducerConsumer *op) override { @@ -236,15 +228,13 @@ class ExtractTileOperations : public IRMutator { } auto body = mutate(op->body); - return ProducerConsumer::make(amx_alloc, op->is_producer, body); + return ProducerConsumer::make(amx_name, op->is_producer, body); } Expr visit(const Load *op) override { - if (op->name == tile_name) { - // Any tile load will be matched elsewhere, so a load here means that - // the AMX tile is used outside of a tile instruction. - is_valid = false; - } + // Any tile load will be matched elsewhere, so a load here means that + // the AMX tile is used outside of a tile instruction. + user_assert(op->name != tile_name) << "AMX tile allocation used outside a tile instruction"; return IRMutator::visit(op); } @@ -254,24 +244,18 @@ class ExtractTileOperations : public IRMutator { if (!load || load->name != tile_name) { return op; } - auto store = convert_to_tile_store(op, amx_alloc, found_tile_x, found_tile_y); - if (store.defined()) { - return store; - } else { - // Found store of tile_name that is not a tile store. - is_valid = false; - return op; - } + auto store = convert_to_tile_store(op, amx_name, found_tile_x, found_tile_y); + user_assert(store.defined()) << "Store to AMX tile allocation of a non-tile value"; + return store; } - auto matmul = convert_to_matmul(op, amx_alloc); + auto matmul = convert_to_matmul(op, amx_name); if (matmul.result) { - if ((found_tile_x > 0 && matmul.tile_x != found_tile_x) || - (found_tile_r > 0 && matmul.tile_r != found_tile_r) || - (found_tile_y > 0 && matmul.tile_y != found_tile_y)) { - is_valid = false; - return op; - } + user_assert( + (found_tile_x < 0 || matmul.tile_x == found_tile_x) && + (found_tile_x < 0 || matmul.tile_x == found_tile_x) && + (found_tile_x < 0 || matmul.tile_x == found_tile_x)) + << "Found different tile sizes for AMX tile allocation"; found_tile_x = matmul.tile_x; found_tile_y = matmul.tile_y; found_tile_r = matmul.tile_r; @@ -283,13 +267,13 @@ class ExtractTileOperations : public IRMutator { return op; } - auto zero = convert_to_zero(op, found_tile_x, found_tile_y, amx_alloc); + auto zero = convert_to_zero(op, found_tile_x, found_tile_y, amx_name); if (zero.defined()) { return zero; } // Otherwise there is some other operation using the allocation, so we cannot use the AMX instructions - is_valid = false; + user_assert(false) << "Found non-tile operations for AMX tile allocation"; return op; } }; diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index 7670ecb7d642..fbcd4b292942 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -8,6 +8,11 @@ using namespace Halide; int main(int argc, char **argv) { + Target target = get_jit_target_from_environment(); + if (!target.has_feature(Target::AVX512_SapphireRapids)) { + std::cout << "[SKIP] The tiled matmul test is only designed to test AMX support.\n"; + return 0; + } const int row = 16; const int col = 16; const int acc = 16; @@ -85,7 +90,6 @@ int main(int argc, char **argv) { Func result = mm.in(); // Uncomment to check the asm - //Target target = get_jit_target_from_environment(); //result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.ll", {A, B}, target); //result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, target); From 06714a587dfa88d18042d4e6ca479fea04562825 Mon Sep 17 00:00:00 2001 From: John Lawson Date: Fri, 12 Mar 2021 14:43:59 +0000 Subject: [PATCH 10/11] Add Tile lowering pass to makefile --- Makefile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Makefile b/Makefile index 1e80eba0a4a5..30fc4f531273 100644 --- a/Makefile +++ b/Makefile @@ -454,6 +454,7 @@ SOURCE_FILES = \ EmulateFloat16Math.cpp \ Error.cpp \ Expr.cpp \ + ExtractTileOperations.cpp \ FastIntegerDivide.cpp \ FindCalls.cpp \ FindIntrinsics.cpp \ @@ -626,6 +627,7 @@ HEADER_FILES = \ ExprUsesVar.h \ Extern.h \ ExternFuncArgument.h \ + ExtractTileOperations.h \ FastIntegerDivide.h \ FindCalls.h \ FindIntrinsics.h \ From 55db851bc50ba1ab2ca5a0e64818fc1f99b52be9 Mon Sep 17 00:00:00 2001 From: John Lawson Date: Fri, 12 Mar 2021 15:15:35 +0000 Subject: [PATCH 11/11] Use spaces in Makefile --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 30fc4f531273..ca4802069fa4 100644 --- a/Makefile +++ b/Makefile @@ -454,7 +454,7 @@ SOURCE_FILES = \ EmulateFloat16Math.cpp \ Error.cpp \ Expr.cpp \ - ExtractTileOperations.cpp \ + ExtractTileOperations.cpp \ FastIntegerDivide.cpp \ FindCalls.cpp \ FindIntrinsics.cpp \ @@ -627,7 +627,7 @@ HEADER_FILES = \ ExprUsesVar.h \ Extern.h \ ExternFuncArgument.h \ - ExtractTileOperations.h \ + ExtractTileOperations.h \ FastIntegerDivide.h \ FindCalls.h \ FindIntrinsics.h \