From 6d5dfdb016f7c4e3e390d6d238adc1a0615e56cb Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 3 Apr 2023 09:45:37 -0700 Subject: [PATCH 1/2] patching ABI compatibility (#112) Fixes #103 --- CMakeLists.txt | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5044a890bcb..9041f812501 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,13 +14,14 @@ endif () set(NVFUSER_ROOT ${PROJECT_SOURCE_DIR}) set(NVFUSER_SRCS_DIR "${NVFUSER_ROOT}/csrc") if (PROJECT_IS_TOP_LEVEL) - message(STATUS "top-level build") find_package(Torch REQUIRED) find_package(Python REQUIRED Development Interpreter) find_package(pybind11 REQUIRED) # need this since the pytorch execution uses a different name set(PYTHON_EXECUTABLE ${Python_EXECUTABLE}) set(ATEN_CUDA_ROOT "${TORCH_INSTALL_PREFIX}/include/ATen") + # CXX flags is necessary since https://github.com/pytorch/pytorch/issues/98093 + string(APPEND CMAKE_CXX_FLAGS ${TORCH_CXX_FLAGS}) if(BUILD_NVFUSER_BENCHMARK) add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/googletest) add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/benchmark) @@ -198,8 +199,6 @@ target_include_directories(${NVFUSER_CODEGEN} PUBLIC ) set_property(TARGET ${NVFUSER_CODEGEN} PROPERTY CXX_STANDARD 17) -message(STATUS "install_lib_dir: ${TORCH_INSTALL_LIB_DIR}") -message(STATUS "cmake_install_includedir: ${CMAKE_INSTALL_INCLUDEDIR}") if (PROJECT_IS_TOP_LEVEL) target_link_libraries(${NVFUSER_CODEGEN} PRIVATE torch ${TORCH_LIBRARIES}) # TODO: setup header and lib installation @@ -272,7 +271,6 @@ if(BUILD_PYTHON) target_compile_definitions(${NVFUSER} PRIVATE EXTENSION_NAME=_C) if (PROJECT_IS_TOP_LEVEL) - message(STATUS "skipping the rest of the libs") target_compile_options(${NVFUSER} PRIVATE -Wall -Wno-unused-function) target_link_libraries(${NVFUSER} PRIVATE ${TORCH_LIBRARIES}) target_link_libraries(${NVFUSER} PRIVATE "${TORCH_INSTALL_PREFIX}/lib/libtorch_python.so") @@ -359,7 +357,6 @@ if(BUILD_TEST) target_include_directories(${NVFUSER_TESTS} PRIVATE "${NVFUSER_ROOT}") if (PROJECT_IS_TOP_LEVEL) - message(STATUS "skipping the rest of the libs") target_compile_options(${NVFUSER_TESTS} PRIVATE -Wall -Wno-unused-function) target_link_libraries(${NVFUSER_TESTS} PRIVATE ${TORCH_LIBRARIES}) else() # PROJECT_IS_TOP_LEVEL From c0cfc8fe54c605fd2e1399ebe9226fa4b319f68e Mon Sep 17 00:00:00 2001 From: Andrzej Bekas <118676880+drzejan2@users.noreply.github.com> Date: Thu, 16 Mar 2023 04:37:09 -0700 Subject: [PATCH 2/2] Segmenter - support for matmul scheduler (#23) - connect matmul scheduler in segmenter with implementation of matmul scheduling structures, - update matmul params to store key items needed for matmul scheduling (based on prototype param structure), - add matmul compile/runtime checks in separete source file, - apply improvement in matmul instruction scheduling with loop rotation (changes from #2488), - add initial heuristicis for matmul scheduler in segmenter, - add implementation of helper functions for matmul heuristics, - add dedicated debug logger, - add tests for checking matmul schedule integration with segmenter, - fix documentation - add code for calculating index mode, - add code for calculating problem shape, - fix clang-tidy warnings in modified source files, - add guards for MmaOps in schedulers other than matmul, --- CMakeLists.txt | 3 +- benchmark/matmul.cpp | 18 +- csrc/ir_nodes.cpp | 11 + csrc/ir_utils.cpp | 11 + csrc/ir_utils.h | 2 + csrc/iter_visitor.h | 16 +- csrc/mma_type.cpp | 83 +++- csrc/mma_type.h | 20 +- csrc/scheduler/all_schedulers.h | 4 +- csrc/scheduler/matmul.cpp | 87 ++-- csrc/scheduler/matmul.h | 66 +--- csrc/scheduler/matmul_heuristic.h | 160 ++++++++ csrc/scheduler/matmul_utils.cpp | 634 ++++++++++++++++++++++++++++++ csrc/scheduler/matmul_utils.h | 37 ++ csrc/scheduler/registry.cpp | 134 ++++++- csrc/scheduler/registry.h | 8 + csrc/utils.cpp | 3 +- csrc/utils.h | 2 + test/test_gpu_matmul_sass.cpp | 11 +- test/test_gpu_tensorcore.cpp | 233 +++++++---- test/test_utils.cpp | 92 ++++- test/test_utils.h | 24 ++ 22 files changed, 1420 insertions(+), 239 deletions(-) create mode 100644 csrc/scheduler/matmul_heuristic.h create mode 100644 csrc/scheduler/matmul_utils.cpp create mode 100644 csrc/scheduler/matmul_utils.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 9041f812501..b792e0b565a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -152,10 +152,11 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/scheduler/pointwise.cpp ${NVFUSER_SRCS_DIR}/scheduler/pointwise_utils.cpp ${NVFUSER_SRCS_DIR}/scheduler/transpose.cpp + ${NVFUSER_SRCS_DIR}/scheduler/matmul.cpp + ${NVFUSER_SRCS_DIR}/scheduler/matmul_utils.cpp ${NVFUSER_SRCS_DIR}/scheduler/normalization.cpp ${NVFUSER_SRCS_DIR}/scheduler/normalization_utils.cpp ${NVFUSER_SRCS_DIR}/scheduler/reduction.cpp - ${NVFUSER_SRCS_DIR}/scheduler/matmul.cpp ${NVFUSER_SRCS_DIR}/scheduler/reduction_utils.cpp ${NVFUSER_SRCS_DIR}/scheduler/registry.cpp ${NVFUSER_SRCS_DIR}/scheduler/utils.cpp diff --git a/benchmark/matmul.cpp b/benchmark/matmul.cpp index e12ce0ab751..10c32fecdaf 100644 --- a/benchmark/matmul.cpp +++ b/benchmark/matmul.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include @@ -128,7 +129,7 @@ std::pair fp16MatmulAtInput( // TODO: separate compute and schedule definition once the can schedule // logic and pattern matching is ready. -void setupMatmul(Fusion* fusion, MatmulLayout layout, MatmulParam params) { +void setupMatmul(Fusion* fusion, MatmulLayout layout, MatmulParams params) { // Only hgemm on the initial setup auto a = makeContigTensor(2, DataType::Half); auto b = makeContigTensor(2, DataType::Half); @@ -139,7 +140,7 @@ void setupMatmul(Fusion* fusion, MatmulLayout layout, MatmulParam params) { fusion->addInput(b); fusion->addOutput(c); - scheduleMatmul(c, a, b, params); + scheduleMatmul(fusion, params); } void checkMatch(at::Tensor expect, at::Tensor result, int64_t k) { @@ -197,7 +198,7 @@ void checkMatch(at::Tensor expect, at::Tensor result, int64_t k) { static void SingleMatmulBase( benchmark::State& benchmark_state, MatmulLayout layout, - MatmulParam params) { + MatmulParams params) { std::vector input_mnk{ benchmark_state.range(0), benchmark_state.range(1), @@ -288,7 +289,7 @@ size_t getSmemSize(GemmTile cta_tile, int stage_number) { } // TODO: this part eventually will be automated by heuristics -MatmulParam getMatmulParams( +MatmulParams getMatmulParams( GemmTile cta_tile, int stage_number, MatmulLayout layout) { @@ -298,12 +299,9 @@ MatmulParam getMatmulParams( gemm_tile.warp_tile = GemmTile(64, 64, cta_tile.k); gemm_tile.instruction_tile = GemmTile(16, 16, 16); - // Collect mma swizzle info - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile) - .layout(layout); - - MatmulParam params(mma_builder); + MatmulParams params; + params.mma_op = MmaOptions::MacroType::Ampere_16_16_16; + params.layout = layout; params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; diff --git a/csrc/ir_nodes.cpp b/csrc/ir_nodes.cpp index 48c668d816b..8a3c8aee7de 100644 --- a/csrc/ir_nodes.cpp +++ b/csrc/ir_nodes.cpp @@ -1351,6 +1351,17 @@ MmaOp::MmaOp( in_b->getValType().value() == ValType::TensorIndex, in_b->getValType().value()); + const auto isBroadcastIn = [](const Val* val) { + if (val->getValType().value() == ValType::TensorView) { + const auto* tv = val->as(); + return tv->hasBroadcast(); + } + return true; + }; + + TORCH_INTERNAL_ASSERT(isBroadcastIn(in_a)); + TORCH_INTERNAL_ASSERT(isBroadcastIn(in_b)); + addOutput(out); addInput(in_a); addInput(in_b); diff --git a/csrc/ir_utils.cpp b/csrc/ir_utils.cpp index 562da480afa..cba2c1557b8 100644 --- a/csrc/ir_utils.cpp +++ b/csrc/ir_utils.cpp @@ -438,6 +438,17 @@ std::vector getSelectOps(Fusion* fusion) { return select_ops; } +std::vector getMmaOps(Fusion* fusion) { + std::vector mma_ops; + for (auto expr : fusion->exprs()) { + if (expr->isA()) { + mma_ops.push_back(expr->as()); + } + } + + return mma_ops; +} + namespace { class ValReplacementMutator : private OptOutMutator { diff --git a/csrc/ir_utils.h b/csrc/ir_utils.h index 9f037d8a7be..d2a5ddf10b5 100644 --- a/csrc/ir_utils.h +++ b/csrc/ir_utils.h @@ -325,6 +325,8 @@ TORCH_CUDA_CU_API std::vector getIndexSelectOps(Fusion* fusion); TORCH_CUDA_CU_API std::vector getTorchGatherOps(Fusion* fusion); +TORCH_CUDA_CU_API std::vector getMmaOps(Fusion* fusion); + TORCH_CUDA_CU_API std::vector getSelectOps(Fusion* fusion); // Returns the initialization value of tv or nullptr if not initialized. diff --git a/csrc/iter_visitor.h b/csrc/iter_visitor.h index 4e961088197..b17a8a52be8 100644 --- a/csrc/iter_visitor.h +++ b/csrc/iter_visitor.h @@ -132,7 +132,7 @@ class TORCH_CUDA_CU_API IterVisitor : public OptOutDispatch { // registered outputs. void traverse(Fusion* fusion); - // Same as traverse put it traverses every edge, meaning it will traverse + // Same as traverse but it traverses every edge, meaning it will traverse // values more than once. void traverseAllPaths(Fusion* fusion); @@ -147,8 +147,8 @@ class TORCH_CUDA_CU_API IterVisitor : public OptOutDispatch { }; /* - * Backward visitor IterVisitor calls handle in reverse order from outputs - * to inputs It would be really nice to unify this with IterVisitor, however, + * Backward visitor calls handle in reverse order from outputs to inputs. + * It would be really nice to unify this with IterVisitor, however, * the challenge there is that we specify traversal from outputs towards inputs * because it implicitly provides DCE. However, if users are not careful, they * could miss necessary outputs to do a backward traversal. @@ -163,10 +163,10 @@ class TORCH_CUDA_CU_API IterVisitor : public OptOutDispatch { * outputs of some exprs, example being the `N` output of welford ops. * `must_cover_all_expr_outputs` is added to disable the check, and in * this case the visitor pass need be aware - * 1. Exprs with any output that has a use chain that ends with a final - * consumer in the `from` list `will be` visited. - * 2. Vals that doesn't have a use chain that ends with a final - * consumer in the `from` list `will not be` visited, even though its + * 1. Exprs in the `from` list with any output that has a use chain that + * ends with a final consumer `will be` visited. + * 2. Vals in the `from` list that doesn't have a use chain that ends with + * a final consumer `will not be` visited, even though its * definition expr might be visited. An example is if the `N` output * of an welford op is unused, but other outputs are, the welford op * will be visited but the `N` output will not. @@ -302,7 +302,7 @@ class StmtSort : public IterVisitor { bool traverse_members = false, bool traverse_attributes = false); - // Returns ordered Statements required to produce from, including from. + // Returns ordered Statements required to produce 'to', including 'to'. static std::vector getStmts( Fusion* fusion, const std::vector& to, diff --git a/csrc/mma_type.cpp b/csrc/mma_type.cpp index 8367233c07e..5b290f30d4f 100644 --- a/csrc/mma_type.cpp +++ b/csrc/mma_type.cpp @@ -8,6 +8,7 @@ #include #include #include +#include namespace nvfuser { @@ -124,7 +125,8 @@ bool isTuring(MmaOptions::MacroType macro) { } bool isAmpere(MmaOptions::MacroType macro) { - return macro == MmaOptions::MacroType::Ampere_16_8_16 || + return macro == MmaOptions::MacroType::Ampere_16_8_8 || + macro == MmaOptions::MacroType::Ampere_16_8_16 || macro == MmaOptions::MacroType::Ampere_16_16_16; } @@ -134,11 +136,9 @@ int getOutputRegisterSize(MmaOptions::MacroType macro) { case MmaOptions::MacroType::Ampere_16_16_16: case MmaOptions::MacroType::Turing_16_16_16: return 8; - break; case MmaOptions::MacroType::Turing_16_8_16: case MmaOptions::MacroType::Ampere_16_8_16: return 4; - break; default: TORCH_INTERNAL_ASSERT(false, "unknown macro"); break; @@ -150,13 +150,11 @@ int getInputARegisterSize(MmaOptions::MacroType macro) { switch (macro) { case MmaOptions::MacroType::Volta_16_16_4: return 4; - break; case MmaOptions::MacroType::Turing_16_8_16: case MmaOptions::MacroType::Turing_16_16_16: case MmaOptions::MacroType::Ampere_16_8_16: case MmaOptions::MacroType::Ampere_16_16_16: return 8; - break; default: TORCH_INTERNAL_ASSERT(false, "unknown macro"); break; @@ -168,7 +166,6 @@ int getInputBRegisterSize(MmaOptions::MacroType macro) { switch (macro) { case MmaOptions::MacroType::Volta_16_16_4: return 4; - break; case MmaOptions::MacroType::Turing_16_8_16: case MmaOptions::MacroType::Ampere_16_8_16: return 4; @@ -196,6 +193,25 @@ bool isOperandTransposed(MmaOptions options) { return false; } +GemmTile getMmaOpShape(MmaOptions::MacroType macro) { + switch (macro) { + case MmaOptions::MacroType::Volta_16_16_4: + return {16, 16, 4}; + case MmaOptions::MacroType::Turing_16_8_16: + case MmaOptions::MacroType::Ampere_16_8_16: + return {16, 8, 16}; + case MmaOptions::MacroType::Turing_16_16_16: + case MmaOptions::MacroType::Ampere_16_16_16: + return {16, 16, 16}; + case MmaOptions::MacroType::Ampere_16_8_8: + return {16, 8, 8}; + case MmaOptions::MacroType::NoMMA: + return {1, 1, 1}; + } + + TORCH_INTERNAL_ASSERT(false, "unknown MMA macro"); +} + std::string toString(MmaOptions::MmaInputLayout input_layout) { std::stringstream ss; switch (input_layout) { @@ -238,4 +254,59 @@ std::string toString(MmaOptions::MacroType mt) { return ss.str(); } +std::string toString(const GemmTile& tile) { + std::stringstream ss; + ss << "[" << tile.m << ", " << tile.n << ", " << tile.k << "]"; + return ss.str(); +} + +std::string toString(const MatMulTileOptions& opts) { + std::stringstream ss; + ss << "MatMulTileOptions: " + << "instruction tile " << toString(opts.instruction_tile) << ", " + << "warp tile " << toString(opts.warp_tile) << ", " + << "CTA tile " << toString(opts.cta_tile); + return ss.str(); +} + +std::string toString(MmaOptions::MacroType mt, bool) { + switch (mt) { + case MmaOptions::MacroType::Ampere_16_8_8: + return "Ampere_16_8_8"; + case MmaOptions::MacroType::Ampere_16_8_16: + return "Ampere_16_8_16"; + case MmaOptions::MacroType::Ampere_16_16_16: + return "Ampere_16_16_16"; + case MmaOptions::MacroType::NoMMA: + return "NoOp"; + case MmaOptions::MacroType::Turing_16_8_16: + return "Turing_16_8_16"; + case MmaOptions::MacroType::Turing_16_16_16: + return "Turing_16_16_16"; + case MmaOptions::MacroType::Volta_16_16_4: + return "Volta_16_16_4"; + } + TORCH_INTERNAL_ASSERT(false, "Unsupported mma type"); + return "Unsupported"; +} + +size_t hash(MmaOptions::MacroType macro) { + return std::hash{}(static_cast(macro)); +} + +size_t hash(MmaOptions::MmaInputLayout input_layout) { + return std::hash{}(static_cast(input_layout)); +} + +size_t hash(const GemmTile& tile) { + return std::hash{}( + (static_cast(tile.m) << 32) + + (static_cast(tile.n) << 16) + (static_cast(tile.k))); +} + +size_t hash(const MatMulTileOptions& opts) { + return (hash(opts.instruction_tile) << 0) ^ (hash(opts.warp_tile) << 1) ^ + (hash(opts.cta_tile) << 2); +} + } // namespace nvfuser diff --git a/csrc/mma_type.h b/csrc/mma_type.h index 74fdf9d8a75..a642166ef39 100644 --- a/csrc/mma_type.h +++ b/csrc/mma_type.h @@ -16,15 +16,15 @@ struct GemmTile { int m, n, k; GemmTile(int m_, int n_, int k_) : m(m_), n(n_), k(k_) {} - bool operator==(const GemmTile& other) { + bool operator==(const GemmTile& other) const { return m == other.m && n == other.n && k == other.k; } - GemmTile operator/(const GemmTile& other) { + GemmTile operator/(const GemmTile& other) const { return GemmTile(m / other.m, n / other.n, k / other.k); } - std::vector toVector() { + std::vector toVector() const { return {m, n, k}; } }; @@ -186,9 +186,19 @@ int getOutputRegisterSize(MmaOptions::MacroType macro); int getInputARegisterSize(MmaOptions::MacroType macro); int getInputBRegisterSize(MmaOptions::MacroType macro); +// Unpack MMA op shape +GemmTile getMmaOpShape(MmaOptions::MacroType macro); + // MMA stringify utils std::string toString(MmaOptions::MacroType macro); std::string toString(MmaOptions::MmaInputLayout input_layout); -std::string toString(MmaOptions::MacroType mt); - +std::string toString(const GemmTile& tile); +std::string toString(const MatMulTileOptions& opts); +std::string toString(MmaOptions::MacroType macro, bool); + +// MMA hash utils +size_t hash(MmaOptions::MacroType macro); +size_t hash(MmaOptions::MmaInputLayout input_layout); +size_t hash(const GemmTile& tile); +size_t hash(const MatMulTileOptions& opts); } // namespace nvfuser diff --git a/csrc/scheduler/all_schedulers.h b/csrc/scheduler/all_schedulers.h index 9122c749117..b9774065280 100644 --- a/csrc/scheduler/all_schedulers.h +++ b/csrc/scheduler/all_schedulers.h @@ -6,6 +6,7 @@ */ // clang-format on #pragma once +#include #include #include #include @@ -19,7 +20,8 @@ enum class TORCH_CUDA_CU_API ScheduleHeuristic { PointWise, Reduction, Persistent, - Transpose + Transpose, + Matmul }; } // namespace nvfuser diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index 660b4a291a7..c6575b2ff34 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -7,8 +7,15 @@ // clang-format on #include #include +#include #include +// NOTE: included to avoid compilation error caused by missing destructor in +// 'SchedulerRuntimeInfo' +#include + +#include + namespace nvfuser { namespace { @@ -46,14 +53,35 @@ void moveInnerBroadcastLeft(TensorView* tv, int number_of_inner_pos = 3) { } // namespace -void scheduleMatmul( - TensorView* c, - TensorView* a, - TensorView* b, - MatmulParam& params) { - // Unpack from params. - auto& mma_builder = params.mma_builder; - auto& gemm_tile = params.tile_sizes; +void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { + const auto& inputs = fusion->inputs(); + const auto& outputs = fusion->outputs(); + + TORCH_INTERNAL_ASSERT( + inputs.size() == 2, + "scheduleMatmul supports only fusions with two inputs"); + TORCH_INTERNAL_ASSERT( + outputs.size() == 1, + "scheduleMatmul supports only fusions with single output"); + + TORCH_INTERNAL_ASSERT( + inputs[0]->isA(), + "fusion's first inpus is not an instance of TensorView class"); + TORCH_INTERNAL_ASSERT( + inputs[1]->isA(), + "fusion's second inpus is not an instance of TensorView class"); + TORCH_INTERNAL_ASSERT( + outputs[0]->isA(), + "fusion's output is not an instance of TensorView class"); + + TensorView* a = inputs[0]->as(); + TensorView* b = inputs[1]->as(); + TensorView* c = outputs[0]->as(); + + // Collect mma swizzle info + auto mma_builder = + MmaBuilder(params.mma_op, params.tile_sizes).layout(params.layout); + const auto& gemm_tile = params.tile_sizes; // Including current tensor naming convention for reference, // this is very temporary and will change over time and @@ -83,16 +111,12 @@ void scheduleMatmul( // Currently only support a, b, c as fusion inputs/outputs // aka. no prolog and epilog fusion yet. - TORCH_CHECK( - c->isFusionOutput() && a->isFusionInput() && b->isFusionInput(), - "not supporting matmul fusion yet"); - TORCH_CHECK(c->definition() && c->definition()->isA()); mma_builder.configureMma(c); // TODO: // Beyond this point, mma_builder really just becomes a populated - // list of parameters to describes the mma swizzles that should + // list of parameters to describe the mma swizzles that should // be annotated on the tensor domain. Conceptually the mma builder // object should be separated to 2 parts, one as scheduler utility // and the other as matmul heuristic parameters, which we are @@ -106,8 +130,7 @@ void scheduleMatmul( auto cc = c->cacheBefore(); // Get the input to the mma op. - auto mma = dynamic_cast(cc->definition()); - TORCH_INTERNAL_ASSERT(mma != nullptr); + auto mma = cc->definition()->as(); auto ab = mma->inA()->as(); auto bb = mma->inB()->as(); @@ -184,8 +207,7 @@ void scheduleMatmul( // Swizzle block tiles: if (params.grid_swizzle_factor != 1) { int factor = std::max(1, params.grid_swizzle_factor); // must be >=1 - if (params.rasterization_order == - MatmulParam::TileRasterizationOrder::RowMajor) { + if (params.cta_order == MatmulParams::TileRasterizationOrder::RowMajor) { cc->split(1, factor); // [I1, I2/factor, factor] cc->reorder({{1, 2}}); @@ -193,8 +215,7 @@ void scheduleMatmul( cc->merge(0); // [I1*factor, I2/factor] } else if ( - params.rasterization_order == - MatmulParam::TileRasterizationOrder::ColumnMajor) { + params.cta_order == MatmulParams::TileRasterizationOrder::ColumnMajor) { cc->split(0, factor); // [I1/factor, factor, I2] cc->reorder({{1, 2}}); @@ -311,18 +332,18 @@ void scheduleMatmul( // 0 1 2 3 4 5 6 7 8 9 10 // [Mo No Ko Kwo Mwo Nwo Mw Nw (Mi Ni Ki)] - if (params.rasterization_order == - MatmulParam::TileRasterizationOrder::RowMajor) { - cc->axis(0)->parallelize(ParallelType::BIDx); - cc->axis(1)->parallelize(ParallelType::BIDy); - } else if ( - params.rasterization_order == - MatmulParam::TileRasterizationOrder::ColumnMajor) { - cc->axis(0)->parallelize(ParallelType::BIDy); - cc->axis(1)->parallelize(ParallelType::BIDx); - } else { - TORCH_CHECK( - false, "Invalid TileRasterizationOrder passed to Matmul scheduler"); + switch (params.cta_order) { + case MatmulParams::TileRasterizationOrder::RowMajor: + cc->axis(0)->parallelize(ParallelType::BIDx); + cc->axis(1)->parallelize(ParallelType::BIDy); + break; + case MatmulParams::TileRasterizationOrder::ColumnMajor: + cc->axis(0)->parallelize(ParallelType::BIDy); + cc->axis(1)->parallelize(ParallelType::BIDx); + break; + default: + TORCH_INTERNAL_ASSERT( + false, "Invalid TileRasterizationOrder passed to Matmul scheduler"); } cc->axis(4)->parallelize(ParallelType::TIDz); @@ -330,11 +351,11 @@ void scheduleMatmul( // Propagate mma output swizzle and parallelization down the DAG if (params.double_buffer_options.double_buffer_smem_write) { - TORCH_CHECK( + TORCH_INTERNAL_ASSERT( params.double_buffer_options.smem_double_buffer_stage > 1, "Invalid buffer stage config") if (params.double_buffer_options.smem_double_buffer_stage > 2) { - TORCH_CHECK( + TORCH_INTERNAL_ASSERT( params.async_gmem_load_operands, "Circular buffer only supports async load"); } diff --git a/csrc/scheduler/matmul.h b/csrc/scheduler/matmul.h index e2e50e4333b..6e81abb185e 100644 --- a/csrc/scheduler/matmul.h +++ b/csrc/scheduler/matmul.h @@ -11,72 +11,12 @@ #include #include +#include namespace nvfuser { -//! Starting point for a matmul scheduler parameters: -class MatmulParam { - public: - MatmulParam(MmaBuilder builder) : mma_builder(builder) {} - - struct DoubleBufferOptions { - bool double_buffer_smem_write = false; - bool double_buffer_smem_read = false; - int smem_double_buffer_stage = 2; - }; - - //! Whether to rotate the ldmatrix out of the main loop - bool rotate_ldmatrix_out_of_main_loop = true; - - //! (Ampere+) Use cp.async to load operands. - bool async_gmem_load_operands = false; - - //! Specifies the tiling hierarchy on block, - //! warp, and instruction levels. - MatMulTileOptions tile_sizes; - - //! Parameters for configuring mma ops. - MmaBuilder mma_builder; - - //! Specify which tensor we double buffer. - DoubleBufferOptions double_buffer_options; - - //! Configurable rasterization/parallelization order. - //! Depending on the problem shape, switching blockIdx.x and blockIdx.y can - //! help improve L2 hit rate. - enum class TileRasterizationOrder { - RowMajor = 0, - ColumnMajor = 1 - } rasterization_order = TileRasterizationOrder::RowMajor; - - //! Swizzle factor is used to increase L2 hit rate. - //! It horizontally squeezes the grid so that gridDim.x is larger and - //! gridDim.y is smaller. - //! We rely on the observation that the CTAs are scheduled by the GPU by - //! iterating on gridDim.x first. As a result, as blocks are launched, they - //! will more likely be forming sub-tiles of the C matrix. This will increase - //! L2 hit rate/data reuse of A and B. - //! - //! Eg for grid_swizzle_factor=2: - //! A1 A2 B1 B2 --> A1 A2 A3 A4 B1 B2 B3 B4 - //! A3 A4 B3 B4 C1 C2 C3 C4 D1 D2 D3 D4 - //! C1 C2 D1 D2 - //! C3 C4 D3 D4 - int grid_swizzle_factor = 1; -}; - -//! Prototype auto scheduling function. -//! Currently only support a pure matmul with no -//! fused prolog or epilog. -//! -//! TODO: -//! - will support a range of fusions in a follow up -//! - will formalize scheduling decisions into -//! matmul params data structure. TORCH_CUDA_CU_API void scheduleMatmul( - TensorView* c_tv, - TensorView* a_tv, - TensorView* b_tv, - MatmulParam& params); + Fusion* fusion, + const MatmulParams& params); } // namespace nvfuser diff --git a/csrc/scheduler/matmul_heuristic.h b/csrc/scheduler/matmul_heuristic.h new file mode 100644 index 00000000000..fd452630e4b --- /dev/null +++ b/csrc/scheduler/matmul_heuristic.h @@ -0,0 +1,160 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include +#include +#include +#include + +#include +#include "type.h" + +namespace nvfuser { + +// Parameters of the matmul heuristic to describe the optimial schedule. +class MatmulParams : public HeuristicParams { + public: + //! A list of possible strategies used to define along which axis + //! parallelization will be done. + enum class TileRasterizationOrder { RowMajor = 0, ColumnMajor = 1 }; + + //! A wrapper for double buffering config pieces + struct DoubleBufferOptions { + bool double_buffer_smem_write = false; + bool double_buffer_smem_read = false; + int smem_double_buffer_stage = 2; + + bool operator==(const DoubleBufferOptions& other) const { + return other.double_buffer_smem_write == double_buffer_smem_write && + other.double_buffer_smem_read == double_buffer_smem_read && + other.smem_double_buffer_stage == smem_double_buffer_stage; + } + + std::string toString() const { + std::stringstream ss; + ss << "DoubleBufferOptions:\n" + << " double_buffer_smem_write: " + << (double_buffer_smem_write ? "true" : "false") << "\n" + << " double_buffer_smem_read: " + << (double_buffer_smem_read ? "true" : "false") << "\n" + << " smem_double_buffer_stage: " << smem_double_buffer_stage; + return ss.str(); + } + + size_t hash() const { + return std::hash{}( + (static_cast(smem_double_buffer_stage) << 2) | + (static_cast(double_buffer_smem_write)) << 1) | + (static_cast(double_buffer_smem_read)); + } + }; + + //! Whether to rotate the ldmatrix out of the main loop + bool rotate_ldmatrix_out_of_main_loop = true; + + //! (Ampere+) Use cp.async to load operands. + bool async_gmem_load_operands = false; + + //! Specifies the tiling hierarchy on block, + //! warp, and instruction levels. + MatMulTileOptions tile_sizes = {}; + + //! Specify the type of MMA op to be used in generated kernel. + MmaOptions::MacroType mma_op = MmaOptions::MacroType::NoMMA; + + //! Specify the input layout of input tensors. + MmaOptions::MmaInputLayout layout = + static_cast(-1); + + //! Specify CTA rastrization order. + TileRasterizationOrder cta_order = TileRasterizationOrder::RowMajor; + + //! Specify which tensor we double buffer. + DoubleBufferOptions double_buffer_options = {}; + + //! Swizzle factor is used to increase L2 hit rate. + //! It horizontally squeezes the grid so that gridDim.x is larger and + //! gridDim.y is smaller. + //! We rely on the observation that the CTAs are scheduled by the GPU by + //! iterating on gridDim.x first. As a result, as blocks are launched, they + //! will more likely be forming sub-tiles of the C matrix. This will increase + //! L2 hit rate/data reuse of A and B. + //! + //! Eg for grid_swizzle_factor=2: + //! A1 A2 B1 B2 --> A1 A2 A3 A4 B1 B2 B3 B4 + //! A3 A4 B3 B4 C1 C2 C3 C4 D1 D2 D3 D4 + //! C1 C2 D1 D2 + //! C3 C4 D3 D4 + int grid_swizzle_factor = 1; + + std::string toString() const override { + std::stringstream ss; + ss << "\n===== Matmul Parameters ========\n" + << (tag.empty() ? "" : "Tag: ") << tag << "\n" + << "MMA op: " << nvfuser::toString(mma_op, true) << "\n" + << "Layout: " << nvfuser::toString(layout) << "\n" + << double_buffer_options.toString() << "\n" + << nvfuser::toString(tile_sizes) << "\n" + << "Rotate ldmatrix out of main loop: " + << (rotate_ldmatrix_out_of_main_loop ? "true" : "false") << "\n" + << "Async global mem load: " + << (async_gmem_load_operands ? "true" : "false") << "\n" + << "Indexing mode: " + << "Tile rastrization order: " + << ((cta_order == TileRasterizationOrder::RowMajor) ? "row-major" + : "column-major") + << "Grid swizzle factor: " << grid_swizzle_factor + << (cparams.index_type.has_value() + ? (cparams.index_type.value() == PrimDataType::Int ? "int64_t" + : "int32_t") + : "unavailable") + << "\n" + << "====================================\n"; + return ss.str(); + } + + size_t hash() const override { + // combine boolean flags for hashing + size_t attr_hash = + (static_cast(rotate_ldmatrix_out_of_main_loop) << 1) | + (static_cast(async_gmem_load_operands)); + + // combined hash + attr_hash = std::hash{}(attr_hash) ^ (nvfuser::hash(mma_op) << 1) ^ + (nvfuser::hash(layout) << 2) ^ (double_buffer_options.hash() << 3) ^ + (nvfuser::hash(tile_sizes) << 4) ^ + (std::hash{}(static_cast(cta_order)) << 5) ^ + (std::hash{}(grid_swizzle_factor) << 6); + return attr_hash; + } + + bool sameAs( + const std::shared_ptr& other_base) const override { + auto other_casted = std::dynamic_pointer_cast(other_base); + if (other_casted == nullptr) { + return false; + } + + return other_casted->layout == layout && other_casted->mma_op == mma_op && + other_casted->async_gmem_load_operands == async_gmem_load_operands && + other_casted->rotate_ldmatrix_out_of_main_loop == + rotate_ldmatrix_out_of_main_loop && + other_casted->tile_sizes == tile_sizes && + other_casted->double_buffer_options == double_buffer_options && + other_casted->cta_order == cta_order && + other_casted->grid_swizzle_factor == grid_swizzle_factor; + } + + std::shared_ptr clone() const override { + return std::make_shared(*this); + } +}; + +} // namespace nvfuser diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp new file mode 100644 index 00000000000..068ef4fad22 --- /dev/null +++ b/csrc/scheduler/matmul_utils.cpp @@ -0,0 +1,634 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include +#include + +// NOTE: included to avoid compilation error caused by missing destructor in +// 'SchedulerRuntimeInfo' +#include +#include +#include +#include +#include +#include +#include +#include +#include "ATen/cuda/CUDAContext.h" +#include "c10/util/Optional.h" +#include "ir_base_nodes.h" +#include "ir_interface_nodes.h" +#include "ir_internal_nodes.h" +#include "ir_utils.h" +#include "mma_type.h" +#include "type.h" +#include "utils.h" + +namespace nvfuser { +namespace { + +using MatmulLayout = MmaOptions::MmaInputLayout; +using LayoutData = + std::pair, c10::optional>; +using TensorShape = std::vector; +using ProblemShape = TensorShape; + +//! A constant with position of M value (a number of columns in A tensor for TT +//! layout) in problem in ProblemShape type. +constexpr size_t M_POS = 0; +//! A constant with position of N value (a number of rows in B tensor for TT +//! layout) in problem in ProblemShape type. +constexpr size_t N_POS = 1; +//! A constant with position of K value (a number of rows in A tensor for TT +//! layout) in problem in ProblemShape type. +constexpr size_t K_POS = 2; +//! A constant with expected number of dimensions in ProblemShape type. +constexpr size_t PROBLEM_DIMS = 3; + +// TODO: helpers to be moved to 'iter_visitor.h' +std::deque> getAllDepndencyChains( + const std::vector& producers, + const std::vector& consumers) { + std::deque> all_paths; + for (auto* consumer : consumers) { + for (auto* producer : producers) { + auto paths = DependencyCheck::getAllDependencyChains(producer, consumer); + if (paths.empty()) { + continue; + } + all_paths.insert( + all_paths.end(), + std::make_move_iterator(paths.begin()), + std::make_move_iterator(paths.end())); + } + } + + return all_paths; +} + +//! A wrapper for printing debug details. +void printMsg(const std::string& msg) { + std::cout << msg << std::endl; +} + +//! A helper for deciding what kernel indexing mode use (int32_t or int64_t). +//! TODO: add strides to handle non-continous tensors +PrimDataType getIndexType(const ProblemShape& problem_shape) { + // based on collectIndexMode function + constexpr int64_t most_positive_int32_index = + std::numeric_limits::max() / 2; + + const auto m = static_cast(problem_shape[M_POS]); + const auto n = static_cast(problem_shape[N_POS]); + const auto k = static_cast(problem_shape[K_POS]); + + const bool use_i64_index = m * k > most_positive_int32_index || // tensor A + k * n > most_positive_int32_index || // tensor B + m * n > most_positive_int32_index; // output tensor + + return use_i64_index ? PrimDataType::Int : PrimDataType::Int32; +} + +//! A helper for deciding the type of MMA op for given fusion and problem shape. +inline c10::optional getMmaOp( + const int dev_version, + const ProblemShape& problem) { + using MacroType = MmaOptions::MacroType; + + TORCH_INTERNAL_ASSERT( + problem.size() == PROBLEM_DIMS, + "Invalid size of problem shape (number of dimensions)"); + + // NOTE: A temp condition + const bool use_small_n = + ((problem[N_POS] % 8) == 0) && ((problem[N_POS] % 16) != 0); + + switch (dev_version) { + case 70: + return MacroType::Volta_16_16_4; + case 75: + return (use_small_n) ? MacroType::Turing_16_8_16 + : MacroType::Turing_16_16_16; + case 80: + return (use_small_n) ? MacroType::Ampere_16_8_16 + : MacroType::Ampere_16_16_16; + default: + break; + } + return c10::nullopt; +} + +//! A helper for checking if layout of MMA op's inputs. It will return optional +//! message if check fails. +LayoutData getInputsLayout(const MmaOp* mma_expr) { + std::stringstream ss; + const auto& mmaExprInputs = mma_expr->inputs(); + + const auto* in_A = mmaExprInputs[0]->as(); + const auto* in_B = mmaExprInputs[1]->as(); + + // The number of IterDomains of MMA inputs must be the same + if (in_A->nDims() != in_B->nDims()) { + ss << "Mma op inputs don't have the same number of IterDomains, 1st input(" + << std::to_string(in_A->nDims()) << "), 2nd input(" + << std::to_string(in_B->nDims()) + ")"; + return {c10::nullopt, ss.str()}; + } + + // The currently supported number of IterDomains per MMA op input is 3 + constexpr size_t supportedDims = 3; + if (in_A->nDims() != supportedDims) { + ss << "Mma op inputs have unsupported number of IterDomains, got: " + << std::to_string(in_A->nDims()) << ", expected " + << std::to_string(supportedDims); + return {c10::nullopt, ss.str()}; + } + + using AxisPos = decltype(std::declval().nDims()); + constexpr AxisPos unInitPos = -1; + AxisPos bcastInApos = unInitPos; + AxisPos bcastInBpos = unInitPos; + + // The first and the second input of MMA have the same number of + // IterDomains + for (AxisPos pos = 0; pos < in_A->nDims(); ++pos) { + if (in_A->axis(static_cast(pos))->isBroadcast()) { + if (bcastInApos != unInitPos) { + ss << "Mma op first input has more than one broadcast IterDomain: " + << std::to_string(bcastInApos) << " and " << std::to_string(pos); + return {c10::nullopt, ss.str()}; + } + bcastInApos = pos; + } + if (in_B->axis(static_cast(pos))->isBroadcast()) { + if (bcastInBpos != unInitPos) { + ss << "Mma op second input has more than one broadcast IterDomain: " + << std::to_string(bcastInBpos) << " and " << std::to_string(pos); + return {c10::nullopt, ss.str()}; + } + bcastInBpos = pos; + } + } + + // MMA inputs need to have broadcast IterDomains + if (bcastInApos == unInitPos || bcastInBpos == unInitPos) { + ss << "The " << (bcastInApos == unInitPos ? "first" : "second") + << " mma op has no broadcast IterDomain"; + return {c10::nullopt, ss.str()}; + } + + // MMA inputs must have supported data layout, defined in MatmulLayout + // MatmulLayout::TT + if (bcastInApos == static_cast(2) && + bcastInBpos == static_cast(0)) { + return {MatmulLayout::TT, c10::nullopt}; + } + // MatmulLayout::TN + if (bcastInApos == static_cast(1) && + bcastInBpos == static_cast(0)) { + return {MatmulLayout::TN, c10::nullopt}; + } + // MatmulLayout::NT + if (bcastInApos == static_cast(2) && + bcastInBpos == static_cast(1)) { + return {MatmulLayout::NT, c10::nullopt}; + } + + ss << "Unsupported layout, broadcasts: inputA(" << bcastInApos << "), inputB(" + << bcastInBpos << ")"; + return {c10::nullopt, ss.str()}; +} + +//! A wrapper for core heuristics initialization +inline bool initCoreHeuristics( + std::shared_ptr params, + const MmaOptions::MacroType& mma_op, + const MatmulLayout& layout, + const ProblemShape& problem_shape) { + const GemmTile instruction_tile = getMmaOpShape(mma_op); + GemmTile warp_tile = {-1, -1, -1}; + GemmTile cta_tile = {-1, -1, -1}; + + using DimType = decltype(GemmTile::m); + + // warp tile shape + { + if (isAmpere(mma_op)) { + // Initial target: + // - 1 MMA ops per thread in a warp (32 threads), warp tile should be + // then 32x bigger than instruction tile, + // - start with [4, 4, 2] shape, later it should depend on problem + // shape and have bigger impact on CTA tile shape + + const DimType m_ratio = 4; + const DimType n_ratio = 4; + const DimType k_ratio = 2; + + warp_tile = { + instruction_tile.m * m_ratio, + instruction_tile.n * n_ratio, + instruction_tile.k * k_ratio}; + } else { + // No support for Volta and Turing + return false; + } + } + + // cta tile shape + { + // Initial target: + // - 4 warp tiles per CTA + // - CTA k-dim should be same as warp tile k-dim + + DimType m_ratio = 2; + DimType n_ratio = 2; + + const auto mn_ratio = + (double)problem_shape[M_POS] / (double)problem_shape[N_POS]; + if (mn_ratio < 0.5) { + m_ratio = 1; + n_ratio = 4; + } else if (mn_ratio > 2) { + m_ratio = 4; + n_ratio = 1; + } + + cta_tile = {warp_tile.m * m_ratio, warp_tile.n * n_ratio, warp_tile.k}; + } + + params->mma_op = mma_op; + params->layout = layout; + params->tile_sizes = {cta_tile, warp_tile, instruction_tile}; + + return true; +} + +//! A wrapper for additional heuristics initialization +inline bool initExtraHeuristics( + std::shared_ptr params, + const ProblemShape& problem_shape) { + // TODO: add logic to calculate efficient number of stages + constexpr int stages = 3; + + params->async_gmem_load_operands = true; + params->double_buffer_options.double_buffer_smem_write = true; + params->double_buffer_options.double_buffer_smem_read = true; + params->double_buffer_options.smem_double_buffer_stage = stages; + + return true; +} + +//! A helper for getting problem shape from fusion and runtime info. Operation +//! can fail and nullopt object is returned. +c10::optional getProblemShape( + Fusion* fusion, + const MmaOp* mma_expr, + SchedulerRuntimeInfo& runtime_info, + const MatmulLayout matmul_layout) { + const auto& fusion_inputs = fusion->inputs(); + const auto& fusion_outputs = fusion->outputs(); + const auto& mma_inputs = mma_expr->inputs(); + const auto& mma_outputs = mma_expr->outputs(); + + // It is an unsupported fusion if + // - there are more than one fusion input TensorViews (producers) + // for MMA op input + // - there are more than one fusion output TensorViews (consumers) + // MMA op output + const auto getKeyTvFromPathBetween = + [](const std::vector& producers, + const std::vector& consumers) -> Val* { + const auto paths = getAllDepndencyChains(producers, consumers); + + if (paths.empty()) { + return nullptr; + } + + std::vector tvs; + for (const auto& path : paths) { + if (path.empty()) { + continue; + } + if (path.front()->isA()) { + tvs.push_back(path.front()); + } + } + return (tvs.size() == 1) ? tvs[0] : nullptr; + }; + + const auto* tv_input_A = + getKeyTvFromPathBetween(fusion_inputs, {mma_inputs[0]}); + if (nullptr == tv_input_A) { + return c10::nullopt; + } + + const auto* tv_input_B = + getKeyTvFromPathBetween(fusion_inputs, {mma_inputs[1]}); + if (nullptr == tv_input_B) { + return c10::nullopt; + } + + const auto* tv_output = + getKeyTvFromPathBetween({mma_outputs[0]}, fusion_outputs); + if (nullptr == tv_output) { + return c10::nullopt; + } + + // A helper for populating concrete domains from TensorView + const auto getShape = [&runtime_info](const TensorView* tv) { + TensorShape tv_shape; + const auto concrete_domains = TensorDomain::noReductions( + TensorDomain::noBroadcasts(tv->as()->domain()->domain())); + for (const auto* domain : concrete_domains) { + const auto domain_extend = + runtime_info.expressionEvaluator().evaluate(domain->extent()); + if (domain_extend) { + tv_shape.push_back(domain_extend->as()); + } + } + return tv_shape; + }; + + const auto& in_A = getShape(tv_input_A->as()); + const auto& in_B = getShape(tv_input_B->as()); + const auto& output = getShape(tv_output->as()); + + constexpr size_t expected_dims = 2; + if (in_A.size() != expected_dims || // + in_B.size() != expected_dims || // + output.size() != expected_dims) { + return c10::nullopt; + } + + switch (matmul_layout) { + case MatmulLayout::TT: { + // in_A := [M, K] + // in_B := [K, N] + // output := [M, N] + const bool check_k = in_A[1] == in_B[0]; + const bool check_m = in_A[0] == output[0]; + const bool check_n = in_B[1] == output[1]; + if (!(check_k && check_m && check_n)) { + return c10::nullopt; + } + // [M, N, K] + return TensorShape{output[0], output[1], in_A[1]}; + } + case MatmulLayout::NT: { + // in_A := [K, M] + // in_B := [K, N] + // output := [M, N] + const bool check_k = in_A[0] == in_B[0]; + const bool check_m = in_A[1] == output[0]; + const bool check_n = in_B[1] == output[1]; + if (!(check_k && check_m && check_n)) { + return c10::nullopt; + } + // [M, N, K] + return TensorShape{output[0], output[1], in_A[0]}; + } + case MatmulLayout::TN: { + // in_A := [M, K] + // in_B := [N, K] + // output := [M, N] + const bool check_k = in_A[1] == in_B[1]; + const bool check_m = in_A[0] == output[0]; + const bool check_n = in_B[0] == output[1]; + if (!(check_k && check_m && check_n)) { + return c10::nullopt; + } + // [M, N, K] + return TensorShape{output[0], output[1], in_A[1]}; + } + default: + return c10::nullopt; + } + return c10::nullopt; +} + +std::string checkMatmulType(Fusion* fusion, const MmaOp* mma_expr) { + const auto& fusion_inputs = fusion->inputs(); + const auto& fusion_outputs = fusion->outputs(); + const auto& mma_inputs = mma_expr->inputs(); + const auto& mma_outputs = mma_expr->outputs(); + + const auto fusion_inputs_tvs = + ir_utils::filterByType(fusion_inputs).vector(); + const auto fusion_outputs_tvs = + ir_utils::filterByType(fusion_outputs).vector(); + + using DimSizeType = std::decay::type::size_type; + + static_assert( + std::is_same< + DimSizeType, + std::decay::type::size_type>::value, + "The type used to define the number of dimension in input and output TV must be the same."); + + constexpr DimSizeType expected_gemm_dims = static_cast(2); + constexpr size_t expected_number_of_inputs = 2; + constexpr size_t expected_number_of_outputs = 1; + + // Quick checks + { + // Fusion can only have two TV inputs + if (fusion_inputs.size() != fusion_inputs_tvs.size()) { + return "Fusion inputs contain at least one non-TensorView object"; + } + if (expected_number_of_inputs != fusion_inputs.size()) { + return "Fusion inputs contain at least one non-TensorView object"; + } + + // Fusion can only have TVs as outputs, and there can be only one output + if (fusion_outputs_tvs.size() != fusion_outputs.size()) { + return "Fusion has output which is not a TensorView object"; + } + if ((expected_number_of_outputs != fusion_outputs_tvs.size())) { + return "Fusion has more than a single TensorView object in outputs"; + } + + // Each of fusion input TVs must have: + // - 2 concrete domains, + // - no broadcasts domain, + for (const auto tv : fusion_inputs_tvs) { + if (tv->hasBroadcast()) { + return "Fusion input TV has broadcast domain"; + } + const auto result = + TensorDomain::noReductions( + TensorDomain::noBroadcasts(tv->domain()->domain())) + .size(); + if (result != expected_gemm_dims) { + return "Fusion input TV has unsupported number of domains"; + } + } + + // Each of fusion output TVs must have: + // - 2 concrete domains, + // - reduction domain, + // - no broadcast domain, + for (const auto tv : fusion_outputs_tvs) { + if (tv->hasBroadcast()) { + return "Fusion output TV has broadcast domain"; + } + if (!tv->hasReduction()) { + return "Fusion output TV has no reduction domain"; + } + const auto result = + TensorDomain::noReductions( + TensorDomain::noBroadcasts(tv->domain()->domain())) + .size(); + if (result != expected_gemm_dims) { + return "Fusion output TV has unsupported number of domains"; + } + } + } + + // MmaOp inputs/outputs dependencies check + { + // Check the expected path between MmaOp input and fusion inputs + const auto areMmaOpInputDependeciesValid = [](const Val* val) { + if (val->definition()->isA()) { + const auto& bcast_inputs = val->definition()->inputs(); + // BroadcastOp has single input/output, not need to check other things + return bcast_inputs.front()->isFusionInput(); + } + return false; + }; + + // MmaOp input is a result of broadcast op with input being fusion input + for (const auto* mma_in : mma_inputs) { + if (!areMmaOpInputDependeciesValid(mma_in)) { + return "MmaOp input has unsupported dependency"; + } + } + + // MmaOp output must be a fusion output + if (!mma_outputs.front()->isFusionOutput()) { + return "Mma op output does not belong to fusion outputs"; + } + } + + return ""; +} + +} // anonymous namespace + +std::string getMatmulRunTimeRejectReason( + Fusion* fusion, + HeuristicSummary* data_cache, + SchedulerRuntimeInfo& runtime_info) { + // TODO: add proper set of checks + return ""; +} + +std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { + // The plan: + // 1. check if there is exactly one MmaOp defined in the fusion + // 2. check if MmaOp inputs match any of supported inputs layout + // 3. check if fusion represents expressions that are recognized by matmul + // scheduler + + // #1 + auto mma_exprs = ir_utils::getMmaOps(fusion); + if (mma_exprs.size() != 1) { + std::stringstream ss; + ss << "Matmul scheduler supports fusions only with a single MMA op, got: " + << mma_exprs.size(); + return ss.str(); + } + + // #2 + { + for (const auto* mma_expr : mma_exprs) { + const auto layout_data = getInputsLayout(mma_expr); + if (layout_data.second) { + return layout_data.second.value(); + } + } + } + + // #3 + { + for (auto mma_expr : mma_exprs) { + auto matmul_status = checkMatmulType(fusion, mma_expr); + if (!matmul_status.empty()) { + return matmul_status; + } + } + } + + return ""; +} + +std::shared_ptr getMatmulHeuristics( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache) { + FusionGuard fg(fusion); + (void)data_cache; + (void)runtime_info; + auto params = std::make_shared(); + + // Check initial conditions + const auto fusion_exprs = fusion->exprs(); + auto mma_exprs = ir_utils::filterByType(fusion_exprs).vector(); + if (mma_exprs.size() != 1) { + // Support only for fusion with a single mma op + return nullptr; + } + + const auto layout = getInputsLayout(mma_exprs.front()); + if (layout.second) { + // Layout check returned an error message + if (isDebugDumpEnabled(DebugDumpOption::MatmulChecks)) { + printMsg(layout.second.value()); + } + return nullptr; + } + + const auto problem_shape = getProblemShape( + fusion, mma_exprs[0]->as(), runtime_info, layout.first.value()); + if (!problem_shape) { + // Failed to acquire problem shape + return nullptr; + } + + const auto device_prop = at::cuda::getCurrentDeviceProperties(); + const auto mma_op = getMmaOp( + device_prop->major * 10 + device_prop->minor, problem_shape.value()); + if (!mma_op) { + // No heuristics can be prepared if mma op request is empty + return nullptr; + } + + // Populate heuristic details + auto status = initCoreHeuristics( + params, mma_op.value(), layout.first.value(), problem_shape.value()); + if (!status) { + // Core part of heuristics failed to initialize + return nullptr; + } + + status = initExtraHeuristics(params, problem_shape.value()); + if (!status) { + // Additional pieces of heuristics failed to initialize + return nullptr; + } + + // set kernel index mode + params->cparams.index_type = getIndexType(problem_shape.value()); + + if (isDebugDumpEnabled(DebugDumpOption::MatmulChecks)) { + printMsg(params->toString()); + } + + return params; +} + +} // namespace nvfuser diff --git a/csrc/scheduler/matmul_utils.h b/csrc/scheduler/matmul_utils.h new file mode 100644 index 00000000000..627c5fc7558 --- /dev/null +++ b/csrc/scheduler/matmul_utils.h @@ -0,0 +1,37 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +namespace nvfuser { + +class SchedulerRuntimeInfo; +class HeuristicSummary; +class MatmulParams; + +//! An implementation of functionality that will prepare heuristics for fusion +//! that represents matmul. May return empty object if any of conditions are +//! not met. +TORCH_CUDA_CU_API std::shared_ptr getMatmulHeuristics( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache = nullptr); + +//! An implementation of compile time checks. Returns messasge if given fusion +//! does not represent matmul, otherwise an empty string is returned. +TORCH_CUDA_CU_API std::string getMatmulCompileTimeRejectReason(Fusion* fusion); + +//! An implementation of runtime time checks. Returns messasge if given fusion +//! does not represent matmul, otherwise an empty string is returned. +TORCH_CUDA_CU_API std::string getMatmulRunTimeRejectReason( + Fusion* fusion, + HeuristicSummary* data_cache, + SchedulerRuntimeInfo& runtime_info); + +} // namespace nvfuser diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index 9d0af53e6f2..01bd2c3ee76 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -164,7 +165,7 @@ class SchedulerTopologyChecker { auto p_id = entry.first; auto c_id = entry.second; if (p_id->isBroadcast() && !c_id->isBroadcast()) { - ids_to_resolve.emplace_back(std::make_pair(c_id, c_id)); + ids_to_resolve.emplace_back(c_id, c_id); } } @@ -264,7 +265,7 @@ class SchedulerTopologyChecker { // if all ids were not resolved, then we've found an instance of a // bad broadcast resolution after reduction - if (ids_to_resolve.size()) { + if (!ids_to_resolve.empty()) { return true; } @@ -322,7 +323,7 @@ class SchedulerTopologyChecker { static bool supportedPostReductionFusion( Fusion* fusion, std::vector reduction_tvs) { - TORCH_INTERNAL_ASSERT(reduction_tvs.size()); + TORCH_INTERNAL_ASSERT(!reduction_tvs.empty()); bool fastest_dim_reduction = true; auto red_root_dom = reduction_tvs[0]->getRootDomain(); for (size_t i = red_root_dom.size(); i > 0; i--) { @@ -683,7 +684,7 @@ bool reductionInterferingView( } // Don't add empty group (would happen if it's a 2D scheduler not 3D) - if (current_dims.size() > 0) { + if (!current_dims.empty()) { groups.push_back(current_dims); dims = remove_dims(dims, processed); } @@ -701,7 +702,7 @@ bool reductionInterferingView( // Convert id's in groups to disjoint_set_ids of disjoint_set_information std::vector> disjoint_groups; - for (auto group : groups) { + for (const auto& group : groups) { std::vector disjoint_id_sets; for (auto id : group) { auto find_it = std::find( @@ -1275,9 +1276,16 @@ class ReductionScheduler : public SchedulerEntry { return false; } + // Fusions handled by reduction scheduler cannot have MmaOp. + if (!ir_utils::getMmaOps(fusion).empty()) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Reduction, "no support for mma ops."); + return false; + } + auto reduction_tvs = scheduler_utils::getReductionTvs(fusion); - if (reduction_tvs.size() == 0) { + if (reduction_tvs.empty()) { // Use pointwise logic return false; } @@ -1289,7 +1297,7 @@ class ReductionScheduler : public SchedulerEntry { return false; } - if (ir_utils::getViewOps(fusion).size() > 0) { + if (!ir_utils::getViewOps(fusion).empty()) { ComputeAtMap ca_map(fusion); if (requiresForwardViewReplay(fusion, ca_map)) { scheduler_debug_utils::canScheduleRejectReason( @@ -1365,7 +1373,7 @@ class ReductionScheduler : public SchedulerEntry { // Doesn't allow persistent kernels in this scheduler auto persistent_buffer_info = scheduler_utils::persistentBuffers(fusion); - if (persistent_buffer_info.persistent_buffers.size() > 0) { + if (!persistent_buffer_info.persistent_buffers.empty()) { scheduler_debug_utils::canScheduleRejectReason( ScheduleHeuristic::Reduction, "need persistent buffers that reduction scheduler doesn't handle"); @@ -1420,7 +1428,7 @@ class TransposeScheduler : public SchedulerEntry { // Temporarily disallow view in transpose scheduler // TODO Add more testing before enabling auto view_tvs = scheduler_utils::getViewTVs(fusion); - if (view_tvs.size() > 0) { + if (!view_tvs.empty()) { scheduler_debug_utils::canScheduleRejectReason( ScheduleHeuristic::Transpose, "No support for view op"); return false; @@ -1431,6 +1439,13 @@ class TransposeScheduler : public SchedulerEntry { return false; } + // Fusions handled by transpose scheduler cannot have MmaOp. + if (!ir_utils::getMmaOps(fusion).empty()) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Transpose, "no support for mma ops."); + return false; + } + for (auto select : ir_utils::getSelectOps(fusion)) { auto root = TensorDomain::noReductions( select->input(0)->as()->getMaybeRFactorDomain()); @@ -1546,7 +1561,14 @@ class PointWiseScheduler : public SchedulerEntry { return false; } - if (ir_utils::getViewOps(fusion).size() > 0) { + // Fusions handled by pointwise scheduler cannot have MmaOp. + if (!ir_utils::getMmaOps(fusion).empty()) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::PointWise, "no support for mma ops."); + return false; + } + + if (!ir_utils::getViewOps(fusion).empty()) { ComputeAtMap ca_map(fusion); if (requiresForwardViewReplay(fusion, ca_map)) { scheduler_debug_utils::canScheduleRejectReason( @@ -1636,6 +1658,13 @@ class PersistentKernelScheduler : public SchedulerEntry { return false; } + // Fusions handled by persistent kernel scheduler cannot have MmaOp. + if (!ir_utils::getMmaOps(fusion).empty()) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Persistent, "no support for mma ops."); + return false; + } + if (hasNonUniqueBcast(fusion)) { scheduler_debug_utils::canScheduleRejectReason( ScheduleHeuristic::Persistent, @@ -1645,14 +1674,14 @@ class PersistentKernelScheduler : public SchedulerEntry { auto reduction_tvs = scheduler_utils::getReductionTvs(fusion); - if (reduction_tvs.size() == 0) { + if (reduction_tvs.empty()) { // Use pointwise logic scheduler_debug_utils::canScheduleRejectReason( ScheduleHeuristic::Persistent, "no reduction tv"); return false; } - if (ir_utils::getViewOps(fusion).size() > 0) { + if (!ir_utils::getViewOps(fusion).empty()) { ComputeAtMap ca_map(fusion); if (requiresForwardViewReplay(fusion, ca_map)) { scheduler_debug_utils::canScheduleRejectReason( @@ -1721,7 +1750,7 @@ class PersistentKernelScheduler : public SchedulerEntry { // Only accept persistent kernels auto persistent_buffer_info = scheduler_utils::persistentBuffers(fusion); - if (persistent_buffer_info.persistent_buffers.size() == 0) { + if (persistent_buffer_info.persistent_buffers.empty()) { scheduler_debug_utils::canScheduleRejectReason( ScheduleHeuristic::Persistent, "no persistent buffer identified"); return false; @@ -2054,6 +2083,57 @@ class PersistentKernelScheduler : public SchedulerEntry { } }; +class MatmulScheduler : public SchedulerEntry { + public: + explicit MatmulScheduler( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache = nullptr) + : SchedulerEntry(ScheduleHeuristic::Matmul) { + computeHeuristics(fusion, runtime_info); + } + + void schedule(Fusion* fusion) override { + FUSER_PERF_SCOPE("Schedule Matmul Fusion"); + scheduleMatmul(fusion, matmulParams()); + } + + static bool canScheduleCompileTime(Fusion* fusion) { + const auto msg = getMatmulCompileTimeRejectReason(fusion); + if (!msg.empty()) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Matmul, msg); + return false; + } + + return true; + } + + static bool canScheduleRunTime( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache = nullptr) { + FUSER_PERF_SCOPE("MatmulScheduler::canSchedule"); + auto reason = + getMatmulRunTimeRejectReason(fusion, data_cache, runtime_info); + if (!reason.empty()) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Matmul, reason); + return false; + } + return true; + } + + private: + void computeHeuristics( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache = nullptr) { + params_ = getMatmulHeuristics(fusion, runtime_info, data_cache); + TORCH_INTERNAL_ASSERT(params_ != nullptr); + } +}; + // Schedule Table const std::vector& all_heuristics() { static const std::vector hlist = { @@ -2061,7 +2141,8 @@ const std::vector& all_heuristics() { ScheduleHeuristic::Reduction, ScheduleHeuristic::Transpose, ScheduleHeuristic::PointWise, - ScheduleHeuristic::Persistent}; + ScheduleHeuristic::Persistent, + ScheduleHeuristic::Matmul}; return hlist; } @@ -2114,6 +2195,9 @@ bool SchedulerEntry::canSchedule( case ScheduleHeuristic::Transpose: return checkCanSchedule( fusion, runtime_info, data_cache); + case ScheduleHeuristic::Matmul: + return checkCanSchedule( + fusion, runtime_info, data_cache); default: TORCH_INTERNAL_ASSERT(false, "unreachable"); return false; @@ -2148,6 +2232,10 @@ std::unique_ptr SchedulerEntry::makeEntry( scheduler_entry = std::make_unique( fusion, runtime_info, data_cache); break; + case ScheduleHeuristic::Matmul: + scheduler_entry = + std::make_unique(fusion, runtime_info, data_cache); + break; default: TORCH_INTERNAL_ASSERT(false, "unreachable"); } @@ -2184,6 +2272,8 @@ std::string toString(ScheduleHeuristic sh) { return "persistent"; case ScheduleHeuristic::Transpose: return "transpose"; + case ScheduleHeuristic::Matmul: + return "matmul"; default: TORCH_INTERNAL_ASSERT(false, "undefined schedule"); } @@ -2221,8 +2311,7 @@ HeuristicSummary::HeuristicSummary( Fusion* fusion, ScheduleHeuristic heuristic, SchedulerRuntimeInfo& runtime_info) - : heuristic_(heuristic) { - recording_ = true; + : heuristic_(heuristic), recording_(true) { switch (heuristic) { case ScheduleHeuristic::NoOp: NoOpScheduler::canScheduleRunTime(fusion, runtime_info, this); @@ -2243,6 +2332,15 @@ HeuristicSummary::HeuristicSummary( getTransposeHeuristics(fusion, runtime_info, this); TransposeScheduler::canScheduleRunTime(fusion, runtime_info, this); break; + case ScheduleHeuristic::Matmul: { + const auto heuristics = getMatmulHeuristics(fusion, runtime_info, this); + TORCH_INTERNAL_ASSERT(heuristics, "Failed to get matmul heuristics"); + const auto canSchedule = + MatmulScheduler::canScheduleRunTime(fusion, runtime_info, this); + TORCH_INTERNAL_ASSERT( + canSchedule, "Could not schedule matmul (run time)"); + break; + } default: TORCH_INTERNAL_ASSERT(false, "unknown heuristic"); } @@ -2317,6 +2415,10 @@ void HeuristicSummary::validate() const { entry_type_map_.count(EntryType::SCOPE_PERSISTENT_FACTOR_INFO)); break; } + case ScheduleHeuristic::Matmul: { + // TODO: add a proper set of checks + break; + } default: TORCH_INTERNAL_ASSERT(false, "unknown heuristic"); } diff --git a/csrc/scheduler/registry.h b/csrc/scheduler/registry.h index 93b4ba5058d..9b3d416514f 100644 --- a/csrc/scheduler/registry.h +++ b/csrc/scheduler/registry.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -205,6 +206,13 @@ class TORCH_CUDA_CU_API SchedulerEntry { return *tparams; } + const MatmulParams& matmulParams() const { + auto mparams = std::dynamic_pointer_cast(params_); + TORCH_INTERNAL_ASSERT( + mparams != nullptr, "Heuristic parameter is not a matmul parameter"); + return *mparams; + } + void updateLaunchConstraint(const LaunchParams& launch_params) { params_->lparams = launch_params; } diff --git a/csrc/utils.cpp b/csrc/utils.cpp index 4aa7534c8c6..72ffa6e168d 100644 --- a/csrc/utils.cpp +++ b/csrc/utils.cpp @@ -137,7 +137,8 @@ auto parseDebugDumpOptions() { {"lower_verbose", DebugDumpOption::LowerVerbose}, {"expr_simplify", DebugDumpOption::ExprSimplification}, {"expr_sort", DebugDumpOption::ExprSort}, - {"loop_rotation", DebugDumpOption::LoopRotation}}; + {"loop_rotation", DebugDumpOption::LoopRotation}, + {"matmul_checks", DebugDumpOption::MatmulChecks}}; return parseEnvOptions("PYTORCH_NVFUSER_DUMP", available_options); } diff --git a/csrc/utils.h b/csrc/utils.h index 8f95ab6243c..42061664dd4 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -80,6 +80,8 @@ enum class DebugDumpOption { ExprSimplification, //! Print all passes' transform in simplifyExpr ExprSort, //! Print merging decisions on expression sorting LoopRotation, //! Print loop rotation log + MatmulChecks, //! Print logs from tools around matmul scheduler used in + //! segmenter EndOfOption //! Placeholder for counting the number of elements }; diff --git a/test/test_gpu_matmul_sass.cpp b/test/test_gpu_matmul_sass.cpp index f579c1711ad..ff5370a3831 100644 --- a/test/test_gpu_matmul_sass.cpp +++ b/test/test_gpu_matmul_sass.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -58,15 +59,15 @@ sass::Container getSASSFor( gemm_tile.warp_tile = warp_tile; gemm_tile.instruction_tile = instruction_tile; - auto mma_builder = MmaBuilder(macro, gemm_tile).layout(layout); - - MatmulParam params(mma_builder); + MatmulParams params; + params.mma_op = macro; + params.layout = layout; params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 4; - scheduleMatmul(tv2, tv0, tv1, params); + scheduleMatmul(&fusion, params); at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -288,7 +289,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSASSRegisterUsageLDSM_CUDA) { std::string_view view(smem_address); // example: [R0+UR0+0x200] view = view.substr(1, view.size() - 2); // example: R0+UR0+0x200 std::string_view base; - int offset; + int offset = 0; using namespace std::literals; auto last = view.find_last_of("+"sv); if (last == std::string::npos || diff --git a/test/test_gpu_tensorcore.cpp b/test/test_gpu_tensorcore.cpp index 7f5e7a7e87b..ef1179987f0 100644 --- a/test/test_gpu_tensorcore.cpp +++ b/test/test_gpu_tensorcore.cpp @@ -44,6 +44,10 @@ #include #include +#include "dispatch.h" +#include "ir_builder.h" +#include "ops/arith.h" +#include "type.h" namespace nvfuser { @@ -307,13 +311,11 @@ TEST_F(NVFuserTest, FusionVoltaMatmul_CUDA) { gemm_tile.warp_tile = GemmTile(64, 64, 32); gemm_tile.instruction_tile = GemmTile(16, 16, 4); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) - .layout(layout); - - MatmulParam params(mma_builder); + MatmulParams params; + params.mma_op = MmaOptions::MacroType::Volta_16_16_4; + params.layout = layout; params.tile_sizes = gemm_tile; - scheduleMatmul(tv2, tv0, tv1, params); + scheduleMatmul(&fusion, params); at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -357,14 +359,12 @@ TEST_F(NVFuserTest, FusionVoltaMatmulRegDoubleBuffer_CUDA) { gemm_tile.warp_tile = GemmTile(64, 64, 32); gemm_tile.instruction_tile = GemmTile(16, 16, 4); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) - .layout(layout); - - MatmulParam params(mma_builder); + MatmulParams params; + params.mma_op = MmaOptions::MacroType::Volta_16_16_4; + params.layout = layout; params.tile_sizes = gemm_tile; params.double_buffer_options.double_buffer_smem_read = true; - scheduleMatmul(tv2, tv0, tv1, params); + scheduleMatmul(&fusion, params); at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -641,16 +641,14 @@ TEST_F(NVFuserTest, FusionAmpereMatmul_CUDA) { gemm_tile.warp_tile = GemmTile(64, 64, 32); gemm_tile.instruction_tile = GemmTile(16, 8, 16); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(layout); - - MatmulParam params(mma_builder); + MatmulParams params; + params.mma_op = MmaOptions::MacroType::Ampere_16_8_16; + params.layout = layout; params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.smem_double_buffer_stage = 4; - scheduleMatmul(tv2, tv0, tv1, params); + scheduleMatmul(&fusion, params); at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -697,17 +695,15 @@ TEST_F(NVFuserTest, FusionAmpereMatmulPipelineGmem_CUDA) { gemm_tile.warp_tile = GemmTile(64, 64, 32); gemm_tile.instruction_tile = GemmTile(16, 8, 16); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(layout); - - MatmulParam params(mma_builder); + MatmulParams params; + params.mma_op = MmaOptions::MacroType::Ampere_16_8_16; + params.layout = layout; params.tile_sizes = gemm_tile; params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.smem_double_buffer_stage = stage; - scheduleMatmul(tv2, tv0, tv1, params); + scheduleMatmul(&fusion, params); at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -735,13 +731,13 @@ TEST_F(NVFuserTest, FusionAmpereSwizzle_CUDA) { int dim = 8192; int M = dim, N = dim, K = dim; const auto all_orders = { - MatmulParam::TileRasterizationOrder::RowMajor, - MatmulParam::TileRasterizationOrder::ColumnMajor}; + MatmulParams::TileRasterizationOrder::RowMajor, + MatmulParams::TileRasterizationOrder::ColumnMajor}; REQUIRE_DEVICE_SMEM_SIZE(70 << 10, 0); auto test = [&](MatmulLayout layout, - MatmulParam::TileRasterizationOrder order, + MatmulParams::TileRasterizationOrder order, int swizzle, float& runtime) { Fusion fusion; @@ -761,21 +757,19 @@ TEST_F(NVFuserTest, FusionAmpereSwizzle_CUDA) { gemm_tile.warp_tile = GemmTile(64, 64, 32); gemm_tile.instruction_tile = GemmTile(16, 8, 16); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(layout); - - MatmulParam params(mma_builder); + MatmulParams params; + params.mma_op = MmaOptions::MacroType::Ampere_16_8_16; + params.layout = layout; params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 3; - params.rasterization_order = order; + params.cta_order = order; params.grid_swizzle_factor = swizzle; - scheduleMatmul(tv2, tv0, tv1, params); + scheduleMatmul(&fusion, params); at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -874,18 +868,15 @@ TEST_F(NVFuserTest, FusionAmpereMatmulRegDoubleBuffer_CUDA) { gemm_tile.warp_tile = GemmTile(64, 64, 32); gemm_tile.instruction_tile = GemmTile(16, 8, 16); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(layout); - - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; + MatmulParams params; + params.mma_op = MmaOptions::MacroType::Ampere_16_8_16; + params.layout = layout; params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.smem_double_buffer_stage = stage; params.double_buffer_options.double_buffer_smem_read = true; - scheduleMatmul(tv2, tv0, tv1, params); + scheduleMatmul(&fusion, params); at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -1822,13 +1813,11 @@ TEST_F(NVFuserTest, FusionTuringMatmul_CUDA) { gemm_tile.warp_tile = GemmTile(64, 64, 32); gemm_tile.instruction_tile = GemmTile(16, 8, 16); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) - .layout(layout); - - MatmulParam params(mma_builder); + MatmulParams params; + params.mma_op = MmaOptions::MacroType::Turing_16_8_16; + params.layout = layout; params.tile_sizes = gemm_tile; - scheduleMatmul(tv2, tv0, tv1, params); + scheduleMatmul(&fusion, params); at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -2860,17 +2849,15 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoad_CUDA) { gemm_tile.warp_tile = GemmTile(64, 64, 64); gemm_tile.instruction_tile = GemmTile(16, 16, 16); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile) - .layout(layout); - - MatmulParam params(mma_builder); + MatmulParams params; + params.mma_op = MmaOptions::MacroType::Ampere_16_16_16; + params.layout = layout; params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 3; - scheduleMatmul(tv2, tv0, tv1, params); + scheduleMatmul(&fusion, params); at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -2914,13 +2901,11 @@ TEST_F(NVFuserTest, FusionTuringMatmulLargeLoad_CUDA) { gemm_tile.warp_tile = GemmTile(64, 64, 32); gemm_tile.instruction_tile = GemmTile(16, 16, 16); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Turing_16_16_16, gemm_tile) - .layout(layout); - - MatmulParam params(mma_builder); + MatmulParams params; + params.mma_op = MmaOptions::MacroType::Turing_16_16_16; + params.layout = layout; params.tile_sizes = gemm_tile; - scheduleMatmul(tv2, tv0, tv1, params); + scheduleMatmul(&fusion, params); at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -2969,15 +2954,13 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck4warp_CUDA) { gemm_tile.warp_tile = GemmTile(mn_size / 2, mn_size / 2, k_size); gemm_tile.instruction_tile = GemmTile(16, 16, 16); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile) - .layout(layout); - - MatmulParam params(mma_builder); + MatmulParams params; + params.mma_op = MmaOptions::MacroType::Ampere_16_16_16; + params.layout = layout; params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; - scheduleMatmul(tv2, tv0, tv1, params); + scheduleMatmul(&fusion, params); at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -3033,18 +3016,16 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck8warp_CUDA) { gemm_tile.warp_tile = GemmTile(m_size / 4, n_size / 2, k_size); gemm_tile.instruction_tile = GemmTile(16, 16, 16); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile) - .layout(layout); - - MatmulParam params(mma_builder); + MatmulParams params; + params.mma_op = MmaOptions::MacroType::Ampere_16_16_16; + params.layout = layout; params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 2; - scheduleMatmul(tv2, tv0, tv1, params); + scheduleMatmul(&fusion, params); at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -3094,18 +3075,16 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck6warp_CUDA) { gemm_tile.warp_tile = GemmTile(64, 64, k_size); gemm_tile.instruction_tile = GemmTile(16, 16, 16); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile) - .layout(layout); - - MatmulParam params(mma_builder); + MatmulParams params; + params.mma_op = MmaOptions::MacroType::Ampere_16_16_16; + params.layout = layout; params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 2; - scheduleMatmul(tv2, tv0, tv1, params); + scheduleMatmul(&fusion, params); at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -3149,17 +3128,15 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoadLargeK_CUDA) { gemm_tile.warp_tile = GemmTile(64, 64, 64); gemm_tile.instruction_tile = GemmTile(16, 16, 16); - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile) - .layout(layout); - - MatmulParam params(mma_builder); + MatmulParams params; + params.mma_op = MmaOptions::MacroType::Ampere_16_16_16; + params.layout = layout; params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 3; - scheduleMatmul(tv2, tv0, tv1, params); + scheduleMatmul(&fusion, params); at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -3180,6 +3157,96 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoadLargeK_CUDA) { } } +// Matmul test on Ampere relying on segmenter for 'C = A x B' fusion, +// with strict ref check hence single layout check +TEST_F(NVFuserTest, FusionMatmulSegmenterBasicMatmulStrictCheckTT_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + const int M = 128, N = 256, K = 512; + const auto layout = MatmulLayout::TT; + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); + auto tv2 = matmul(tv0, tv1, layout); + + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addOutput(tv2); + + at::manual_seed(0); + + at::Tensor t0 = matmulAtInput(M, N, K, layout, TensorMatmulPos::A, at::kHalf); + at::Tensor t1 = matmulAtInput(M, N, K, layout, TensorMatmulPos::B, at::kHalf); + at::Tensor tref = atMatmul(t0, t1, layout); + + FusionExecutorCache executor_cache(std::move(fusion)); + + auto outputs = executor_cache.runFusionWithInputs({t0, t1}); + + TORCH_CHECK( + !executor_cache.getMostRecentKernelRuntime()->isSegmented(), + "fusion got segmented, expected to match whole fusion with single segment"); + + TORCH_CHECK( + isSchedulerInUse( + executor_cache.getMostRecentKernelRuntime(), + ScheduleHeuristic::Matmul), + "matmul scheduler was not used to handle prepared fusion"); + + testValidate( + executor_cache.fusion(), outputs, {t0, t1}, {tref}, __LINE__, __FILE__); +} + +// Matmul test on Ampere relying on segmenter for 'C = A x B' fusion, +// with relaxed result verification +TEST_F(NVFuserTest, FusionMatmulSegmenterBasicMatmulRelaxedCheck_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + const int M = 504, N = 136, K = 2048; + for (auto layout : kAllSupportedMatmulLayout) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); + auto tv2 = matmul(tv0, tv1, layout); + + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addOutput(tv2); + + at::manual_seed(0); + + at::Tensor t0 = + matmulAtInput(M, N, K, layout, TensorMatmulPos::A, at::kHalf); + at::Tensor t1 = + matmulAtInput(M, N, K, layout, TensorMatmulPos::B, at::kHalf); + at::Tensor tref = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); + + FusionExecutorCache executor_cache(std::move(fusion)); + + auto outputs = executor_cache.runFusionWithInputs({t0, t1}); + + TORCH_CHECK( + !executor_cache.getMostRecentKernelRuntime()->isSegmented(), + "fusion got segmented, expected to match whole fusion with single segment"); + + TORCH_CHECK( + isSchedulerInUse( + executor_cache.getMostRecentKernelRuntime(), + ScheduleHeuristic::Matmul), + "matmul scheduler was not used to handle prepared fusion"); + + // NOTE: checking with lower expectations for relative/absolute error +#if 1 + TORCH_CHECK(outputs[0].allclose(tref, 0.001, 0.001)); +#else + testValidate( + executor_cache.fusion(), outputs, {t0, t1}, {tref}, __LINE__, __FILE__); +#endif + } +} + #undef NVFUSER_TEST_CUDA_ARCH_GUARD } // namespace nvfuser diff --git a/test/test_utils.cpp b/test/test_utils.cpp index 2d86b15eec8..ca4132fcc99 100644 --- a/test/test_utils.cpp +++ b/test/test_utils.cpp @@ -107,7 +107,7 @@ bool starts_with(std::string_view self, std::string_view __s) noexcept { std::string Instruction::predicate() { if (str[0] == '@') { std::stringstream ss(str); - char ignore_at; + char ignore_at = '\0'; std::string result; ss >> ignore_at >> result; return result; @@ -157,7 +157,7 @@ std::vector Instruction::args() { auto comma_pos = args_view.find_first_of(','); auto token = args_view.substr(0, comma_pos); token = trim(token); - result.push_back(std::string(token)); + result.emplace_back(token); args_view = (comma_pos != std::string_view::npos) ? args_view.substr(comma_pos + 1) @@ -221,21 +221,21 @@ Container parse(const std::string& nvdisasm_output) { if (line[0] == '.') { std::stringstream ss(line); Label l; - char ignore_dot; + char ignore_dot = '\0'; ss >> ignore_dot >> l.name; l.name.resize(l.name.size() - 1); // remove trailing : - result.code.push_back(l); + result.code.emplace_back(l); } else { Instruction i; std::stringstream ss(line); - char ignore; + char ignore = '\0'; // parse /*address*/ ss >> ignore >> ignore >> std::hex >> i.address >> ignore >> ignore; std::getline(ss, i.str); i.str = trim(i.str); i.str.resize(i.str.size() - 1); // remove trailing ; i.str = trim(i.str); - result.code.push_back(i); + result.code.emplace_back(i); } } else { if (line == header) { @@ -243,7 +243,7 @@ Container parse(const std::string& nvdisasm_output) { } else if (line[0] == '.') { std::stringstream ss(line); std::string key, value; - char ignore; + char ignore = '\0'; ss >> ignore >> key >> value; result.attributes[key] = value; if (key == "global") { @@ -320,4 +320,82 @@ std::pair fp16MatmulAtInput( return std::make_pair(at::Tensor(), at::Tensor()); } +at::Tensor matmulAtInput( + const int M, + const int N, + const int K, + const MatmulLayout layout, + const TensorMatmulPos tensor, + const c10::ScalarType dType, + const int device) { + const auto options = + at::TensorOptions().dtype(dType).device(at::kCUDA, device); + + // handle C and D tensors, layout does not impact shape + switch (tensor) { + case TensorMatmulPos::C: + case TensorMatmulPos::D: + return at::randn({M, N}, options); + default: + break; + } + + switch (layout) { + case MatmulLayout::TT: + switch (tensor) { + case TensorMatmulPos::A: + return at::randn({M, K}, options); + case TensorMatmulPos::B: + return at::randn({K, N}, options); + default: + break; + } + break; + case MatmulLayout::TN: + switch (tensor) { + case TensorMatmulPos::A: + return at::randn({M, K}, options); + case TensorMatmulPos::B: + return at::randn({N, K}, options); + default: + break; + } + break; + case MatmulLayout::NT: + switch (tensor) { + case TensorMatmulPos::A: + return at::randn({K, M}, options); + case TensorMatmulPos::B: + return at::randn({K, N}, options); + default: + break; + } + break; + default: + TORCH_CHECK(false, "unsupported data layout."); + } + TORCH_CHECK(false, "unsupported tensor position."); +} + +bool isSchedulerInUse( + nvfuser::FusionKernelRuntime* kernel_rt, + const ScheduleHeuristic& scheduler) { + if (nullptr == kernel_rt) { + return false; + } + const auto scheduler_heurs = kernel_rt->schedulerHeuristics(); + if (nullptr == scheduler_heurs) { + return false; + } + const auto& heurs = scheduler_heurs->heuristicsList(); + + for (const auto& heur_entry : heurs) { + if (heur_entry && (scheduler == heur_entry->heuristic())) { + return true; + } + } + + return false; +} + } // namespace nvfuser diff --git a/test/test_utils.h b/test/test_utils.h index 37e05e5e674..af99b387564 100644 --- a/test/test_utils.h +++ b/test/test_utils.h @@ -27,6 +27,7 @@ #include #include #include +#include "kernel_cache.h" namespace nvfuser { @@ -475,12 +476,35 @@ std::pair fp16MatmulAtInput( int K, MatmulLayout layout); +// Labels to describe tensor position in matmul: +// A, B - input +// C - input if beta is provided, shape must be the same as output (D) +// D - output +enum class TensorMatmulPos { A, B, C, D }; + +// Utility to generate buffers based on given problem, layout and tensor +// position in matmul +at::Tensor matmulAtInput( + const int M, + const int N, + const int K, + const MatmulLayout layout, + const TensorMatmulPos tensor, + const c10::ScalarType dType = at::kHalf, + const int device = 0); + #define REQUIRE_DEVICE_SMEM_SIZE(required_size, device_idx) \ if (at::cuda::getDeviceProperties(device_idx)->sharedMemPerBlockOptin < \ required_size) { \ GTEST_SKIP() << "not enough shared memory space on device to run test"; \ } +// Utility to check if for given kernel the expected scheduler has +// been used +bool isSchedulerInUse( + nvfuser::FusionKernelRuntime* kernel_rt, + const ScheduleHeuristic& scheduler); + // Disable magic zero constexpr CompileParams matmul_cparams{DataType::Int32, 255, false};