diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index e8bcc028fc58..70526bde6639 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -797,6 +797,24 @@ TVM_DLL const Op& start_profile_intrinsic(); */ TVM_DLL const Op& end_profile_intrinsic(); +// Intrinsics for the VK_NV_cooperative_matrix Vulkan extension. + +/*! \brief The intrinsic corresponding to the OpCooperativeMatrixLoadNV instruction. */ +TVM_DLL const Op& cooperative_matrix_load_NV(); + +/*! \brief The intrinsic corresponding to the OpCooperativeMatrixStoreNV instruction. */ +TVM_DLL const Op& cooperative_matrix_store_NV(); + +/*! + * \brief Create a new cooperative matrix filled with the provided value. + * + * There is no such instruction in the extension, but it is added for convenience. + */ +TVM_DLL const Op& cooperative_matrix_fill_NV(); + +/*! \brief The intrinsic corresponding to the OpCooperativeMatrixMulAddNV instruction. */ +TVM_DLL const Op& cooperative_matrix_mad_NV(); + /*! \brief The kind of structure field info used in intrinsic */ enum TVMStructFieldKind : int { // array head address diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index c8285ccc52ce..19477a8ee780 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1820,6 +1820,10 @@ def wrapped(*args, **kwargs): TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace) start_profile_intrinsic = _op_wrapper(_tir_op.start_profile_intrinsic) end_profile_intrinsic = _op_wrapper(_tir_op.end_profile_intrinsic) +cooperative_matrix_load_NV = _op_wrapper(_tir_op.cooperative_matrix_load_NV) +cooperative_matrix_store_NV = _op_wrapper(_tir_op.cooperative_matrix_store_NV) +cooperative_matrix_fill_NV = _op_wrapper(_tir_op.cooperative_matrix_fill_NV) +cooperative_matrix_mad_NV = _op_wrapper(_tir_op.cooperative_matrix_mad_NV) def _dtype_forward(func): @@ -2144,4 +2148,8 @@ def wrapped(*args, **kwargs): "IterVar", "CommReducer", "Range", + "cooperative_matrix_load_NV", + "cooperative_matrix_store_NV", + "cooperative_matrix_fill_NV", + "cooperative_matrix_mad_NV", ] diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 419ab2275858..f0523dbe78b8 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -3037,6 +3037,148 @@ def TVMBackendFreeWorkspace(device_type, device_id, ptr): return call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, device_id, ptr) +# Intrinsics for the VK_NV_cooperative_matrix Vulkan extension. + + +def cooperative_matrix_load_NV(buffer_mat, offset, src, rows, cols, stride, column_major): + """The intrinsic corresponding to the OpCooperativeMatrixLoadNV instruction. + + Parameters + ---------- + buffer_mat: Var + The destination buffer with "cooperative_matrix_nv" scope. + + offset: IntImm + The element offset for the matrix to be loaded in the destination buffer. + + src: Expr + The source pointer expression. + + rows: int + The number of rows in the matrix. + + cols: int + The number of columns in the matrix. + + stride: Expr + The stride of the matrix in the number of elements. + + column_major: bool + Whether the matrix elements are stored in the column-major order. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + "handle", + "tir.cooperative_matrix_load_NV", + buffer_mat, + offset, + src, + rows, + cols, + stride, + column_major, + ) + + +def cooperative_matrix_store_NV(dst, buffer_mat, offset, stride, column_major): + """The intrinsic corresponding to the OpCooperativeMatrixStoreNV instruction. + + Parameters + ---------- + dst: Expr + The destination pointer expression. + + buffer_mat: Var + The source buffer with "cooperative_matrix_nv" scope. + + offset: IntImm + The element offset for the matrix to be stored in the destination buffer. + + stride: Expr + The stride of the matrix in the number of elements. + + column_major: bool + Whether the matrix elements are stored in the column-major order. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + "handle", "tir.cooperative_matrix_store_NV", dst, buffer_mat, offset, stride, column_major + ) + + +def cooperative_matrix_fill_NV(buffer_mat, offset, rows, cols, value): + """Create a new cooperative matrix filled with the provided value. + + There is no such instruction in the extension, but it is added for convenience. + + Parameters + ---------- + + buffer_mat: Var + The buffer with "cooperative_matrix_nv" scope to be filled. + + offset: IntImm + The element offset for the matrix to be filled in `buffer_mat`. + + rows: int + The number of rows in the matrix. + + cols: int + The number of columns in the matrix. + + value: Expr + The value the matrix will be filled with. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + "handle", "tir.cooperative_matrix_fill_NV", buffer_mat, offset, rows, cols, value + ) + + +def cooperative_matrix_mad_NV(A, A_off, B, B_off, C, C_off): + """The intrinsic corresponding to the OpCooperativeMatrixMulAddNV instruction. + + Parameters + ---------- + + A : Var + The buffer with "cooperative_matrix_nv" scope corresponding to the "A" matrix. + + A_off : IntImm + The element offset of the "A" matrix in the buffer `A`. + + B : Var + The buffer with "cooperative_matrix_nv" scope corresponding to the "B" matrix. + + B_off : IntImm + The element offset of the "B" matrix in the buffer `B`. + + C : Var + The buffer with "cooperative_matrix_nv" scope corresponding to the "C" matrix. + + C_off : IntImm + The element offset of the "C" matrix in the buffer `C`. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("handle", "tir.cooperative_matrix_mad_NV", A, A_off, B, B_off, C, C_off) + + # pylint: disable=unnecessary-lambda sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum") min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 51dba038b6ac..cc40ab7592b0 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -64,6 +64,11 @@ enum class StorageRank { kTexture = 7, /*! \brief global scope amx tmm memory */ kAMXTMM = 8, + /*! + * \brief Scope representing the cooperative matrix in the Vulkan + * VK_NV_cooperative_matrix extension. + */ + kCooperativeMatrixNV = 9, }; /*! @@ -154,6 +159,9 @@ struct StorageScope { } else if (s.compare(0, 7, "amx.tmm") == 0) { r.rank = StorageRank::kAMXTMM; r.tag = s.substr(7, std::string::npos); + } else if (s == "cooperative_matrix_nv") { + r.rank = StorageRank::kCooperativeMatrixNV; + r.tag = ""; } else { LOG(FATAL) << "unknown storage scope " << s; } diff --git a/src/runtime/vulkan/vulkan_device.cc b/src/runtime/vulkan/vulkan_device.cc index b3e017d03418..51120f5ac30d 100644 --- a/src/runtime/vulkan/vulkan_device.cc +++ b/src/runtime/vulkan/vulkan_device.cc @@ -133,6 +133,7 @@ VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, !support::BoolEnvironmentVar("TVM_VULKAN_DISABLE_DEDICATED_ALLOCATION"); supports_integer_dot_product = device.HasExtension("VK_KHR_shader_integer_dot_product"); + supports_cooperative_matrix_nv = device.HasExtension("VK_NV_cooperative_matrix"); // The check of VK_SHADER_STAGE_COMPUTE_BIT isn't technically // needed, since it will be set so long at least one queue has @@ -435,7 +436,8 @@ std::vector VulkanDevice::SelectEnabledExtensions() const { "VK_KHR_get_memory_requirements2", "VK_KHR_dedicated_allocation", "VK_KHR_spirv_1_4", - "VK_KHR_shader_integer_dot_product"}; + "VK_KHR_shader_integer_dot_product", + "VK_NV_cooperative_matrix"}; uint32_t device_extension_prop_count; VULKAN_CALL(vkEnumerateDeviceExtensionProperties(physical_device_, nullptr, diff --git a/src/runtime/vulkan/vulkan_device.h b/src/runtime/vulkan/vulkan_device.h index 59ebf430e6e6..3ba6de40d720 100644 --- a/src/runtime/vulkan/vulkan_device.h +++ b/src/runtime/vulkan/vulkan_device.h @@ -88,6 +88,7 @@ struct VulkanDeviceProperties { bool supports_push_descriptor{false}; bool supports_dedicated_allocation{false}; bool supports_integer_dot_product{false}; + bool supports_cooperative_matrix_nv{false}; uint32_t supported_subgroup_operations{0}; uint32_t max_num_threads{1}; uint32_t thread_warp_size{1}; diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index 93f017a5aa66..f2ce95b42d35 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -241,6 +241,10 @@ void VulkanDeviceAPI::GetTargetProperty(Device dev, const std::string& property, *rv = prop.supports_integer_dot_product; } + if (property == "supports_cooperative_matrix_nv") { + *rv = prop.supports_cooperative_matrix_nv; + } + if (property == "device_name") { *rv = prop.device_name; } diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index e3ef5acb8331..4896b9b0e8eb 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -28,6 +28,7 @@ #include #include +#include #include "../../runtime/pack_args.h" #include "../../runtime/vulkan/vulkan_common.h" @@ -395,6 +396,74 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { LOG(FATAL) << "SPIR-V shader cannot make extern calls. Graph contains extern \"" << Downcast(op->args[0]) << "\""; return spirv::Value(); + } else if (op->op.same_as(builtin::address_of())) { + const BufferLoadNode* load = op->args[0].as(); + ICHECK(op->args.size() == 1 && load); + ICHECK_EQ(load->indices.size(), 1) << "CodeGenSPIRV only supports flat memory allocations."; + auto buffer_var = Downcast(load->buffer->data); + auto it = storage_info_.find(buffer_var.get()); + ICHECK(it != storage_info_.end()); + StorageInfo& info = it->second; + spirv::SType content_type = builder_->GetSType(info.element_type); + auto buffer = MakeValue(buffer_var); + spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class); + return builder_->StructArrayAccess(ptr_type, buffer, MakeValue(load->indices[0])); + } else if (op->op.same_as(builtin::cooperative_matrix_load_NV())) { + auto ptr = Downcast(op->args[0]); + auto elem_offset = op->args[1].as(); + auto src_ptr = MakeValue(op->args[2]); + const int rows = op->args[3].as()->value; + const int cols = op->args[4].as()->value; + auto stride = MakeValue(op->args[5]); + auto column_major = MakeValue(op->args[6]); + auto elem_ty = builder_->GetBufferElementType(ptr); + auto mat_ty = builder_->GetCooperativeMatrixNVType(elem_ty, rows, cols); + auto mat = builder_->CallCooperativeMatrixLoadNV(mat_ty, src_ptr, stride, column_major); + ICHECK(elem_offset) << "Expects a constant element offset."; + builder_->SetCooperativeMatrix(ptr, elem_offset->value, mat); + return mat; + } else if (op->op.same_as(builtin::cooperative_matrix_store_NV())) { + auto dst_ptr = MakeValue(op->args[0]); + auto buffer_var_mat = Downcast(op->args[1]); + auto elem_offset = op->args[2].as(); + ICHECK(elem_offset) << "Expects a constant element offset."; + auto mat = builder_->GetCooperativeMatrix(buffer_var_mat, elem_offset->value); + auto stride = MakeValue(op->args[3]); + auto column_major = MakeValue(op->args[4]); + builder_->CallCooperativeMatrixStoreNV(dst_ptr, mat, stride, column_major); + return spirv::Value(); + } else if (op->op.same_as(builtin::cooperative_matrix_fill_NV())) { + auto ptr = Downcast(op->args[0]); + auto elem_offset = op->args[1].as(); + ICHECK(elem_offset) << "Expects a constant element offset."; + const int rows = op->args[2].as()->value; + const int cols = op->args[3].as()->value; + auto v = MakeValue(op->args[4]); + auto elem_ty = builder_->GetBufferElementType(ptr); + auto mat_ty = builder_->GetCooperativeMatrixNVType(elem_ty, rows, cols); + auto filled = builder_->CallCooperativeMatrixFillNV(mat_ty, v); + builder_->SetCooperativeMatrix(ptr, elem_offset->value, filled); + return filled; + } else if (op->op.same_as(builtin::cooperative_matrix_mad_NV())) { + auto A_elem_offset = op->args[1].as(); + auto B_elem_offset = op->args[3].as(); + auto C_elem_offset = op->args[5].as(); + ICHECK(A_elem_offset) << "Expects a constant element offset."; + ICHECK(B_elem_offset) << "Expects a constant element offset."; + ICHECK(C_elem_offset) << "Expects a constant element offset."; + + auto get_matrix = [this](PrimExpr arg, int offset) { + auto buffer_var_mat = Downcast(arg); + return builder_->GetCooperativeMatrix(buffer_var_mat, offset); + }; + + auto A = get_matrix(op->args[0], A_elem_offset->value); + auto B = get_matrix(op->args[2], B_elem_offset->value); + auto C = get_matrix(op->args[4], C_elem_offset->value); + auto acc = builder_->CallCooperativeMatrixMadNV(A, B, C); + auto C_ptr = Downcast(op->args[4]); + builder_->SetCooperativeMatrix(C_ptr, C_elem_offset->value, acc); + return acc; } else { LOG(FATAL) << "Unresolved call " << op->op; } @@ -561,6 +630,55 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { // Loop head builder_->StartLabel(head_label); + + // In a normal matmul, the update semantics c += a * b is implemented by load and store to the + // memory location corresponding to the scalar c. But when C, A, and B are cooperative matrices, + // C += A * B cannot be implemented in the same way, since they cannot be load from or stored in + // a buffer. + + // We leverage SPIRV's SSA semantics to implement multiply-add update on C matrices. + // The original buffer is subdivided into fixed-size matrices, and multiply-add on individual + // matrix is unrolled. For example, if the buffer for the C matrix is of shape 64x64, we + // materialze 16 16x16 matrices and generate 16 multiply-add instructions. Each matrix is + // represented by a phi value, and each iteration of the reduction loop reads the phi values for + // the C matrices and update them. + + // A map from a buffer with kCooperativeMatrixNV storage scope (whose size can be arbitrary) + // to cooperative matrices "C" (whose sizes are fixed, e.g. 16x16). + // Each matrix is identified by an element offset into the buffer. This encoding lets us bridge + // the gap between a typical TIR matmul schedule (which uses + // `cache_read(..., "cooperative_matrix_nv")` from arbitrary-sized shared memory) and + // the fixed-size matrices as expected by the extension. + std::unordered_map> accum_matrices; + + if (op->kind == ForKind::kSerial) { + // If this is a serial loop, record which C matrices are updated in this loop. + tir::PostOrderVisit(op->body, [&accum_matrices](const ObjectRef& obj) { + if (auto call = obj.as(); + call && call->op.same_as(builtin::cooperative_matrix_mad_NV())) { + auto C_elem_offset = call->args[5].as(); + ICHECK(C_elem_offset) << "Expects a constant element offset."; + auto buffer_var_C = Downcast(call->args[4]); + accum_matrices[buffer_var_C.get()].insert(C_elem_offset->value); + } + }); + } + + // Phi values for all cooperative matrices "C". + std::vector cooperative_matrix_phis; + + // Initialize phi values and use them as the matrix "C" in this loop. + for (const auto& [var, elem_offsets] : accum_matrices) { + Var buffer_var_mat = GetRef(var); + for (auto offset : elem_offsets) { + auto mat = builder_->GetCooperativeMatrix(buffer_var_mat, offset); + auto mat_phi = builder_->MakePhi(mat.stype, 2); + mat_phi.SetIncoming(0, mat, init_label); + cooperative_matrix_phis.push_back(mat_phi); + builder_->SetCooperativeMatrix(buffer_var_mat, offset, mat_phi); + } + } + spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2); loop_var.SetIncoming(0, init_value, init_label); spirv::Value loop_cond = builder_->LT(loop_var, extent_value); @@ -581,8 +699,22 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { spirv::Value one = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1) : builder_->UIntImm(loop_var.stype, 1); spirv::Value next_value = builder_->Add(loop_var, one); - loop_var.SetIncoming(1, next_value, builder_->CurrentLabel()); + loop_var.SetIncoming(1, next_value, continue_label); builder_->MakeInst(spv::OpBranch, head_label); + + // Update phi values to the new value of the matrix "C" accumulated in this loop. + // The next iteration will read these values as the "C" matrix. + int phi_index = 0; + for (const auto& [var, elem_offsets] : accum_matrices) { + Var buffer_var_mat = GetRef(var); + for (auto offset : elem_offsets) { + auto mat_phi = cooperative_matrix_phis[phi_index++]; + auto mat = builder_->GetCooperativeMatrix(buffer_var_mat, offset); + mat_phi.SetIncoming(1, mat, continue_label); + builder_->SetCooperativeMatrix(buffer_var_mat, offset, mat_phi); + } + } + // loop merge builder_->StartLabel(merge_label); } @@ -672,6 +804,9 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { size_t num_bytes = op->dtype.bytes() * op->dtype.lanes() * static_cast(constant_size); shared_memory_bytes_used_ += num_bytes; + } else if (storage_scope.rank == runtime::StorageRank::kCooperativeMatrixNV) { + this->VisitStmt(op->body); + return; } else { LOG(FATAL) << "Can only allocate shared or local memory inside kernel"; } diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index 46c9c5869c79..9dbdaa83d579 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -60,6 +60,11 @@ void IRBuilder::InitHeader() { } #endif + if (spirv_support_.supports_cooperative_matrix_nv) { + capabilities_used_.insert(spv::CapabilityCooperativeMatrixNV); + extensions_used_.insert("SPV_NV_cooperative_matrix"); + } + // memory model ib_.Begin(spv::OpMemoryModel) .AddSeq(spv::AddressingModelLogical, spv::MemoryModelGLSL450) @@ -810,6 +815,57 @@ Value IRBuilder::Select(Value cond, Value a, Value b) { return MakeValue(spv::OpSelect, a.stype, cond, a, b); } +SType IRBuilder::GetCooperativeMatrixNVType(const SType& elem_ty, int rows, int cols) { + auto key = std::make_tuple(elem_ty.id, rows, cols); + auto entry = cooperative_matrix_type_tbl_.find(key); + if (entry != cooperative_matrix_type_tbl_.end()) { + return entry->second; + } + + auto rows_spv = IntImm(t_int32_, rows); + auto cols_spv = IntImm(t_int32_, cols); + auto scope = IntImm(t_int32_, spv::Scope::ScopeSubgroup); + + SType t; + t.id = id_counter_++; + t.element_type_id = elem_ty.id; + ib_.Begin(spv::Op::OpTypeCooperativeMatrixNV) + .AddSeq(t, elem_ty, scope, rows_spv, cols_spv) + .Commit(&global_); + + cooperative_matrix_type_tbl_[key] = t; + return t; +} + +Value IRBuilder::CallCooperativeMatrixLoadNV(const SType& mat_type, Value src, Value stride, + Value column_major) { + Value val = NewValue(mat_type, kNormal); + + ib_.Begin(spv::Op::OpCooperativeMatrixLoadNV) + .AddSeq(mat_type, val, src, stride, column_major) + .Commit(&function_); + return val; +} + +void IRBuilder::CallCooperativeMatrixStoreNV(Value dst, Value mat, Value stride, + Value column_major) { + ib_.Begin(spv::Op::OpCooperativeMatrixStoreNV) + .AddSeq(dst, mat, stride, column_major) + .Commit(&function_); +} + +Value IRBuilder::CallCooperativeMatrixFillNV(const SType& mat_type, Value v) { + Value val = NewValue(mat_type, kNormal); + ib_.Begin(spv::OpCompositeConstruct).AddSeq(mat_type, val, v).Commit(&function_); + return val; +} + +Value IRBuilder::CallCooperativeMatrixMadNV(Value A, Value B, Value C) { + Value val = NewValue(C.stype, kNormal); + ib_.Begin(spv::Op::OpCooperativeMatrixMulAddNV).AddSeq(C.stype, val, A, B, C).Commit(&function_); + return val; +} + } // namespace spirv } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index d642484532f9..69934d20a7f7 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -593,6 +593,40 @@ class IRBuilder { Value GE(Value a, Value b); Value Select(Value cond, Value a, Value b); + // The VK_NV_cooperative_matrix extension related, see the documentation for the + // SPIRV SPV_NV_cooperative_matrix extension for details. + SType GetCooperativeMatrixNVType(const SType& elem_ty, int rows, int cols); + Value CallCooperativeMatrixLoadNV(const SType& mat_type, Value src, Value stride, + Value column_major); + void CallCooperativeMatrixStoreNV(Value dst, Value mat, Value stride, Value column_major); + Value CallCooperativeMatrixFillNV(const SType& mat_type, Value v); + Value CallCooperativeMatrixMadNV(Value A, Value B, Value C); + + // Helper functions for cooperative matrix support + + /*! \brief Return the pointer element type for the buffer.*/ + SType GetBufferElementType(const tir::Var& buffer) { + const auto* ptr = buffer->type_annotation.as(); + ICHECK(ptr) << "Expects a pointer type."; + const auto* prim = ptr->element_type.as(); + ICHECK(prim) << "Expects a primitive type."; + return GetSType(prim->dtype); + } + + /*! \brief Associate a TIR buffer at the provided offset with the matrix. */ + void SetCooperativeMatrix(const tir::Var& buffer_var_mat, int elem_offset, Value mat) { + auto key = std::make_pair(buffer_var_mat.get(), elem_offset); + cooperative_matrix_defs[key] = mat; + } + + /*! \brief Retrieve the matrix corresponding to the provided offset in a TIR buffer. */ + Value GetCooperativeMatrix(const tir::Var& buffer_var_mat, int elem_offset) { + auto key = std::make_pair(buffer_var_mat.get(), elem_offset); + auto entry = cooperative_matrix_defs.find(key); + ICHECK(entry != cooperative_matrix_defs.end()); + return entry->second; + } + private: /*! * \brief Create new value @@ -703,8 +737,10 @@ class IRBuilder { std::map, SType> pointer_type_tbl_; /*! \brief map from constant int to its value */ std::map, Value> const_tbl_; - /*! \brief map from name of a ExtInstImport to its value */ + /*! \brief map from name of an ExtInstImport to its value */ std::map ext_inst_tbl_; + /*! \brief map from (element-type code, rows, cols) to a cooperative matrix type */ + std::map, SType> cooperative_matrix_type_tbl_; /*! \brief Header segment * @@ -745,6 +781,8 @@ class IRBuilder { std::vector function_scope_vars_; /*! \brief Function segment */ std::vector function_; + /*! \brief map from (element-type code, rows, cols) a cooperative matrix */ + std::map, Value> cooperative_matrix_defs; }; } // namespace spirv diff --git a/src/target/spirv/spirv_support.cc b/src/target/spirv/spirv_support.cc index 81b5cd8b8a6a..02fab4dd59c4 100644 --- a/src/target/spirv/spirv_support.cc +++ b/src/target/spirv/spirv_support.cc @@ -92,6 +92,10 @@ SPIRVSupport::SPIRVSupport(tvm::Target target) { if (target->GetAttr("supports_integer_dot_product")) { supports_integer_dot_product = target->GetAttr("supports_integer_dot_product").value(); } + if (target->GetAttr("supports_cooperative_matrix_nv")) { + supports_cooperative_matrix_nv = + target->GetAttr("supports_cooperative_matrix_nv").value(); + } // Check whether integer dot product is enabled in mattr. if (const Optional>& v = target->GetAttr>("mattr")) { for (const String& s : v.value()) { diff --git a/src/target/spirv/spirv_support.h b/src/target/spirv/spirv_support.h index 6365e576b8cf..f670b1ae0833 100644 --- a/src/target/spirv/spirv_support.h +++ b/src/target/spirv/spirv_support.h @@ -276,6 +276,9 @@ struct SPIRVSupport { * attempting to perform integer dot product. */ bool supports_integer_dot_product{false}; + + /*! \brief Whether the driver supports VK_NV_cooperative_matrix extension. */ + bool supports_cooperative_matrix_nv{false}; }; } // namespace codegen diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 3a555e304cb0..3c28a48a6a2c 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -378,6 +378,7 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option("supports_push_descriptor") .add_attr_option("supports_dedicated_allocation") .add_attr_option("supports_integer_dot_product") + .add_attr_option("supports_cooperative_matrix_nv") .add_attr_option("supported_subgroup_operations") // Physical device limits .add_attr_option("max_num_threads", Integer(256)) diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index c85590428450..aa2434fc2e06 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -355,6 +355,17 @@ TIR_DEFINE_BUILTIN_FUNC(start_profile_intrinsic) TIR_DEFINE_BUILTIN_FUNC(end_profile_intrinsic) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); +TIR_DEFINE_BUILTIN_FUNC(cooperative_matrix_load_NV) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(cooperative_matrix_store_NV) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(cooperative_matrix_fill_NV) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(cooperative_matrix_mad_NV) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); } // namespace builtin } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index bfb10ca85a38..fceff2079438 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -29,6 +29,7 @@ from tvm import relay, te from tvm.topi.math import cast from tvm.script import tir as T +from tvm.tir import TensorIntrin, IntImm, Cast, Schedule dtype = tvm.testing.parameter("float32", "int32", "float16", "int8") @@ -151,10 +152,10 @@ def build_f(f_ref): a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(np.random.uniform(size=(n,))) b = tvm.nd.empty((n,), B.dtype, dev).copyfrom(np.random.uniform(size=(n,))) cs = [tvm.nd.empty((n,), A.dtype, dev) for _ in fs] - for ((f, _), c) in zip(fs, cs): + for (f, _), c in zip(fs, cs): f(a, b, c) - for ((_, ref), c) in zip(fs, cs): + for (_, ref), c in zip(fs, cs): tvm.testing.assert_allclose(c.numpy(), ref(a.numpy(), b.numpy())) ts = [threading.Thread(target=worker) for _ in range(np.random.randint(1, 10))] @@ -278,6 +279,7 @@ def test_unique(target, dev): vulkan_parameter_impl = tvm.testing.parameter("push_constants", "ubo") vulkan_parameter_dtype = tvm.testing.parameter("int32", "float32", "int64") + # Only run on vulkan because extremely large numbers of input # parameters can crash cuda/llvm compiler. @tvm.testing.parametrize_targets("vulkan -from_device=0") @@ -600,5 +602,383 @@ def func(A: T.Buffer((N, 2), "int32")): np.testing.assert_array_equal(a[:, 1], (np.arange(N) - offset) % divisor) +@pytest.mark.parametrize("out_dtype", ["float32", "float16"]) +def test_cooperative_matrix_nv(out_dtype): + @T.prim_func + def cooperative_matrix_load_desc(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=64, offset_factor=8, scope="shared") + C = T.match_buffer( + c, (16, 16), "float16", align=64, offset_factor=8, scope="cooperative_matrix_nv" + ) + + with T.block("root"): + T.reads(A[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + for i, j in T.grid(16, 16): + with T.block("load"): + vii, vjj = T.axis.remap("SS", [i, j]) + C[vii, vjj] = A[vii, vjj] + + def get_load_impl(column_major): + @T.prim_func + def cooperative_matrix_load_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + A = T.match_buffer( + a, + (16, 16), + "float16", + align=64, + offset_factor=8, + scope="shared", + strides=[s1, s0], + ) + C = T.match_buffer( + c, + (16, 16), + "float16", + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + + with T.block("root"): + T.reads(A[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + T.evaluate( + T.cooperative_matrix_load_NV( + C.data, + C.elem_offset, + A.access_ptr("r"), + 16, + 16, + s1, + column_major, + dtype="handle", + ) + ) + + return cooperative_matrix_load_impl + + def get_store_desc(out_dtype="float32", out_scope="global"): + @T.prim_func + def cooperative_matrix_store_desc(a: T.handle, c: T.handle) -> None: + A = T.match_buffer( + a, + (16, 16), + out_dtype, + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + C = T.match_buffer(c, (16, 16), out_dtype, align=64, offset_factor=8, scope=out_scope) + with T.block("root"): + T.reads(A[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + for i, j in T.grid(16, 16): + with T.block("store"): + vii, vjj = T.axis.remap("SS", [i, j]) + C[vii, vjj] = A[vii, vjj] + + return cooperative_matrix_store_desc + + def get_store_impl(out_dtype="float32", out_scope="global"): + @T.prim_func + def cooperative_matrix_store_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + A = T.match_buffer( + a, + (16, 16), + out_dtype, + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + C = T.match_buffer( + c, + (16, 16), + out_dtype, + align=64, + offset_factor=8, + scope=out_scope, + strides=[s1, s0], + ) + + with T.block("root"): + T.reads(A[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + T.evaluate( + T.cooperative_matrix_store_NV( + C.access_ptr("w"), A.data, A.elem_offset, s1, False, dtype="handle" + ) + ) + + return cooperative_matrix_store_impl + + def get_fill_desc(out_dtype="float32"): + zero = IntImm("int32", 0).astype(out_dtype) + + @T.prim_func + def cooperative_matrix_fill_desc(c: T.handle) -> None: + C = T.match_buffer( + c, + (16, 16), + out_dtype, + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + + with T.block("root"): + T.reads() + T.writes(C[0:16, 0:16]) + for i, j in T.grid(16, 16): + with T.block("init"): + vii, vjj = T.axis.remap("SS", [i, j]) + C[vii, vjj] = zero + + return cooperative_matrix_fill_desc + + def get_fill_impl(out_dtype="float32"): + zero = IntImm("int32", 0).astype(out_dtype) + + @T.prim_func + def cooperative_matrix_fill_impl(c: T.handle) -> None: + C = T.match_buffer( + c, + (16, 16), + out_dtype, + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + + with T.block("root"): + T.reads() + T.writes(C[0:16, 0:16]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + T.evaluate( + T.cooperative_matrix_fill_NV( + C.data, C.elem_offset, 16, 16, zero, dtype="handle" + ) + ) + + return cooperative_matrix_fill_impl + + def get_mad_desc(out_dtype="float32"): + def maybe_cast(v): + if out_dtype in ["float32", "int32"]: + return Cast(out_dtype, v) + return v + + @T.prim_func + def cooperative_matrix_mad_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer( + a, + (16, 16), + "float16", + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + B = T.match_buffer( + b, + (16, 16), + "float16", + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + C = T.match_buffer( + c, + (16, 16), + out_dtype, + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + + with T.block("root"): + T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) + C[vii, vjj] = C[vii, vjj] + maybe_cast(A[vii, vkk]) * maybe_cast( + B[vkk, vjj] + ) + + return cooperative_matrix_mad_desc + + def get_mad_impl(out_dtype="float32"): + @T.prim_func + def cooperative_matrix_mad_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer( + a, + (16, 16), + "float16", + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + B = T.match_buffer( + b, + (16, 16), + "float16", + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + C = T.match_buffer( + c, + (16, 16), + out_dtype, + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + + with T.block("root"): + T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + T.evaluate( + T.cooperative_matrix_mad_NV( + A.data, + A.elem_offset, + B.data, + B.elem_offset, + C.data, + C.elem_offset, + dtype="handle", + ) + ) + + return cooperative_matrix_mad_impl + + STORE_INTRIN = "cooperative_matrix_store_{}".format(out_dtype) + FILL_INTRIN = "cooperative_matrix_fill_{}".format(out_dtype) + MAD_INTRIN = "cooperative_matrix_mad_{}".format(out_dtype) + + TensorIntrin.register( + "cooperative_matrix_load", cooperative_matrix_load_desc, get_load_impl(False), override=True + ) + TensorIntrin.register( + STORE_INTRIN, + get_store_desc(out_dtype), + get_store_impl(out_dtype), + override=True, + ) + TensorIntrin.register( + FILL_INTRIN, get_fill_desc(out_dtype), get_fill_impl(out_dtype), override=True + ) + TensorIntrin.register( + MAD_INTRIN, get_mad_desc(out_dtype), get_mad_impl(out_dtype), override=True + ) + + def get_matmul(m, n, k, out_dtype="float32"): + X = te.placeholder((m, k), name="X", dtype="float16") + W = te.placeholder((k, n), name="W", dtype="float16") + ak = te.reduce_axis((0, k), name="k") + + if out_dtype == "float32": + matmul = te.compute( + (m, n), + lambda i, j: te.sum( + X[i, ak].astype("float32") * W[ak, j].astype("float32"), + axis=ak, + ), + name="compute", + ) + else: + matmul = te.compute( + (m, n), + lambda i, j: te.sum(X[i, ak] * W[ak, j], axis=ak), + name="compute", + ) + + return te.create_prim_func([X, W, matmul]) + + M, N, K = 16, 16, 32 + func = get_matmul(M, N, K, out_dtype) + sch = Schedule(func) + block = sch.get_block("compute") + + i, j, k = sch.get_loops(block) + i_outer, i_inner = sch.split(i, factors=[None, 16]) + j_outer, j_inner = sch.split(j, factors=[None, 16]) + k_outer, k_inner = sch.split(k, factors=[None, 16]) + sch.reorder(i_outer, j_outer, k_outer, i_inner, j_inner, k_inner) + fused_outer = sch.fuse(i_outer, j_outer) + sch.bind(fused_outer, "blockIdx.x") + + def fetch_to_shared(block, idx): + block_read = sch.cache_read(block, idx, "shared") + sch.compute_at(block_read, k_outer) + warp_size = 32 + + fused = sch.fuse(*sch.get_loops(block_read)[-2:]) + + vector_size = 4 + _, f_2, f_3 = sch.split(fused, factors=[None, warp_size, vector_size]) + sch.bind(f_2, "threadIdx.x") + sch.vectorize(f_3) + + def tensorize_load(block, dim, intrin): + loops = sch.get_loops(block) + i, j = loops[-dim : (len(loops) - dim + 2)] + + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + sch.unroll(i0) + sch.unroll(j0) + sch.tensorize(i1, intrin) + + fetch_to_shared(block, 0) + fetch_to_shared(block, 1) + + A_mat = sch.cache_read(block, 0, "cooperative_matrix_nv") + B_mat = sch.cache_read(block, 1, "cooperative_matrix_nv") + + tensorize_load(A_mat, 2, "cooperative_matrix_load") + tensorize_load(B_mat, 2, "cooperative_matrix_load") + + store = sch.cache_write(block, 0, "cooperative_matrix_nv") + sch.reverse_compute_at(store, fused_outer) + init = sch.decompose_reduction(block, sch.get_loops(block)[1]) + + sch.tensorize(sch.get_loops(init)[1], FILL_INTRIN) + sch.tensorize(sch.get_loops(store)[1], STORE_INTRIN) + sch.tensorize(sch.get_loops(block)[2], MAD_INTRIN) + + target = "vulkan -from_device=0" + tgt_attrs = tvm.target.Target(target).attrs + + if tgt_attrs.get("supports_cooperative_matrix_nv"): + f = tvm.build(sch.mod, target=target) + + dev = tvm.device(target, 0) + + A = tvm.nd.array(np.random.randn(M, K).astype("float16"), dev) + B = tvm.nd.array(np.random.randn(K, N).astype("float16"), dev) + C = tvm.nd.array(np.random.randn(M, N).astype(out_dtype), dev) + + f(A, B, C) + + A_np = A.numpy() + B_np = B.numpy() + ref = np.dot(A_np.astype("float32"), B_np.astype("float32")) + + tvm.testing.assert_allclose(C.numpy(), ref, rtol=1e-2, atol=1e-2) + + if __name__ == "__main__": tvm.testing.main()