diff --git a/Makefile b/Makefile index 10bbed901d53..9ce0f38a7a42 100644 --- a/Makefile +++ b/Makefile @@ -454,6 +454,7 @@ SOURCE_FILES = \ EmulateFloat16Math.cpp \ Error.cpp \ Expr.cpp \ + ExtractTileOperations.cpp \ FastIntegerDivide.cpp \ FindCalls.cpp \ FindIntrinsics.cpp \ @@ -626,6 +627,7 @@ HEADER_FILES = \ ExprUsesVar.h \ Extern.h \ ExternFuncArgument.h \ + ExtractTileOperations.h \ FastIntegerDivide.h \ FindCalls.h \ FindIntrinsics.h \ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2ffdf85a51d4..05d9d4ab2a51 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -59,6 +59,7 @@ set(HEADER_FILES ExprUsesVar.h Extern.h ExternFuncArgument.h + ExtractTileOperations.h FastIntegerDivide.h FindCalls.h FindIntrinsics.h @@ -218,6 +219,7 @@ set(SOURCE_FILES EmulateFloat16Math.cpp Error.cpp Expr.cpp + ExtractTileOperations.cpp FastIntegerDivide.cpp FindCalls.cpp FindIntrinsics.cpp diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 2038dcce75c8..a8c7cdb12f05 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -79,15 +79,21 @@ class CodeGen_X86 : public CodeGen_Posix { void visit(const EQ *) override; void visit(const NE *) override; void visit(const Select *) override; + void visit(const Allocate *) override; + void visit(const Load *) override; + void visit(const Store *) override; void codegen_vector_reduce(const VectorReduce *, const Expr &init) override; // @} + +private: + Scope mem_type; }; CodeGen_X86::CodeGen_X86(Target t) : CodeGen_Posix(complete_x86_target(t)) { } -const int max_intrinsic_args = 4; +const int max_intrinsic_args = 6; struct x86Intrinsic { const char *intrin_name; @@ -95,6 +101,10 @@ struct x86Intrinsic { const char *name; halide_type_t arg_types[max_intrinsic_args]; Target::Feature feature = Target::FeatureEnd; + uint32_t flags = 0; + enum Options { + AccessesMemory = 1 << 0, + }; }; // clang-format off @@ -184,6 +194,13 @@ const x86Intrinsic intrinsic_defs[] = { {"dpwssdx16", Int(32, 16), "dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_SapphireRapids}, {"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids}, {"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids}, + + {"tileloadd64_i8", Int(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, + {"tdpbssd", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), Int(8, 1024), Int(8, 1024)}, Target::AVX512_SapphireRapids}, + {"tilezero_i32", Int(32, 256), "tile_zero", {Int(16), Int(16)}, Target::AVX512_SapphireRapids}, + // CodeGen_LLVM cannot cope with returning Type() ie void*, and return type needs to be vector to trigger call_overloaded_intrin + {"tilestored64", Bool(2), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Int(32, 256)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, + }; // clang-format on @@ -206,7 +223,9 @@ void CodeGen_X86::init_module() { } auto *fn = declare_intrin_overload(i.name, ret_type, i.intrin_name, std::move(arg_types)); - fn->addFnAttr(llvm::Attribute::ReadNone); + if ((i.flags & x86Intrinsic::AccessesMemory) == 0) { + fn->addFnAttr(llvm::Attribute::ReadNone); + } fn->addFnAttr(llvm::Attribute::NoUnwind); } } @@ -576,6 +595,38 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init CodeGen_Posix::codegen_vector_reduce(op, init); } +void CodeGen_X86::visit(const Allocate *op) { + ScopedBinding bind(mem_type, op->name, op->memory_type); + CodeGen_Posix::visit(op); +} + +void CodeGen_X86::visit(const Load *op) { + if (mem_type.contains(op->name) && mem_type.get(op->name) == MemoryType::AMXTile) { + const Ramp *ramp = op->index.as(); + internal_assert(ramp) << "Expected AMXTile to have index ramp\n"; + Value *ptr = codegen_buffer_pointer(op->name, op->type, ramp->base); + LoadInst *load = builder->CreateAlignedLoad(ptr, llvm::Align(op->type.bytes())); + add_tbaa_metadata(load, op->name, op->index); + value = load; + return; + } + CodeGen_Posix::visit(op); +} + +void CodeGen_X86::visit(const Store *op) { + if (mem_type.contains(op->name) && mem_type.get(op->name) == MemoryType::AMXTile) { + Value *val = codegen(op->value); + Halide::Type value_type = op->value.type(); + const Ramp *ramp = op->index.as(); + internal_assert(ramp) << "Expected AMXTile to have index ramp\n"; + Value *ptr = codegen_buffer_pointer(op->name, value_type, ramp->base); + StoreInst *store = builder->CreateAlignedStore(val, ptr, llvm::Align(value_type.bytes())); + add_tbaa_metadata(store, op->name, op->index); + return; + } + CodeGen_Posix::visit(op); +} + string CodeGen_X86::mcpu() const { if (target.has_feature(Target::AVX512_SapphireRapids)) { #if LLVM_VERSION >= 120 diff --git a/src/Expr.h b/src/Expr.h index c5472e766fa4..b70d608d290b 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -379,6 +379,10 @@ enum class MemoryType { * intermediate buffers. Necessary for vgather-vscatter instructions * on Hexagon */ VTCM, + + /** AMX Tile register for X86. Any data that would be used in an AMX matrix + * multiplication must first be loaded into an AMX tile register. */ + AMXTile, }; namespace Internal { diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp new file mode 100644 index 000000000000..62d877aac354 --- /dev/null +++ b/src/ExtractTileOperations.cpp @@ -0,0 +1,288 @@ +#include "ExtractTileOperations.h" + +#include "IRMatch.h" // expr_match +#include "IRMutator.h" +#include "IROperator.h" // Expr + Expr +#include "Util.h" // ScopedValue + +namespace Halide { +namespace Internal { + +using std::string; +using std::vector; + +namespace { + +template +struct Tile { + bool result; + Expr base; + Expr stride[Dim]; + int extent[Dim]; +}; + +const auto wild_i32 = Variable::make(Int(32), "*"); +const auto wild_i32x = Variable::make(Int(32, 0), "*"); + +Tile<2> is_2d_tile_index(const Expr &e) { + // ramp(ramp(base, 1, 4), x4(stride), 4) + vector matches; + if (const auto *r1 = e.as()) { + if (const auto *r2 = r1->base.as()) { + auto ramp_2d_pattern = Ramp::make(Ramp::make(wild_i32, wild_i32, r2->lanes), Broadcast::make(wild_i32, r2->lanes), r1->lanes); + if (expr_match(ramp_2d_pattern, e, matches)) { + return {true, std::move(matches[0]), {std::move(matches[2]), std::move(matches[1])}, {r1->lanes, r2->lanes}}; + } + } + } + return {}; +} + +Tile<3> is_3d_tile_index(const Expr &e) { + vector matches; + auto add_sub_pattern = (wild_i32x + wild_i32x) - wild_i32x; + if (!expr_match(add_sub_pattern, e, matches)) { return {}; } + // ramp(x16(base), x16(stride), 4) + x16(ramp(idx, 1, 4)) y: 4, x: 4, r: 4 + // ramp(x10(base), x10(stride), 3) + x6(ramp(idx, 1, 5)) y: 2, x: 3, r: 5 + Expr first = std::move(matches[0]); + Expr second = std::move(matches[1]); + Expr adj = std::move(matches[2]); + const auto *r1 = first.as(); + const auto *b2 = second.as(); + if (!r1 && !b2) { + // Try switching the order + r1 = second.as(); + b2 = first.as(); + } + if (!r1 || !b2) { return {}; } + + const auto *b1 = r1->base.as(); + const auto *r2 = b2->value.as(); + + if (!b1 || !r2) { return {}; } + + int x_tile = r1->lanes; + int r_tile = r2->lanes; + int y_tile = b1->lanes / r_tile; + if (y_tile != b2->lanes / x_tile) { return {}; } + + auto pattern1 = Ramp::make(Broadcast::make(wild_i32, b1->lanes), Broadcast::make(wild_i32, b1->lanes), r1->lanes); + if (!expr_match(pattern1, first, matches)) { return {}; } + Expr base = std::move(matches[0]); + Expr x_stride = std::move(matches[1]); + + auto pattern2 = Broadcast::make(Ramp::make(wild_i32, wild_i32, r2->lanes), b2->lanes); + if (!expr_match(pattern2, second, matches)) { return {}; } + base += std::move(matches[0]); + Expr r_stride = std::move(matches[1]); + + auto pattern3 = Broadcast::make(wild_i32, b1->lanes * r1->lanes); + if (!expr_match(pattern3, adj, matches)) { return {}; } + base -= std::move(matches[0]); + + return {true, base, {x_stride, 0, r_stride}, {x_tile, y_tile, r_tile}}; +} + +struct NewMatmul { + bool result = false; + Stmt stmt; + int tile_x; + int tile_y; + int tile_r; +}; + +NewMatmul +convert_to_matmul(const Store *op, const string &new_name) { + // m[ramp(0, 1, S)] = VectorAdd(lhs[{XYR tile}] * xX(rhs[{YR tile}])) + m[ramp(0, 1, S)] + const auto wild_i8x = Variable::make(Int(8, 0), "*"); + const auto wild_i16x = Variable::make(Int(16, 0), "*"); + vector matches; + const auto pattern1 = wild_i32x + wild_i32x; + if (!expr_match(pattern1, op->value, matches)) { return {}; } + const auto *reduce = matches[0].as(); + const auto *load = matches[1].as(); + if (!reduce || reduce->op != VectorReduce::Add) { return {}; } + if (!load || load->name != op->name || !equal(load->index, op->index)) { return {}; } + + // FIXME: Add support for uint8 and bf16 for LLVM 13+ + auto pattern2 = cast(Int(32, 0), cast(Int(16, 0), wild_i8x) * wild_i16x); + if (!expr_match(pattern2, reduce->value, matches)) { return {}; } + const auto *lhs_load = matches[0].as(); + // FIXME: When tile_r is not 4 the broadcast is inside the index, not of the value + const auto *rhs_broadcast = matches[1].as(); + if (!lhs_load || !rhs_broadcast) { return {}; } + const auto *rhs_cast = rhs_broadcast->value.as(); + if (!rhs_cast || rhs_cast->value.type().element_of() != Int(8)) { return {}; } + const auto *rhs_load = rhs_cast->value.as(); + if (!rhs_load) { return {}; } + + const auto lhs_tile = is_3d_tile_index(lhs_load->index); + const auto rhs_tile = is_2d_tile_index(rhs_load->index); + // FIXME: When tile_r is not 4 the RHS load will be 4D (x, r/4, y, r%4) + if (!lhs_tile.result || !rhs_tile.result) { return {}; } + + const int tile_x = lhs_tile.extent[0]; + const int tile_y = lhs_tile.extent[1]; + const int tile_r = lhs_tile.extent[2]; + const int factor = reduce->value.type().lanes() / reduce->type.lanes(); + if (op->index.type().lanes() != tile_x * tile_y || + factor != tile_r || + tile_y != rhs_tile.extent[0] || + tile_r != rhs_tile.extent[1]) { + return {}; + } + + // {rows, colbytes, var, index} + auto lhs_var = Variable::make(Handle(), lhs_load->name); + auto lhs = Call::make(Int(8, 1024), "tile_load", {tile_x, tile_r, lhs_var, lhs_tile.base, lhs_tile.stride[0]}, Call::Intrinsic); + auto rhs_var = Variable::make(Handle(), rhs_load->name); + auto rhs = Call::make(Int(8, 1024), "tile_load", {1, tile_y * tile_r, rhs_var, rhs_tile.base, rhs_tile.stride[0]}, Call::Intrinsic); + + // {rows, colbytes, acc, out, lhs, rhs} + auto out = Load::make(Int(32, 256), new_name, Ramp::make(0, 1, 256), {}, {}, const_true(256), {}); + auto colbytes = tile_y * 32 / rhs_load->type.bits(); + auto matmul = Call::make(Int(32, 256), "tile_matmul", {tile_x, colbytes, tile_r, out, lhs, rhs}, Call::Intrinsic); + auto store = Store::make(new_name, matmul, Ramp::make(0, 1, 256), Parameter(), const_true(256), ModulusRemainder()); + return {true, std::move(store), tile_x, tile_y, tile_r}; +} + +Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const string &new_name) { + if (const auto *ramp = op->index.as()) { + if (const auto *bcast = op->value.as()) { + if (is_const_one(ramp->stride) && + is_const_zero(bcast->value) && + (bcast->lanes == tile_x * tile_y)) { + auto rows = Cast::make(Int(16), tile_x); + auto bytes = op->value.type().bytes(); + auto colbytes = Cast::make(Int(16), tile_y * bytes); + auto val = Call::make(Int(32, 256), "tile_zero", {rows, colbytes}, Call::Intrinsic); + auto store = Store::make(new_name, val, Ramp::make(0, 1, 256), Parameter(), const_true(256), ModulusRemainder()); + return store; + } + } + } + return {}; +} + +Stmt convert_to_tile_store(const Store *op, const string &amx_name, int tile_x, int tile_y) { + auto tile = is_2d_tile_index(op->index); + if (tile.result && tile.extent[0] == tile_x && tile.extent[1] == tile_y) { + auto out = Variable::make(Handle(), op->name); + auto tile_val = Load::make(Int(32, 256), amx_name, Ramp::make(0, 1, 256), {}, {}, const_true(256), {}); + auto bytes = op->value.type().bytes(); + internal_assert(bytes == 4) << "AMX store only supported for int32 and float32, not for " << op->value.type() << "\n"; + // {tile_x, tile_y, var, base, stride} + auto store = Call::make(Bool(2), "tile_store", {tile_x, tile_y * bytes, out, tile.base * bytes, tile.stride[0] * bytes, tile_val}, Call::Intrinsic); + return Evaluate::make(store); + } + return {}; +} + +class ExtractTileOperations : public IRMutator { + using IRMutator::visit; + + string tile_name; + string amx_name; + vector pending_stores; + bool in_allocate = false; + int found_tile_x = -1; + int found_tile_y = -1; + int found_tile_r = -1; + + Stmt visit(const Allocate *op) override { + if (op->memory_type == MemoryType::AMXTile && + op->type.is_int() && + op->type.bits() == 32) { + // FIXME: Handle nested allocations better + user_assert(!in_allocate) << "Found two possible tile allocations for AMX allocation"; + ScopedValue old_amx_name(amx_name, op->name + ".amx"); + ScopedValue old_tile_name(tile_name, op->name); + ScopedValue old_in_alloc(in_allocate, true); + Stmt body = op->body; + + pending_stores.clear(); + body = mutate(body); + if (found_tile_x < 0 || found_tile_y < 0 || found_tile_r < 0) { + return op; + } + if (!pending_stores.empty()) { + // Really only need to go over the pending stores + body = mutate(body); + } + + return Allocate::make(amx_name, Int(32, 256), MemoryType::AMXTile, {1}, const_true(), body); + } + return IRMutator::visit(op); + } + + Stmt visit(const Free *op) override { + if (op->name != tile_name) { + return op; + } + return Free::make(amx_name); + } + + Stmt visit(const ProducerConsumer *op) override { + if (op->name != tile_name) { + return IRMutator::visit(op); + } + + auto body = mutate(op->body); + return ProducerConsumer::make(amx_name, op->is_producer, body); + } + + Expr visit(const Load *op) override { + // Any tile load will be matched elsewhere, so a load here means that + // the AMX tile is used outside of a tile instruction. + user_assert(op->name != tile_name) << "AMX tile allocation used outside a tile instruction"; + return IRMutator::visit(op); + } + + Stmt visit(const Store *op) override { + if (op->name != tile_name) { + const auto *load = op->value.as(); + if (!load || load->name != tile_name) { + return op; + } + auto store = convert_to_tile_store(op, amx_name, found_tile_x, found_tile_y); + user_assert(store.defined()) << "Store to AMX tile allocation of a non-tile value"; + return store; + } + + auto matmul = convert_to_matmul(op, amx_name); + if (matmul.result) { + user_assert( + (found_tile_x < 0 || matmul.tile_x == found_tile_x) && + (found_tile_x < 0 || matmul.tile_x == found_tile_x) && + (found_tile_x < 0 || matmul.tile_x == found_tile_x)) + << "Found different tile sizes for AMX tile allocation"; + found_tile_x = matmul.tile_x; + found_tile_y = matmul.tile_y; + found_tile_r = matmul.tile_r; + return matmul.stmt; + } + + if (found_tile_x < 0 || found_tile_y < 0) { + pending_stores.emplace_back(op); + return op; + } + + auto zero = convert_to_zero(op, found_tile_x, found_tile_y, amx_name); + if (zero.defined()) { + return zero; + } + + // Otherwise there is some other operation using the allocation, so we cannot use the AMX instructions + user_assert(false) << "Found non-tile operations for AMX tile allocation"; + return op; + } +}; + +} // namespace + +Stmt extract_tile_operations(const Stmt &s) { + return ExtractTileOperations().mutate(s); +} + +} // namespace Internal +} // namespace Halide diff --git a/src/ExtractTileOperations.h b/src/ExtractTileOperations.h new file mode 100644 index 000000000000..918e3b1b9940 --- /dev/null +++ b/src/ExtractTileOperations.h @@ -0,0 +1,21 @@ +#ifndef HALIDE_EXTRACT_TILE_OPERATIONS_H +#define HALIDE_EXTRACT_TILE_OPERATIONS_H + +/** \file + * Defines the lowering pass that injects calls to tile intrinsics that support + * AMX instructions. + */ + +#include "Expr.h" + +namespace Halide { +namespace Internal { + +/** Rewrite any AMX tile operations that have been stored in the AMXTile memory + * type as intrinsic calls, to be used in the X86 backend. */ +Stmt extract_tile_operations(const Stmt &s); + +} // namespace Internal +} // namespace Halide + +#endif diff --git a/src/FuseGPUThreadLoops.cpp b/src/FuseGPUThreadLoops.cpp index 7fa67ac2192f..0c58e318a86c 100644 --- a/src/FuseGPUThreadLoops.cpp +++ b/src/FuseGPUThreadLoops.cpp @@ -1279,6 +1279,7 @@ class InjectThreadBarriers : public IRMutator { case MemoryType::Register: case MemoryType::LockedCache: case MemoryType::VTCM: + case MemoryType::AMXTile: break; } @@ -1303,6 +1304,7 @@ class InjectThreadBarriers : public IRMutator { case MemoryType::Register: case MemoryType::LockedCache: case MemoryType::VTCM: + case MemoryType::AMXTile: break; } diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index 62fdb705997b..38e2eeefd511 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -135,6 +135,9 @@ std::ostream &operator<<(std::ostream &out, const MemoryType &t) { case MemoryType::VTCM: out << "VTCM"; break; + case MemoryType::AMXTile: + out << "AMXTile"; + break; } return out; } diff --git a/src/Lower.cpp b/src/Lower.cpp index a7227b6ec76d..4a8ec2df34b5 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -22,6 +22,7 @@ #include "DebugToFile.h" #include "Deinterleave.h" #include "EarlyFree.h" +#include "ExtractTileOperations.h" #include "FindCalls.h" #include "FindIntrinsics.h" #include "FlattenNestedRamps.h" @@ -385,6 +386,15 @@ Module lower(const vector &output_funcs, s = lower_unsafe_promises(s, t); log("Lowering after lowering unsafe promises:", s); +#if LLVM_VERSION >= 12 + if (t.has_feature(Target::AVX512_SapphireRapids)) { + debug(1) << "Extracting tile operations...\n"; + s = extract_tile_operations(s); + debug(2) << "Lowering after extracting tile operations:\n" + << s << "\n\n"; + } +#endif + debug(1) << "Flattening nested ramps...\n"; s = flatten_nested_ramps(s); log("Lowering after flattening nested ramps:", s); diff --git a/src/runtime/x86_avx512.ll b/src/runtime/x86_avx512.ll index 904fabe9368e..7217caeb9f01 100644 --- a/src/runtime/x86_avx512.ll +++ b/src/runtime/x86_avx512.ll @@ -90,3 +90,41 @@ define weak_odr <4 x i32> @dpwssdx4(<4 x i32> %init, <8 x i16> %a, <8 x i16> %b ret <4 x i32> %3 } declare <4 x i32> @llvm.x86.avx512.vpdpwssd.128(<4 x i32>, <4 x i32>, <4 x i32>) + +define weak_odr <1024 x i8> @tileloadd64_i8(i16 %rows, i16 %colbytes, i8* %ptr, i64 %off, i64 %stride) nounwind alwaysinline readonly { + %1 = getelementptr i8, i8* %ptr, i64 %off + %2 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %rows, i16 %colbytes, i8* %1, i64 %stride) nounwind readonly + %3 = bitcast x86_amx %2 to <1024 x i8> + ret <1024 x i8> %3 +} +declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) + +define weak_odr <256 x i32> @tdpbssd(i16 %rows, i16 %colbytes, i16 %acc, <256 x i32> %out, <1024 x i8> %lhs, <1024 x i8> %rhs) nounwind alwaysinline readnone { + %1 = bitcast <1024 x i8> %lhs to x86_amx + %2 = bitcast <1024 x i8> %rhs to x86_amx + %3 = bitcast <256 x i32> %out to x86_amx + %4 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %rows, i16 %colbytes, i16 %acc, x86_amx %3, x86_amx %1, x86_amx %2) nounwind readnone + %5 = bitcast x86_amx %4 to <256 x i32> + ret <256 x i32> %5 +} +declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) + +define weak_odr <2 x i1> @tilestored64(i16 %rows, i16 %cols, i8* %ptr, i64 %off, i64 %stride, <256 x i32> %val) nounwind alwaysinline writeonly { + %1 = getelementptr i8, i8* %ptr, i64 %off + %2 = bitcast <256 x i32> %val to x86_amx + tail call void @llvm.x86.tilestored64.internal(i16 %rows, i16 %cols, i8* %1, i64 %stride, x86_amx %2) nounwind writeonly + ret <2 x i1> zeroinitializer +} +declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) + +; NB: Even though this should be readnone, that will cause LLVM to try to +; generate a single zero tile, and copy it each time it is used. However the AMX +; registers cannot be copied, so this causes compilation failures: +; LLVM ERROR: Cannot emit physreg copy instruction +; renamable $tmm1 = COPY renamable $tmm0 +define weak_odr <256 x i32> @tilezero_i32(i16 %rows, i16 %colbytes) nounwind alwaysinline { + %1 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %rows, i16 %colbytes) nounwind + %2 = bitcast x86_amx %1 to <256 x i32> + ret <256 x i32> %2 +} +declare x86_amx @llvm.x86.tilezero.internal(i16, i16) diff --git a/test/performance/CMakeLists.txt b/test/performance/CMakeLists.txt index 65aa41da00f3..80f58a16afae 100644 --- a/test/performance/CMakeLists.txt +++ b/test/performance/CMakeLists.txt @@ -1,5 +1,6 @@ tests(GROUPS performance SOURCES + tiled_matmul.cpp async_gpu.cpp block_transpose.cpp boundary_conditions.cpp diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp new file mode 100644 index 000000000000..fbcd4b292942 --- /dev/null +++ b/test/performance/tiled_matmul.cpp @@ -0,0 +1,116 @@ +#include "Halide.h" +#include "halide_benchmark.h" +#include "halide_test_dirs.h" + +#include +#include + +using namespace Halide; + +int main(int argc, char **argv) { + Target target = get_jit_target_from_environment(); + if (!target.has_feature(Target::AVX512_SapphireRapids)) { + std::cout << "[SKIP] The tiled matmul test is only designed to test AMX support.\n"; + return 0; + } + const int row = 16; + const int col = 16; + const int acc = 16; + + Var x("x"), y("y"); + ImageParam A(Int(8), 2, "lhs"); + // NB the RHS matrix in AMX instructions should be tiled in "VNNI format", + // where instead of being (cols, rows) where rows are adjacent in memory it + // should be (4, cols, rows / 4) for int8, or (2, cols, rows / 2) for bf16. + // This means that the rows must always be divisible by 4 (or 2 for bf16). + ImageParam B(Int(8), 3, "rhs"); + + RDom r(0, acc); + + Func mm("matmul"); + mm(y, x) = cast(0); + mm(y, x) += cast(A(r.x, x)) * B(r.x % 4, y, r.x / 4); + + // Ensure all (x, y) tile sizes are the same so that loops are fused. + int tile_y = 8; + int tile_x = 6; + int tile_r = 4; + + // Schedule the reduction + Var rxi("rxi"), ryi("ryi"); + RVar rri("rri"), rro("rro"); + mm.compute_at(mm.in(), y) + .store_in(MemoryType::AMXTile) + .update() + // Split into (x,y) tile + .tile(y, x, ryi, rxi, tile_y, tile_x, TailStrategy::GuardWithIf) + // Split reduction dim by tile_r + .split(r.x, rro, rri, tile_r) + // Reorder so that the (x,y) tile is inside the inner ro loop + .reorder({rri, ryi, rxi, rro, y, x}) + .atomic() + .vectorize(rri) + .vectorize(ryi) + .vectorize(rxi); + + // Schedule the initialization + Var ixi("ixi"), iyi("iyi"); + mm.compute_at(mm.in(), y) + .tile(y, x, iyi, ixi, tile_y, tile_x) + .vectorize(iyi) + .vectorize(ixi); + + // Schedule the consumer + Var mmxi("mmxi"), mmyi("mmyi"); + mm.in() + .tile(y, x, mmyi, mmxi, tile_y, tile_x) + .vectorize(mmyi) + .vectorize(mmxi); + + Buffer a_buf(acc, row); + for (int iy = 0; iy < row; iy++) { + for (int ix = 0; ix < acc; ix++) { + a_buf(ix, iy) = rand() % 256 - 128; + } + } + A.set(a_buf); + + Buffer b_buf(4, col, acc / 4); + for (int iy = 0; iy < acc / 4; iy++) { + for (int ix = 0; ix < col; ix++) { + for (int ik = 0; ik < 4; ++ik) { + b_buf(ik, ix, iy) = rand() % 256 - 128; + } + } + } + B.set(b_buf); + + Buffer out(col, row); + + Func result = mm.in(); + + // Uncomment to check the asm + //result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.ll", {A, B}, target); + //result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, target); + + auto time = Tools::benchmark(20, 20, [&]() { + result.realize(out); + }); + std::cout << "Exec time: " << time << "\n"; + + for (int j = 0; j < row; ++j) { + for (int i = 0; i < col; ++i) { + int32_t val = 0; + for (int k = 0; k < acc; ++k) { + val += a_buf(k, j) * b_buf(k % 4, i, k / 4); + } + if (val != out(i, j)) { + std::cerr << "Invalid result at " << i << ", " << j << "\n" + << out(i, j) << " != " << val << "\n"; + return 1; + } + } + } + std::cout << "Success!\n"; + return 0; +}