From 57a983d64dc4ddf800b62f54ddbc2c149a70b037 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 16 Jan 2025 11:27:35 -0500 Subject: [PATCH] [Fix][KVCache] Fix incorrect tile size calculation This PR fixes the tile size calculation in the TIR attention kernels, where the computed tile sizes may not divide the total loop extent. --- python/tvm/relax/frontend/nn/llm/kv_cache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index f60c40efa21c..399e418c464b 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -855,7 +855,7 @@ def get_tile_size(x, y, t): cnt = (x * y) // t assert (x * y) % t == 0 tile_y = (int)(math.ceil(math.sqrt(cnt))) - while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt: + while (cnt % tile_y != 0 or y % tile_y != 0 or x % (cnt // tile_y) != 0) and tile_y <= cnt: tile_y += 1 assert tile_y <= cnt tile_x = cnt // tile_y @@ -1509,7 +1509,7 @@ def get_tile_size(x, y, t): cnt = (x * y) // t assert (x * y) % t == 0 tile_y = (int)(math.ceil(math.sqrt(cnt))) - while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt: + while (cnt % tile_y != 0 or y % tile_y != 0 or x % (cnt // tile_y) != 0) and tile_y <= cnt: tile_y += 1 assert tile_y <= cnt tile_x = cnt // tile_y @@ -1867,7 +1867,7 @@ def get_tile_size(x, y, t): cnt = (x * y) // t assert (x * y) % t == 0 tile_y = (int)(math.ceil(math.sqrt(cnt))) - while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt: + while (cnt % tile_y != 0 or y % tile_y != 0 or x % (cnt // tile_y) != 0) and tile_y <= cnt: tile_y += 1 assert tile_y <= cnt tile_x = cnt // tile_y