Skip to content

Conversation

@Jiawei-Shao
Copy link
Contributor

@Jiawei-Shao Jiawei-Shao commented May 8, 2024

This patch implements tir.dp4a with WGSL built-in function dot4I8Packed() on WebGPU backend.

Here is an example to use tir.dp4a in WebGPU target:

n = te.var("n")
A = te.placeholder((n,), "int8x4", name="A")
B = te.placeholder((n,), "int8x4", name="B")
C = te.compute(A.shape, lambda i: tvm.tir.dp4a(A[i], B[i]), name="C")
s = te.create_schedule(C.op)
    
bx, tx = s[C].split(C.op.axis[0], factor=64)

s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
mod = tvm.build(s, [A, B, C], tgt, name="dp4aTest")

Issue: #16627

This patch adds the support of `__dp4a(int8x4, int8x4)` as a pure
extern method of WebGPU target. In the generated WGSL shader,
`int8x4` will be translated into `u32`, and `__dp4a(int8x4, int8x4)`
will be translated into the WGSL built-in function
`dot4I8Packed(u32, u32)`.

Here is an example to use `__dp4a` in WebGPU target:

```
n = te.var("n")
A = te.placeholder((n,), "int8x4", name="A")
B = te.placeholder((n,), "int8x4", name="B")
C = te.compute(A.shape, lambda i: tvm.tir.call_pure_extern("int32", "__dp4a", A[i], B[i]), name="C")
s = te.create_schedule(C.op)
bx, tx = s[C].split(C.op.axis[0], factor=64)
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
mod = tvm.build(s, [A, B, C], tgt, name="dp4aTest")
```

Issue: apache#16627
@Jiawei-Shao Jiawei-Shao changed the title [WebGPU] Support __dp4a(int8x4, int8x4) as a pure extern method [WebGPU] Support dot4I8Packed(int8x4, int8x4) as a pure extern method May 9, 2024
// extra dispatch
TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchFastErf);

TVM_REGISTER_OP("tir.dot4I8Packed").set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchPureExtern<Direct>);
Copy link
Member

@tqchen tqchen May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry i was not being clear, for tir, it is better to have a common name dp4a (as this intrinsic shared across backends)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I find there is no tir.dp4a in TVM right now, and I see in TVM dp4a is all called through call_pure_extern(): vulkan cuda

Do you mean we add tir.dp4a in TVM or still support dp4a as a pure external call like what dp4a is supported in codegen_spirv.cc?

Copy link
Member

@tqchen tqchen May 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add tir.dp4a intrinsic, and use it to lower to various places

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Jiawei-Shao do u mind followup

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry I was busy on some other urgent stuffs these days. I will go back to work on this next week. I will follow the steps to add tir.dp4a first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @tqchen,

Sorry for my late response. I've updated this PR. PTAL, thanks!

@Jiawei-Shao Jiawei-Shao changed the title [WebGPU] Support dot4I8Packed(int8x4, int8x4) as a pure extern method [WebGPU] Implement builtin::dp4a with WGSL built-in function dot4I8Packed Jul 2, 2024
@Jiawei-Shao Jiawei-Shao changed the title [WebGPU] Implement builtin::dp4a with WGSL built-in function dot4I8Packed [WebGPU] Implement tir.dp4a with WGSL built-in function dot4I8Packed Jul 2, 2024
@Jiawei-Shao
Copy link
Contributor Author

@tqchen Now the PR has passed all the tests. PTAL, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants