Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
7ef15e0
[MatMul] Prolog build out, adding automatic swizzle generator
zasdfgbnm Mar 20, 2023
6afb375
Merge branch 'main' into matmul_swizzle_gen
zasdfgbnm Mar 20, 2023
0ee396c
Merge branch 'main' into matmul_swizzle_gen
zasdfgbnm Mar 22, 2023
a28293a
Merge branch 'main' into matmul_swizzle_gen
zasdfgbnm Mar 23, 2023
0700c9e
Merge branch 'main' into matmul_swizzle_gen
zasdfgbnm Mar 23, 2023
87c16d4
Merge branch 'main' into matmul_swizzle_gen
xwang233 Mar 24, 2023
f7be625
Merge branch 'main' into matmul_swizzle_gen
xwang233 Mar 24, 2023
4abdc49
Merge branch 'main' of github.com:NVIDIA/Fuser into matmul_swizzle_gen
zasdfgbnm Mar 29, 2023
98cfadc
Merge branch 'matmul_swizzle_gen' of github.com:NVIDIA/Fuser into mat…
zasdfgbnm Mar 29, 2023
4fda654
Merge branch 'main' into matmul_swizzle_gen
zasdfgbnm Mar 29, 2023
91852d2
Merge branch 'main' into matmul_swizzle_gen
zasdfgbnm Mar 31, 2023
6589e06
Merge branch 'main' of github.com:NVIDIA/Fuser into matmul_swizzle_gen
zasdfgbnm Apr 3, 2023
7b8e75a
fix
zasdfgbnm Apr 3, 2023
2a84c52
Merge branch 'main' into matmul_swizzle_gen
zasdfgbnm Apr 5, 2023
7a83199
Merge branch 'main' into matmul_swizzle_gen
zasdfgbnm Apr 5, 2023
e38a9a6
Merge branch 'main' into matmul_swizzle_gen
zasdfgbnm Apr 8, 2023
80e0c2e
Merge branch 'main' of github.com:NVIDIA/Fuser into matmul_swizzle_gen
zasdfgbnm Apr 10, 2023
ed2205f
Merge branch 'matmul_swizzle_gen' of github.com:NVIDIA/Fuser into mat…
zasdfgbnm Apr 10, 2023
9ac4f51
test bank conflict
zasdfgbnm Apr 10, 2023
d1c15f1
cleanup
zasdfgbnm Apr 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 216 additions & 35 deletions csrc/scheduler/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
namespace nvfuser {

namespace {

// Returns true if given number is power of 2
bool isPowOf2(int x) {
return x > 1 && (x & (x - 1)) == 0;
}

// Move the broadcast axes to the left on the specified number of inner
// dimensions e.g. (when number_of_inner_pos == 3):
// [... I0, B, I1] -> [... B, I0, I1]
Expand Down Expand Up @@ -51,6 +57,213 @@ void moveInnerBroadcastLeft(TensorView* tv, int number_of_inner_pos = 3) {
tv->reorder(order_map);
}

//! Automatically generates the shared memory swizzled data layout
//! for matmul mainloop.
//! The shared mem datalayout is always 2D currently, and this utility
//! function assumes that the innermost 2 dimensions on shared_mem_tv
//! are the ones begin swizzled.
void prologSwizzle(TensorView* shared_mem_tv, const MatmulParams& params) {
// Check that the innermost 2 dimensions are concrete and static
// sized so that the swizzle function can be defined.

// Utility to check concrete static size:
auto check_concrete_static_dim = [](IterDomain* id) {
TORCH_INTERNAL_ASSERT(
!id->isBroadcast() && !id->isReduction(),
"no support on reduction or broadcast dims, but get ",
id->toString());
TORCH_INTERNAL_ASSERT(
id->extent()->isConstInt(),
"swizzled dimensions need to be statically, but get ",
id->toString());
};

TORCH_INTERNAL_ASSERT(
shared_mem_tv->nDims() >= 2,
"At least 2D input needed for swizzling, but get ",
shared_mem_tv->toString());
check_concrete_static_dim(shared_mem_tv->axis(-2));
check_concrete_static_dim(shared_mem_tv->axis(-1));

// Extract the constant sizes of the swizzled tile
const auto tile_size_x = shared_mem_tv->axis(-2)->extent()->evaluateInt();
const auto tile_size_y = shared_mem_tv->axis(-1)->extent()->evaluateInt();

if (isTuring(params.mma_op) || isAmpere(params.mma_op)) {
// TODO: right now, we are assuming ldmatrix access, which only supports
// 16bit load according to offical doc:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-load-instruction-ldmatrix
// In the future, when we start adding support for tf32(different macro),
// fp32(ffma), double, int8, fp8, etc. we need to update this function.
TORCH_INTERNAL_ASSERT(dataTypeSize(*shared_mem_tv->getDataType()) == 2);

// Each ldmatrix access is 8x8
int row_unit = 8;
int col_unit = 8;

// Column size of the tile needs to be multiples of 8 for ldmatrix to work.
TORCH_INTERNAL_ASSERT(
tile_size_x >= row_unit && tile_size_x % row_unit == 0 &&
tile_size_y >= col_unit && tile_size_y % col_unit == 0,
"Prolog swizzle for ldmatrix, illegal tile size for prolog swizzle",
tile_size_x,
"x",
tile_size_y);

int units_per_row = tile_size_y / col_unit;

// Number of column units that can fit in a conflict free shared mem wave
// with memory width = 128 Byte assumed.
const int units_per_memory_row =
128 / dataTypeSize(DataType::Half) / col_unit;

// Calculate swizzle period:
int residue_unit_count = units_per_row % units_per_memory_row;

// In the case where tile row is a multiple of memory row, the whole memory
// row is the repeated pattern of swizzle. In the case where tile row is
// not divisible, the residule part is the repeated pattern.
int repeated_pattern_size_in_units =
residue_unit_count == 0 ? units_per_memory_row : residue_unit_count;

// Calculate row multiplier, which is defined as minimum number of rows
// to look down from an element until the same bank index is observed.
c10::optional<int> maybe_row_multiplier = c10::nullopt;

if (units_per_memory_row % repeated_pattern_size_in_units == 0) {
maybe_row_multiplier =
units_per_memory_row / repeated_pattern_size_in_units;
} else if (
units_per_memory_row > repeated_pattern_size_in_units &&
units_per_memory_row %
(units_per_memory_row - repeated_pattern_size_in_units) ==
0) {
maybe_row_multiplier = units_per_memory_row /
(units_per_memory_row - repeated_pattern_size_in_units);
}

// The case where the row multiplier cannot be an integer would be where
// fractional tiling support is needed. Would gradually build out support
// on this one.
if (!maybe_row_multiplier.has_value()) {
// calculate effective row_period = lcm(row_period, repeated_pattern) /
// repeated_pattern_size which is the same as below
int row_period = units_per_memory_row /
std::gcd(units_per_memory_row, repeated_pattern_size_in_units);

if (row_period < row_unit) {
TORCH_WARN_ONCE(
"Fractional pattern not yet implemented for swizzling memory row of size :",
units_per_memory_row,
" and tile row of size: ",
repeated_pattern_size_in_units);
// This would not lead to functional issue but just perf regression, so
// just do not swizzle anything yet.
// TODO: add support for swizzles with different row and col periods to
// enable this case.
return;
} else {
// This case would not need swizzling at all as the period of
// memory bank index over the row is wider than the access window.
return;
}
} else if (maybe_row_multiplier.value() >= row_unit) {
// No need to swizzle in this case.
return;
}

// Calculate swizzle period, only equal row/col periods at the moment:
// TODO: aperiodic swizzle could also be supported in a follow up:
int max_swizzle_period = repeated_pattern_size_in_units;

int swizzle_period = max_swizzle_period;

// Do not have to use the max_swizzle period if we already had
// enough swizzle to permute a row_unit. This would encourage
// usage of power of 2 swizzle periods.
if (row_unit % maybe_row_multiplier.value() == 0) {
swizzle_period =
std::min(swizzle_period, row_unit / maybe_row_multiplier.value());
}

int row_multiplier = maybe_row_multiplier.value();

TORCH_INTERNAL_ASSERT(
tile_size_x % (swizzle_period * row_multiplier) == 0 &&
tile_size_y % (swizzle_period * col_unit) == 0,
"need aperiodic swizzle config for tile size ",
tile_size_x,
"x",
tile_size_y,
"with units ",
row_unit,
"x",
col_unit);

// add the swizzling op:
shared_mem_tv->split(-2, row_multiplier * swizzle_period);
shared_mem_tv->split(-2, row_multiplier);

shared_mem_tv->split(-1, col_unit * swizzle_period);
shared_mem_tv->split(-1, col_unit);

// -6 -5 -4 -3 -2 -1
// [..., row_o, row_period, row_multiplier, col_o, col_period, col_unit]
if (isPowOf2(swizzle_period)) {
shared_mem_tv->swizzle(Swizzle2DType::XOR, -5, -2);
} else {
shared_mem_tv->swizzle(Swizzle2DType::CyclicShift, -5, -2);
}

// Merge back the tile for subsequent vectorization scheduling
// TODO: could potentially simplify away the merges
shared_mem_tv->merge(-6);
shared_mem_tv->merge(-5);
shared_mem_tv->merge(-3);
shared_mem_tv->merge(-2);
} else if (isVolta(params.mma_op)) {
// TODO: Volta is slightly more complex, and a fixed recipe would
// not scale. In a follow up this would be inferred from the mma
// macro layout themselves as we already have them registered in
// the utils.
return;
} else {
TORCH_INTERNAL_ASSERT(false, "Prolog swizzle: unsupported mma macro");
}
}

//! Generates the prolog schedule on the shared memory buffer
//! tensor. The scheduling performs two steps:
//!
//! 1. Swizzled the shared mem data layout.
//! 2. Coalesce and vectorize the read write schedule.
void scheduleProlog(TensorView* shared_mem_tv, const MatmulParams& params) {
shared_mem_tv->setMemoryType(MemoryType::Shared);

scheduler_utils::matmul_utils::orderTiledConcreteIdAsRoot(shared_mem_tv);

// Swizzle the shared memory data layout
prologSwizzle(shared_mem_tv, params);

// Assuming we are always vectorizing smem write by 128b at the moment:
// TODO: would need a data-type and alignment dependent interface
// to support non-vectorizable shapes.
// The vectorizable width logic would be in a separate PR as the
// current effort tries to focus on generating swizzles.
shared_mem_tv->merge(-2);
scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
shared_mem_tv, params.tile_sizes, 8, true);

// Propagate prolog tensors
// propagate up the DAG, and propagate parallel type.
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
shared_mem_tv,
-1,
{},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());
}

} // namespace

void scheduleMatmul(Fusion* fusion, const MatmulParams& params) {
Expand Down Expand Up @@ -237,35 +450,10 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) {
cc, -1, {acw_smem, bcw_smem}, {c});

// Schedule prolog:
// TODO: this section goes to a separate matmul util,
// and needs more configurability.
// TODO: this section needs more configurability.
// ------------------------------------------------------------------
scheduler_utils::matmul_utils::orderTiledConcreteIdAsRoot(acw_smem);
// [... M, K]
acw_smem->merge(-2);
scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
acw_smem, gemm_tile, 8, false);

// [... N, K]
scheduler_utils::matmul_utils::orderTiledConcreteIdAsRoot(bcw_smem);
bcw_smem->merge(-2);
scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
bcw_smem, gemm_tile, 8, false);

// Propagate prolog tensors
// propagate up the DAG, and propagate parallel type.
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
acw_smem,
-1,
{a},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
bcw_smem,
-1,
{b},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());
scheduleProlog(acw_smem, params);
scheduleProlog(bcw_smem, params);

// Set computeAt, setup the loop nesting structure on the kernel.
// TODO: this section goes to a separate matmul util,
Expand Down Expand Up @@ -314,19 +502,12 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) {
cc->applyMmaSwizzle(
mma_builder.operand(MmaOptions::Operand::Accumulator).build());

// Set memory type:
acw_smem->setMemoryType(MemoryType::Shared);
bcw_smem->setMemoryType(MemoryType::Shared);

// Set parallelization:
// TODO: this section goes to a separate matmul util,
// and needs more configurability.
// ------------------------------------------------------------------

// Vectorize smem stores/loads:
acw_smem->axis(-1)->parallelize(ParallelType::Vectorize);
bcw_smem->axis(-1)->parallelize(ParallelType::Vectorize);

acr->axis(-1)->parallelize(ParallelType::Vectorize);
bcr->axis(-1)->parallelize(ParallelType::Vectorize);

Expand Down
8 changes: 4 additions & 4 deletions csrc/scheduler/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1949,7 +1949,7 @@ bool isFakeBoundaryTensorview(
//! transform to by BoundedDirectionalTransformPropagator.
std::unordered_set<TensorView*> getDirectionalPropagatePathSet(
TensorView* from_tv,
std::vector<TensorView*> boundary_tvs,
const std::vector<TensorView*>& boundary_tvs,
BoundedDirectionalTransformPropagator::Options options,
PropagateDirection direction) {
// Prepare to collect all candidate tensorviews
Expand Down Expand Up @@ -2061,9 +2061,9 @@ void BoundedDirectionalTransformPropagator::backward(
if (!options.has_value()) {
options = Options();
}
TORCH_INTERNAL_ASSERT(
!to.empty(),
"Propagation needs to be bounded, so no support for empty boundary.");
if (to.empty()) {
to = ir_utils::inputTvsOf(from);
}

// Collect all tvs to included on the backward path as specified
// by boundary and options.
Expand Down
Loading