From 486bdbe1c8876c44351ba4110a2cb6a4880fdc48 Mon Sep 17 00:00:00 2001 From: masa Date: Mon, 1 May 2023 19:44:04 +0900 Subject: [PATCH 01/21] add intrinsics --- include/tvm/tir/builtin.h | 5 +++++ python/tvm/script/ir_builder/tir/ir.py | 8 ++++++++ python/tvm/tir/op.py | 16 ++++++++++++++++ src/tir/op/builtin.cc | 11 +++++++++++ 4 files changed, 40 insertions(+) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index e8bcc028fc58..848cce0f788c 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -797,6 +797,11 @@ TVM_DLL const Op& start_profile_intrinsic(); */ TVM_DLL const Op& end_profile_intrinsic(); +TVM_DLL const Op& cooperative_matrix_load_NV(); +TVM_DLL const Op& cooperative_matrix_store_NV(); +TVM_DLL const Op& cooperative_matrix_fill_NV(); +TVM_DLL const Op& cooperative_matrix_mad_NV(); + /*! \brief The kind of structure field info used in intrinsic */ enum TVMStructFieldKind : int { // array head address diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index c8285ccc52ce..19477a8ee780 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1820,6 +1820,10 @@ def wrapped(*args, **kwargs): TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace) start_profile_intrinsic = _op_wrapper(_tir_op.start_profile_intrinsic) end_profile_intrinsic = _op_wrapper(_tir_op.end_profile_intrinsic) +cooperative_matrix_load_NV = _op_wrapper(_tir_op.cooperative_matrix_load_NV) +cooperative_matrix_store_NV = _op_wrapper(_tir_op.cooperative_matrix_store_NV) +cooperative_matrix_fill_NV = _op_wrapper(_tir_op.cooperative_matrix_fill_NV) +cooperative_matrix_mad_NV = _op_wrapper(_tir_op.cooperative_matrix_mad_NV) def _dtype_forward(func): @@ -2144,4 +2148,8 @@ def wrapped(*args, **kwargs): "IterVar", "CommReducer", "Range", + "cooperative_matrix_load_NV", + "cooperative_matrix_store_NV", + "cooperative_matrix_fill_NV", + "cooperative_matrix_mad_NV", ] diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 419ab2275858..2a5aa4b15681 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -3037,6 +3037,22 @@ def TVMBackendFreeWorkspace(device_type, device_id, ptr): return call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, device_id, ptr) +def cooperative_matrix_load_NV(): + return call_intrin("handle", "tir.cooperative_matrix_load_NV") + + +def cooperative_matrix_store_NV(): + return call_intrin("handle", "tir.cooperative_matrix_store_NV") + + +def cooperative_matrix_fill_NV(): + return call_intrin("handle", "tir.cooperative_matrix_fill_NV") + + +def cooperative_matrix_mad_NV(): + return call_intrin("handle", "tir.cooperative_matrix_mad_NV") + + # pylint: disable=unnecessary-lambda sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum") min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index c85590428450..aa2434fc2e06 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -355,6 +355,17 @@ TIR_DEFINE_BUILTIN_FUNC(start_profile_intrinsic) TIR_DEFINE_BUILTIN_FUNC(end_profile_intrinsic) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); +TIR_DEFINE_BUILTIN_FUNC(cooperative_matrix_load_NV) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(cooperative_matrix_store_NV) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(cooperative_matrix_fill_NV) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(cooperative_matrix_mad_NV) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); } // namespace builtin } // namespace tir } // namespace tvm From 396cc67a23d181722fc6470299281e66349f1ce4 Mon Sep 17 00:00:00 2001 From: masa Date: Mon, 1 May 2023 19:50:57 +0900 Subject: [PATCH 02/21] add parameters to intrin --- python/tvm/tir/op.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 2a5aa4b15681..b3bb8520edc4 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -3037,20 +3037,20 @@ def TVMBackendFreeWorkspace(device_type, device_id, ptr): return call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, device_id, ptr) -def cooperative_matrix_load_NV(): - return call_intrin("handle", "tir.cooperative_matrix_load_NV") +def cooperative_matrix_load_NV(mat, src, stdride, column_major): + return call_intrin("handle", "tir.cooperative_matrix_load_NV", mat, src, stdride, column_major) -def cooperative_matrix_store_NV(): - return call_intrin("handle", "tir.cooperative_matrix_store_NV") +def cooperative_matrix_store_NV(dst, mat, stride, column_major): + return call_intrin("handle", "tir.cooperative_matrix_store_NV", dst, mat, stride, column_major) -def cooperative_matrix_fill_NV(): - return call_intrin("handle", "tir.cooperative_matrix_fill_NV") +def cooperative_matrix_fill_NV(mat, v): + return call_intrin("handle", "tir.cooperative_matrix_fill_NV", mat, v) -def cooperative_matrix_mad_NV(): - return call_intrin("handle", "tir.cooperative_matrix_mad_NV") +def cooperative_matrix_mad_NV(A, B, C): + return call_intrin("handle", "tir.cooperative_matrix_mad_NV", A, B, C) # pylint: disable=unnecessary-lambda From a4c243fc1897bd2da695cfdbb783e2e08b06195f Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 2 May 2023 07:49:05 +0900 Subject: [PATCH 03/21] add storage scope --- src/runtime/thread_storage_scope.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 51dba038b6ac..95bb4e370b37 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -64,6 +64,7 @@ enum class StorageRank { kTexture = 7, /*! \brief global scope amx tmm memory */ kAMXTMM = 8, + kCooperativeMatrixNV = 9, }; /*! @@ -154,6 +155,9 @@ struct StorageScope { } else if (s.compare(0, 7, "amx.tmm") == 0) { r.rank = StorageRank::kAMXTMM; r.tag = s.substr(7, std::string::npos); + } else if (s == "cooperative_matrix_nv") { + r.rank = StorageRank::kCooperativeMatrixNV; + r.tag = ""; } else { LOG(FATAL) << "unknown storage scope " << s; } From 8da1d9fad167d86389e64cedc4636d474ca246dc Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 3 May 2023 04:19:00 +0900 Subject: [PATCH 04/21] fix --- python/tvm/tir/op.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index b3bb8520edc4..3b567ab5089d 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -3037,16 +3037,16 @@ def TVMBackendFreeWorkspace(device_type, device_id, ptr): return call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, device_id, ptr) -def cooperative_matrix_load_NV(mat, src, stdride, column_major): - return call_intrin("handle", "tir.cooperative_matrix_load_NV", mat, src, stdride, column_major) +def cooperative_matrix_load_NV(mat, offset, src, stride, column_major): + return call_intrin("handle", "tir.cooperative_matrix_load_NV", mat, offset, src, stride, column_major) -def cooperative_matrix_store_NV(dst, mat, stride, column_major): - return call_intrin("handle", "tir.cooperative_matrix_store_NV", dst, mat, stride, column_major) +def cooperative_matrix_store_NV(dst, mat, offset, stride, column_major): + return call_intrin("handle", "tir.cooperative_matrix_store_NV", dst, mat, offset, stride, column_major) -def cooperative_matrix_fill_NV(mat, v): - return call_intrin("handle", "tir.cooperative_matrix_fill_NV", mat, v) +def cooperative_matrix_fill_NV(mat, offset, v): + return call_intrin("handle", "tir.cooperative_matrix_fill_NV", mat, offset, v) def cooperative_matrix_mad_NV(A, B, C): From b39cbdea73318159147b61478b4602489f6cd0ed Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 3 May 2023 04:46:11 +0900 Subject: [PATCH 05/21] wip --- src/target/spirv/ir_builder.cc | 3 +++ src/target/spirv/ir_builder.h | 22 ++++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index 46c9c5869c79..428ffaeac9bf 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -60,6 +60,9 @@ void IRBuilder::InitHeader() { } #endif + // TODO + capabilities_used_.insert(spv::CapabilityCooperativeMatrixNV); + // memory model ib_.Begin(spv::OpMemoryModel) .AddSeq(spv::AddressingModelLogical, spv::MemoryModelGLSL450) diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index d642484532f9..889a4847e23f 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -431,6 +431,27 @@ class IRBuilder { Value CallKHRIntegerDotProduct(const SType& ret_type, const std::vector& args, const DataType& dtype); + SType GetCooperativeMatrixNVType(const SType& elem_ty, int rows, int cols) { + auto key = std::make_tuple(elem_ty.id, rows, cols); + auto entry = cooperative_matrix_type_tbl_.find(key); + if (entry != cooperative_matrix_type_tbl_.end()) { + return entry->second; + } + + auto rows_spv = IntImm(t_int32_, rows); + auto cols_spv = IntImm(t_int32_, cols); + + SType t; + t.id = id_counter_++; + t.element_type_id = elem_ty.id; + ib_.Begin(spv::Op::OpTypeCooperativeMatrixNV) + .AddSeq(t, elem_ty, spv::Scope::ScopeSubgroup, rows_spv, cols_spv) + .Commit(&global_); + + cooperative_matrix_type_tbl_[key] = t; + return t; + } + /*! * \brief Build vector by concatenating components * @@ -745,6 +766,7 @@ class IRBuilder { std::vector function_scope_vars_; /*! \brief Function segment */ std::vector function_; + std::map, SType> cooperative_matrix_type_tbl_; }; } // namespace spirv From bb9699b92d5585eddcc1d38008f660b909a5191f Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 3 May 2023 05:32:04 +0900 Subject: [PATCH 06/21] add instruction --- src/target/spirv/ir_builder.h | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index 889a4847e23f..bc27ec275ed4 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -452,6 +452,36 @@ class IRBuilder { return t; } + Value CallCooperativeMatrixLoadNV(const SType& mat_type, Value src, Value stride, + Value column_major) { + Value val = NewValue(mat_type, kNormal); + + ib_.Begin(spv::Op::OpCooperativeMatrixLoadNV) + .AddSeq(mat_type, val, src, stride, column_major) + .Commit(&function_); + return val; + } + + Value CallCooperativeMatrixLoadNV(const SType& mat_type, Value v) { + Value val = NewValue(mat_type, kNormal); + ib_.Begin(spv::OpCompositeConstruct).AddSeq(mat_type, val, v).Commit(&function_); + return val; + } + + Value CallJointMatrixMadIntel(const SType& mat_type, Value A, Value B, Value C) { + Value val = NewValue(mat_type, kNormal); + ib_.Begin(spv::Op::OpCooperativeMatrixMulAddNV) + .AddSeq(mat_type, val, A, B, C) + .Commit(&function_); + return val; + } + + void CallJointMatrixStoreIntel(Value dst, Value mat, Value stride, Value column_major) { + ib_.Begin(spv::Op::OpCooperativeMatrixStoreNV) + .AddSeq(dst, mat, stride, column_major) + .Commit(&function_); + } + /*! * \brief Build vector by concatenating components * From 9fcee17001bb2b41560356b66cd547f0b2ec06ee Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 3 May 2023 05:39:22 +0900 Subject: [PATCH 07/21] add handling of multiple coop matrices --- src/target/spirv/codegen_spirv.cc | 58 +++++++++++++++++++++++++++++-- src/target/spirv/ir_builder.h | 28 +++++++++++++++ 2 files changed, 83 insertions(+), 3 deletions(-) diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index e3ef5acb8331..b7da950ba2b4 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -542,6 +542,22 @@ void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) { } } +class AccumulatedJointMatrixCollector : public StmtExprVisitor { + public: + void VisitExpr_(const CallNode* op) final { + if (op->op.same_as(builtin::cooperative_matrix_mad_NV())) { + // TODO + LOG(FATAL) << "Not implemented"; + auto C_elem_offset = op->args[5]; + ICHECK(C_elem_offset->IsInstance()); + auto buffer_var_C = Downcast(op->args[4]); + joint_matrices[buffer_var_C.get()].insert(C_elem_offset.as()->value); + } + ExprVisitor::VisitExpr_(op); + } + std::unordered_map> joint_matrices; +}; + void CodeGenSPIRV::VisitStmt_(const ForNode* op) { ICHECK(is_zero(op->min)); analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); @@ -561,6 +577,26 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { // Loop head builder_->StartLabel(head_label); + + AccumulatedJointMatrixCollector acc_mat_collector; + + if (op->kind == ForKind::kSerial) { + acc_mat_collector(op->body); + } + + std::vector joint_matrix_phis; + for (auto [var, elem_offsets] : acc_mat_collector.joint_matrices) { + Var buffer_var_mat = GetRef(var); + auto mat_ty = builder_->GetCooperativeMatrixNVType(builder_->GetBufferElementType(buffer_var_mat), 8, 8); + for (auto offset : elem_offsets) { + spirv::PhiValue mat_phi = builder_->MakePhi(mat_ty, 2); + auto mat_def = builder_->GetJointMatrixDef(buffer_var_mat, offset); + mat_phi.SetIncoming(0, mat_def.cur_value, init_label); + joint_matrix_phis.push_back(mat_phi); + builder_->SetJointMatrixDef(buffer_var_mat, offset, mat_phi); + } + } + spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2); loop_var.SetIncoming(0, init_value, init_label); spirv::Value loop_cond = builder_->LT(loop_var, extent_value); @@ -574,15 +610,28 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { builder_->StartLabel(body_label); var_map_[op->loop_var.get()] = spirv::Value(loop_var); this->VisitStmt(op->body); + spirv::Value one = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1) + : builder_->UIntImm(loop_var.stype, 1); builder_->MakeInst(spv::OpBranch, continue_label); // loop continue builder_->StartLabel(continue_label); - spirv::Value one = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1) - : builder_->UIntImm(loop_var.stype, 1); + spirv::Value next_value = builder_->Add(loop_var, one); - loop_var.SetIncoming(1, next_value, builder_->CurrentLabel()); + loop_var.SetIncoming(1, next_value, continue_label); + + builder_->MakeInst(spv::OpBranch, head_label); + + int phi_index = 0; + for (auto [var, elem_offsets] : acc_mat_collector.joint_matrices) { + for (auto offset : elem_offsets) { + spirv::PhiValue mat_phi = joint_matrix_phis[phi_index++]; + auto mat_def = builder_->GetJointMatrixDef(GetRef(var), offset); + mat_phi.SetIncoming(1, mat_def.cur_value, continue_label); + builder_->SetJointMatrixDef(GetRef(var), offset, mat_phi); + } + } // loop merge builder_->StartLabel(merge_label); } @@ -672,6 +721,9 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { size_t num_bytes = op->dtype.bytes() * op->dtype.lanes() * static_cast(constant_size); shared_memory_bytes_used_ += num_bytes; + } else if (storage_scope.rank == runtime::StorageRank::kCooperativeMatrixNV) { + this->VisitStmt(op->body); + return; } else { LOG(FATAL) << "Can only allocate shared or local memory inside kernel"; } diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index bc27ec275ed4..79387e41ac8f 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -644,6 +644,33 @@ class IRBuilder { Value GE(Value a, Value b); Value Select(Value cond, Value a, Value b); + struct JointMatrixDef { + Value cur_value; + Label defined_label; // TODO: remove it + }; + + void SetJointMatrixDef(const tir::Var& buffer_var_mat, int alloc_id, Value mat) { + auto key = std::make_pair(buffer_var_mat.get(), alloc_id); + joint_matrix_defs[key] = JointMatrixDef{mat, curr_label_}; + } + + JointMatrixDef GetJointMatrixDef(const tir::Var& buffer_var_mat, int alloc_id) { + auto key = std::make_pair(buffer_var_mat.get(), alloc_id); + auto entry = joint_matrix_defs.find(key); + ICHECK(entry != joint_matrix_defs.end()); + return entry->second; + } + + Value GetJointMatrix(const tir::Var& buffer_var_mat, int alloc_id) { + return GetJointMatrixDef(buffer_var_mat, alloc_id).cur_value; + } + + SType GetBufferElementType(const tir::Var& buffer) { + auto* ptr = buffer->type_annotation.as(); + auto* prim = ptr->element_type.as(); + return GetSType(prim->dtype); + }; + private: /*! * \brief Create new value @@ -797,6 +824,7 @@ class IRBuilder { /*! \brief Function segment */ std::vector function_; std::map, SType> cooperative_matrix_type_tbl_; + std::map, JointMatrixDef> joint_matrix_defs; }; } // namespace spirv From 144c5536735fdea8a06a528c245f7db43e6d2095 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 3 May 2023 05:57:22 +0900 Subject: [PATCH 08/21] add load and store --- src/target/spirv/codegen_spirv.cc | 44 +++++++++++++++++++++++++++++++ src/target/spirv/ir_builder.h | 16 +++++------ 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index b7da950ba2b4..b446ea85a341 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -395,6 +395,50 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { LOG(FATAL) << "SPIR-V shader cannot make extern calls. Graph contains extern \"" << Downcast(op->args[0]) << "\""; return spirv::Value(); + } else if (op->op.same_as(builtin::address_of())) { + const BufferLoadNode* load = op->args[0].as(); + ICHECK(op->args.size() == 1 && load); + ICHECK_EQ(load->indices.size(), 1) << "CodeGenSPIRV only supports flat memory allocations."; + auto buffer_var = Downcast(load->buffer->data); + ICHECK(buffer_var.defined()); + auto it = storage_info_.find(buffer_var.get()); + ICHECK(it != storage_info_.end()); + StorageInfo& info = it->second; + spirv::SType content_type = builder_->GetSType(info.element_type); + auto buffer = MakeValue(buffer_var); + spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class); + return builder_->StructArrayAccess(ptr_type, buffer, MakeValue(load->indices[0])); + } else if (op->op.same_as(builtin::cooperative_matrix_load_NV())) { + auto ptr = Downcast(op->args[0]); + ICHECK(ptr.defined()); + auto elem_offset = op->args[1]; + spirv::Value src_ptr = MakeValue(op->args[2]); + auto stride = MakeValue(op->args[3]); + + auto column_major = MakeValue(op->args[4]); + + // todo + auto mat_ty = builder_->GetCooperativeMatrixNVType(builder_->GetBufferElementType(ptr), + 16, 16); + auto mat = builder_->CallCooperativeMatrixLoadNV(mat_ty, src_ptr, stride, column_major); + ICHECK(elem_offset->IsInstance()); + builder_->SetJointMatrixDef(ptr, elem_offset.as()->value, mat); + return mat; + } else if (op->op.same_as(builtin::cooperative_matrix_store_NV())) { + auto buffer_var_mat = Downcast(op->args[1]); + ICHECK(buffer_var_mat.defined()); + + spirv::Value dst_ptr = MakeValue(op->args[0]); + + auto elem_offset = op->args[2]; + ICHECK(elem_offset->IsInstance()); + auto mat = builder_->GetJointMatrix(buffer_var_mat, elem_offset.as()->value); + spirv::Value stride = MakeValue(op->args[3]); + spirv::Value column_major = MakeValue(op->args[4]); + builder_->CallCooperativeMatrixStoreNV(dst_ptr, mat, stride, column_major); + return spirv::Value(); + } else if (op->op.same_as(builtin::cooperative_matrix_fill_NV())) { + } else if (op->op.same_as(builtin::cooperative_matrix_mad_NV())) { } else { LOG(FATAL) << "Unresolved call " << op->op; } diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index 79387e41ac8f..9efe4224078e 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -462,13 +462,19 @@ class IRBuilder { return val; } - Value CallCooperativeMatrixLoadNV(const SType& mat_type, Value v) { + void CallCooperativeMatrixStoreNV(Value dst, Value mat, Value stride, Value column_major) { + ib_.Begin(spv::Op::OpCooperativeMatrixStoreNV) + .AddSeq(dst, mat, stride, column_major) + .Commit(&function_); + } + + Value CallCooperativeMatrixFillNV(const SType& mat_type, Value v) { Value val = NewValue(mat_type, kNormal); ib_.Begin(spv::OpCompositeConstruct).AddSeq(mat_type, val, v).Commit(&function_); return val; } - Value CallJointMatrixMadIntel(const SType& mat_type, Value A, Value B, Value C) { + Value CallCooperativeMatrixMadNV(const SType& mat_type, Value A, Value B, Value C) { Value val = NewValue(mat_type, kNormal); ib_.Begin(spv::Op::OpCooperativeMatrixMulAddNV) .AddSeq(mat_type, val, A, B, C) @@ -476,12 +482,6 @@ class IRBuilder { return val; } - void CallJointMatrixStoreIntel(Value dst, Value mat, Value stride, Value column_major) { - ib_.Begin(spv::Op::OpCooperativeMatrixStoreNV) - .AddSeq(dst, mat, stride, column_major) - .Commit(&function_); - } - /*! * \brief Build vector by concatenating components * From ce1cf2799616aae07f5b2300d3ca794756e795a1 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 3 May 2023 05:59:33 +0900 Subject: [PATCH 09/21] add fill and store --- src/target/spirv/codegen_spirv.cc | 36 +++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index b446ea85a341..fc91119a4495 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -438,7 +438,43 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { builder_->CallCooperativeMatrixStoreNV(dst_ptr, mat, stride, column_major); return spirv::Value(); } else if (op->op.same_as(builtin::cooperative_matrix_fill_NV())) { + auto ptr = Downcast(op->args[0]); + ICHECK(ptr.defined()); + auto mat_ty = + builder_->GetCooperativeMatrixNVType(builder_->GetBufferElementType(ptr), 16, 16); + auto filled = builder_->CallCooperativeMatrixFillNV(mat_ty, MakeValue(op->args[2])); + + auto elem_offset = op->args[1]; + ICHECK(elem_offset->IsInstance()); + builder_->SetJointMatrixDef(ptr, elem_offset.as()->value, filled); + return filled; } else if (op->op.same_as(builtin::cooperative_matrix_mad_NV())) { + auto A_elem_offset = op->args[1]; + ICHECK(A_elem_offset->IsInstance()); + auto B_elem_offset = op->args[3]; + ICHECK(B_elem_offset->IsInstance()); + auto C_elem_offset = op->args[5]; + ICHECK(C_elem_offset->IsInstance()); + + auto C_ptr = Downcast(op->args[4]); + ICHECK(C_ptr.defined()); + auto mat_ty = + builder_->GetCooperativeMatrixNVType(builder_->GetBufferElementType(C_ptr), 16, 16); + + auto get_matrix = [this](PrimExpr arg, int offset) { + auto buffer_var_mat = Downcast(arg); + ICHECK(buffer_var_mat.defined()); + return builder_->GetJointMatrix(buffer_var_mat, offset); + }; + + auto A = get_matrix(op->args[0], A_elem_offset.as()->value); + auto B = get_matrix(op->args[2], B_elem_offset.as()->value); + auto c_offset = C_elem_offset.as()->value; + auto C = get_matrix(op->args[4], c_offset); + + auto acc = builder_->CallCooperativeMatrixMadNV(mat_ty, A, B, C); + builder_->SetJointMatrixDef(C_ptr, c_offset, acc); + return acc; } else { LOG(FATAL) << "Unresolved call " << op->op; } From 2f32d56225c917c69995291b79f4d3a29c2f85ce Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 3 May 2023 06:16:17 +0900 Subject: [PATCH 10/21] build and validate work --- python/tvm/tir/op.py | 4 ++-- src/target/spirv/codegen_spirv.cc | 2 -- src/target/spirv/ir_builder.cc | 1 + src/target/spirv/ir_builder.h | 5 +++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 3b567ab5089d..bcd137987db1 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -3049,8 +3049,8 @@ def cooperative_matrix_fill_NV(mat, offset, v): return call_intrin("handle", "tir.cooperative_matrix_fill_NV", mat, offset, v) -def cooperative_matrix_mad_NV(A, B, C): - return call_intrin("handle", "tir.cooperative_matrix_mad_NV", A, B, C) +def cooperative_matrix_mad_NV(A, A_off, B, B_off, C, C_off): + return call_intrin("handle", "tir.cooperative_matrix_mad_NV", A, A_off, B, B_off, C, C_off) # pylint: disable=unnecessary-lambda diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index fc91119a4495..1be5c5916d09 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -626,8 +626,6 @@ class AccumulatedJointMatrixCollector : public StmtExprVisitor { public: void VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::cooperative_matrix_mad_NV())) { - // TODO - LOG(FATAL) << "Not implemented"; auto C_elem_offset = op->args[5]; ICHECK(C_elem_offset->IsInstance()); auto buffer_var_C = Downcast(op->args[4]); diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index 428ffaeac9bf..0e87efb5b11a 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -62,6 +62,7 @@ void IRBuilder::InitHeader() { // TODO capabilities_used_.insert(spv::CapabilityCooperativeMatrixNV); + extensions_used_.insert("SPV_NV_cooperative_matrix"); // memory model ib_.Begin(spv::OpMemoryModel) diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index 9efe4224078e..fa2972c48b9c 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -440,13 +440,14 @@ class IRBuilder { auto rows_spv = IntImm(t_int32_, rows); auto cols_spv = IntImm(t_int32_, cols); + auto scope = IntImm(t_int32_, spv::Scope::ScopeSubgroup); SType t; t.id = id_counter_++; t.element_type_id = elem_ty.id; ib_.Begin(spv::Op::OpTypeCooperativeMatrixNV) - .AddSeq(t, elem_ty, spv::Scope::ScopeSubgroup, rows_spv, cols_spv) - .Commit(&global_); + .AddSeq(t, elem_ty, scope, rows_spv, cols_spv) + .Commit(&global_); cooperative_matrix_type_tbl_[key] = t; return t; From e9504e3828531486f1c619ffb6ca7467a04dca62 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 4 May 2023 05:17:04 +0900 Subject: [PATCH 11/21] pass rows and cols to load and fill --- python/tvm/tir/op.py | 8 ++++---- src/target/spirv/codegen_spirv.cc | 24 +++++++++++++++--------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index bcd137987db1..bf7bcbfea9a5 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -3037,16 +3037,16 @@ def TVMBackendFreeWorkspace(device_type, device_id, ptr): return call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, device_id, ptr) -def cooperative_matrix_load_NV(mat, offset, src, stride, column_major): - return call_intrin("handle", "tir.cooperative_matrix_load_NV", mat, offset, src, stride, column_major) +def cooperative_matrix_load_NV(mat, offset, src, rows, cols, stride, column_major): + return call_intrin("handle", "tir.cooperative_matrix_load_NV", mat, offset, src, rows, cols, stride, column_major) def cooperative_matrix_store_NV(dst, mat, offset, stride, column_major): return call_intrin("handle", "tir.cooperative_matrix_store_NV", dst, mat, offset, stride, column_major) -def cooperative_matrix_fill_NV(mat, offset, v): - return call_intrin("handle", "tir.cooperative_matrix_fill_NV", mat, offset, v) +def cooperative_matrix_fill_NV(mat, offset, rows, cols, v): + return call_intrin("handle", "tir.cooperative_matrix_fill_NV", mat, offset, rows, cols, v) def cooperative_matrix_mad_NV(A, A_off, B, B_off, C, C_off): diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 1be5c5916d09..be5366150ea7 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -413,13 +413,14 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { ICHECK(ptr.defined()); auto elem_offset = op->args[1]; spirv::Value src_ptr = MakeValue(op->args[2]); - auto stride = MakeValue(op->args[3]); + int rows = op->args[3].as()->value; + int cols = op->args[4].as()->value; + auto stride = MakeValue(op->args[5]); - auto column_major = MakeValue(op->args[4]); + auto column_major = MakeValue(op->args[6]); - // todo auto mat_ty = builder_->GetCooperativeMatrixNVType(builder_->GetBufferElementType(ptr), - 16, 16); + rows, cols); auto mat = builder_->CallCooperativeMatrixLoadNV(mat_ty, src_ptr, stride, column_major); ICHECK(elem_offset->IsInstance()); builder_->SetJointMatrixDef(ptr, elem_offset.as()->value, mat); @@ -440,12 +441,17 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { } else if (op->op.same_as(builtin::cooperative_matrix_fill_NV())) { auto ptr = Downcast(op->args[0]); ICHECK(ptr.defined()); - auto mat_ty = - builder_->GetCooperativeMatrixNVType(builder_->GetBufferElementType(ptr), 16, 16); - auto filled = builder_->CallCooperativeMatrixFillNV(mat_ty, MakeValue(op->args[2])); - auto elem_offset = op->args[1]; ICHECK(elem_offset->IsInstance()); + + int rows = op->args[2].as()->value; + int cols = op->args[3].as()->value; + auto v = MakeValue(op->args[4]); + + auto mat_ty = + builder_->GetCooperativeMatrixNVType(builder_->GetBufferElementType(ptr), rows, cols); + auto filled = builder_->CallCooperativeMatrixFillNV(mat_ty, v); + builder_->SetJointMatrixDef(ptr, elem_offset.as()->value, filled); return filled; } else if (op->op.same_as(builtin::cooperative_matrix_mad_NV())) { @@ -665,7 +671,7 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { std::vector joint_matrix_phis; for (auto [var, elem_offsets] : acc_mat_collector.joint_matrices) { Var buffer_var_mat = GetRef(var); - auto mat_ty = builder_->GetCooperativeMatrixNVType(builder_->GetBufferElementType(buffer_var_mat), 8, 8); + auto mat_ty = builder_->GetCooperativeMatrixNVType(builder_->GetBufferElementType(buffer_var_mat), 16, 16); for (auto offset : elem_offsets) { spirv::PhiValue mat_phi = builder_->MakePhi(mat_ty, 2); auto mat_def = builder_->GetJointMatrixDef(buffer_var_mat, offset); From 123efdf912a2d48f8e6a80752813327c0fd81c0c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 4 May 2023 13:53:26 +0900 Subject: [PATCH 12/21] add test --- .../unittest/test_target_codegen_vulkan.py | 380 +++++++++++++++++- 1 file changed, 378 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index bfb10ca85a38..4638f13190fd 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -29,6 +29,7 @@ from tvm import relay, te from tvm.topi.math import cast from tvm.script import tir as T +from tvm.tir import TensorIntrin, IntImm, Cast, Schedule dtype = tvm.testing.parameter("float32", "int32", "float16", "int8") @@ -151,10 +152,10 @@ def build_f(f_ref): a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(np.random.uniform(size=(n,))) b = tvm.nd.empty((n,), B.dtype, dev).copyfrom(np.random.uniform(size=(n,))) cs = [tvm.nd.empty((n,), A.dtype, dev) for _ in fs] - for ((f, _), c) in zip(fs, cs): + for (f, _), c in zip(fs, cs): f(a, b, c) - for ((_, ref), c) in zip(fs, cs): + for (_, ref), c in zip(fs, cs): tvm.testing.assert_allclose(c.numpy(), ref(a.numpy(), b.numpy())) ts = [threading.Thread(target=worker) for _ in range(np.random.randint(1, 10))] @@ -278,6 +279,7 @@ def test_unique(target, dev): vulkan_parameter_impl = tvm.testing.parameter("push_constants", "ubo") vulkan_parameter_dtype = tvm.testing.parameter("int32", "float32", "int64") + # Only run on vulkan because extremely large numbers of input # parameters can crash cuda/llvm compiler. @tvm.testing.parametrize_targets("vulkan -from_device=0") @@ -600,5 +602,379 @@ def func(A: T.Buffer((N, 2), "int32")): np.testing.assert_array_equal(a[:, 1], (np.arange(N) - offset) % divisor) +@T.prim_func +def cooperative_matrix_load_desc(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=64, offset_factor=8, scope="shared") + C = T.match_buffer( + c, (16, 16), "float16", align=64, offset_factor=8, scope="cooperative_matrix_nv" + ) + + with T.block("root"): + T.reads(A[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + for i, j in T.grid(16, 16): + with T.block("load"): + vii, vjj = T.axis.remap("SS", [i, j]) + C[vii, vjj] = A[vii, vjj] + + +def get_load_impl(column_major): + @T.prim_func + def cooperative_matrix_load_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + A = T.match_buffer( + a, + (16, 16), + "float16", + align=64, + offset_factor=8, + scope="shared", + strides=[s1, s0], + ) + C = T.match_buffer( + c, + (16, 16), + "float16", + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + + with T.block("root"): + T.reads(A[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + T.evaluate( + T.cooperative_matrix_load_NV( + C.data, + C.elem_offset, + A.access_ptr("r"), + 16, + 16, + s1, + column_major, + dtype="handle", + ) + ) + + return cooperative_matrix_load_impl + + +def get_store_desc(out_dtype="float32", out_scope="global"): + @T.prim_func + def cooperative_matrix_store_desc(a: T.handle, c: T.handle) -> None: + A = T.match_buffer( + a, + (16, 16), + out_dtype, + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + C = T.match_buffer(c, (16, 16), out_dtype, align=64, offset_factor=8, scope=out_scope) + with T.block("root"): + T.reads(A[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + for i, j in T.grid(16, 16): + with T.block("store"): + vii, vjj = T.axis.remap("SS", [i, j]) + C[vii, vjj] = A[vii, vjj] + + return cooperative_matrix_store_desc + + +def get_store_impl(out_dtype="float32", out_scope="global"): + @T.prim_func + def cooperative_matrix_store_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + A = T.match_buffer( + a, + (16, 16), + out_dtype, + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + C = T.match_buffer( + c, + (16, 16), + out_dtype, + align=64, + offset_factor=8, + scope=out_scope, + strides=[s1, s0], + ) + + with T.block("root"): + T.reads(A[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + T.evaluate( + T.cooperative_matrix_store_NV( + C.access_ptr("w"), A.data, A.elem_offset, s1, False, dtype="handle" + ) + ) + + return cooperative_matrix_store_impl + + +def get_fill_desc(out_dtype="float32"): + zero = IntImm("int32", 0).astype(out_dtype) + + @T.prim_func + def cooperative_matrix_fill_desc(c: T.handle) -> None: + C = T.match_buffer( + c, + (16, 16), + out_dtype, + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + + with T.block("root"): + T.reads() + T.writes(C[0:16, 0:16]) + for i, j in T.grid(16, 16): + with T.block("init"): + vii, vjj = T.axis.remap("SS", [i, j]) + C[vii, vjj] = zero + + return cooperative_matrix_fill_desc + + +def get_fill_impl(out_dtype="float32"): + zero = IntImm("int32", 0).astype(out_dtype) + + @T.prim_func + def cooperative_matrix_fill_impl(c: T.handle) -> None: + C = T.match_buffer( + c, + (16, 16), + out_dtype, + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + + with T.block("root"): + T.reads() + T.writes(C[0:16, 0:16]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + T.evaluate( + T.cooperative_matrix_fill_NV(C.data, C.elem_offset, 16, 16, zero, dtype="handle") + ) + + return cooperative_matrix_fill_impl + + +def get_mad_desc(out_dtype="float32"): + def maybe_cast(v): + if out_dtype in ["float32", "int32"]: + return Cast(out_dtype, v) + return v + + @T.prim_func + def cooperative_matrix_mad_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer( + a, + (16, 16), + "float16", + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + B = T.match_buffer( + b, + (16, 16), + "float16", + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + C = T.match_buffer( + c, + (16, 16), + out_dtype, + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + + with T.block("root"): + T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) + C[vii, vjj] = C[vii, vjj] + maybe_cast(A[vii, vkk]) * maybe_cast(B[vkk, vjj]) + + return cooperative_matrix_mad_desc + + +def get_mad_impl(out_dtype="float32"): + @T.prim_func + def cooperative_matrix_mad_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer( + a, + (16, 16), + "float16", + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + B = T.match_buffer( + b, + (16, 16), + "float16", + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + C = T.match_buffer( + c, + (16, 16), + out_dtype, + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + + with T.block("root"): + T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + T.evaluate( + T.cooperative_matrix_mad_NV( + A.data, + A.elem_offset, + B.data, + B.elem_offset, + C.data, + C.elem_offset, + dtype="handle", + ) + ) + + return cooperative_matrix_mad_impl + + +TensorIntrin.register("cooperative_matrix_load", cooperative_matrix_load_desc, get_load_impl(False)) + + +@pytest.mark.parametrize("out_dtype", ["float32", "float16"]) +def test_cooperative_matrix_nv(out_dtype): + STORE_INTRIN = "cooperative_matrix_store_{}".format(out_dtype) + FILL_INTRIN = "cooperative_matrix_fill_{}".format(out_dtype) + MAD_INTRIN = "cooperative_matrix_mad_{}".format(out_dtype) + + TensorIntrin.register( + STORE_INTRIN, + get_store_desc(out_dtype), + get_store_impl(out_dtype), + ) + TensorIntrin.register(FILL_INTRIN, get_fill_desc(out_dtype), get_fill_impl(out_dtype)) + TensorIntrin.register(MAD_INTRIN, get_mad_desc(out_dtype), get_mad_impl(out_dtype)) + + def get_matmul(m, n, k, out_dtype="float32"): + X = te.placeholder((m, k), name="X", dtype="float16") + W = te.placeholder((k, n), name="W", dtype="float16") + ak = te.reduce_axis((0, k), name="k") + + if out_dtype == "float32": + matmul = te.compute( + (m, n), + lambda i, j: te.sum( + X[i, ak].astype("float32") * W[ak, j].astype("float32"), + axis=ak, + ), + name="compute", + ) + else: + matmul = te.compute( + (m, n), + lambda i, j: te.sum(X[i, ak] * W[ak, j], axis=ak), + name="compute", + ) + + return te.create_prim_func([X, W, matmul]) + + M, N, K = 16, 16, 32 + func = get_matmul(M, N, K, out_dtype) + sch = Schedule(func) + block = sch.get_block("compute") + + i, j, k = sch.get_loops(block) + i_outer, i_inner = sch.split(i, factors=[None, 16]) + j_outer, j_inner = sch.split(j, factors=[None, 16]) + k_outer, k_inner = sch.split(k, factors=[None, 16]) + sch.reorder(i_outer, j_outer, k_outer, i_inner, j_inner, k_inner) + fused_outer = sch.fuse(i_outer, j_outer) + sch.bind(fused_outer, "blockIdx.x") + + def fetch_to_shared(block, idx): + block_read = sch.cache_read(block, idx, "shared") + sch.compute_at(block_read, k_outer) + warp_size = 32 + + fused = sch.fuse(*sch.get_loops(block_read)[-2:]) + + vector_size = 4 + _, f_2, f_3 = sch.split(fused, factors=[None, warp_size, vector_size]) + sch.bind(f_2, "threadIdx.x") + sch.vectorize(f_3) + + def tensorize_load(block, dim, intrin): + loops = sch.get_loops(block) + i, j = loops[-dim : (len(loops) - dim + 2)] + + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + sch.unroll(i0) + sch.unroll(j0) + sch.tensorize(i1, intrin) + + fetch_to_shared(block, 0) + fetch_to_shared(block, 1) + + A_mat = sch.cache_read(block, 0, "cooperative_matrix_nv") + B_mat = sch.cache_read(block, 1, "cooperative_matrix_nv") + + tensorize_load(A_mat, 2, "cooperative_matrix_load") + tensorize_load(B_mat, 2, "cooperative_matrix_load") + + store = sch.cache_write(block, 0, "cooperative_matrix_nv") + sch.reverse_compute_at(store, fused_outer) + init = sch.decompose_reduction(block, sch.get_loops(block)[1]) + + sch.tensorize(sch.get_loops(init)[1], FILL_INTRIN) + sch.tensorize(sch.get_loops(store)[1], STORE_INTRIN) + sch.tensorize(sch.get_loops(block)[2], MAD_INTRIN) + + target = "vulkan -from_device=0" + f = tvm.build(sch.mod, target=target) + + dev = tvm.device(target, 0) + + A = tvm.nd.array(np.random.randn(M, K).astype("float16"), dev) + B = tvm.nd.array(np.random.randn(K, N).astype("float16"), dev) + C = tvm.nd.array(np.random.randn(M, N).astype(out_dtype), dev) + + f(A, B, C) + + A_np = A.numpy() + B_np = B.numpy() + ref = np.dot(A_np.astype("float32"), B_np.astype("float32")) + + tvm.testing.assert_allclose(C.numpy(), ref, rtol=1e-2, atol=1e-2) + + if __name__ == "__main__": tvm.testing.main() From 8f0cc1d014813ff447e5e92fbb88c035be1e8f20 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 4 May 2023 20:33:13 +0900 Subject: [PATCH 13/21] enable cooperative matrix from target attr --- src/runtime/vulkan/vulkan_device.cc | 4 +++- src/runtime/vulkan/vulkan_device.h | 1 + src/runtime/vulkan/vulkan_device_api.cc | 4 ++++ src/target/spirv/codegen_spirv.cc | 10 +++++----- src/target/spirv/ir_builder.cc | 7 ++++--- src/target/spirv/ir_builder.h | 6 +++--- src/target/spirv/spirv_support.cc | 4 ++++ src/target/spirv/spirv_support.h | 3 +++ src/target/target_kind.cc | 1 + 9 files changed, 28 insertions(+), 12 deletions(-) diff --git a/src/runtime/vulkan/vulkan_device.cc b/src/runtime/vulkan/vulkan_device.cc index b3e017d03418..51120f5ac30d 100644 --- a/src/runtime/vulkan/vulkan_device.cc +++ b/src/runtime/vulkan/vulkan_device.cc @@ -133,6 +133,7 @@ VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, !support::BoolEnvironmentVar("TVM_VULKAN_DISABLE_DEDICATED_ALLOCATION"); supports_integer_dot_product = device.HasExtension("VK_KHR_shader_integer_dot_product"); + supports_cooperative_matrix_nv = device.HasExtension("VK_NV_cooperative_matrix"); // The check of VK_SHADER_STAGE_COMPUTE_BIT isn't technically // needed, since it will be set so long at least one queue has @@ -435,7 +436,8 @@ std::vector VulkanDevice::SelectEnabledExtensions() const { "VK_KHR_get_memory_requirements2", "VK_KHR_dedicated_allocation", "VK_KHR_spirv_1_4", - "VK_KHR_shader_integer_dot_product"}; + "VK_KHR_shader_integer_dot_product", + "VK_NV_cooperative_matrix"}; uint32_t device_extension_prop_count; VULKAN_CALL(vkEnumerateDeviceExtensionProperties(physical_device_, nullptr, diff --git a/src/runtime/vulkan/vulkan_device.h b/src/runtime/vulkan/vulkan_device.h index 59ebf430e6e6..3ba6de40d720 100644 --- a/src/runtime/vulkan/vulkan_device.h +++ b/src/runtime/vulkan/vulkan_device.h @@ -88,6 +88,7 @@ struct VulkanDeviceProperties { bool supports_push_descriptor{false}; bool supports_dedicated_allocation{false}; bool supports_integer_dot_product{false}; + bool supports_cooperative_matrix_nv{false}; uint32_t supported_subgroup_operations{0}; uint32_t max_num_threads{1}; uint32_t thread_warp_size{1}; diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index 93f017a5aa66..f2ce95b42d35 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -241,6 +241,10 @@ void VulkanDeviceAPI::GetTargetProperty(Device dev, const std::string& property, *rv = prop.supports_integer_dot_product; } + if (property == "supports_cooperative_matrix_nv") { + *rv = prop.supports_cooperative_matrix_nv; + } + if (property == "device_name") { *rv = prop.device_name; } diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index be5366150ea7..279edae14e27 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -419,8 +419,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { auto column_major = MakeValue(op->args[6]); - auto mat_ty = builder_->GetCooperativeMatrixNVType(builder_->GetBufferElementType(ptr), - rows, cols); + auto mat_ty = + builder_->GetCooperativeMatrixNVType(builder_->GetBufferElementType(ptr), rows, cols); auto mat = builder_->CallCooperativeMatrixLoadNV(mat_ty, src_ptr, stride, column_major); ICHECK(elem_offset->IsInstance()); builder_->SetJointMatrixDef(ptr, elem_offset.as()->value, mat); @@ -465,7 +465,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { auto C_ptr = Downcast(op->args[4]); ICHECK(C_ptr.defined()); auto mat_ty = - builder_->GetCooperativeMatrixNVType(builder_->GetBufferElementType(C_ptr), 16, 16); + builder_->GetCooperativeMatrixNVType(builder_->GetBufferElementType(C_ptr), 16, 16); auto get_matrix = [this](PrimExpr arg, int offset) { auto buffer_var_mat = Downcast(arg); @@ -671,7 +671,8 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { std::vector joint_matrix_phis; for (auto [var, elem_offsets] : acc_mat_collector.joint_matrices) { Var buffer_var_mat = GetRef(var); - auto mat_ty = builder_->GetCooperativeMatrixNVType(builder_->GetBufferElementType(buffer_var_mat), 16, 16); + auto mat_ty = builder_->GetCooperativeMatrixNVType( + builder_->GetBufferElementType(buffer_var_mat), 16, 16); for (auto offset : elem_offsets) { spirv::PhiValue mat_phi = builder_->MakePhi(mat_ty, 2); auto mat_def = builder_->GetJointMatrixDef(buffer_var_mat, offset); @@ -704,7 +705,6 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { spirv::Value next_value = builder_->Add(loop_var, one); loop_var.SetIncoming(1, next_value, continue_label); - builder_->MakeInst(spv::OpBranch, head_label); int phi_index = 0; diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index 0e87efb5b11a..dd9e8fb33bb4 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -60,9 +60,10 @@ void IRBuilder::InitHeader() { } #endif - // TODO - capabilities_used_.insert(spv::CapabilityCooperativeMatrixNV); - extensions_used_.insert("SPV_NV_cooperative_matrix"); + if (spirv_support_.supports_cooperative_matrix_nv) { + capabilities_used_.insert(spv::CapabilityCooperativeMatrixNV); + extensions_used_.insert("SPV_NV_cooperative_matrix"); + } // memory model ib_.Begin(spv::OpMemoryModel) diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index fa2972c48b9c..7e0e0dcdd774 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -446,8 +446,8 @@ class IRBuilder { t.id = id_counter_++; t.element_type_id = elem_ty.id; ib_.Begin(spv::Op::OpTypeCooperativeMatrixNV) - .AddSeq(t, elem_ty, scope, rows_spv, cols_spv) - .Commit(&global_); + .AddSeq(t, elem_ty, scope, rows_spv, cols_spv) + .Commit(&global_); cooperative_matrix_type_tbl_[key] = t; return t; @@ -647,7 +647,7 @@ class IRBuilder { struct JointMatrixDef { Value cur_value; - Label defined_label; // TODO: remove it + Label defined_label; // TODO: remove it }; void SetJointMatrixDef(const tir::Var& buffer_var_mat, int alloc_id, Value mat) { diff --git a/src/target/spirv/spirv_support.cc b/src/target/spirv/spirv_support.cc index 81b5cd8b8a6a..02fab4dd59c4 100644 --- a/src/target/spirv/spirv_support.cc +++ b/src/target/spirv/spirv_support.cc @@ -92,6 +92,10 @@ SPIRVSupport::SPIRVSupport(tvm::Target target) { if (target->GetAttr("supports_integer_dot_product")) { supports_integer_dot_product = target->GetAttr("supports_integer_dot_product").value(); } + if (target->GetAttr("supports_cooperative_matrix_nv")) { + supports_cooperative_matrix_nv = + target->GetAttr("supports_cooperative_matrix_nv").value(); + } // Check whether integer dot product is enabled in mattr. if (const Optional>& v = target->GetAttr>("mattr")) { for (const String& s : v.value()) { diff --git a/src/target/spirv/spirv_support.h b/src/target/spirv/spirv_support.h index 6365e576b8cf..5aef7d1344ff 100644 --- a/src/target/spirv/spirv_support.h +++ b/src/target/spirv/spirv_support.h @@ -276,6 +276,9 @@ struct SPIRVSupport { * attempting to perform integer dot product. */ bool supports_integer_dot_product{false}; + + /*! \brief Whether the driver supports VK_NV_cooperative_matrix extention. */ + bool supports_cooperative_matrix_nv{false}; }; } // namespace codegen diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 3a555e304cb0..3c28a48a6a2c 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -378,6 +378,7 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option("supports_push_descriptor") .add_attr_option("supports_dedicated_allocation") .add_attr_option("supports_integer_dot_product") + .add_attr_option("supports_cooperative_matrix_nv") .add_attr_option("supported_subgroup_operations") // Physical device limits .add_attr_option("max_num_threads", Integer(256)) From 35d914ebc62328803c0a390dac16accdca0926a5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 4 May 2023 20:46:20 +0900 Subject: [PATCH 14/21] clean --- src/target/spirv/codegen_spirv.cc | 144 ++++++++---------- src/target/spirv/ir_builder.cc | 51 +++++++ src/target/spirv/ir_builder.h | 101 ++++-------- .../unittest/test_target_codegen_vulkan.py | 22 +-- 4 files changed, 149 insertions(+), 169 deletions(-) diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 279edae14e27..2fd7b5463d46 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -400,7 +400,6 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { ICHECK(op->args.size() == 1 && load); ICHECK_EQ(load->indices.size(), 1) << "CodeGenSPIRV only supports flat memory allocations."; auto buffer_var = Downcast(load->buffer->data); - ICHECK(buffer_var.defined()); auto it = storage_info_.find(buffer_var.get()); ICHECK(it != storage_info_.end()); StorageInfo& info = it->second; @@ -410,76 +409,59 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { return builder_->StructArrayAccess(ptr_type, buffer, MakeValue(load->indices[0])); } else if (op->op.same_as(builtin::cooperative_matrix_load_NV())) { auto ptr = Downcast(op->args[0]); - ICHECK(ptr.defined()); - auto elem_offset = op->args[1]; - spirv::Value src_ptr = MakeValue(op->args[2]); - int rows = op->args[3].as()->value; - int cols = op->args[4].as()->value; + auto elem_offset = op->args[1].as(); + auto src_ptr = MakeValue(op->args[2]); + const int rows = op->args[3].as()->value; + const int cols = op->args[4].as()->value; auto stride = MakeValue(op->args[5]); - auto column_major = MakeValue(op->args[6]); - - auto mat_ty = - builder_->GetCooperativeMatrixNVType(builder_->GetBufferElementType(ptr), rows, cols); + auto elem_ty = builder_->GetBufferElementType(ptr); + auto mat_ty = builder_->GetCooperativeMatrixNVType(elem_ty, rows, cols); auto mat = builder_->CallCooperativeMatrixLoadNV(mat_ty, src_ptr, stride, column_major); - ICHECK(elem_offset->IsInstance()); - builder_->SetJointMatrixDef(ptr, elem_offset.as()->value, mat); + ICHECK(elem_offset) << "Expects a constant element offset."; + builder_->SetCooperativeMatrix(ptr, elem_offset->value, mat); return mat; } else if (op->op.same_as(builtin::cooperative_matrix_store_NV())) { auto buffer_var_mat = Downcast(op->args[1]); - ICHECK(buffer_var_mat.defined()); - - spirv::Value dst_ptr = MakeValue(op->args[0]); - - auto elem_offset = op->args[2]; - ICHECK(elem_offset->IsInstance()); - auto mat = builder_->GetJointMatrix(buffer_var_mat, elem_offset.as()->value); - spirv::Value stride = MakeValue(op->args[3]); - spirv::Value column_major = MakeValue(op->args[4]); + auto dst_ptr = MakeValue(op->args[0]); + auto elem_offset = op->args[2].as(); + ICHECK(elem_offset) << "Expects a constant element offset."; + auto mat = builder_->GetCooperativeMatrix(buffer_var_mat, elem_offset->value); + auto stride = MakeValue(op->args[3]); + auto column_major = MakeValue(op->args[4]); builder_->CallCooperativeMatrixStoreNV(dst_ptr, mat, stride, column_major); return spirv::Value(); } else if (op->op.same_as(builtin::cooperative_matrix_fill_NV())) { auto ptr = Downcast(op->args[0]); - ICHECK(ptr.defined()); - auto elem_offset = op->args[1]; - ICHECK(elem_offset->IsInstance()); - - int rows = op->args[2].as()->value; - int cols = op->args[3].as()->value; + auto elem_offset = op->args[1].as(); + ICHECK(elem_offset) << "Expects a constant element offset."; + const int rows = op->args[2].as()->value; + const int cols = op->args[3].as()->value; auto v = MakeValue(op->args[4]); - - auto mat_ty = - builder_->GetCooperativeMatrixNVType(builder_->GetBufferElementType(ptr), rows, cols); + auto elem_ty = builder_->GetBufferElementType(ptr); + auto mat_ty = builder_->GetCooperativeMatrixNVType(elem_ty, rows, cols); auto filled = builder_->CallCooperativeMatrixFillNV(mat_ty, v); - - builder_->SetJointMatrixDef(ptr, elem_offset.as()->value, filled); + builder_->SetCooperativeMatrix(ptr, elem_offset->value, filled); return filled; } else if (op->op.same_as(builtin::cooperative_matrix_mad_NV())) { - auto A_elem_offset = op->args[1]; - ICHECK(A_elem_offset->IsInstance()); - auto B_elem_offset = op->args[3]; - ICHECK(B_elem_offset->IsInstance()); - auto C_elem_offset = op->args[5]; - ICHECK(C_elem_offset->IsInstance()); - - auto C_ptr = Downcast(op->args[4]); - ICHECK(C_ptr.defined()); - auto mat_ty = - builder_->GetCooperativeMatrixNVType(builder_->GetBufferElementType(C_ptr), 16, 16); + auto A_elem_offset = op->args[1].as(); + auto B_elem_offset = op->args[3].as(); + auto C_elem_offset = op->args[5].as(); + ICHECK(A_elem_offset) << "Expects a constant element offset."; + ICHECK(B_elem_offset) << "Expects a constant element offset."; + ICHECK(C_elem_offset) << "Expects a constant element offset."; auto get_matrix = [this](PrimExpr arg, int offset) { auto buffer_var_mat = Downcast(arg); - ICHECK(buffer_var_mat.defined()); - return builder_->GetJointMatrix(buffer_var_mat, offset); + return builder_->GetCooperativeMatrix(buffer_var_mat, offset); }; - auto A = get_matrix(op->args[0], A_elem_offset.as()->value); - auto B = get_matrix(op->args[2], B_elem_offset.as()->value); - auto c_offset = C_elem_offset.as()->value; - auto C = get_matrix(op->args[4], c_offset); - - auto acc = builder_->CallCooperativeMatrixMadNV(mat_ty, A, B, C); - builder_->SetJointMatrixDef(C_ptr, c_offset, acc); + auto A = get_matrix(op->args[0], A_elem_offset->value); + auto B = get_matrix(op->args[2], B_elem_offset->value); + auto C = get_matrix(op->args[4], C_elem_offset->value); + auto acc = builder_->CallCooperativeMatrixMadNV(A, B, C); + auto C_ptr = Downcast(op->args[4]); + builder_->SetCooperativeMatrix(C_ptr, C_elem_offset->value, acc); return acc; } else { LOG(FATAL) << "Unresolved call " << op->op; @@ -628,20 +610,6 @@ void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) { } } -class AccumulatedJointMatrixCollector : public StmtExprVisitor { - public: - void VisitExpr_(const CallNode* op) final { - if (op->op.same_as(builtin::cooperative_matrix_mad_NV())) { - auto C_elem_offset = op->args[5]; - ICHECK(C_elem_offset->IsInstance()); - auto buffer_var_C = Downcast(op->args[4]); - joint_matrices[buffer_var_C.get()].insert(C_elem_offset.as()->value); - } - ExprVisitor::VisitExpr_(op); - } - std::unordered_map> joint_matrices; -}; - void CodeGenSPIRV::VisitStmt_(const ForNode* op) { ICHECK(is_zero(op->min)); analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); @@ -662,23 +630,29 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { // Loop head builder_->StartLabel(head_label); - AccumulatedJointMatrixCollector acc_mat_collector; + std::unordered_map> accum_matrices; if (op->kind == ForKind::kSerial) { - acc_mat_collector(op->body); + tir::PostOrderVisit(op->body, [&accum_matrices](const ObjectRef& obj) { + auto call = obj.as(); + if (call && call->op.same_as(builtin::cooperative_matrix_mad_NV())) { + auto C_elem_offset = call->args[5].as(); + ICHECK(C_elem_offset) << "Expects a constant element offset."; + auto buffer_var_C = Downcast(call->args[4]); + accum_matrices[buffer_var_C.get()].insert(C_elem_offset->value); + } + }); } - std::vector joint_matrix_phis; - for (auto [var, elem_offsets] : acc_mat_collector.joint_matrices) { + std::vector cooperative_matrix_phis; + for (const auto& [var, elem_offsets] : accum_matrices) { Var buffer_var_mat = GetRef(var); - auto mat_ty = builder_->GetCooperativeMatrixNVType( - builder_->GetBufferElementType(buffer_var_mat), 16, 16); for (auto offset : elem_offsets) { - spirv::PhiValue mat_phi = builder_->MakePhi(mat_ty, 2); - auto mat_def = builder_->GetJointMatrixDef(buffer_var_mat, offset); - mat_phi.SetIncoming(0, mat_def.cur_value, init_label); - joint_matrix_phis.push_back(mat_phi); - builder_->SetJointMatrixDef(buffer_var_mat, offset, mat_phi); + auto mat = builder_->GetCooperativeMatrix(buffer_var_mat, offset); + auto mat_phi = builder_->MakePhi(mat.stype, 2); + mat_phi.SetIncoming(0, mat, init_label); + cooperative_matrix_phis.push_back(mat_phi); + builder_->SetCooperativeMatrix(buffer_var_mat, offset, mat_phi); } } @@ -695,12 +669,12 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { builder_->StartLabel(body_label); var_map_[op->loop_var.get()] = spirv::Value(loop_var); this->VisitStmt(op->body); - spirv::Value one = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1) - : builder_->UIntImm(loop_var.stype, 1); builder_->MakeInst(spv::OpBranch, continue_label); // loop continue builder_->StartLabel(continue_label); + spirv::Value one = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1) + : builder_->UIntImm(loop_var.stype, 1); spirv::Value next_value = builder_->Add(loop_var, one); loop_var.SetIncoming(1, next_value, continue_label); @@ -708,14 +682,16 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { builder_->MakeInst(spv::OpBranch, head_label); int phi_index = 0; - for (auto [var, elem_offsets] : acc_mat_collector.joint_matrices) { + for (const auto& [var, elem_offsets] : accum_matrices) { + Var buffer_var_mat = GetRef(var); for (auto offset : elem_offsets) { - spirv::PhiValue mat_phi = joint_matrix_phis[phi_index++]; - auto mat_def = builder_->GetJointMatrixDef(GetRef(var), offset); - mat_phi.SetIncoming(1, mat_def.cur_value, continue_label); - builder_->SetJointMatrixDef(GetRef(var), offset, mat_phi); + auto mat_phi = cooperative_matrix_phis[phi_index++]; + auto mat = builder_->GetCooperativeMatrix(buffer_var_mat, offset); + mat_phi.SetIncoming(1, mat, continue_label); + builder_->SetCooperativeMatrix(buffer_var_mat, offset, mat_phi); } } + // loop merge builder_->StartLabel(merge_label); } diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index dd9e8fb33bb4..9dbdaa83d579 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -815,6 +815,57 @@ Value IRBuilder::Select(Value cond, Value a, Value b) { return MakeValue(spv::OpSelect, a.stype, cond, a, b); } +SType IRBuilder::GetCooperativeMatrixNVType(const SType& elem_ty, int rows, int cols) { + auto key = std::make_tuple(elem_ty.id, rows, cols); + auto entry = cooperative_matrix_type_tbl_.find(key); + if (entry != cooperative_matrix_type_tbl_.end()) { + return entry->second; + } + + auto rows_spv = IntImm(t_int32_, rows); + auto cols_spv = IntImm(t_int32_, cols); + auto scope = IntImm(t_int32_, spv::Scope::ScopeSubgroup); + + SType t; + t.id = id_counter_++; + t.element_type_id = elem_ty.id; + ib_.Begin(spv::Op::OpTypeCooperativeMatrixNV) + .AddSeq(t, elem_ty, scope, rows_spv, cols_spv) + .Commit(&global_); + + cooperative_matrix_type_tbl_[key] = t; + return t; +} + +Value IRBuilder::CallCooperativeMatrixLoadNV(const SType& mat_type, Value src, Value stride, + Value column_major) { + Value val = NewValue(mat_type, kNormal); + + ib_.Begin(spv::Op::OpCooperativeMatrixLoadNV) + .AddSeq(mat_type, val, src, stride, column_major) + .Commit(&function_); + return val; +} + +void IRBuilder::CallCooperativeMatrixStoreNV(Value dst, Value mat, Value stride, + Value column_major) { + ib_.Begin(spv::Op::OpCooperativeMatrixStoreNV) + .AddSeq(dst, mat, stride, column_major) + .Commit(&function_); +} + +Value IRBuilder::CallCooperativeMatrixFillNV(const SType& mat_type, Value v) { + Value val = NewValue(mat_type, kNormal); + ib_.Begin(spv::OpCompositeConstruct).AddSeq(mat_type, val, v).Commit(&function_); + return val; +} + +Value IRBuilder::CallCooperativeMatrixMadNV(Value A, Value B, Value C) { + Value val = NewValue(C.stype, kNormal); + ib_.Begin(spv::Op::OpCooperativeMatrixMulAddNV).AddSeq(C.stype, val, A, B, C).Commit(&function_); + return val; +} + } // namespace spirv } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index 7e0e0dcdd774..c19fe6ff39d8 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -431,58 +431,6 @@ class IRBuilder { Value CallKHRIntegerDotProduct(const SType& ret_type, const std::vector& args, const DataType& dtype); - SType GetCooperativeMatrixNVType(const SType& elem_ty, int rows, int cols) { - auto key = std::make_tuple(elem_ty.id, rows, cols); - auto entry = cooperative_matrix_type_tbl_.find(key); - if (entry != cooperative_matrix_type_tbl_.end()) { - return entry->second; - } - - auto rows_spv = IntImm(t_int32_, rows); - auto cols_spv = IntImm(t_int32_, cols); - auto scope = IntImm(t_int32_, spv::Scope::ScopeSubgroup); - - SType t; - t.id = id_counter_++; - t.element_type_id = elem_ty.id; - ib_.Begin(spv::Op::OpTypeCooperativeMatrixNV) - .AddSeq(t, elem_ty, scope, rows_spv, cols_spv) - .Commit(&global_); - - cooperative_matrix_type_tbl_[key] = t; - return t; - } - - Value CallCooperativeMatrixLoadNV(const SType& mat_type, Value src, Value stride, - Value column_major) { - Value val = NewValue(mat_type, kNormal); - - ib_.Begin(spv::Op::OpCooperativeMatrixLoadNV) - .AddSeq(mat_type, val, src, stride, column_major) - .Commit(&function_); - return val; - } - - void CallCooperativeMatrixStoreNV(Value dst, Value mat, Value stride, Value column_major) { - ib_.Begin(spv::Op::OpCooperativeMatrixStoreNV) - .AddSeq(dst, mat, stride, column_major) - .Commit(&function_); - } - - Value CallCooperativeMatrixFillNV(const SType& mat_type, Value v) { - Value val = NewValue(mat_type, kNormal); - ib_.Begin(spv::OpCompositeConstruct).AddSeq(mat_type, val, v).Commit(&function_); - return val; - } - - Value CallCooperativeMatrixMadNV(const SType& mat_type, Value A, Value B, Value C) { - Value val = NewValue(mat_type, kNormal); - ib_.Begin(spv::Op::OpCooperativeMatrixMulAddNV) - .AddSeq(mat_type, val, A, B, C) - .Commit(&function_); - return val; - } - /*! * \brief Build vector by concatenating components * @@ -645,33 +593,34 @@ class IRBuilder { Value GE(Value a, Value b); Value Select(Value cond, Value a, Value b); - struct JointMatrixDef { - Value cur_value; - Label defined_label; // TODO: remove it + // VK_NV_cooperative_matrix related + SType GetCooperativeMatrixNVType(const SType& elem_ty, int rows, int cols); + Value CallCooperativeMatrixLoadNV(const SType& mat_type, Value src, Value stride, + Value column_major); + void CallCooperativeMatrixStoreNV(Value dst, Value mat, Value stride, Value column_major); + Value CallCooperativeMatrixFillNV(const SType& mat_type, Value v); + Value CallCooperativeMatrixMadNV(Value A, Value B, Value C); + + SType GetBufferElementType(const tir::Var& buffer) { + const auto* ptr = buffer->type_annotation.as(); + ICHECK(ptr) << "Expects a pointer type."; + const auto* prim = ptr->element_type.as(); + ICHECK(prim) << "Expects a primitive type."; + return GetSType(prim->dtype); }; - void SetJointMatrixDef(const tir::Var& buffer_var_mat, int alloc_id, Value mat) { - auto key = std::make_pair(buffer_var_mat.get(), alloc_id); - joint_matrix_defs[key] = JointMatrixDef{mat, curr_label_}; + void SetCooperativeMatrix(const tir::Var& buffer_var_mat, int elem_offset, Value mat) { + auto key = std::make_pair(buffer_var_mat.get(), elem_offset); + cooperative_matrix_defs[key] = mat; } - JointMatrixDef GetJointMatrixDef(const tir::Var& buffer_var_mat, int alloc_id) { - auto key = std::make_pair(buffer_var_mat.get(), alloc_id); - auto entry = joint_matrix_defs.find(key); - ICHECK(entry != joint_matrix_defs.end()); + Value GetCooperativeMatrix(const tir::Var& buffer_var_mat, int elem_offset) { + auto key = std::make_pair(buffer_var_mat.get(), elem_offset); + auto entry = cooperative_matrix_defs.find(key); + ICHECK(entry != cooperative_matrix_defs.end()); return entry->second; } - Value GetJointMatrix(const tir::Var& buffer_var_mat, int alloc_id) { - return GetJointMatrixDef(buffer_var_mat, alloc_id).cur_value; - } - - SType GetBufferElementType(const tir::Var& buffer) { - auto* ptr = buffer->type_annotation.as(); - auto* prim = ptr->element_type.as(); - return GetSType(prim->dtype); - }; - private: /*! * \brief Create new value @@ -782,8 +731,10 @@ class IRBuilder { std::map, SType> pointer_type_tbl_; /*! \brief map from constant int to its value */ std::map, Value> const_tbl_; - /*! \brief map from name of a ExtInstImport to its value */ + /*! \brief map from name of an ExtInstImport to its value */ std::map ext_inst_tbl_; + /*! \brief map from (element-type code, rows, cols) to a Cooperative Matrix type */ + std::map, SType> cooperative_matrix_type_tbl_; /*! \brief Header segment * @@ -824,8 +775,8 @@ class IRBuilder { std::vector function_scope_vars_; /*! \brief Function segment */ std::vector function_; - std::map, SType> cooperative_matrix_type_tbl_; - std::map, JointMatrixDef> joint_matrix_defs; + /*! \brief map from (element-type code, rows, cols) to a Cooperative Matrix type */ + std::map, Value> cooperative_matrix_defs; }; } // namespace spirv diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index 4638f13190fd..93cb535afef5 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -959,21 +959,23 @@ def tensorize_load(block, dim, intrin): sch.tensorize(sch.get_loops(block)[2], MAD_INTRIN) target = "vulkan -from_device=0" - f = tvm.build(sch.mod, target=target) - dev = tvm.device(target, 0) + if tvm.target.Target(target).attrs["supports_cooperative_matrix_nv"]: + f = tvm.build(sch.mod, target=target) - A = tvm.nd.array(np.random.randn(M, K).astype("float16"), dev) - B = tvm.nd.array(np.random.randn(K, N).astype("float16"), dev) - C = tvm.nd.array(np.random.randn(M, N).astype(out_dtype), dev) + dev = tvm.device(target, 0) - f(A, B, C) + A = tvm.nd.array(np.random.randn(M, K).astype("float16"), dev) + B = tvm.nd.array(np.random.randn(K, N).astype("float16"), dev) + C = tvm.nd.array(np.random.randn(M, N).astype(out_dtype), dev) - A_np = A.numpy() - B_np = B.numpy() - ref = np.dot(A_np.astype("float32"), B_np.astype("float32")) + f(A, B, C) - tvm.testing.assert_allclose(C.numpy(), ref, rtol=1e-2, atol=1e-2) + A_np = A.numpy() + B_np = B.numpy() + ref = np.dot(A_np.astype("float32"), B_np.astype("float32")) + + tvm.testing.assert_allclose(C.numpy(), ref, rtol=1e-2, atol=1e-2) if __name__ == "__main__": From 770b4ed2db35de0ae562eb56173afe358cfdad03 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 5 May 2023 04:28:18 +0900 Subject: [PATCH 15/21] add doc --- include/tvm/tir/builtin.h | 13 +++ python/tvm/tir/op.py | 140 +++++++++++++++++++++++++++-- src/runtime/thread_storage_scope.h | 4 + src/target/spirv/codegen_spirv.cc | 32 +++++-- src/target/spirv/ir_builder.h | 12 ++- 5 files changed, 186 insertions(+), 15 deletions(-) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 848cce0f788c..aa7d8a9eb6ba 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -797,9 +797,22 @@ TVM_DLL const Op& start_profile_intrinsic(); */ TVM_DLL const Op& end_profile_intrinsic(); +// Intrinsics for the VK_NV_cooperative_matrix Vulkan extention. + +/*! \brief The intrinsic corresponding to the OpCooperativeMatrixLoadNV instruction. */ TVM_DLL const Op& cooperative_matrix_load_NV(); + +/*! \brief The intrinsic corresponding to the OpCooperativeMatrixStoreNV instruction. */ TVM_DLL const Op& cooperative_matrix_store_NV(); + +/*! + * \brief Create a new cooperative matrix filled with the provided value. + * + * There is no such instruction in the extention, but it is added for convenience. + */ TVM_DLL const Op& cooperative_matrix_fill_NV(); + +/*! \brief The intrinsic corresponding to the OpCooperativeMatrixMulAddNV instruction. */ TVM_DLL const Op& cooperative_matrix_mad_NV(); /*! \brief The kind of structure field info used in intrinsic */ diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index bf7bcbfea9a5..6cb42cca5a01 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -3037,20 +3037,146 @@ def TVMBackendFreeWorkspace(device_type, device_id, ptr): return call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, device_id, ptr) -def cooperative_matrix_load_NV(mat, offset, src, rows, cols, stride, column_major): - return call_intrin("handle", "tir.cooperative_matrix_load_NV", mat, offset, src, rows, cols, stride, column_major) +# Intrinsics for the VK_NV_cooperative_matrix Vulkan extention. -def cooperative_matrix_store_NV(dst, mat, offset, stride, column_major): - return call_intrin("handle", "tir.cooperative_matrix_store_NV", dst, mat, offset, stride, column_major) +def cooperative_matrix_load_NV(buffer_mat, offset, src, rows, cols, stride, column_major): + """The intrinsic corresponding to the OpCooperativeMatrixLoadNV instruction. + Parameters + ---------- + buffer_mat: Var + The destination buffer with "cooperative_matrix_nv" scope. + + offset: IntImm + The element offset for the matrix to be loaded in the destination buffer. + + src: Expr + The source pointer expression. + + rows: int + The number of rows in the matrix. -def cooperative_matrix_fill_NV(mat, offset, rows, cols, v): - return call_intrin("handle", "tir.cooperative_matrix_fill_NV", mat, offset, rows, cols, v) + cols: int + The number of columns in the matrix. + + stride: Expr + The stride of the matrix. + + column_major: bool + Whether the matrix elements are stored in the column-major order. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + "handle", + "tir.cooperative_matrix_load_NV", + buffer_mat, + offset, + src, + rows, + cols, + stride, + column_major, + ) + + +def cooperative_matrix_store_NV(dst, buffer_mat, offset, stride, column_major): + """The intrinsic corresponding to the OpCooperativeMatrixStoreNV instruction. + + Parameters + ---------- + dst: Expr + The destination pointer expression. + + buffer_mat: Var + The source buffer with "cooperative_matrix_nv" scope. + + offset: IntImm + The element offset for the matrix to be stored in the destination buffer. + + stride: Expr + The stride of the matrix. + + column_major: bool + Whether the matrix elements are stored in the column-major order. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + "handle", "tir.cooperative_matrix_store_NV", dst, buffer_mat, offset, stride, column_major + ) + + +def cooperative_matrix_fill_NV(buffer_mat, offset, rows, cols, value): + """Create a new cooperative matrix filled with the provided value. + + There is no such instruction in the extention, but it is added for convenience. + + Parameters + ---------- + + buffer_mat: Var + The buffer with "cooperative_matrix_nv" scope to be filled. + + offset: IntImm + The element offset for the matrix to be filled in `buffer_mat`. + + rows: int + The number of rows in the matrix. + + cols: int + The number of columns in the matrix. + + value: Expr + The value the matrix will be filled with. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + "handle", "tir.cooperative_matrix_fill_NV", buffer_mat, offset, rows, cols, value + ) def cooperative_matrix_mad_NV(A, A_off, B, B_off, C, C_off): - return call_intrin("handle", "tir.cooperative_matrix_mad_NV", A, A_off, B, B_off, C, C_off) + """The intrinsic corresponding to the OpCooperativeMatrixMulAddNV instruction. + + Parameters + ---------- + + A : Var + The buffer with "cooperative_matrix_nv" scope corresponding to the "A" matrix. + + A_off : IntImm + The element offset of the "A" matrix in the buffer `A`. + + B : Var + The buffer with "cooperative_matrix_nv" scope corresponding to the "B" matrix. + + B_off : IntImm + The element offset of the "B" matrix in the buffer `B`. + + C : Var + The buffer with "cooperative_matrix_nv" scope corresponding to the "C" matrix. + + C_off : IntImm + The element offset of the "C" matrix in the buffer `C`. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("handle", "tir.cooperative_matrix_mad_NV", A, A_off, B, B_off, C, C_off) # pylint: disable=unnecessary-lambda diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 95bb4e370b37..bd17ddf82e4d 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -64,6 +64,10 @@ enum class StorageRank { kTexture = 7, /*! \brief global scope amx tmm memory */ kAMXTMM = 8, + /*! + * \brief Scope representing the cooperative matrix in the Vulkan + * VK_NV_cooperative_matrix extention. + */ kCooperativeMatrixNV = 9, }; diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 2fd7b5463d46..e876e1588e69 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -422,8 +422,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { builder_->SetCooperativeMatrix(ptr, elem_offset->value, mat); return mat; } else if (op->op.same_as(builtin::cooperative_matrix_store_NV())) { - auto buffer_var_mat = Downcast(op->args[1]); auto dst_ptr = MakeValue(op->args[0]); + auto buffer_var_mat = Downcast(op->args[1]); auto elem_offset = op->args[2].as(); ICHECK(elem_offset) << "Expects a constant element offset."; auto mat = builder_->GetCooperativeMatrix(buffer_var_mat, elem_offset->value); @@ -630,12 +630,31 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { // Loop head builder_->StartLabel(head_label); + // In a normal matmul, the update semantics c += a * b is implemented by load and store to the + // memory location corresponding to the scalar c. But when C, A, and B are cooperative matrices, + // C += A * B cannot be implemented in the same way, since they cannot be load from or stored in + // a buffer. + + // We leverage SPIRV's SSA semantics to implement multiply-add update on C matrices. + // The original buffer is subdivided into fixed-size matrices, and multiply-add on individual + // matrix is unrolled. For example, if the buffer for the C matrix is of shape 64x64, we + // materialze 16 16x16 matrices and generate 16 multiply-add instructions. Each matrix is + // represented by a phi value, and each iteration of the reduction loop reads the phi values for + // the C matrices and update them. + + // A map from a buffer with kCooperativeMatrixNV storage scope (whose size can be arbitrary) + // to cooperative matrices "C" (whose sizes are fixed, e.g. 16x16). + // Each matrix is identified by an element offset into the buffer. This encoding lets us bridge + // the gap between a typical TIR matmul schedule (which uses + // `cache_read(..., "cooperative_matrix_nv")` from arbitrary-sized shared memory) and + // the fixed-size matrices as expected by the extention. std::unordered_map> accum_matrices; if (op->kind == ForKind::kSerial) { + // If this is a serial loop, record which C matrices are updated in this loop. tir::PostOrderVisit(op->body, [&accum_matrices](const ObjectRef& obj) { - auto call = obj.as(); - if (call && call->op.same_as(builtin::cooperative_matrix_mad_NV())) { + if (auto call = obj.as(); + call && call->op.same_as(builtin::cooperative_matrix_mad_NV())) { auto C_elem_offset = call->args[5].as(); ICHECK(C_elem_offset) << "Expects a constant element offset."; auto buffer_var_C = Downcast(call->args[4]); @@ -644,7 +663,10 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { }); } + // Phi values for all cooperative matrices "C". std::vector cooperative_matrix_phis; + + // Initialize phi values and use them as the matrix "C" in this loop. for (const auto& [var, elem_offsets] : accum_matrices) { Var buffer_var_mat = GetRef(var); for (auto offset : elem_offsets) { @@ -675,12 +697,12 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { builder_->StartLabel(continue_label); spirv::Value one = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1) : builder_->UIntImm(loop_var.stype, 1); - spirv::Value next_value = builder_->Add(loop_var, one); loop_var.SetIncoming(1, next_value, continue_label); - builder_->MakeInst(spv::OpBranch, head_label); + // Update phi values to the new value of the matrix "C" accumulated in this loop. + // The next iteration will read these values as the "C" matrix. int phi_index = 0; for (const auto& [var, elem_offsets] : accum_matrices) { Var buffer_var_mat = GetRef(var); diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index c19fe6ff39d8..1f762a25616c 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -593,7 +593,8 @@ class IRBuilder { Value GE(Value a, Value b); Value Select(Value cond, Value a, Value b); - // VK_NV_cooperative_matrix related + // The VK_NV_cooperative_matrix extention related, see the documentation for the + // SPIRV SPV_NV_cooperative_matrix extention for details. SType GetCooperativeMatrixNVType(const SType& elem_ty, int rows, int cols); Value CallCooperativeMatrixLoadNV(const SType& mat_type, Value src, Value stride, Value column_major); @@ -601,6 +602,9 @@ class IRBuilder { Value CallCooperativeMatrixFillNV(const SType& mat_type, Value v); Value CallCooperativeMatrixMadNV(Value A, Value B, Value C); + // Helper functions for cooperative matrix support + + /*! \brief Return the pointer element type for the buffer.*/ SType GetBufferElementType(const tir::Var& buffer) { const auto* ptr = buffer->type_annotation.as(); ICHECK(ptr) << "Expects a pointer type."; @@ -609,11 +613,13 @@ class IRBuilder { return GetSType(prim->dtype); }; + /*! \brief Associate a TIR buffer at the provided offset with the matrix. */ void SetCooperativeMatrix(const tir::Var& buffer_var_mat, int elem_offset, Value mat) { auto key = std::make_pair(buffer_var_mat.get(), elem_offset); cooperative_matrix_defs[key] = mat; } + /*! \brief Retrieve the matrix corresponding to the provided offset in a TIR buffer. */ Value GetCooperativeMatrix(const tir::Var& buffer_var_mat, int elem_offset) { auto key = std::make_pair(buffer_var_mat.get(), elem_offset); auto entry = cooperative_matrix_defs.find(key); @@ -733,7 +739,7 @@ class IRBuilder { std::map, Value> const_tbl_; /*! \brief map from name of an ExtInstImport to its value */ std::map ext_inst_tbl_; - /*! \brief map from (element-type code, rows, cols) to a Cooperative Matrix type */ + /*! \brief map from (element-type code, rows, cols) to a cooperative matrix type */ std::map, SType> cooperative_matrix_type_tbl_; /*! \brief Header segment @@ -775,7 +781,7 @@ class IRBuilder { std::vector function_scope_vars_; /*! \brief Function segment */ std::vector function_; - /*! \brief map from (element-type code, rows, cols) to a Cooperative Matrix type */ + /*! \brief map from (element-type code, rows, cols) a cooperative matrix */ std::map, Value> cooperative_matrix_defs; }; From 933f0a3116660abe762f905276ae7c4b661ac377 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 5 May 2023 06:26:29 +0900 Subject: [PATCH 16/21] cpplint --- src/target/spirv/codegen_spirv.cc | 1 + src/target/spirv/ir_builder.h | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index e876e1588e69..ce313336a2d1 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -28,6 +28,7 @@ #include #include +#include #include "../../runtime/pack_args.h" #include "../../runtime/vulkan/vulkan_common.h" diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index 1f762a25616c..b33e02371200 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -611,7 +611,7 @@ class IRBuilder { const auto* prim = ptr->element_type.as(); ICHECK(prim) << "Expects a primitive type."; return GetSType(prim->dtype); - }; + } /*! \brief Associate a TIR buffer at the provided offset with the matrix. */ void SetCooperativeMatrix(const tir::Var& buffer_var_mat, int elem_offset, Value mat) { From f596494149b0de5504774aa27e08ccba35bb9ccf Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 5 May 2023 09:34:27 +0900 Subject: [PATCH 17/21] workaround for TensorIntrin register in test --- .../python/unittest/test_target_codegen_vulkan.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index 93cb535afef5..d1a314c67566 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -865,22 +865,27 @@ def cooperative_matrix_mad_impl(a: T.handle, b: T.handle, c: T.handle) -> None: return cooperative_matrix_mad_impl -TensorIntrin.register("cooperative_matrix_load", cooperative_matrix_load_desc, get_load_impl(False)) - - @pytest.mark.parametrize("out_dtype", ["float32", "float16"]) def test_cooperative_matrix_nv(out_dtype): STORE_INTRIN = "cooperative_matrix_store_{}".format(out_dtype) FILL_INTRIN = "cooperative_matrix_fill_{}".format(out_dtype) MAD_INTRIN = "cooperative_matrix_mad_{}".format(out_dtype) + TensorIntrin.register( + "cooperative_matrix_load", cooperative_matrix_load_desc, get_load_impl(False), override=True + ) TensorIntrin.register( STORE_INTRIN, get_store_desc(out_dtype), get_store_impl(out_dtype), + override=True, + ) + TensorIntrin.register( + FILL_INTRIN, get_fill_desc(out_dtype), get_fill_impl(out_dtype), override=True + ) + TensorIntrin.register( + MAD_INTRIN, get_mad_desc(out_dtype), get_mad_impl(out_dtype), override=True ) - TensorIntrin.register(FILL_INTRIN, get_fill_desc(out_dtype), get_fill_impl(out_dtype)) - TensorIntrin.register(MAD_INTRIN, get_mad_desc(out_dtype), get_mad_impl(out_dtype)) def get_matmul(m, n, k, out_dtype="float32"): X = te.placeholder((m, k), name="X", dtype="float16") From 62b17d70d6135c16d08aedf37dd9ef8ceb830c19 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 5 May 2023 16:14:10 +0900 Subject: [PATCH 18/21] trying to fix test --- tests/python/unittest/test_target_codegen_vulkan.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index d1a314c67566..1d99aaf6a5ee 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -964,8 +964,9 @@ def tensorize_load(block, dim, intrin): sch.tensorize(sch.get_loops(block)[2], MAD_INTRIN) target = "vulkan -from_device=0" + tgt_attrs = tvm.target.Target(target).attrs - if tvm.target.Target(target).attrs["supports_cooperative_matrix_nv"]: + if tgt_attrs.get("supports_cooperative_matrix_nv"): f = tvm.build(sch.mod, target=target) dev = tvm.device(target, 0) From dc3cc69920bd63d76ee00d3d3f70e7cf93bcc117 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 5 May 2023 16:58:17 +0900 Subject: [PATCH 19/21] typo --- include/tvm/tir/builtin.h | 4 ++-- python/tvm/tir/op.py | 4 ++-- src/runtime/thread_storage_scope.h | 2 +- src/target/spirv/codegen_spirv.cc | 2 +- src/target/spirv/ir_builder.h | 4 ++-- src/target/spirv/spirv_support.h | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index aa7d8a9eb6ba..70526bde6639 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -797,7 +797,7 @@ TVM_DLL const Op& start_profile_intrinsic(); */ TVM_DLL const Op& end_profile_intrinsic(); -// Intrinsics for the VK_NV_cooperative_matrix Vulkan extention. +// Intrinsics for the VK_NV_cooperative_matrix Vulkan extension. /*! \brief The intrinsic corresponding to the OpCooperativeMatrixLoadNV instruction. */ TVM_DLL const Op& cooperative_matrix_load_NV(); @@ -808,7 +808,7 @@ TVM_DLL const Op& cooperative_matrix_store_NV(); /*! * \brief Create a new cooperative matrix filled with the provided value. * - * There is no such instruction in the extention, but it is added for convenience. + * There is no such instruction in the extension, but it is added for convenience. */ TVM_DLL const Op& cooperative_matrix_fill_NV(); diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 6cb42cca5a01..7b9324adecb3 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -3037,7 +3037,7 @@ def TVMBackendFreeWorkspace(device_type, device_id, ptr): return call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, device_id, ptr) -# Intrinsics for the VK_NV_cooperative_matrix Vulkan extention. +# Intrinsics for the VK_NV_cooperative_matrix Vulkan extension. def cooperative_matrix_load_NV(buffer_mat, offset, src, rows, cols, stride, column_major): @@ -3117,7 +3117,7 @@ def cooperative_matrix_store_NV(dst, buffer_mat, offset, stride, column_major): def cooperative_matrix_fill_NV(buffer_mat, offset, rows, cols, value): """Create a new cooperative matrix filled with the provided value. - There is no such instruction in the extention, but it is added for convenience. + There is no such instruction in the extension, but it is added for convenience. Parameters ---------- diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index bd17ddf82e4d..cc40ab7592b0 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -66,7 +66,7 @@ enum class StorageRank { kAMXTMM = 8, /*! * \brief Scope representing the cooperative matrix in the Vulkan - * VK_NV_cooperative_matrix extention. + * VK_NV_cooperative_matrix extension. */ kCooperativeMatrixNV = 9, }; diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index ce313336a2d1..4896b9b0e8eb 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -648,7 +648,7 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { // Each matrix is identified by an element offset into the buffer. This encoding lets us bridge // the gap between a typical TIR matmul schedule (which uses // `cache_read(..., "cooperative_matrix_nv")` from arbitrary-sized shared memory) and - // the fixed-size matrices as expected by the extention. + // the fixed-size matrices as expected by the extension. std::unordered_map> accum_matrices; if (op->kind == ForKind::kSerial) { diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index b33e02371200..69934d20a7f7 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -593,8 +593,8 @@ class IRBuilder { Value GE(Value a, Value b); Value Select(Value cond, Value a, Value b); - // The VK_NV_cooperative_matrix extention related, see the documentation for the - // SPIRV SPV_NV_cooperative_matrix extention for details. + // The VK_NV_cooperative_matrix extension related, see the documentation for the + // SPIRV SPV_NV_cooperative_matrix extension for details. SType GetCooperativeMatrixNVType(const SType& elem_ty, int rows, int cols); Value CallCooperativeMatrixLoadNV(const SType& mat_type, Value src, Value stride, Value column_major); diff --git a/src/target/spirv/spirv_support.h b/src/target/spirv/spirv_support.h index 5aef7d1344ff..f670b1ae0833 100644 --- a/src/target/spirv/spirv_support.h +++ b/src/target/spirv/spirv_support.h @@ -277,7 +277,7 @@ struct SPIRVSupport { */ bool supports_integer_dot_product{false}; - /*! \brief Whether the driver supports VK_NV_cooperative_matrix extention. */ + /*! \brief Whether the driver supports VK_NV_cooperative_matrix extension. */ bool supports_cooperative_matrix_nv{false}; }; From d40da59b8a30b201cf9c20c831de73a37218fe35 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 6 May 2023 16:36:23 +0900 Subject: [PATCH 20/21] move intrin definitions in test out side of top level --- .../unittest/test_target_codegen_vulkan.py | 472 +++++++++--------- 1 file changed, 234 insertions(+), 238 deletions(-) diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index 1d99aaf6a5ee..fceff2079438 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -602,271 +602,267 @@ def func(A: T.Buffer((N, 2), "int32")): np.testing.assert_array_equal(a[:, 1], (np.arange(N) - offset) % divisor) -@T.prim_func -def cooperative_matrix_load_desc(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float16", align=64, offset_factor=8, scope="shared") - C = T.match_buffer( - c, (16, 16), "float16", align=64, offset_factor=8, scope="cooperative_matrix_nv" - ) - - with T.block("root"): - T.reads(A[0:16, 0:16]) - T.writes(C[0:16, 0:16]) - for i, j in T.grid(16, 16): - with T.block("load"): - vii, vjj = T.axis.remap("SS", [i, j]) - C[vii, vjj] = A[vii, vjj] - - -def get_load_impl(column_major): +@pytest.mark.parametrize("out_dtype", ["float32", "float16"]) +def test_cooperative_matrix_nv(out_dtype): @T.prim_func - def cooperative_matrix_load_impl(a: T.handle, c: T.handle) -> None: - s1 = T.var("int32") - s0 = T.var("int32") - A = T.match_buffer( - a, - (16, 16), - "float16", - align=64, - offset_factor=8, - scope="shared", - strides=[s1, s0], - ) + def cooperative_matrix_load_desc(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=64, offset_factor=8, scope="shared") C = T.match_buffer( - c, - (16, 16), - "float16", - align=64, - offset_factor=8, - scope="cooperative_matrix_nv", + c, (16, 16), "float16", align=64, offset_factor=8, scope="cooperative_matrix_nv" ) - with T.block("root"): - T.reads(A[0:16, 0:16]) - T.writes(C[0:16, 0:16]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - T.evaluate( - T.cooperative_matrix_load_NV( - C.data, - C.elem_offset, - A.access_ptr("r"), - 16, - 16, - s1, - column_major, - dtype="handle", - ) - ) - - return cooperative_matrix_load_impl - - -def get_store_desc(out_dtype="float32", out_scope="global"): - @T.prim_func - def cooperative_matrix_store_desc(a: T.handle, c: T.handle) -> None: - A = T.match_buffer( - a, - (16, 16), - out_dtype, - align=64, - offset_factor=8, - scope="cooperative_matrix_nv", - ) - C = T.match_buffer(c, (16, 16), out_dtype, align=64, offset_factor=8, scope=out_scope) with T.block("root"): T.reads(A[0:16, 0:16]) T.writes(C[0:16, 0:16]) for i, j in T.grid(16, 16): - with T.block("store"): + with T.block("load"): vii, vjj = T.axis.remap("SS", [i, j]) C[vii, vjj] = A[vii, vjj] - return cooperative_matrix_store_desc - - -def get_store_impl(out_dtype="float32", out_scope="global"): - @T.prim_func - def cooperative_matrix_store_impl(a: T.handle, c: T.handle) -> None: - s1 = T.var("int32") - s0 = T.var("int32") - A = T.match_buffer( - a, - (16, 16), - out_dtype, - align=64, - offset_factor=8, - scope="cooperative_matrix_nv", - ) - C = T.match_buffer( - c, - (16, 16), - out_dtype, - align=64, - offset_factor=8, - scope=out_scope, - strides=[s1, s0], - ) - - with T.block("root"): - T.reads(A[0:16, 0:16]) - T.writes(C[0:16, 0:16]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - T.evaluate( - T.cooperative_matrix_store_NV( - C.access_ptr("w"), A.data, A.elem_offset, s1, False, dtype="handle" - ) + def get_load_impl(column_major): + @T.prim_func + def cooperative_matrix_load_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + A = T.match_buffer( + a, + (16, 16), + "float16", + align=64, + offset_factor=8, + scope="shared", + strides=[s1, s0], + ) + C = T.match_buffer( + c, + (16, 16), + "float16", + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", ) - return cooperative_matrix_store_impl - - -def get_fill_desc(out_dtype="float32"): - zero = IntImm("int32", 0).astype(out_dtype) - - @T.prim_func - def cooperative_matrix_fill_desc(c: T.handle) -> None: - C = T.match_buffer( - c, - (16, 16), - out_dtype, - align=64, - offset_factor=8, - scope="cooperative_matrix_nv", - ) - - with T.block("root"): - T.reads() - T.writes(C[0:16, 0:16]) - for i, j in T.grid(16, 16): - with T.block("init"): - vii, vjj = T.axis.remap("SS", [i, j]) - C[vii, vjj] = zero - - return cooperative_matrix_fill_desc - - -def get_fill_impl(out_dtype="float32"): - zero = IntImm("int32", 0).astype(out_dtype) - - @T.prim_func - def cooperative_matrix_fill_impl(c: T.handle) -> None: - C = T.match_buffer( - c, - (16, 16), - out_dtype, - align=64, - offset_factor=8, - scope="cooperative_matrix_nv", - ) + with T.block("root"): + T.reads(A[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + T.evaluate( + T.cooperative_matrix_load_NV( + C.data, + C.elem_offset, + A.access_ptr("r"), + 16, + 16, + s1, + column_major, + dtype="handle", + ) + ) - with T.block("root"): - T.reads() - T.writes(C[0:16, 0:16]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - T.evaluate( - T.cooperative_matrix_fill_NV(C.data, C.elem_offset, 16, 16, zero, dtype="handle") + return cooperative_matrix_load_impl + + def get_store_desc(out_dtype="float32", out_scope="global"): + @T.prim_func + def cooperative_matrix_store_desc(a: T.handle, c: T.handle) -> None: + A = T.match_buffer( + a, + (16, 16), + out_dtype, + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + C = T.match_buffer(c, (16, 16), out_dtype, align=64, offset_factor=8, scope=out_scope) + with T.block("root"): + T.reads(A[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + for i, j in T.grid(16, 16): + with T.block("store"): + vii, vjj = T.axis.remap("SS", [i, j]) + C[vii, vjj] = A[vii, vjj] + + return cooperative_matrix_store_desc + + def get_store_impl(out_dtype="float32", out_scope="global"): + @T.prim_func + def cooperative_matrix_store_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + A = T.match_buffer( + a, + (16, 16), + out_dtype, + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + C = T.match_buffer( + c, + (16, 16), + out_dtype, + align=64, + offset_factor=8, + scope=out_scope, + strides=[s1, s0], ) - return cooperative_matrix_fill_impl - + with T.block("root"): + T.reads(A[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + T.evaluate( + T.cooperative_matrix_store_NV( + C.access_ptr("w"), A.data, A.elem_offset, s1, False, dtype="handle" + ) + ) -def get_mad_desc(out_dtype="float32"): - def maybe_cast(v): - if out_dtype in ["float32", "int32"]: - return Cast(out_dtype, v) - return v + return cooperative_matrix_store_impl - @T.prim_func - def cooperative_matrix_mad_desc(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer( - a, - (16, 16), - "float16", - align=64, - offset_factor=8, - scope="cooperative_matrix_nv", - ) - B = T.match_buffer( - b, - (16, 16), - "float16", - align=64, - offset_factor=8, - scope="cooperative_matrix_nv", - ) - C = T.match_buffer( - c, - (16, 16), - out_dtype, - align=64, - offset_factor=8, - scope="cooperative_matrix_nv", - ) + def get_fill_desc(out_dtype="float32"): + zero = IntImm("int32", 0).astype(out_dtype) - with T.block("root"): - T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) - T.writes(C[0:16, 0:16]) - for i, j, k in T.grid(16, 16, 16): - with T.block("update"): - vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) - C[vii, vjj] = C[vii, vjj] + maybe_cast(A[vii, vkk]) * maybe_cast(B[vkk, vjj]) + @T.prim_func + def cooperative_matrix_fill_desc(c: T.handle) -> None: + C = T.match_buffer( + c, + (16, 16), + out_dtype, + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) - return cooperative_matrix_mad_desc + with T.block("root"): + T.reads() + T.writes(C[0:16, 0:16]) + for i, j in T.grid(16, 16): + with T.block("init"): + vii, vjj = T.axis.remap("SS", [i, j]) + C[vii, vjj] = zero + + return cooperative_matrix_fill_desc + + def get_fill_impl(out_dtype="float32"): + zero = IntImm("int32", 0).astype(out_dtype) + + @T.prim_func + def cooperative_matrix_fill_impl(c: T.handle) -> None: + C = T.match_buffer( + c, + (16, 16), + out_dtype, + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + with T.block("root"): + T.reads() + T.writes(C[0:16, 0:16]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + T.evaluate( + T.cooperative_matrix_fill_NV( + C.data, C.elem_offset, 16, 16, zero, dtype="handle" + ) + ) -def get_mad_impl(out_dtype="float32"): - @T.prim_func - def cooperative_matrix_mad_impl(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer( - a, - (16, 16), - "float16", - align=64, - offset_factor=8, - scope="cooperative_matrix_nv", - ) - B = T.match_buffer( - b, - (16, 16), - "float16", - align=64, - offset_factor=8, - scope="cooperative_matrix_nv", - ) - C = T.match_buffer( - c, - (16, 16), - out_dtype, - align=64, - offset_factor=8, - scope="cooperative_matrix_nv", - ) + return cooperative_matrix_fill_impl + + def get_mad_desc(out_dtype="float32"): + def maybe_cast(v): + if out_dtype in ["float32", "int32"]: + return Cast(out_dtype, v) + return v + + @T.prim_func + def cooperative_matrix_mad_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer( + a, + (16, 16), + "float16", + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + B = T.match_buffer( + b, + (16, 16), + "float16", + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + C = T.match_buffer( + c, + (16, 16), + out_dtype, + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) - with T.block("root"): - T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) - T.writes(C[0:16, 0:16]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - T.evaluate( - T.cooperative_matrix_mad_NV( - A.data, - A.elem_offset, - B.data, - B.elem_offset, - C.data, - C.elem_offset, - dtype="handle", - ) + with T.block("root"): + T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) + C[vii, vjj] = C[vii, vjj] + maybe_cast(A[vii, vkk]) * maybe_cast( + B[vkk, vjj] + ) + + return cooperative_matrix_mad_desc + + def get_mad_impl(out_dtype="float32"): + @T.prim_func + def cooperative_matrix_mad_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer( + a, + (16, 16), + "float16", + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + B = T.match_buffer( + b, + (16, 16), + "float16", + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", + ) + C = T.match_buffer( + c, + (16, 16), + out_dtype, + align=64, + offset_factor=8, + scope="cooperative_matrix_nv", ) - return cooperative_matrix_mad_impl + with T.block("root"): + T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + T.evaluate( + T.cooperative_matrix_mad_NV( + A.data, + A.elem_offset, + B.data, + B.elem_offset, + C.data, + C.elem_offset, + dtype="handle", + ) + ) + return cooperative_matrix_mad_impl -@pytest.mark.parametrize("out_dtype", ["float32", "float16"]) -def test_cooperative_matrix_nv(out_dtype): STORE_INTRIN = "cooperative_matrix_store_{}".format(out_dtype) FILL_INTRIN = "cooperative_matrix_fill_{}".format(out_dtype) MAD_INTRIN = "cooperative_matrix_mad_{}".format(out_dtype) From 503ac349438af0abd33e943b047606968f9549e2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 6 May 2023 16:38:16 +0900 Subject: [PATCH 21/21] clarify that the stride is in the number of elements --- python/tvm/tir/op.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 7b9324adecb3..f0523dbe78b8 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -3061,7 +3061,7 @@ def cooperative_matrix_load_NV(buffer_mat, offset, src, rows, cols, stride, colu The number of columns in the matrix. stride: Expr - The stride of the matrix. + The stride of the matrix in the number of elements. column_major: bool Whether the matrix elements are stored in the column-major order. @@ -3099,7 +3099,7 @@ def cooperative_matrix_store_NV(dst, buffer_mat, offset, stride, column_major): The element offset for the matrix to be stored in the destination buffer. stride: Expr - The stride of the matrix. + The stride of the matrix in the number of elements. column_major: bool Whether the matrix elements are stored in the column-major order.