diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index 2e07fa0a5d6..5049e99e8a8 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -19,6 +19,12 @@ namespace nvfuser { namespace { + +// Returns true if given number is power of 2 +bool isPowOf2(int x) { + return x > 1 && (x & (x - 1)) == 0; +} + // Move the broadcast axes to the left on the specified number of inner // dimensions e.g. (when number_of_inner_pos == 3): // [... I0, B, I1] -> [... B, I0, I1] @@ -51,6 +57,213 @@ void moveInnerBroadcastLeft(TensorView* tv, int number_of_inner_pos = 3) { tv->reorder(order_map); } +//! Automatically generates the shared memory swizzled data layout +//! for matmul mainloop. +//! The shared mem datalayout is always 2D currently, and this utility +//! function assumes that the innermost 2 dimensions on shared_mem_tv +//! are the ones begin swizzled. +void prologSwizzle(TensorView* shared_mem_tv, const MatmulParams& params) { + // Check that the innermost 2 dimensions are concrete and static + // sized so that the swizzle function can be defined. + + // Utility to check concrete static size: + auto check_concrete_static_dim = [](IterDomain* id) { + TORCH_INTERNAL_ASSERT( + !id->isBroadcast() && !id->isReduction(), + "no support on reduction or broadcast dims, but get ", + id->toString()); + TORCH_INTERNAL_ASSERT( + id->extent()->isConstInt(), + "swizzled dimensions need to be statically, but get ", + id->toString()); + }; + + TORCH_INTERNAL_ASSERT( + shared_mem_tv->nDims() >= 2, + "At least 2D input needed for swizzling, but get ", + shared_mem_tv->toString()); + check_concrete_static_dim(shared_mem_tv->axis(-2)); + check_concrete_static_dim(shared_mem_tv->axis(-1)); + + // Extract the constant sizes of the swizzled tile + const auto tile_size_x = shared_mem_tv->axis(-2)->extent()->evaluateInt(); + const auto tile_size_y = shared_mem_tv->axis(-1)->extent()->evaluateInt(); + + if (isTuring(params.mma_op) || isAmpere(params.mma_op)) { + // TODO: right now, we are assuming ldmatrix access, which only supports + // 16bit load according to offical doc: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-load-instruction-ldmatrix + // In the future, when we start adding support for tf32(different macro), + // fp32(ffma), double, int8, fp8, etc. we need to update this function. + TORCH_INTERNAL_ASSERT(dataTypeSize(*shared_mem_tv->getDataType()) == 2); + + // Each ldmatrix access is 8x8 + int row_unit = 8; + int col_unit = 8; + + // Column size of the tile needs to be multiples of 8 for ldmatrix to work. + TORCH_INTERNAL_ASSERT( + tile_size_x >= row_unit && tile_size_x % row_unit == 0 && + tile_size_y >= col_unit && tile_size_y % col_unit == 0, + "Prolog swizzle for ldmatrix, illegal tile size for prolog swizzle", + tile_size_x, + "x", + tile_size_y); + + int units_per_row = tile_size_y / col_unit; + + // Number of column units that can fit in a conflict free shared mem wave + // with memory width = 128 Byte assumed. + const int units_per_memory_row = + 128 / dataTypeSize(DataType::Half) / col_unit; + + // Calculate swizzle period: + int residue_unit_count = units_per_row % units_per_memory_row; + + // In the case where tile row is a multiple of memory row, the whole memory + // row is the repeated pattern of swizzle. In the case where tile row is + // not divisible, the residule part is the repeated pattern. + int repeated_pattern_size_in_units = + residue_unit_count == 0 ? units_per_memory_row : residue_unit_count; + + // Calculate row multiplier, which is defined as minimum number of rows + // to look down from an element until the same bank index is observed. + c10::optional maybe_row_multiplier = c10::nullopt; + + if (units_per_memory_row % repeated_pattern_size_in_units == 0) { + maybe_row_multiplier = + units_per_memory_row / repeated_pattern_size_in_units; + } else if ( + units_per_memory_row > repeated_pattern_size_in_units && + units_per_memory_row % + (units_per_memory_row - repeated_pattern_size_in_units) == + 0) { + maybe_row_multiplier = units_per_memory_row / + (units_per_memory_row - repeated_pattern_size_in_units); + } + + // The case where the row multiplier cannot be an integer would be where + // fractional tiling support is needed. Would gradually build out support + // on this one. + if (!maybe_row_multiplier.has_value()) { + // calculate effective row_period = lcm(row_period, repeated_pattern) / + // repeated_pattern_size which is the same as below + int row_period = units_per_memory_row / + std::gcd(units_per_memory_row, repeated_pattern_size_in_units); + + if (row_period < row_unit) { + TORCH_WARN_ONCE( + "Fractional pattern not yet implemented for swizzling memory row of size :", + units_per_memory_row, + " and tile row of size: ", + repeated_pattern_size_in_units); + // This would not lead to functional issue but just perf regression, so + // just do not swizzle anything yet. + // TODO: add support for swizzles with different row and col periods to + // enable this case. + return; + } else { + // This case would not need swizzling at all as the period of + // memory bank index over the row is wider than the access window. + return; + } + } else if (maybe_row_multiplier.value() >= row_unit) { + // No need to swizzle in this case. + return; + } + + // Calculate swizzle period, only equal row/col periods at the moment: + // TODO: aperiodic swizzle could also be supported in a follow up: + int max_swizzle_period = repeated_pattern_size_in_units; + + int swizzle_period = max_swizzle_period; + + // Do not have to use the max_swizzle period if we already had + // enough swizzle to permute a row_unit. This would encourage + // usage of power of 2 swizzle periods. + if (row_unit % maybe_row_multiplier.value() == 0) { + swizzle_period = + std::min(swizzle_period, row_unit / maybe_row_multiplier.value()); + } + + int row_multiplier = maybe_row_multiplier.value(); + + TORCH_INTERNAL_ASSERT( + tile_size_x % (swizzle_period * row_multiplier) == 0 && + tile_size_y % (swizzle_period * col_unit) == 0, + "need aperiodic swizzle config for tile size ", + tile_size_x, + "x", + tile_size_y, + "with units ", + row_unit, + "x", + col_unit); + + // add the swizzling op: + shared_mem_tv->split(-2, row_multiplier * swizzle_period); + shared_mem_tv->split(-2, row_multiplier); + + shared_mem_tv->split(-1, col_unit * swizzle_period); + shared_mem_tv->split(-1, col_unit); + + // -6 -5 -4 -3 -2 -1 + // [..., row_o, row_period, row_multiplier, col_o, col_period, col_unit] + if (isPowOf2(swizzle_period)) { + shared_mem_tv->swizzle(Swizzle2DType::XOR, -5, -2); + } else { + shared_mem_tv->swizzle(Swizzle2DType::CyclicShift, -5, -2); + } + + // Merge back the tile for subsequent vectorization scheduling + // TODO: could potentially simplify away the merges + shared_mem_tv->merge(-6); + shared_mem_tv->merge(-5); + shared_mem_tv->merge(-3); + shared_mem_tv->merge(-2); + } else if (isVolta(params.mma_op)) { + // TODO: Volta is slightly more complex, and a fixed recipe would + // not scale. In a follow up this would be inferred from the mma + // macro layout themselves as we already have them registered in + // the utils. + return; + } else { + TORCH_INTERNAL_ASSERT(false, "Prolog swizzle: unsupported mma macro"); + } +} + +//! Generates the prolog schedule on the shared memory buffer +//! tensor. The scheduling performs two steps: +//! +//! 1. Swizzled the shared mem data layout. +//! 2. Coalesce and vectorize the read write schedule. +void scheduleProlog(TensorView* shared_mem_tv, const MatmulParams& params) { + shared_mem_tv->setMemoryType(MemoryType::Shared); + + scheduler_utils::matmul_utils::orderTiledConcreteIdAsRoot(shared_mem_tv); + + // Swizzle the shared memory data layout + prologSwizzle(shared_mem_tv, params); + + // Assuming we are always vectorizing smem write by 128b at the moment: + // TODO: would need a data-type and alignment dependent interface + // to support non-vectorizable shapes. + // The vectorizable width logic would be in a separate PR as the + // current effort tries to focus on generating swizzles. + shared_mem_tv->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + shared_mem_tv, params.tile_sizes, 8, true); + + // Propagate prolog tensors + // propagate up the DAG, and propagate parallel type. + scheduler_utils::BoundedDirectionalTransformPropagator::backward( + shared_mem_tv, + -1, + {}, + scheduler_utils::BoundedDirectionalTransformPropagator::Options() + .propagateParallelType()); +} + } // namespace void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { @@ -237,35 +450,10 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { cc, -1, {acw_smem, bcw_smem}, {c}); // Schedule prolog: - // TODO: this section goes to a separate matmul util, - // and needs more configurability. + // TODO: this section needs more configurability. // ------------------------------------------------------------------ - scheduler_utils::matmul_utils::orderTiledConcreteIdAsRoot(acw_smem); - // [... M, K] - acw_smem->merge(-2); - scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( - acw_smem, gemm_tile, 8, false); - - // [... N, K] - scheduler_utils::matmul_utils::orderTiledConcreteIdAsRoot(bcw_smem); - bcw_smem->merge(-2); - scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( - bcw_smem, gemm_tile, 8, false); - - // Propagate prolog tensors - // propagate up the DAG, and propagate parallel type. - scheduler_utils::BoundedDirectionalTransformPropagator::backward( - acw_smem, - -1, - {a}, - scheduler_utils::BoundedDirectionalTransformPropagator::Options() - .propagateParallelType()); - scheduler_utils::BoundedDirectionalTransformPropagator::backward( - bcw_smem, - -1, - {b}, - scheduler_utils::BoundedDirectionalTransformPropagator::Options() - .propagateParallelType()); + scheduleProlog(acw_smem, params); + scheduleProlog(bcw_smem, params); // Set computeAt, setup the loop nesting structure on the kernel. // TODO: this section goes to a separate matmul util, @@ -314,19 +502,12 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { cc->applyMmaSwizzle( mma_builder.operand(MmaOptions::Operand::Accumulator).build()); - // Set memory type: - acw_smem->setMemoryType(MemoryType::Shared); - bcw_smem->setMemoryType(MemoryType::Shared); - // Set parallelization: // TODO: this section goes to a separate matmul util, // and needs more configurability. // ------------------------------------------------------------------ // Vectorize smem stores/loads: - acw_smem->axis(-1)->parallelize(ParallelType::Vectorize); - bcw_smem->axis(-1)->parallelize(ParallelType::Vectorize); - acr->axis(-1)->parallelize(ParallelType::Vectorize); bcr->axis(-1)->parallelize(ParallelType::Vectorize); diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index 1697ef988b5..35dd51b96ab 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -1949,7 +1949,7 @@ bool isFakeBoundaryTensorview( //! transform to by BoundedDirectionalTransformPropagator. std::unordered_set getDirectionalPropagatePathSet( TensorView* from_tv, - std::vector boundary_tvs, + const std::vector& boundary_tvs, BoundedDirectionalTransformPropagator::Options options, PropagateDirection direction) { // Prepare to collect all candidate tensorviews @@ -2061,9 +2061,9 @@ void BoundedDirectionalTransformPropagator::backward( if (!options.has_value()) { options = Options(); } - TORCH_INTERNAL_ASSERT( - !to.empty(), - "Propagation needs to be bounded, so no support for empty boundary."); + if (to.empty()) { + to = ir_utils::inputTvsOf(from); + } // Collect all tvs to included on the backward path as specified // by boundary and options. diff --git a/test/test_gpu_tensorcore.cpp b/test/test_gpu_tensorcore.cpp index cd61bce84a1..999decb2888 100644 --- a/test/test_gpu_tensorcore.cpp +++ b/test/test_gpu_tensorcore.cpp @@ -317,6 +317,9 @@ TEST_F(NVFuserTest, FusionVoltaMatmul_CUDA) { params.tile_sizes = gemm_tile; scheduleMatmul(&fusion, params); + // prologSwizzle on Volta is not supported yet + // ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -366,6 +369,9 @@ TEST_F(NVFuserTest, FusionVoltaMatmulRegDoubleBuffer_CUDA) { params.double_buffer_options.double_buffer_smem_read = true; scheduleMatmul(&fusion, params); + // prologSwizzle on Volta is not supported yet + // ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -650,6 +656,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmul_CUDA) { params.double_buffer_options.smem_double_buffer_stage = 4; scheduleMatmul(&fusion, params); + ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -705,6 +713,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulPipelineGmem_CUDA) { params.double_buffer_options.smem_double_buffer_stage = stage; scheduleMatmul(&fusion, params); + ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -771,6 +781,8 @@ TEST_F(NVFuserTest, FusionAmpereSwizzle_CUDA) { scheduleMatmul(&fusion, params); + ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -878,6 +890,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulRegDoubleBuffer_CUDA) { params.double_buffer_options.double_buffer_smem_read = true; scheduleMatmul(&fusion, params); + ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -1819,6 +1833,8 @@ TEST_F(NVFuserTest, FusionTuringMatmul_CUDA) { params.tile_sizes = gemm_tile; scheduleMatmul(&fusion, params); + ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -2859,6 +2875,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoad_CUDA) { params.double_buffer_options.smem_double_buffer_stage = 3; scheduleMatmul(&fusion, params); + ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -2907,6 +2925,8 @@ TEST_F(NVFuserTest, FusionTuringMatmulLargeLoad_CUDA) { params.tile_sizes = gemm_tile; scheduleMatmul(&fusion, params); + ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -2962,6 +2982,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck4warp_CUDA) { params.double_buffer_options.double_buffer_smem_write = true; scheduleMatmul(&fusion, params); + ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -3027,6 +3049,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck8warp_CUDA) { scheduleMatmul(&fusion, params); + ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -3086,6 +3110,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck6warp_CUDA) { scheduleMatmul(&fusion, params); + ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -3138,6 +3164,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoadLargeK_CUDA) { params.double_buffer_options.smem_double_buffer_stage = 3; scheduleMatmul(&fusion, params); + ASSERT_TRUE(fusion.bankConflictInfo().empty()); + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout);