diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index 618345d0a5d2..18f3e19909f6 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -1579,6 +1579,12 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], d, 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), ) + original_tile_y = tile_y + original_tile_z = tile_z + while (tile_x * tile_z) % (bdx * num_warps) != 0: + tile_z += original_tile_z + while (tile_x * tile_y) % (bdx * num_warps) != 0: + tile_y += original_tile_y # Otherwise we would exceed maxComputeWorkgroupStorageSize if ( @@ -1907,7 +1913,6 @@ def apply_to_gemm( # pylint: disable=unused-argument sch.unroll(yio) sch.vectorize(yiv) sch.unroll(xi) - sch.unroll(ki) sch.decompose_reduction(block, ty) def apply_to_md(sch, block):