-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Description
Hi folks !
I am facing the following issue when using .access_ptr() within auto-tensorization using a custom ISA, the backend simply fail to infer the .stride values from the derived ancestor buffer through which the .access_ptr() is looking at the data buffer.
The tvm-ms-testcase.py.gz script is attached here.
Using the main branch 20230804 @ git 6085534 hash.
Description
Here are the declarations for the block description and implementation:
@T.prim_func
def vec_u8_i8_s32_desc(
A: T.Buffer((INT8_MACS,), "uint8", offset_factor=1, align=INT8_MACS, scope="global"),
B: T.Buffer((INT32_LANES, INT8_MACS), "int8", offset_factor=1, align=INT8_MACS, scope="global"),
C: T.Buffer((INT32_LANES,), "int32", offset_factor=1, align=INT8_MACS, scope="global"),
) -> None:
with T.block("root"):
T.reads(C[0:INT32_LANES], A[0:INT8_MACS], B[0:INT32_LANES, 0:INT8_MACS])
T.writes(C[0:INT32_LANES])
for i in T.serial(0, INT32_LANES):
for k in T.serial(0, INT8_MACS):
with T.block("update"):
vi, vk = T.axis.remap("SR", [i, k])
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")
@T.prim_func
def vec_u8_i8_s32_impl(
A: T.Buffer((INT8_MACS,), "uint8", offset_factor=1, align=INT8_MACS, scope="global"),
B: T.Buffer((INT32_LANES, INT8_MACS), "int8", offset_factor=1, strides=[T.int32(), T.int32()], scope="global"),
C: T.Buffer((INT32_LANES,), "int32", offset_factor=1, align=INT32_LANES, scope="global"),
) -> None:
with T.block("root"):
T.block_attr({"pragma_import_llvm": VEC_MAC_impl()})
T.reads(C[0:INT32_LANES], A[0:INT8_MACS], B[0:INT32_LANES, 0:INT8_MACS])
T.writes(C[0:INT32_LANES])
with T.block("update"):
T.call_extern(
f"VEC_MACC",
C.access_ptr("w"),
A.access_ptr("r"),
B.access_ptr("r"),
|--------->>> B.strides[0], # BUG ?! (notwork)
#8, # MANUAL (cbalint)
dtype="int32")
VEC_MACC_INTRIN = f"vec_macc"
TensorIntrin.register(
VEC_MACC_INTRIN, vec_u8_i8_s32_desc, vec_u8_i8_s32_impl
)
Actual results:
In the context of T.call_extern( f"VEC_MACC", {...}):
- Using
B.strides[0]leads to no tensorized schedules at all, they all fail. - Using dummy
8, #MANUAL (cbalint)hard-coded value works just fine.
In contrast, here are two examples from the repo, apparently doing such inference for strides:
tvm/python/tvm/tir/tensor_intrin/cuda.py
Line 175 in 907b29e
| strides=[s0, s1], |
tvm/python/tvm/tir/tensor_intrin/rocm.py
Line 240 in 907b29e
| strides=[s0, s1], |
- Are these working due to some
MultiLevelTilingWithIntrinvs.MultiLevelTilingTensorCorecontext difference ?
Desired results:
Properly infer the values of strides for such case.
- There is no way to work with
.access_ptrwithout knowing the stride access to the buffer. - Application requires
.access_ptrto pass the pointer value (+unstride) to the special ISA instruction.
Thank you !
~Cristian.
- tune:meta_schedule