Skip to content

[Bug] FuseTIR not handle arguments properly #16045

@cyx-6

Description

@cyx-6

This bug is from mlc-llm, when we are trying the fuse decode and take kernels. Things work well after we FuseOpsByPattern, as two kernels gathered in one R.function:

@R.function(private=True)
def fused_fused_decode1_take1(lv744: R.Tensor(("vocab_size", 512), dtype="uint32"), lv745: R.Tensor(("vocab_size", 128), dtype="float16"), input_ids1: R.Tensor((1,), dtype="int32")) -> R.Tensor((1, 4096), dtype="float16"):
  vocab_size = T.int64()
  R.func_attr({"Composite": "decode_take", "Primitive": 1})
  cls = Module
  with R.dataflow():
    lv746 = R.call_tir(cls.fused_decode1, (lv744, lv745), out_sinfo=R.Tensor((vocab_size, 4096), dtype="float16"))
    gv = R.call_tir(cls.take1, (lv746, input_ids1), out_sinfo=R.Tensor((1, 4096), dtype="float16"))
    R.output(gv)
  return gv

But after we FuseTIR, then a stand-alone B = T.Buffer(...) remains in the T_take block:

@T.prim_func(private=True)
def fused_fused_decode1_take1(p_lv744: T.handle, p_lv745: T.handle, input_ids1: T.Buffer((T.int64(1),), "int32"), T_take_handle_intermediate: T.Buffer((T.int64(1), T.int64(4096)), "float16")):
  T.func_attr({"tir.noalias": T.bool(True)})
  vocab_size = T.int64()
  lv744 = T.match_buffer(p_lv744, (vocab_size, T.int64(512)), "uint32")
  lv745 = T.match_buffer(p_lv745, (vocab_size, T.int64(128)), "float16")
  # with T.block("root"):
  p_output0_intermediate = T.alloc_buffer((vocab_size, T.int64(4096)), "float16")
  for i, j in T.grid(vocab_size, T.int64(4096)):
    with T.block("decode"):
      v_i, v_j = T.axis.remap("SS", [i, j])
      T.reads(lv744[v_i, v_j // T.int64(8)], lv745[v_i, v_j // T.int64(32)])
      T.writes(p_output0_intermediate[v_i, v_j])
      p_output0_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv744[v_i, v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv745[v_i, v_j // T.int64(32)]
      for ax0, ax1 in T.grid(T.int64(1), T.int64(4096)):
        with T.block("T_take"):
          v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])

          B = T.Buffer((T.int64(1),), "int32")
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Here

          T.reads(p_output0_intermediate[B[v_ax0], v_ax1], input_ids1[v_ax0])
          T.writes(T_take_handle_intermediate[v_ax0, v_ax1])
          T_take_handle_intermediate[v_ax0, v_ax1] = p_output0_intermediate[input_ids1[v_ax0], v_ax1]

And the stand-alone B is the original arugment of take1 prim function.
reproduce script:
https://gist.github.com/cyx-6/0e4facbdd603436f94104d0eb039c67f#file-script-py

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions