Skip to content
Closed
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ SOURCE_FILES = \
EmulateFloat16Math.cpp \
Error.cpp \
Expr.cpp \
ExtractTileOperations.cpp \
FastIntegerDivide.cpp \
FindCalls.cpp \
FindIntrinsics.cpp \
Expand Down Expand Up @@ -626,6 +627,7 @@ HEADER_FILES = \
ExprUsesVar.h \
Extern.h \
ExternFuncArgument.h \
ExtractTileOperations.h \
FastIntegerDivide.h \
FindCalls.h \
FindIntrinsics.h \
Expand Down
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ set(HEADER_FILES
ExprUsesVar.h
Extern.h
ExternFuncArgument.h
ExtractTileOperations.h
FastIntegerDivide.h
FindCalls.h
FindIntrinsics.h
Expand Down Expand Up @@ -218,6 +219,7 @@ set(SOURCE_FILES
EmulateFloat16Math.cpp
Error.cpp
Expr.cpp
ExtractTileOperations.cpp
FastIntegerDivide.cpp
FindCalls.cpp
FindIntrinsics.cpp
Expand Down
55 changes: 53 additions & 2 deletions src/CodeGen_X86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,22 +79,32 @@ class CodeGen_X86 : public CodeGen_Posix {
void visit(const EQ *) override;
void visit(const NE *) override;
void visit(const Select *) override;
void visit(const Allocate *) override;
void visit(const Load *) override;
void visit(const Store *) override;
void codegen_vector_reduce(const VectorReduce *, const Expr &init) override;
// @}

private:
Scope<MemoryType> mem_type;
};

CodeGen_X86::CodeGen_X86(Target t)
: CodeGen_Posix(complete_x86_target(t)) {
}

const int max_intrinsic_args = 4;
const int max_intrinsic_args = 6;

struct x86Intrinsic {
const char *intrin_name;
halide_type_t ret_type;
const char *name;
halide_type_t arg_types[max_intrinsic_args];
Target::Feature feature = Target::FeatureEnd;
uint32_t flags = 0;
enum Options {
AccessesMemory = 1 << 0,
};
};

// clang-format off
Expand Down Expand Up @@ -184,6 +194,13 @@ const x86Intrinsic intrinsic_defs[] = {
{"dpwssdx16", Int(32, 16), "dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_SapphireRapids},
{"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids},
{"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids},

{"tileloadd64_i8", Int(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory},
{"tdpbssd", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), Int(8, 1024), Int(8, 1024)}, Target::AVX512_SapphireRapids},
{"tilezero_i32", Int(32, 256), "tile_zero", {Int(16), Int(16)}, Target::AVX512_SapphireRapids},
// CodeGen_LLVM cannot cope with returning Type() ie void*, and return type needs to be vector to trigger call_overloaded_intrin
{"tilestored64", Bool(2), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Int(32, 256)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory},

};
// clang-format on

Expand All @@ -206,7 +223,9 @@ void CodeGen_X86::init_module() {
}

auto *fn = declare_intrin_overload(i.name, ret_type, i.intrin_name, std::move(arg_types));
fn->addFnAttr(llvm::Attribute::ReadNone);
if ((i.flags & x86Intrinsic::AccessesMemory) == 0) {
fn->addFnAttr(llvm::Attribute::ReadNone);
}
fn->addFnAttr(llvm::Attribute::NoUnwind);
}
}
Expand Down Expand Up @@ -576,6 +595,38 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init
CodeGen_Posix::codegen_vector_reduce(op, init);
}

void CodeGen_X86::visit(const Allocate *op) {
ScopedBinding<MemoryType> bind(mem_type, op->name, op->memory_type);
CodeGen_Posix::visit(op);
}

void CodeGen_X86::visit(const Load *op) {
if (mem_type.contains(op->name) && mem_type.get(op->name) == MemoryType::AMXTile) {
const Ramp *ramp = op->index.as<Ramp>();
internal_assert(ramp) << "Expected AMXTile to have index ramp\n";
Value *ptr = codegen_buffer_pointer(op->name, op->type, ramp->base);
LoadInst *load = builder->CreateAlignedLoad(ptr, llvm::Align(op->type.bytes()));
add_tbaa_metadata(load, op->name, op->index);
value = load;
return;
}
CodeGen_Posix::visit(op);
}

void CodeGen_X86::visit(const Store *op) {
if (mem_type.contains(op->name) && mem_type.get(op->name) == MemoryType::AMXTile) {
Value *val = codegen(op->value);
Halide::Type value_type = op->value.type();
const Ramp *ramp = op->index.as<Ramp>();
internal_assert(ramp) << "Expected AMXTile to have index ramp\n";
Value *ptr = codegen_buffer_pointer(op->name, value_type, ramp->base);
StoreInst *store = builder->CreateAlignedStore(val, ptr, llvm::Align(value_type.bytes()));
add_tbaa_metadata(store, op->name, op->index);
return;
}
CodeGen_Posix::visit(op);
}

string CodeGen_X86::mcpu() const {
if (target.has_feature(Target::AVX512_SapphireRapids)) {
#if LLVM_VERSION >= 120
Expand Down
4 changes: 4 additions & 0 deletions src/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,10 @@ enum class MemoryType {
* intermediate buffers. Necessary for vgather-vscatter instructions
* on Hexagon */
VTCM,

/** AMX Tile register for X86. Any data that would be used in an AMX matrix
* multiplication must first be loaded into an AMX tile register. */
AMXTile,
};

namespace Internal {
Expand Down
Loading