Skip to content

[Bug] [MetaScheduler] [TensorIR] Stride values are not inferred within _impl() block. #15522

@cbalint13

Description

@cbalint13

Hi folks !

I am facing the following issue when using .access_ptr() within auto-tensorization using a custom ISA, the backend simply fail to infer the .stride values from the derived ancestor buffer through which the .access_ptr() is looking at the data buffer.

The tvm-ms-testcase.py.gz script is attached here.
Using the main branch 20230804 @ git 6085534 hash.

Description

Here are the declarations for the block description and implementation:

@T.prim_func
def vec_u8_i8_s32_desc(
    A: T.Buffer((INT8_MACS,), "uint8", offset_factor=1, align=INT8_MACS, scope="global"),
    B: T.Buffer((INT32_LANES, INT8_MACS), "int8", offset_factor=1, align=INT8_MACS, scope="global"),
    C: T.Buffer((INT32_LANES,), "int32", offset_factor=1, align=INT8_MACS, scope="global"),
) -> None:
    with T.block("root"):
        T.reads(C[0:INT32_LANES], A[0:INT8_MACS], B[0:INT32_LANES, 0:INT8_MACS])
        T.writes(C[0:INT32_LANES])
        for i in T.serial(0, INT32_LANES):
            for k in T.serial(0, INT8_MACS):
                with T.block("update"):
                    vi, vk = T.axis.remap("SR", [i, k])
                    C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")


@T.prim_func
def vec_u8_i8_s32_impl(
    A: T.Buffer((INT8_MACS,), "uint8", offset_factor=1, align=INT8_MACS, scope="global"),
    B: T.Buffer((INT32_LANES, INT8_MACS), "int8", offset_factor=1, strides=[T.int32(), T.int32()], scope="global"),
    C: T.Buffer((INT32_LANES,), "int32", offset_factor=1, align=INT32_LANES, scope="global"),
) -> None:
    with T.block("root"):
        T.block_attr({"pragma_import_llvm": VEC_MAC_impl()})
        T.reads(C[0:INT32_LANES], A[0:INT8_MACS], B[0:INT32_LANES, 0:INT8_MACS])
        T.writes(C[0:INT32_LANES])
        with T.block("update"):
            T.call_extern(
                f"VEC_MACC",
                C.access_ptr("w"),
                A.access_ptr("r"),
                B.access_ptr("r"),
|--------->>>   B.strides[0],   # BUG ?! (notwork)
                #8,             # MANUAL (cbalint)
                dtype="int32")

VEC_MACC_INTRIN = f"vec_macc"

TensorIntrin.register(
    VEC_MACC_INTRIN, vec_u8_i8_s32_desc, vec_u8_i8_s32_impl
)

Actual results:

In the context of T.call_extern( f"VEC_MACC", {...}):

  • Using B.strides[0] leads to no tensorized schedules at all, they all fail.
  • Using dummy 8, #MANUAL (cbalint) hard-coded value works just fine.

In contrast, here are two examples from the repo, apparently doing such inference for strides:

strides=[s0, s1],

strides=[s0, s1],

  • Are these working due to some MultiLevelTilingWithIntrin vs. MultiLevelTilingTensorCore context difference ?

Desired results:

Properly infer the values of strides for such case.

  • There is no way to work with .access_ptr without knowing the stride access to the buffer.
  • Application requires .access_ptr to pass the pointer value (+unstride) to the special ISA instruction.

Thank you !
~Cristian.


  • tune:meta_schedule

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions