-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[WebGPU] Implement tir.dp4a with WGSL built-in function dot4I8Packed
#16976
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This patch adds the support of `__dp4a(int8x4, int8x4)` as a pure
extern method of WebGPU target. In the generated WGSL shader,
`int8x4` will be translated into `u32`, and `__dp4a(int8x4, int8x4)`
will be translated into the WGSL built-in function
`dot4I8Packed(u32, u32)`.
Here is an example to use `__dp4a` in WebGPU target:
```
n = te.var("n")
A = te.placeholder((n,), "int8x4", name="A")
B = te.placeholder((n,), "int8x4", name="B")
C = te.compute(A.shape, lambda i: tvm.tir.call_pure_extern("int32", "__dp4a", A[i], B[i]), name="C")
s = te.create_schedule(C.op)
bx, tx = s[C].split(C.op.axis[0], factor=64)
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
mod = tvm.build(s, [A, B, C], tgt, name="dp4aTest")
```
Issue: apache#16627
__dp4a(int8x4, int8x4) as a pure extern methoddot4I8Packed(int8x4, int8x4) as a pure extern method
| // extra dispatch | ||
| TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchFastErf); | ||
|
|
||
| TVM_REGISTER_OP("tir.dot4I8Packed").set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchPureExtern<Direct>); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry i was not being clear, for tir, it is better to have a common name dp4a (as this intrinsic shared across backends)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can add tir.dp4a intrinsic, and use it to lower to various places
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Jiawei-Shao do u mind followup
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh sorry I was busy on some other urgent stuffs these days. I will go back to work on this next week. I will follow the steps to add tir.dp4a first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @tqchen,
Sorry for my late response. I've updated this PR. PTAL, thanks!
dot4I8Packed(int8x4, int8x4) as a pure extern methodbuiltin::dp4a with WGSL built-in function dot4I8Packed
builtin::dp4a with WGSL built-in function dot4I8Packedtir.dp4a with WGSL built-in function dot4I8Packed
|
@tqchen Now the PR has passed all the tests. PTAL, thanks! |
This patch implements
tir.dp4awith WGSL built-in functiondot4I8Packed()on WebGPU backend.Here is an example to use
tir.dp4ain WebGPU target:Issue: #16627