diff --git a/python/tvm/dlight/gpu/low_batch_gemv.py b/python/tvm/dlight/gpu/low_batch_gemv.py index dfed020853e9..1c27fdfb133a 100644 --- a/python/tvm/dlight/gpu/low_batch_gemv.py +++ b/python/tvm/dlight/gpu/low_batch_gemv.py @@ -98,7 +98,14 @@ def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffe for iter_var in block_stmt.iter_vars if isinstance(iter_var.dom.extent, tir.IntImm) ) - if len(const_iter_vars) == len(block_stmt.iter_vars): + if len(block_stmt.iter_vars) - len(const_iter_vars) != 1: + return None + symbolic_iter_var = list( + iter_var + for iter_var in block_stmt.iter_vars + if not isinstance(iter_var.dom.extent, tir.IntImm) + )[0] + if symbolic_iter_var.iter_type != tir.stmt.IterVar.DataPar: return None ret = [ read.buffer @@ -220,7 +227,8 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- return None sch = tir.Schedule(func) block_infos = normalize_prim_func(sch) - + if block_infos is None: + return None reduction_block_infos = [ block_info for block_info in block_infos if block_info.is_reduction() ] diff --git a/tests/python/dlight/test_gpu_low_batch_gemv.py b/tests/python/dlight/test_gpu_low_batch_gemv.py index 5827b7b81077..d3e635ddaa4e 100644 --- a/tests/python/dlight/test_gpu_low_batch_gemv.py +++ b/tests/python/dlight/test_gpu_low_batch_gemv.py @@ -251,5 +251,29 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(4096), T.int64(4096)), "float tvm.ir.assert_structural_equal(mod["main"], expected) +def test_reduction_symbolic_var(): + # fmt: off + @T.prim_func(private=True) + def before(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + kv_seq_len = T.int64() + A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), kv_seq_len)) + B = T.match_buffer(var_B, (T.int64(1), T.int64(32), kv_seq_len, T.int64(128))) + # with T.block("root"): + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128), kv_seq_len): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) + T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + matmul[v_i0, v_i1, v_i2, v_i3] = T.float32(0) + matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] + # fmt: on + mod = tvm.IRModule({"main": before}) + with Target("metal"): + mod = dl.ApplyDefaultSchedule(dl.gpu.LowBatchGEMV(4))(mod) + tvm.ir.assert_structural_equal(mod["main"], before) + + if __name__ == "__main__": tvm.testing.main()