-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Closed
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
This bug is from mlc-llm, when we are trying the fuse decode and take kernels. Things work well after we FuseOpsByPattern, as two kernels gathered in one R.function:
@R.function(private=True)
def fused_fused_decode1_take1(lv744: R.Tensor(("vocab_size", 512), dtype="uint32"), lv745: R.Tensor(("vocab_size", 128), dtype="float16"), input_ids1: R.Tensor((1,), dtype="int32")) -> R.Tensor((1, 4096), dtype="float16"):
vocab_size = T.int64()
R.func_attr({"Composite": "decode_take", "Primitive": 1})
cls = Module
with R.dataflow():
lv746 = R.call_tir(cls.fused_decode1, (lv744, lv745), out_sinfo=R.Tensor((vocab_size, 4096), dtype="float16"))
gv = R.call_tir(cls.take1, (lv746, input_ids1), out_sinfo=R.Tensor((1, 4096), dtype="float16"))
R.output(gv)
return gvBut after we FuseTIR, then a stand-alone B = T.Buffer(...) remains in the T_take block:
@T.prim_func(private=True)
def fused_fused_decode1_take1(p_lv744: T.handle, p_lv745: T.handle, input_ids1: T.Buffer((T.int64(1),), "int32"), T_take_handle_intermediate: T.Buffer((T.int64(1), T.int64(4096)), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
vocab_size = T.int64()
lv744 = T.match_buffer(p_lv744, (vocab_size, T.int64(512)), "uint32")
lv745 = T.match_buffer(p_lv745, (vocab_size, T.int64(128)), "float16")
# with T.block("root"):
p_output0_intermediate = T.alloc_buffer((vocab_size, T.int64(4096)), "float16")
for i, j in T.grid(vocab_size, T.int64(4096)):
with T.block("decode"):
v_i, v_j = T.axis.remap("SS", [i, j])
T.reads(lv744[v_i, v_j // T.int64(8)], lv745[v_i, v_j // T.int64(32)])
T.writes(p_output0_intermediate[v_i, v_j])
p_output0_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv744[v_i, v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv745[v_i, v_j // T.int64(32)]
for ax0, ax1 in T.grid(T.int64(1), T.int64(4096)):
with T.block("T_take"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
B = T.Buffer((T.int64(1),), "int32")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Here
T.reads(p_output0_intermediate[B[v_ax0], v_ax1], input_ids1[v_ax0])
T.writes(T_take_handle_intermediate[v_ax0, v_ax1])
T_take_handle_intermediate[v_ax0, v_ax1] = p_output0_intermediate[input_ids1[v_ax0], v_ax1]And the stand-alone B is the original arugment of take1 prim function.
reproduce script:
https://gist.github.com/cyx-6/0e4facbdd603436f94104d0eb039c67f#file-script-py
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug