diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index c96b8d6006a..8ab8283ff51 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -349,40 +349,80 @@ void prologSwizzle(TensorView* shared_mem_tv, const MatmulParams& params) { * 6| | | * 7| | | * +----------+----------+ + * + * We can consider each repeated_pattern_size rows as a gigarow, and each + * repeated_pattern_size megabanks as a gigabank. Note that megabank is a + * contiguous chunk of banks, but gigabank is not contiguous. Indeed, + * nearby megabanks in a gigabank has a distance of `g` megabanks */ + TORCH_INTERNAL_ASSERT( + ldmatrix_rows % repeated_pattern_size == 0, + "Can not partition matrix into megarows"); + int64_t num_gigarows = ldmatrix_rows / repeated_pattern_size; + int64_t num_gigabanks = g; // also = num_megabanks / repeated_pattern_size + // -2 -1 // [row, col] - TORCH_INTERNAL_ASSERT( - tile_size_x % ldmatrix_rows == 0, "Partial matrices not supported"); - shared_mem_tv->split(-2, ldmatrix_rows); - TORCH_INTERNAL_ASSERT( - tile_size_y % ldmatrix_cols == 0, "Partial matrices not supported"); + shared_mem_tv->split(-2, repeated_pattern_size); shared_mem_tv->split(-1, ldmatrix_cols); - // -4 -3 -2 -1 - // [matrix id, matrix, matrix id, matrix] - TORCH_INTERNAL_ASSERT( - ldmatrix_rows % repeated_pattern_size == 0, - "ldmatrix_rows is assumed to be a multiple of repeated_pattern_size"); - shared_mem_tv->split(-3, repeated_pattern_size); - // -5 -4 -3 -2 -1 - // [matrix id, repeat, pattern, matrix id, matrix] - int64_t swizzle_period = ldmatrix_rows / repeated_pattern_size; + // -4 -3 -2 -1 + // [gigarow id, gigarow, matrix id, matrix] + shared_mem_tv->split(-2, num_gigabanks); + // -5 -4 -3 -2 -1 + // [gigarow id, gigarow, y outer, gigabank id, matrix] + // Note that megabanks inside a gigabank are not contiguous, so the gigabank + // id is -2 instead of -3 + + /* We want to evenly distribute gigarows across gigabanks, for example, if + * we have 7 gigarows and 3 gigabanks, then we might distribute them as: + * +---+ + * |x | + * | x | + * | x| + * |x | + * | x | + * | x| + * |x | + * +---+ + * considering all matrices, this is a swizzle function like: + * +---+ + * |012| + * |201| + * |120| + * |012| + * |201| + * |120| + * |012| + * +---+ + * which is a cyclic shift. + * + * Note that because num_gigabanks (a.k.a. g) divide num_megabanks and + * row_stride_znz (which is row_stride % num_megabanks), g should also + * divide row_stride, because according to the fundamental + * division-with-remainder property (see comment in expr_simplifier.h): + * row_stride = q * num_megabanks + row_stride_znz + * which means, we can just consider each num_gigabanks matrices as a group, + * and we always have complete groups (i.e. no group has less than + * num_gigabanks matrices). Interleaving the memory of matrices within each + * group should be enough to fully remove bank conflict. + */ + + /* To further simplify the problem, if we assume: */ TORCH_INTERNAL_ASSERT( - tile_size_y % (swizzle_period * ldmatrix_cols) == 0, - "need aperiodic swizzle config for tile size ", - tile_size_x, - "x", - tile_size_y, - "with units ", - ldmatrix_rows, - "x", - ldmatrix_cols); - shared_mem_tv->split(-2, swizzle_period); - // -6 -5 -4 -3 -2 -1 - // [matrix id, repeat, pattern, matrix id outer, pattern id, matrix] - // swizzle repeat with pattern id to make repeat no longer repeat - if (isPowOf2(swizzle_period)) { + num_gigarows % num_gigabanks == 0, + "Requires non-square swizzle, which is not supported yet"); + /* Then we can partition gigarows into full waves, each wave has + * num_gigabanks gigarows. This partition creates square dimensions, making + * the swizzle implementation easier */ + + // -5 -4 -3 -2 -1 + // [gigarow id, gigarow, y outer, gigabank id, matrix] + shared_mem_tv->split(-5, num_gigabanks); + // -6 -5 -4 -3 -2 -1 + // [wave id, wave, gigarow, y outer, gigabank id, matrix] + + if (isPowOf2(num_gigabanks)) { shared_mem_tv->swizzle(Swizzle2DType::XOR, -5, -2); } else { shared_mem_tv->swizzle(Swizzle2DType::CyclicShift, -5, -2);