From 7a8a9f83d7135f21666e5ca9d06900ba048a8775 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Mon, 6 May 2024 09:06:13 +0800 Subject: [PATCH 1/6] [WebGPU] Support `__dp4a(int8x4, int8x4)` as a pure extern method This patch adds the support of `__dp4a(int8x4, int8x4)` as a pure extern method of WebGPU target. In the generated WGSL shader, `int8x4` will be translated into `u32`, and `__dp4a(int8x4, int8x4)` will be translated into the WGSL built-in function `dot4I8Packed(u32, u32)`. Here is an example to use `__dp4a` in WebGPU target: ``` n = te.var("n") A = te.placeholder((n,), "int8x4", name="A") B = te.placeholder((n,), "int8x4", name="B") C = te.compute(A.shape, lambda i: tvm.tir.call_pure_extern("int32", "__dp4a", A[i], B[i]), name="C") s = te.create_schedule(C.op) bx, tx = s[C].split(C.op.axis[0], factor=64) s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) mod = tvm.build(s, [A, B, C], tgt, name="dp4aTest") ``` Issue: #16627 --- src/target/source/codegen_webgpu.cc | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index a9a23fb999d8..d14123d0dc4c 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -298,6 +298,11 @@ void CodeGenWebGPU::PrintType(DataType t, std::ostream& os) { // NOLINT(*) if (lanes != 1) { ICHECK(lanes >= 2 && lanes <= 4) << "CodeGenWebGPU: only allows vector with lanes in {2, 3, 4}"; + + if (t.is_int() && t.bits() == 8 && lanes == 4) { + os << "u32"; + return; + } os << "vec" << lanes << "<"; } @@ -405,6 +410,15 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLIN this->EndScope(else_scope); } os << result; + } else if (op->op.same_as(builtin::call_pure_extern())) { + ICHECK_GE(op->args.size(), 1U); + const std::string& func_name = op->args[0].as()->value; + if (func_name == "__dp4a") { + os << "dot4I8Packed(" << PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2]) << ")"; + } else { + LOG(FATAL) << "WGSL shader cannot make extern calls. Graph contains extern \"" + << Downcast(op->args[0]) << "\""; + } } else { CodeGenC::VisitExpr_(op, os); } From 868b7209f9bea1aff0e81d93c1be792f302e2113 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Wed, 8 May 2024 16:11:47 +0800 Subject: [PATCH 2/6] Add validation --- src/target/source/codegen_webgpu.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index d14123d0dc4c..ea9f8274f62a 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -414,7 +414,11 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLIN ICHECK_GE(op->args.size(), 1U); const std::string& func_name = op->args[0].as()->value; if (func_name == "__dp4a") { - os << "dot4I8Packed(" << PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2]) << ")"; + if (op->args.size() != 3) { + LOG(FATAL) << "__dp4a can only accept 2 parameters (now: " << op->args.size() - 1 << ")"; + } else { + os << "dot4I8Packed(" << PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2]) << ")"; + } } else { LOG(FATAL) << "WGSL shader cannot make extern calls. Graph contains extern \"" << Downcast(op->args[0]) << "\""; From 7743caf19b5a2701a08656e3be7cea1b7959dbcb Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Thu, 9 May 2024 16:46:43 +0800 Subject: [PATCH 3/6] Add `dot4I8Packed` to WebGPU lower intrinsic --- src/target/source/codegen_webgpu.cc | 13 ------------- src/target/source/intrin_rule_webgpu.cc | 2 ++ 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index ea9f8274f62a..b6a6fe716ea6 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -410,19 +410,6 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLIN this->EndScope(else_scope); } os << result; - } else if (op->op.same_as(builtin::call_pure_extern())) { - ICHECK_GE(op->args.size(), 1U); - const std::string& func_name = op->args[0].as()->value; - if (func_name == "__dp4a") { - if (op->args.size() != 3) { - LOG(FATAL) << "__dp4a can only accept 2 parameters (now: " << op->args.size() - 1 << ")"; - } else { - os << "dot4I8Packed(" << PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2]) << ")"; - } - } else { - LOG(FATAL) << "WGSL shader cannot make extern calls. Graph contains extern \"" - << Downcast(op->args[0]) << "\""; - } } else { CodeGenC::VisitExpr_(op, os); } diff --git a/src/target/source/intrin_rule_webgpu.cc b/src/target/source/intrin_rule_webgpu.cc index f3e561f71477..bed453d00701 100644 --- a/src/target/source/intrin_rule_webgpu.cc +++ b/src/target/source/intrin_rule_webgpu.cc @@ -113,6 +113,8 @@ TVM_REGISTER_OP("tir.trunc") // extra dispatch TVM_REGISTER_OP("tir.erf").set_attr("webgpu.FLowerIntrinsic", DispatchFastErf); +TVM_REGISTER_OP("tir.dot4I8Packed").set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + } // namespace intrin } // namespace codegen } // namespace tvm From e736abc1986e8ab485e76a736293650656dee134 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Tue, 2 Jul 2024 10:44:54 +0800 Subject: [PATCH 4/6] Implement builtin `dp4a` with `dot4I8Packed` --- src/target/source/codegen_webgpu.cc | 8 ++++++++ src/target/source/intrin_rule_webgpu.cc | 2 -- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index cf0e06faabf3..f0d8e2bc4908 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -410,6 +410,14 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLIN this->EndScope(else_scope); } os << result; + } else if (op->op.same_as(builtin::dp4a())) { + // generate `dot4I8Packed(vec1, vec2) + acc` for the builtin `dp4a` + os << "dot4I8Packed("; + this->PrintExpr(op->args[0], os); + os << ", "; + this->PrintExpr(op->args[1], os); + os << ") + "; + this->PrintExpr(op->args[2], os); } else { CodeGenC::VisitExpr_(op, os); } diff --git a/src/target/source/intrin_rule_webgpu.cc b/src/target/source/intrin_rule_webgpu.cc index bed453d00701..f3e561f71477 100644 --- a/src/target/source/intrin_rule_webgpu.cc +++ b/src/target/source/intrin_rule_webgpu.cc @@ -113,8 +113,6 @@ TVM_REGISTER_OP("tir.trunc") // extra dispatch TVM_REGISTER_OP("tir.erf").set_attr("webgpu.FLowerIntrinsic", DispatchFastErf); -TVM_REGISTER_OP("tir.dot4I8Packed").set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); - } // namespace intrin } // namespace codegen } // namespace tvm From dd3b6d5527d44d4e9c5a6771bbd6ad8f5a8a2fc9 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Tue, 2 Jul 2024 11:24:21 +0800 Subject: [PATCH 5/6] Small fix --- src/target/source/codegen_webgpu.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index f0d8e2bc4908..4d7782a3af4c 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -298,7 +298,6 @@ void CodeGenWebGPU::PrintType(DataType t, std::ostream& os) { // NOLINT(*) if (lanes != 1) { ICHECK(lanes >= 2 && lanes <= 4) << "CodeGenWebGPU: only allows vector with lanes in {2, 3, 4}"; - if (t.is_int() && t.bits() == 8 && lanes == 4) { os << "u32"; return; From 056e18dbef9d9a2c083450d7c6c25e1a12e61f5b Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Tue, 2 Jul 2024 11:25:12 +0800 Subject: [PATCH 6/6] Add missing comment --- src/target/source/codegen_webgpu.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 4d7782a3af4c..b76b05470d5d 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -298,6 +298,7 @@ void CodeGenWebGPU::PrintType(DataType t, std::ostream& os) { // NOLINT(*) if (lanes != 1) { ICHECK(lanes >= 2 && lanes <= 4) << "CodeGenWebGPU: only allows vector with lanes in {2, 3, 4}"; + // Currently WebGPU doesn't support `i8` and an `int8x4` is represented as a `u32`. if (t.is_int() && t.bits() == 8 && lanes == 4) { os << "u32"; return;