Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 68 additions & 28 deletions csrc/scheduler/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,40 +349,80 @@ void prologSwizzle(TensorView* shared_mem_tv, const MatmulParams& params) {
* 6| | |
* 7| | |
* +----------+----------+
*
* We can consider each repeated_pattern_size rows as a gigarow, and each
* repeated_pattern_size megabanks as a gigabank. Note that megabank is a
* contiguous chunk of banks, but gigabank is not contiguous. Indeed,
* nearby megabanks in a gigabank has a distance of `g` megabanks
*/

TORCH_INTERNAL_ASSERT(
ldmatrix_rows % repeated_pattern_size == 0,
"Can not partition matrix into megarows");
int64_t num_gigarows = ldmatrix_rows / repeated_pattern_size;
int64_t num_gigabanks = g; // also = num_megabanks / repeated_pattern_size

// -2 -1
// [row, col]
TORCH_INTERNAL_ASSERT(
tile_size_x % ldmatrix_rows == 0, "Partial matrices not supported");
shared_mem_tv->split(-2, ldmatrix_rows);
TORCH_INTERNAL_ASSERT(
tile_size_y % ldmatrix_cols == 0, "Partial matrices not supported");
shared_mem_tv->split(-2, repeated_pattern_size);
shared_mem_tv->split(-1, ldmatrix_cols);
// -4 -3 -2 -1
// [matrix id, matrix, matrix id, matrix]
TORCH_INTERNAL_ASSERT(
ldmatrix_rows % repeated_pattern_size == 0,
"ldmatrix_rows is assumed to be a multiple of repeated_pattern_size");
shared_mem_tv->split(-3, repeated_pattern_size);
// -5 -4 -3 -2 -1
// [matrix id, repeat, pattern, matrix id, matrix]
int64_t swizzle_period = ldmatrix_rows / repeated_pattern_size;
// -4 -3 -2 -1
// [gigarow id, gigarow, matrix id, matrix]
shared_mem_tv->split(-2, num_gigabanks);
// -5 -4 -3 -2 -1
// [gigarow id, gigarow, y outer, gigabank id, matrix]
// Note that megabanks inside a gigabank are not contiguous, so the gigabank
// id is -2 instead of -3

/* We want to evenly distribute gigarows across gigabanks, for example, if
* we have 7 gigarows and 3 gigabanks, then we might distribute them as:
* +---+
* |x |
* | x |
* | x|
* |x |
* | x |
* | x|
* |x |
* +---+
* considering all matrices, this is a swizzle function like:
* +---+
* |012|
* |201|
* |120|
* |012|
* |201|
* |120|
* |012|
* +---+
* which is a cyclic shift.
*
* Note that because num_gigabanks (a.k.a. g) divide num_megabanks and
* row_stride_znz (which is row_stride % num_megabanks), g should also
* divide row_stride, because according to the fundamental
* division-with-remainder property (see comment in expr_simplifier.h):
* row_stride = q * num_megabanks + row_stride_znz
* which means, we can just consider each num_gigabanks matrices as a group,
* and we always have complete groups (i.e. no group has less than
* num_gigabanks matrices). Interleaving the memory of matrices within each
* group should be enough to fully remove bank conflict.
*/

/* To further simplify the problem, if we assume: */
TORCH_INTERNAL_ASSERT(
tile_size_y % (swizzle_period * ldmatrix_cols) == 0,
"need aperiodic swizzle config for tile size ",
tile_size_x,
"x",
tile_size_y,
"with units ",
ldmatrix_rows,
"x",
ldmatrix_cols);
shared_mem_tv->split(-2, swizzle_period);
// -6 -5 -4 -3 -2 -1
// [matrix id, repeat, pattern, matrix id outer, pattern id, matrix]
// swizzle repeat with pattern id to make repeat no longer repeat
if (isPowOf2(swizzle_period)) {
num_gigarows % num_gigabanks == 0,
"Requires non-square swizzle, which is not supported yet");
/* Then we can partition gigarows into full waves, each wave has
* num_gigabanks gigarows. This partition creates square dimensions, making
* the swizzle implementation easier */

// -5 -4 -3 -2 -1
// [gigarow id, gigarow, y outer, gigabank id, matrix]
shared_mem_tv->split(-5, num_gigabanks);
// -6 -5 -4 -3 -2 -1
// [wave id, wave, gigarow, y outer, gigabank id, matrix]

if (isPowOf2(num_gigabanks)) {
shared_mem_tv->swizzle(Swizzle2DType::XOR, -5, -2);
} else {
shared_mem_tv->swizzle(Swizzle2DType::CyclicShift, -5, -2);
Expand Down