diff --git a/csrc/executor.cpp b/csrc/executor.cpp index 3c0d9dc8687..6315a622639 100644 --- a/csrc/executor.cpp +++ b/csrc/executor.cpp @@ -177,7 +177,8 @@ void FusionExecutor::debugCompileFusionFromStr( } std::tie(compiled_kernel_, last_compiler_log_, last_compiled_binary_) = - executor_utils::nvrtcCompile(c10::nullopt, code, name, fusion_id_); + executor_utils::nvrtcCompile( + c10::nullopt, code, name, fusion_id_, swizzle_factor); TORCH_INTERNAL_ASSERT( fusion_id_ > 0, "assign a fusion_id_ <= 0 is not accepted."); } @@ -236,6 +237,11 @@ void FusionExecutor::compileFusion( // TODO: refactor the options_ passed through options_.device = c10::Device(c10::DeviceType::CUDA, args.getDeviceIndex()); + // Set if before compilation and launchParams + if (compile_params.swizzle_factor) { + swizzle_factor = std::max(1, *compile_params.swizzle_factor); + } + // Set the index type of compile params if not already set. If set, // make sure the compile param type is valid with the given kernel // arguments. @@ -362,6 +368,7 @@ void FusionExecutor::compileFusion( structured_code, (kernelNamespace() + "::" + kernelName()).c_str(), fusion_id_, + swizzle_factor, block_size, maxrregcount_high_water_mark, save_compiled_binary_ || isDebugDumpEnabled(DebugDumpOption::Sass)); @@ -778,6 +785,12 @@ LaunchParams FusionExecutor::computeLaunchParams( launch_params.setSmem(dynamic_smem_size); + const int gdimx = launch_params.gdimx(); + const int gdimy = launch_params.gdimy(); + + launch_params.bind(gdimx * swizzle_factor, ParallelType::BIDx, true); + launch_params.bind( + (gdimy + swizzle_factor - 1) / swizzle_factor, ParallelType::BIDy, true); return launch_params; } @@ -1161,8 +1174,21 @@ std::vector FusionExecutor::runFusion( // Recompile the kernel if the number of threads in the block has increased // or maxrregcount has changed - if (launch_params_.nThreads() > block_size_high_water_mark || - compile_params.maxrregcount != maxrregcount_high_water_mark) { + + bool need_to_recompile = false; + + need_to_recompile |= launch_params_.nThreads() > block_size_high_water_mark; + + need_to_recompile |= + compile_params.maxrregcount != maxrregcount_high_water_mark; + + if (compile_params.swizzle_factor) { + need_to_recompile |= + swizzle_factor != std::max(1, *compile_params.swizzle_factor); + swizzle_factor = std::max(1, *compile_params.swizzle_factor); + } + + if (need_to_recompile) { const auto kernel = lowered_->kernel(); kernel_code_ = codegen::generateCudaKernel(kernel, kernelName()); const auto structured_code = @@ -1176,6 +1202,7 @@ std::vector FusionExecutor::runFusion( structured_code, (kernelNamespace() + "::" + kernelName()).c_str(), fusion_id_, + swizzle_factor, block_size_high_water_mark, maxrregcount_high_water_mark, save_compiled_binary_); @@ -1446,7 +1473,8 @@ void FusionExecutor::compileRtc( fusion_id_ = 1; std::tie(compiled_kernel_, last_compiler_log_, last_compiled_binary_) = - executor_utils::nvrtcCompile(c10::nullopt, scode, name, fusion_id_); + executor_utils::nvrtcCompile( + c10::nullopt, scode, name, fusion_id_, swizzle_factor); } float FusionExecutor::runRtc( diff --git a/csrc/executor.h b/csrc/executor.h index 473084b62f3..e417f84ff81 100644 --- a/csrc/executor.h +++ b/csrc/executor.h @@ -346,6 +346,7 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable { // increases, recompile to adjust maxregister count. int64_t block_size_high_water_mark = 1; int maxrregcount_high_water_mark = 255; + int swizzle_factor = 1; // lookup table to take short cut to retrieve recorded information in order to // launch kernels without re-inference parameters. diff --git a/csrc/executor_params.cpp b/csrc/executor_params.cpp index 07f067cb423..9c0bd6db260 100644 --- a/csrc/executor_params.cpp +++ b/csrc/executor_params.cpp @@ -33,25 +33,25 @@ void LaunchParams::assertValid() { gdimz()); } -void LaunchParams::bind(int64_t val, ParallelType p_type) { +void LaunchParams::bind(int64_t val, ParallelType p_type, bool allow_rebind) { switch (p_type) { case ParallelType::TIDx: - checkAndSet(val, bdimx_, "blockDim.x"); + checkAndSet(val, bdimx_, "blockDim.x", allow_rebind); break; case ParallelType::BIDx: - checkAndSet(val, gdimx_, "gridDim.x"); + checkAndSet(val, gdimx_, "gridDim.x", allow_rebind); break; case ParallelType::TIDy: - checkAndSet(val, bdimy_, "blockDim.y"); + checkAndSet(val, bdimy_, "blockDim.y", allow_rebind); break; case ParallelType::BIDy: - checkAndSet(val, gdimy_, "gridDim.y"); + checkAndSet(val, gdimy_, "gridDim.y", allow_rebind); break; case ParallelType::TIDz: - checkAndSet(val, bdimz_, "blockdim.z"); + checkAndSet(val, bdimz_, "blockdim.z", allow_rebind); break; case ParallelType::BIDz: - checkAndSet(val, gdimz_, "gridDim.z"); + checkAndSet(val, gdimz_, "gridDim.z", allow_rebind); break; default: TORCH_INTERNAL_ASSERT( diff --git a/csrc/executor_params.h b/csrc/executor_params.h index b67ea81c8c8..71e3307f4a7 100644 --- a/csrc/executor_params.h +++ b/csrc/executor_params.h @@ -16,18 +16,20 @@ struct TORCH_CUDA_CU_API CompileParams { std::optional index_type = std::nullopt; int maxrregcount = 255; bool enable_magic_zero = true; + std::optional swizzle_factor = std::nullopt; bool operator==(const CompileParams& other) const { // Disallow comparison if the index type is nullopt TORCH_INTERNAL_ASSERT( - index_type.has_value(), + index_type && swizzle_factor, "cannot compare as the index type is not defined"); TORCH_INTERNAL_ASSERT( - other.index_type.has_value(), + other.index_type && other.swizzle_factor, "cannot compare as the other index type is not defined"); return index_type == other.index_type && maxrregcount == other.maxrregcount && - enable_magic_zero == other.enable_magic_zero; + enable_magic_zero == other.enable_magic_zero && + swizzle_factor == other.swizzle_factor; } bool operator!=(const CompileParams& other) const { @@ -100,17 +102,20 @@ class TORCH_CUDA_CU_API LaunchParams { void checkAndSet( const int64_t incoming_val, int64_t& class_val, - std::string val) { - TORCH_INTERNAL_ASSERT( - class_val == UNINITIALIZED_VAL || incoming_val == class_val, - "Tried to set ", - val, - " from ", - class_val, - " to ", - incoming_val, - ", but it was already set and new value does not match.", - " Thread dims all have to be bound to the same value."); + std::string val, + bool allow_rebind = false) { + if (!allow_rebind) { + TORCH_INTERNAL_ASSERT( + class_val == UNINITIALIZED_VAL || incoming_val == class_val, + "Tried to set ", + val, + " from ", + class_val, + " to ", + incoming_val, + ", but it was already set and new value does not match.", + " Thread dims all have to be bound to the same value."); + } TORCH_CHECK( incoming_val > 0, "Received a thread binding on ", @@ -118,14 +123,14 @@ class TORCH_CUDA_CU_API LaunchParams { " that is ", incoming_val, ". Cannot create negative threads."); - if (class_val == UNINITIALIZED_VAL) { + if (class_val == UNINITIALIZED_VAL || allow_rebind) { class_val = incoming_val; } assertValid(); } // Binds dim assocaited with p_type to val - void bind(int64_t val, ParallelType p_type); + void bind(int64_t val, ParallelType p_type, bool allow_rebind = false); // Adjusted value based on get functions above for each value int64_t getDim(ParallelType p_type) const; diff --git a/csrc/executor_utils.cpp b/csrc/executor_utils.cpp index 8d41f6972f7..2afee31a571 100644 --- a/csrc/executor_utils.cpp +++ b/csrc/executor_utils.cpp @@ -1028,6 +1028,7 @@ std::tuple> nvrtcCompile( const std::string& code, const std::string& func_name, int id, + const int swizzle_factor, c10::optional opt_block_size, const int max_register_heuristic, bool return_compiled_binary) { @@ -1092,6 +1093,10 @@ std::tuple> nvrtcCompile( args.push_back("-DPYTORCH_NVFUSER_PROFILE_KERNEL"); } + const std::string swizzle_factor_str = + "-DSWIZZLE_FACTOR=" + std::to_string(std::max(swizzle_factor, 1)); + args.push_back(swizzle_factor_str.c_str()); + const char* ptxas_opt_level = getenv("PYTORCH_NVFUSER_JIT_OPT_LEVEL"); std::string jit_opt_level = "-O"; diff --git a/csrc/executor_utils.h b/csrc/executor_utils.h index 112611fa4f5..4a263b215e9 100644 --- a/csrc/executor_utils.h +++ b/csrc/executor_utils.h @@ -63,6 +63,7 @@ std::tuple> nvrtcCompile( const std::string& code, const std::string& func_name, int id, + const int swizzle_factor, c10::optional opt_block_size = c10::nullopt, const int max_register_heuristic = 255, bool return_compiled_binary = false); diff --git a/csrc/lower_bank_conflict.cpp b/csrc/lower_bank_conflict.cpp index d4d846370bf..30deba482cc 100644 --- a/csrc/lower_bank_conflict.cpp +++ b/csrc/lower_bank_conflict.cpp @@ -50,13 +50,32 @@ inline int64_t getPhaseSize(int64_t word_size_bytes) { return 32; } +ParallelType getParallelType(const std::string& name) { + if (name == "threadIdx.x") { + return ParallelType::TIDx; + } else if (name == "threadIdx.y") { + return ParallelType::TIDy; + } else if (name == "threadIdx.z") { + return ParallelType::TIDz; + } else if (name == "getBlockIdX()") { + return ParallelType::BIDx; + } else if (name == "getBlockIdY()") { + return ParallelType::BIDy; + } else if (name == "getBlockIdZ()") { + return ParallelType::BIDz; + } + TORCH_INTERNAL_ASSERT(false, "Not a parallel type"); +} + bool isThreadIdx(const std::string& name) { return name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z"; } bool isBlockIdx(const std::string& name) { - return name == "blockIdx.x" || name == "blockIdx.y" || name == "blockIdx.z"; + auto parallelType = getParallelType(name); + return parallelType == ParallelType::BIDx || + parallelType == ParallelType::BIDy || parallelType == ParallelType::BIDz; } bool isBlockDim(const std::string& name) { @@ -67,23 +86,6 @@ bool isGridDim(const std::string& name) { return name == "gridDim.x" && name == "gridDim.y" && name == "gridDim.z"; } -ParallelType getParallelType(const std::string& name) { - if (name == "threadIdx.x") { - return ParallelType::TIDx; - } else if (name == "threadIdx.y") { - return ParallelType::TIDy; - } else if (name == "threadIdx.z") { - return ParallelType::TIDz; - } else if (name == "blockIdx.x") { - return ParallelType::BIDx; - } else if (name == "blockIdx.y") { - return ParallelType::BIDy; - } else if (name == "blockIdx.z") { - return ParallelType::BIDz; - } - TORCH_INTERNAL_ASSERT(false, "Not a parallel type"); -} - std::vector evaluateAddressesOnFirstPhase( kir::TensorIndex* ti, const std::vector& for_loops, diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index 38fb7a7a050..6d52252cef8 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -377,6 +377,17 @@ void scheduleMatmul( // [... M,N,K] scheduler_utils::matmul_utils::makeTile(cc, gemm_tile.cta_tile.toVector()); + // int factor = + // getenv("SWIZZLE_FACTOR") != nullptr ? + // std::atoi(getenv("SWIZZLE_FACTOR")) : 1; + // if (factor != 1) { + // cc->split(1, factor, false); // outer split + // cc->split(0, factor, true); // inner + // // Mo, Mi, Ni, Mo + // cc->merge(1, 2); + // cc->merge(2, 1); + // } + // [Mo, No, Ko, Mi, Ni, Ki] // Propagate tiling globally scheduler_utils::transformPropagateToAllFrom(cc, -1); diff --git a/csrc/type.cpp b/csrc/type.cpp index f2420c44a3c..6c01c969086 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -607,11 +607,11 @@ static const char* rng_op_type2string(RNGOpType t) { static const char* parallel_type2string(ParallelType t) { switch (t) { case ParallelType::BIDz: - return "blockIdx.z"; + return "getBlockIdZ()"; case ParallelType::BIDy: - return "blockIdx.y"; + return "getBlockIdY()"; case ParallelType::BIDx: - return "blockIdx.x"; + return "getBlockIdX()"; case ParallelType::TIDz: return "threadIdx.z"; case ParallelType::TIDy: diff --git a/runtime/helpers.cu b/runtime/helpers.cu index 45edd368f0e..9f9faf424a3 100644 --- a/runtime/helpers.cu +++ b/runtime/helpers.cu @@ -22,6 +22,30 @@ #include #endif // __NVCC__ +constexpr unsigned swizzle_factor = SWIZZLE_FACTOR; + +__device__ unsigned getBlockIdX() { + static_assert(swizzle_factor >= 1); + if constexpr (swizzle_factor == 1) { + return blockIdx.x; + } else { + return blockIdx.x / swizzle_factor; + } +} + +__device__ unsigned getBlockIdY() { + static_assert(swizzle_factor >= 1); + if constexpr (swizzle_factor == 1) { + return blockIdx.y; + } else { + return blockIdx.y * swizzle_factor + (blockIdx.x % swizzle_factor); + } +} + +__device__ unsigned getBlockIdZ() { + return blockIdx.z; +} + __device__ constexpr int ceilDiv(int a, int b) { return (a + b - 1) / b; } diff --git a/test/test_gpu_tensorcore.cpp b/test/test_gpu_tensorcore.cpp index 061ee418a93..a7f5cc3100a 100644 --- a/test/test_gpu_tensorcore.cpp +++ b/test/test_gpu_tensorcore.cpp @@ -621,54 +621,64 @@ TEST_F(NVFuserTest, FusionAmpereMMANT_CUDA) { // Matmul test for Ampere MMA: across supported layouts TEST_F(NVFuserTest, FusionAmpereMatmul_CUDA) { // Keep multiples of 8 to keep vectorizable. - int M = 504, N = 136, K = 248; + int M = 8192, N = 8192, K = 8192; - for (auto layout : kAllSupportedMatmulLayout) { - Fusion fusion; - FusionGuard fg(&fusion); - auto tv0 = makeContigTensor(2, DataType::Half); - auto tv1 = makeContigTensor(2, DataType::Half); + int swizzle_factor = getenv("SWIZZLE_FACTOR") + ? std::max(1, std::atoi(getenv("SWIZZLE_FACTOR"))) + : 1; - fusion.addInput(tv0); - fusion.addInput(tv1); + CompileParams cparams{ + DataType::Int32, 255, false, .swizzle_factor = swizzle_factor}; - auto tv2 = matmul(tv0, tv1, layout); + for (auto order : + {MatmulParam::TileRasterizationOrder::RowMajor, + MatmulParam::TileRasterizationOrder::ColumnMajor}) + for (auto layout : kAllSupportedMatmulLayout) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addOutput(tv2); + fusion.addInput(tv0); + fusion.addInput(tv1); - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); + auto tv2 = matmul(tv0, tv1, layout); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(layout); + fusion.addOutput(tv2); - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - params.async_gmem_load_operands = true; - params.double_buffer_options.double_buffer_smem_write = true; - params.double_buffer_options.smem_double_buffer_stage = 4; - scheduleMatmul(tv2, tv0, tv1, params); + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); - at::manual_seed(0); - auto inputs = fp16MatmulAtInput(M, N, K, layout); + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) + .layout(layout); - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, - 0, - fe.compileFusion( - &fusion, - {inputs.first, inputs.second}, - LaunchParams(), - matmul_cparams)); - auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); - auto tref = atMatmul( - inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); - } + MatmulParam params(mma_builder); + params.tile_sizes = gemm_tile; + params.async_gmem_load_operands = true; + params.double_buffer_options.double_buffer_smem_write = true; + params.double_buffer_options.smem_double_buffer_stage = 4; + params.rasterization_order = order; + scheduleMatmul(tv2, tv0, tv1, params); + + at::manual_seed(0); + auto inputs = fp16MatmulAtInput(M, N, K, layout); + + FusionExecutor fe; + fe.setMeasureKernelTimeFlag(true); + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, + 0, + fe.compileFusion( + &fusion, {inputs.first, inputs.second}, LaunchParams(), cparams)); + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + std::cout << fe.kernelTimeMs() << std::endl; + auto tref = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.01, 0.01)); + } } // Matmul test for Ampere MMA: with pipelined gmem load