From 58687ada8e99700cdbf664b050900512dd0b7c40 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 20 May 2024 17:01:28 +0000 Subject: [PATCH 01/28] Generalize CombineMulSum as MatmulPatterns This also uses IdModel to find IterDomain and Tensor roles, and checks allocation domain to find problem layout. We guard the matmul tensor to reject problems that have non-trivial input allocation domains. --- csrc/ir/utils.cpp | 25 ++ csrc/ir/utils.h | 2 + csrc/mma_type.h | 2 +- csrc/scheduler/expr_eval_sched.cpp | 18 +- csrc/scheduler/matmul.cpp | 43 +- csrc/scheduler/matmul_utils.cpp | 174 ++++---- csrc/scheduler/mma_utils.cpp | 629 ++++++++++++++++------------ csrc/scheduler/mma_utils.h | 126 +++--- tests/cpp/test_combine_mul_sum.cpp | 54 +-- tests/cpp/test_gpu_tensorcore.cpp | 2 +- tests/cpp/test_matmul_scheduler.cpp | 52 +-- 11 files changed, 613 insertions(+), 514 deletions(-) diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 296cff951b8..265b2b1a202 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -1099,6 +1099,31 @@ int64_t getVectorizeSize(const TensorView* tv) { return 1; } +bool hasTrivialAllocationDomain(const TensorView* tv) { + if (!tv->hasAllocation()) { + return true; + } + const std::vector& alloc = tv->getMaybeAllocationDomain(); + const std::vector& rf = tv->getMaybeRFactorDomain(); + size_t i = 0, j = 0; + while (i < alloc.size() && j < rf.size()) { + if (alloc[i]->isBroadcast() || alloc[i]->isReduction()) { + i++; + continue; + } + if (rf[j]->isBroadcast() || rf[j]->isReduction()) { + j++; + continue; + } + if (!alloc[i]->sameAs(rf[j])) { + return false; + } + i++; + j++; + } + return true; +} + } // namespace nvfuser::ir_utils namespace nvfuser::MmaOpUtils { diff --git a/csrc/ir/utils.h b/csrc/ir/utils.h index a9b5c95c5e0..498a22b57c4 100644 --- a/csrc/ir/utils.h +++ b/csrc/ir/utils.h @@ -649,4 +649,6 @@ std::optional> computePermutation( return permutation; } +bool hasTrivialAllocationDomain(const TensorView* tv); + } // namespace nvfuser::ir_utils diff --git a/csrc/mma_type.h b/csrc/mma_type.h index e38619576e2..52c27a63e7e 100644 --- a/csrc/mma_type.h +++ b/csrc/mma_type.h @@ -26,7 +26,7 @@ namespace nvfuser { constexpr std::string_view MATMUL_LOG_PREFIX = "[MATMUL DEBUG] "; //! Named descriptors of domains in matmul -enum class MatmulDomain { M = 0, N, K }; +enum class MatmulDomain { M = 0, N, K, Batch }; //! Named descriptors of TensorView roles in fusion //! INPUT_A - a producer of MMA input A diff --git a/csrc/scheduler/expr_eval_sched.cpp b/csrc/scheduler/expr_eval_sched.cpp index b25a290ce99..8a138c92026 100644 --- a/csrc/scheduler/expr_eval_sched.cpp +++ b/csrc/scheduler/expr_eval_sched.cpp @@ -16,13 +16,19 @@ namespace nvfuser { // Check if the fusion has a single MatmulOp/LinearOp node bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { auto exprs = fusion->exprs(); - if (exprs.size() == 1 && - (exprs.front()->isA() || exprs.front()->isA())) { - return true; + if (!isOptionDisabled(DisableOption::MatmulExprEval)) { + if (exprs.size() == 1 && + (exprs.front()->isA() || exprs.front()->isA())) { + return true; + } + scheduler_debug_utils::canScheduleRejectReason( + heuristicType(), + "Fusion must contain a single expression of type MatmulOp or LinearOp"); + } else { + scheduler_debug_utils::canScheduleRejectReason( + heuristicType(), + "Matmul ATen evaluation was disabled by NVFUSER_DISABLE=matmul_expr_eval"); } - scheduler_debug_utils::canScheduleRejectReason( - heuristicType(), - "Fusion must contain a single expression of type MatmulOp or LinearOp"); return false; } diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index f29b2fd4617..4a636a596ff 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -749,38 +749,33 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { // Cache and fork outputs auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, true); - mma_utils::CombineMulSum combiner(fusion); - auto mma_ops = ir_utils::getOpsOfType(fusion); - if (combiner.isValid() && mma_ops.empty()) { - combiner.replaceWithMmaOp(); - mma_ops = ir_utils::getOpsOfType(fusion); - } - + std::vector patterns = + mma_utils::findMatmulPatterns(fusion); + NVF_ERROR(!patterns.empty(), "No matmul patterns were found"); NVF_ERROR( - mma_ops.size() == 1, - "scheduleMatmul supports fusion with single mma op in definition, got ", - mma_ops.size()); - - // Skip scheduling if Matmul will be expression evaluated. - if (!isOptionDisabled(DisableOption::MatmulExprEval)) { - NVF_CHECK(fusion->outputs().size() == 1) - fusion->aliasOutputToInput( - fusion->outputs()[0], /*input=*/nullptr, AllocationType::Evaluate); - scheduler_debug_utils::log( - __FILE__, - ":", - __LINE__, - ", Matmul output to be computed through expression evaluator. Skipping codegen."); - return; + patterns.size() == 1, + "Only a single matmul pattern can currently be fused"); + std::vector mma_ops; + mma_ops.reserve(patterns.size()); + for (mma_utils::MatmulPattern& pattern : patterns) { + mma_ops.push_back(pattern.translateToMmaOp()); } - const auto& roles_map_opt = mma_utils::getTensorsRoles(fusion); + IdModel id_model(fusion); + std::unordered_map id_roles = + patterns.front().getDimRoles(id_model); + const auto& roles_map_opt = + mma_utils::getTensorsRoles(fusion, id_model, id_roles); // NOTE: the contents of roles_map have been already validated during // compute-time checks NVF_ERROR(roles_map_opt.isValid(), roles_map_opt.getErrorMsg()); const auto roles_map = roles_map_opt.getData(); + const mma_utils::MatmulProblemLayoutOpt fusion_layout = + mma_utils::getProblemLayout(id_model, id_roles, roles_map); + NVF_ERROR(fusion_layout.isValid(), fusion_layout.getErrorMsg()); + // Core roles: there can be only one... TV with assigned core role TensorView* a = roles_map.at(MatmulRole::INPUT_A).front(); TensorView* b = roles_map.at(MatmulRole::INPUT_B).front(); @@ -791,8 +786,6 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { NVF_ERROR( mma_layout_opt.has_value(), "fusion mma op has undefined input layout"); const auto mma_layout = mma_layout_opt.value(); - const auto fusion_layout = mma_utils::getMmaLayout(fusion); - NVF_ERROR(fusion_layout.isValid(), fusion_layout.getErrorMsg()); const auto& gemm_tile = params.tile_sizes; const bool has_epilogue = !mma->out()->isFusionOutput(); diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 06f427788bd..998706673d9 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -15,11 +15,13 @@ // 'SchedulerRuntimeInfo' #include #include +#include #include #include #include #include #include +#include #include #include #include @@ -38,9 +40,8 @@ namespace nvfuser { namespace { -//! Access to the structure should be done with labels defined in -//! MmaOptions::MmaDomains. -using ProblemShape = std::array; +//! Access to the structure should be done with labels defined in MatmulDomain. +using ProblemShape = std::array; //! A helper for deciding the type of MMA op for given fusion and problem shape. inline std::optional getMmaOp( @@ -158,43 +159,37 @@ inline bool initCoreHeuristics( } //! A helper for getting problem shape from fusion and runtime info. +//! +//! For a given domain, try to find the size by evaluating the extent of an +//! IterDomain in each group of that domain type. For example, if there are +//! multiple Batch dimensions, we find all ValGroups that are mapped as +//! MatmulDomain::Batch, we evaluate the extent of each, then we multiply those +//! dimensions together to get the overall batch size. ProblemShape getProblemShape( - const mma_utils::MulSumProperties::InputsOutputs& props, + const std::unordered_map& group_to_domain, SchedulerRuntimeInfo& runtime_info) { - const auto mma_output_domains = mma_utils::getProblemIterDomains({props}); - if (!mma_output_domains.isValid()) { - NVF_ERROR(false, mma_output_domains.getErrorMsg()); - } - - const auto [m, n, k] = mma_output_domains.getData(); - - auto m_extend = runtime_info.expressionEvaluator().evaluate(m->extent()); - auto n_extend = runtime_info.expressionEvaluator().evaluate(n->extent()); - auto k_extend = runtime_info.expressionEvaluator().evaluate(k->extent()); - - if (!(m_extend && n_extend && k_extend)) { + ProblemShape shape{1, 1, 1, 1}; + for (const auto& [g, dom] : group_to_domain) { + NVF_ERROR(!g->empty()); + IterDomain* id = g->front()->as(); + const PolymorphicValue extent = + runtime_info.expressionEvaluator().evaluate(id->extent()); NVF_ERROR( - false, - "Failed to acquire one of problem dimensions, M(", - m_extend.hasValue(), - "), N(", - n_extend.hasValue(), - " K(", - k_extend.hasValue(), - ")"); + extent.hasValue(), "Could not evaluate extent of ", id->toString()); + shape[(size_t)dom] *= extent.as(); } - - return ProblemShape{ - m_extend.as(), n_extend.as(), k_extend.as()}; + return shape; } std::string isMatmulFusionDefinitionSupported( Fusion* fusion, - const mma_utils::MulSumProperties::InputsOutputs& props) { + const mma_utils::MatmulPattern& pattern, + const mma_utils::RolesMap& roles_map, + const std::unordered_map& id_roles) { const auto& fusion_inputs = fusion->inputs(); const auto& fusion_outputs = fusion->outputs(); - std::vector mma_inputs = {props.a, props.b}; - const auto mma_output = props.out; + std::vector mma_inputs = {pattern.A, pattern.B}; + const auto mma_output = pattern.output; const auto fusion_inputs_tvs = ir_utils::filterByType(fusion_inputs).vector(); @@ -202,18 +197,20 @@ std::string isMatmulFusionDefinitionSupported( ir_utils::filterByType(fusion_outputs).vector(); constexpr size_t minimal_number_of_inputs = 2; - MmaOpUtils::MmaOpDetails mma_details = - MmaOpUtils::getMmaOpDetails(props.out, props.a, props.b); // Quick checks - MmaOp { // Check if MmaOp represents gemm (requires M/N/K == 1, B == 0) // or bgemm (requires M/N/K/B == 1) + std::array num_axes{}; + for (const auto& [g, dom] : id_roles) { + num_axes[(size_t)dom]++; + } constexpr size_t expected_axes_numbers = 1; - if (mma_details.m_axes.size() != expected_axes_numbers || - mma_details.n_axes.size() != expected_axes_numbers || - mma_details.k_axes.size() != expected_axes_numbers || - mma_details.batch_axes.size() > expected_axes_numbers) { + if (num_axes[(size_t)MatmulDomain::M] != expected_axes_numbers || + num_axes[(size_t)MatmulDomain::N] != expected_axes_numbers || + num_axes[(size_t)MatmulDomain::K] != expected_axes_numbers || + num_axes[(size_t)MatmulDomain::Batch] > expected_axes_numbers) { return "MmaOp has unsupported number of one of M/N/K/Batch axes"; } @@ -232,12 +229,6 @@ std::string isMatmulFusionDefinitionSupported( // Fusion topology check { - const auto& roles_map_opt = mma_utils::getTensorsRoles(fusion, props); - if (!roles_map_opt.isValid()) { - return roles_map_opt.getErrorMsg(); - } - - const auto& roles_map = roles_map_opt.getData(); auto entry = roles_map.find(MatmulRole::INPUT_A); std::set tvs_with_roles; @@ -288,6 +279,21 @@ std::string isMatmulFusionDefinitionSupported( } } + // Check that no non-trivial allocation domains are set on inputs or outputs. + // TODO: Lift this requirement once we have proper allocation domain support + for (Val* inp : fusion->inputs()) { + if (auto tv = dynamic_cast(inp); + tv && !ir_utils::hasTrivialAllocationDomain(tv)) { + return "detected input TV with non-trivial allocation domain"; + } + } + for (Val* outp : fusion->outputs()) { + if (auto tv = dynamic_cast(outp); + tv && !ir_utils::hasTrivialAllocationDomain(tv)) { + return "detected output TV with non-trivial allocation domain"; + } + } + return ""; } @@ -450,8 +456,7 @@ std::string getMatmulRunTimeRejectReason( std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { // The plan: // 0. Check if the current CUDA device is supported - // 1. Check if there is exactly one MmaOp or suitable mul sum pair - // defined in the fusion. + // 1. Check if there is exactly one matmul pattern defined in the fusion. // 2. Check if inputs to the mma op or mul sum pair match any of // supported inputs layout // 3. Check if fusion represents expressions that are recognized by matmul @@ -463,7 +468,7 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { // Use a dummy problem shape to determine whether this is a supported // device. const auto mma_op = - getMmaOp(device_prop->major * 10 + device_prop->minor, {128, 128, 128}); + getMmaOp(device_prop->major * 10 + device_prop->minor, {128, 128, 128, 1}); if (!mma_op.has_value()) { return "Unsupported device compute capability"; } @@ -472,29 +477,37 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { // #1 // Initializing the machinery to check if there's a Mul-Sum pair // can be replaced by a Mma Op. - mma_utils::CombineMulSum combiner(fusion); - if (!combiner.isValid()) { - std::stringstream ss; - ss << "Matmul scheduler supports fusions only with a single mma op" - << "or supports a mul-sum pair which can be replaced with a mma op"; - return ss.str(); + std::vector patterns = + mma_utils::findMatmulPatterns(fusion); + if (patterns.empty()) { + return "No matmul patterns were found"; } - - const std::vector& mma_from_mul_sums = - combiner.getMulSumCanidates(); - // #2 - { - const auto input_layout_opt = - mma_utils::getMmaLayout(fusion, mma_from_mul_sums.front().insouts); - if (!input_layout_opt.isValid()) { - return input_layout_opt.getErrorMsg(); - } + if (patterns.size() > 1) { + return "Only a single matmul pattern can currently be fused"; } // #3 + // Prepare an IdModel which will be reused to check remaining conditions + IdModel id_model(fusion); + const auto id_roles = patterns.front().getDimRoles(id_model); + const mma_utils::RolesMapOpt roles_map_opt = + mma_utils::getTensorsRoles(fusion, id_model, id_roles); + if (!roles_map_opt.isValid()) { + return {roles_map_opt.getErrorMsg()}; + } + mma_utils::RolesMap roles_map = roles_map_opt.getData(); + + // #4 + const auto input_layout_opt = + mma_utils::getProblemLayout(id_model, id_roles, roles_map); + if (!input_layout_opt.isValid()) { + return input_layout_opt.getErrorMsg(); + } + + // #5 { auto support_status = isMatmulFusionDefinitionSupported( - fusion, mma_from_mul_sums.front().insouts); + fusion, patterns.front(), roles_map, id_roles); if (!support_status.empty()) { return support_status; } @@ -532,21 +545,22 @@ std::shared_ptr getMatmulHeuristics( // Set kernel index mode params->cparams.index_type = runtime_info.getIndexType(); - if (!isOptionDisabled(DisableOption::MatmulExprEval)) { - return params; - } - // Check initial conditions - auto mma_exprs = ir_utils::getOpsOfType(fusion); - mma_utils::CombineMulSum combiner(fusion); + std::vector patterns = + mma_utils::findMatmulPatterns(fusion); + NVF_ERROR(!patterns.empty(), "No matmul patterns were found"); NVF_ERROR( - combiner.isValid(), - "There's no (single) mma op or mul-sum op which mma op can replace") + patterns.size() == 1, + "Only a single matmul pattern can currently be fused"); + mma_utils::MatmulPattern& pattern = patterns.front(); + + // IdModel is used to analyze problem shape & layout + IdModel id_model(fusion); + + const std::unordered_map id_roles = + pattern.getDimRoles(id_model); - const std::vector& mulSum = - combiner.getMulSumCanidates(); - const auto problem_shape = - getProblemShape(mulSum.front().insouts, runtime_info); + const auto problem_shape = getProblemShape(id_roles, runtime_info); const auto device_prop = at::cuda::getCurrentDeviceProperties(); const auto mma_op = @@ -556,7 +570,7 @@ std::shared_ptr getMatmulHeuristics( params->mma_macro = mma_op.value(); const auto& roles_map_opt = - mma_utils::getTensorsRoles(fusion, mulSum.front().insouts); + mma_utils::getTensorsRoles(fusion, id_model, id_roles); NVF_ERROR(roles_map_opt.isValid(), "Tensor roles map in mma is not valid."); const auto roles_map = roles_map_opt.getData(); @@ -565,17 +579,17 @@ std::shared_ptr getMatmulHeuristics( if (matmul_heuristic_plugin::hasPlugin()) { const mma_utils::MatmulProblemLayoutOpt layout_opt = - mma_utils::getMmaLayout(fusion, mulSum.front().insouts); + mma_utils::getProblemLayout(id_model, id_roles, roles_map); NVF_ERROR(layout_opt.isValid(), layout_opt.getErrorMsg()); const MmaLayout layout = layout_opt.getData(); // Fill in proper values using plugin matmul_heuristic_plugin::updateMatmulParams( *params, - /*M=*/problem_shape[0], - /*N=*/problem_shape[1], - /*K=*/problem_shape[2], - /*batch_size=*/1, // TODO: extract actual batch size + problem_shape[(size_t)MatmulDomain::M], + problem_shape[(size_t)MatmulDomain::N], + problem_shape[(size_t)MatmulDomain::K], + problem_shape[(size_t)MatmulDomain::Batch], layout, roles_map); } else { diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index ee7be460059..5014f30f194 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -9,10 +9,13 @@ #include #include #include +#include #include +#include #include #include #include +#include #include #include "mma_type.h" namespace nvfuser { @@ -1051,14 +1054,13 @@ inline void resolveTvToMatmulDomainsMapping( } // anonymous namespace -ProblemIterDomainsOpt getProblemIterDomains( - const mma_utils::MulSumProperties::InputsOutputs& props) { +ProblemIterDomainsOpt getProblemIterDomains(const MatmulPattern& pattern) { // NOTE: the iter domains of MMA output should be [...,M,K,N] IterDomain* m = nullptr; IterDomain* n = nullptr; IterDomain* k = nullptr; - const auto& leaf_domains = props.out->getLeafDomain(); + const auto& leaf_domains = pattern.output->getLeafDomain(); const auto concrete = TensorDomain::noDevices( TensorDomain::noReductions(TensorDomain::noBroadcasts(leaf_domains))); if (concrete.size() < MIN_MATMUL_INPUTS_NUMBER) { @@ -1100,105 +1102,103 @@ ProblemIterDomainsOpt getProblemIterDomains(Fusion* fusion) { static_cast(mma_exprs.front()->out())}); } -MatmulProblemLayoutOpt getMmaLayout( - Fusion* fusion, - const mma_utils::MulSumProperties::InputsOutputs& props) { - ComputeAtMap ca_map(fusion); - const auto mma_input_candidates = - ir_utils::filterByType(fusion->inputs()).vector(); - if (mma_input_candidates.empty()) { - return {"Failed to find any TV that is fusion input"}; - } - - const auto mma_output_domains = getProblemIterDomains(props); - if (!mma_output_domains.isValid()) { - return mma_output_domains.getErrorMsg(); +MatmulProblemLayoutOpt getProblemLayout(Fusion* fusion) { + const std::vector patterns = findMatmulPatterns(fusion); + if (patterns.size() != 1) { + std::stringstream ss; + ss << "Invalid number of MmaOp instances in fusion, expected 1, got " + << patterns.size(); + return ss.str(); } + return getProblemLayout(fusion, patterns.front()); +} - const auto domains_data = mma_output_domains.getData(); - const auto m = domains_data[(size_t)MatmulDomain::M]; - const auto n = domains_data[(size_t)MatmulDomain::N]; - const auto k = domains_data[(size_t)MatmulDomain::K]; - - DependenciesMap deps_map; - resolveTvToMatmulDomainsMapping( - deps_map, mma_input_candidates, m, n, k, ca_map); - - bool mk_found = false; - bool km_found = false; - bool nk_found = false; - bool kn_found = false; - const static DomainsDesc mk_desc = {MatmulDomain::M, MatmulDomain::K}; - const static DomainsDesc km_desc = {MatmulDomain::K, MatmulDomain::M}; - const static DomainsDesc nk_desc = {MatmulDomain::N, MatmulDomain::K}; - const static DomainsDesc kn_desc = {MatmulDomain::K, MatmulDomain::N}; - - for (const auto& item : deps_map) { - if (item.second == mk_desc) { - if (mk_found) { - return { - "Failed to find MMA input, more than one fusion input has [..., M, ..., K, ...] iter domains"}; - } - mk_found = true; +MatmulProblemLayoutOpt getProblemLayout( + const IdModel& id_model, + const std::unordered_map& group_to_domain, + const RolesMap& roles_map) { + // Assumes the exact graph has already been built, since we've been provided + // group_to_domain + const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); + + // Note: using DataWrapperOpt would be preferable here. However, + // using DataWrapperOpt(std::move(dom)) leads to a clang-tidy + // warning because MatmulDomain is trivially movable. There is only a move + // constructor for DataWrapperOpt to prevent inadvertent copying. To avoid + // this complication I'm using a simple pair for the lambda's result type. + using InnerDomResult = std::pair; + const auto innerDomain = [&roles_map, &group_to_domain, &exact_graph]( + MatmulRole role) -> InnerDomResult { + const auto role_it = roles_map.find(role); + if (role_it == roles_map.end()) { + return {MatmulDomain::M, "Could not find role in roles_map"}; } - if (item.second == km_desc) { - if (km_found) { + std::optional group_inner_dom = std::nullopt; + for (TensorView* tv : role_it->second) { + IterDomain* inner_id = + TensorDomain::noReductions(tv->getMaybeAllocationDomain()).back(); + const ValGroup& g = exact_graph.toGroup(inner_id); + auto g_it = group_to_domain.find(g); + if (g_it == group_to_domain.end()) { return { - "Failed to find MMA input, more than one fusion input has [..., K, ..., M, ...] iter domains"}; + MatmulDomain::M, + "Inner domain of tensor was not mapped to a MatmulDomain"}; } - km_found = true; - } - if (item.second == nk_desc) { - if (nk_found) { + if (!group_inner_dom.has_value()) { + group_inner_dom = g_it->second; + } else if (group_inner_dom.value() != g_it->second) { return { - "Failed to find MMA input, more than one fusion input has [..., N, ..., K, ...] iter domains"}; + MatmulDomain::M, "Group contains multiple inner dimension domains"}; } - nk_found = true; } - if (item.second == kn_desc) { - if (kn_found) { - return { - "Failed to find MMA input, more than one fusion input has [..., K, ..., N, ...] iter domains"}; - } - kn_found = true; + if (!group_inner_dom.has_value()) { + return {MatmulDomain::M, "No tensor found in role"}; } - } + return {group_inner_dom.value(), ""}; + }; - if ((mk_found && kn_found) && !(km_found || nk_found)) { - return MmaLayout::TT; + const InnerDomResult a_dom_res = innerDomain(MatmulRole::INPUT_A); + if (!a_dom_res.second.empty()) { + std::string err = a_dom_res.second; + return err; } - if ((km_found && kn_found) && !(mk_found || nk_found)) { - return MmaLayout::NT; + const bool kinner_a = a_dom_res.first == MatmulDomain::K; + + const InnerDomResult b_dom_res = innerDomain(MatmulRole::INPUT_B); + if (!b_dom_res.second.empty()) { + std::string err = b_dom_res.second; + return err; } - if ((mk_found && nk_found) && !(km_found || kn_found)) { + const bool kinner_b = b_dom_res.first == MatmulDomain::K; + + if (kinner_a && kinner_b) { return MmaLayout::TN; - } - if ((km_found && nk_found) && !(mk_found || kn_found)) { + } else if (kinner_a && !kinner_b) { + return MmaLayout::TT; + } else if (!kinner_a && !kinner_b) { + return MmaLayout::NT; + } else if (!kinner_a && kinner_b) { return MmaLayout::NN; } - - return {"Failed to decide fusion inputs' data layout."}; + NVF_ERROR(false, "Reached unreachable section of getProblemLayout"); } -MatmulProblemLayoutOpt getMmaLayout(Fusion* fusion) { - auto mma_exprs = ir_utils::getOpsOfType(fusion); - if (mma_exprs.size() != 1) { - std::stringstream ss; - ss << "Invalid number of MmaOp instances in fusion, expected 1, got " - << mma_exprs.size(); - return ss.str(); +MatmulProblemLayoutOpt getProblemLayout( + Fusion* fusion, + const MatmulPattern& pattern) { + IdModel id_model(fusion); + const auto id_roles = pattern.getDimRoles(id_model); + const auto roles_map_opt = getTensorsRoles(fusion, id_model, id_roles); + if (!roles_map_opt.isValid()) { + return {roles_map_opt.getErrorMsg()}; } - return getMmaLayout( - fusion, - {static_cast(mma_exprs.front()->inA()), - static_cast(mma_exprs.front()->inB()), - static_cast(mma_exprs.front()->out())}); + return getProblemLayout(id_model, id_roles, roles_map_opt.getData()); } RolesMapOpt getTensorsRoles( Fusion* fusion, - const mma_utils::MulSumProperties::InputsOutputs& props) { - ComputeAtMap ca_map(fusion); + const IdModel& id_model, + const std::unordered_map& group_to_domain) { const auto mma_input_candidates = ir_utils::filterByType(fusion->inputs()).vector(); if (mma_input_candidates.empty()) { @@ -1210,143 +1210,118 @@ RolesMapOpt getTensorsRoles( return {"Failed to find any TV that is fusion output"}; } - const auto mma_output_domains = getProblemIterDomains(props); - if (!mma_output_domains.isValid()) { - return mma_output_domains.getErrorMsg(); - } - - const auto findInputRolesByDomains = [](const DependenciesMap& deps_map, - RolesMap& roles_map) { - for (const auto& entry : deps_map) { - const auto& domains = entry.second; - const auto begin = domains.begin(); - const auto end = domains.end(); + RolesMap roles_map; - bool has_m = (end != std::find(begin, end, MatmulDomain::M)); - bool has_n = (end != std::find(begin, end, MatmulDomain::N)); - bool has_k = (end != std::find(begin, end, MatmulDomain::K)); + // Assumes the exact graph has already been built, since we've been provided + // group_to_domain + const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); - if (has_m && has_k && !has_n) { - roles_map[MatmulRole::INPUT_A].push_back(entry.first); + for (TensorView* tv : mma_input_candidates) { + bool has_m = false, has_n = false, has_k = false, has_unmapped = false; + for (IterDomain* id : + TensorDomain::noReductions(tv->getMaybeRFactorDomain())) { + if (id->isBroadcast()) { + // Broadcast domains won't exact map to concrete domains so skip them continue; } - if (has_n && has_k && !has_m) { - roles_map[MatmulRole::INPUT_B].push_back(entry.first); + const ValGroup& g = exact_graph.toGroup(id); + auto it = group_to_domain.find(g); + if (it == group_to_domain.end()) { + // tv has an unmapped non-broadcast and non-reduction dimension + has_unmapped = true; continue; } - // Bias vectors are assigned to INPUT_C role - if (!has_k) { - roles_map[MatmulRole::INPUT_C].push_back(entry.first); + has_m |= it->second == MatmulDomain::M; + has_n |= it->second == MatmulDomain::N; + has_k |= it->second == MatmulDomain::K; + } + if (has_unmapped) { + // Don't map TVs to roles if they have unmapped dims + continue; + } + if (has_m && has_k && !has_n) { + roles_map[MatmulRole::INPUT_A].push_back(tv); + continue; + } + if (has_n && has_k && !has_m) { + roles_map[MatmulRole::INPUT_B].push_back(tv); + continue; + } + // Bias vectors are assigned to INPUT_C role + if (!has_k) { + roles_map[MatmulRole::INPUT_C].push_back(tv); + continue; + } + } + + std::vector storage; + for (TensorView* tv : mma_output_candidates) { + bool has_m = false, has_n = false, has_k = false, has_unmapped = false; + for (IterDomain* id : + TensorDomain::noReductions(tv->getMaybeRFactorDomain())) { + const ValGroup& g = exact_graph.toGroup(id); + auto it = group_to_domain.find(g); + if (it == group_to_domain.end()) { + // tv has an unmapped dimension + has_unmapped = true; continue; } + has_m |= it->second == MatmulDomain::M; + has_n |= it->second == MatmulDomain::N; + has_k |= it->second == MatmulDomain::K; } - - for (auto& [role, tvs] : roles_map) { - // NOTE: sort input roles in descending order by uses() size, and - // if equal then by name() to ensure the stable ordering of tensor - // views in collections assigned to the supported roles - std::sort(tvs.begin(), tvs.end(), [](TensorView* a, TensorView* b) { - return (a->uses().size() == b->uses().size()) - ? (a->name() < b->name()) - : (a->uses().size() > b->uses().size()); - }); + // NOTE: depending on fusion definition k domain may appear in the output: + // - for mma_output == fusion output k domain is present + // - for mma_output != fusion output (fusion with epilogue) k domain + // is not present + if (has_k || has_unmapped) { + // Don't map TVs to output roles if they have unmapped dims, or if they + // have K dimension + continue; } - }; - - const auto findOutputRolesByDomains = [](const DependenciesMap& deps_map, - RolesMap& roles_map) { - std::vector storage; - storage.reserve(deps_map.size()); - - for (const auto& entry : deps_map) { - const auto& domains = entry.second; - const auto begin = domains.begin(); - const auto end = domains.end(); - bool has_m = (end != std::find(begin, end, MatmulDomain::M)); - bool has_n = (end != std::find(begin, end, MatmulDomain::N)); - - // NOTE: depending on fusion definition k domain may appear in the output: - // - for mma_output == fusion output k domain is present - // - for mma_output != fusion output (fusion with epilogue) k domain - // is not present + // NOTE: the core fusion output tensors are the ones with m and n + // domains + if (has_m && has_n) { + storage.push_back(tv); + } + } - // NOTE: the core fusion output tensors are the ones with m and n - // domains - if (has_m && has_n) { - storage.push_back(entry.first); - } + // NOTE: sort output roles in descending order by uses() size, and + // if equal then by name() to ensure the stable ordering of tensor + // views in collections assigned to the supported roles + std::sort(storage.begin(), storage.end(), [](TensorView* a, TensorView* b) { + return (a->uses().size() == b->uses().size()) + ? (a->name() < b->name()) + : (a->uses().size() > b->uses().size()); + }); + + if (!storage.empty()) { + // NOTE: currently, we pick as a reference tensor one with `m` and `n` + // IterDomains and the most uses + auto pos = storage.begin(); + roles_map[MatmulRole::OUTPUT_D].push_back(*pos); + for (++pos; pos != storage.end(); ++pos) { + roles_map[MatmulRole::OUTPUT_AUX].push_back(*pos); } + } - // NOTE: sort output roles in descending order by uses() size, and + for (auto& [role, tvs] : roles_map) { + // NOTE: sort input roles in descending order by uses() size, and // if equal then by name() to ensure the stable ordering of tensor // views in collections assigned to the supported roles - std::sort(storage.begin(), storage.end(), [](TensorView* a, TensorView* b) { + std::sort(tvs.begin(), tvs.end(), [](TensorView* a, TensorView* b) { return (a->uses().size() == b->uses().size()) ? (a->name() < b->name()) : (a->uses().size() > b->uses().size()); }); - - if (!storage.empty()) { - // NOTE: currently, we pick as a reference tensor one with `m` and `n` - // IterDomains and the most uses - auto pos = storage.begin(); - roles_map[MatmulRole::OUTPUT_D].push_back(*pos); - for (++pos; pos != storage.end(); ++pos) { - roles_map[MatmulRole::OUTPUT_AUX].push_back(*pos); - } - } - }; - - const auto domains_data = mma_output_domains.getData(); - const auto m = domains_data[(size_t)MatmulDomain::M]; - const auto n = domains_data[(size_t)MatmulDomain::N]; - const auto k = domains_data[(size_t)MatmulDomain::K]; - - DependenciesMap deps_map; - RolesMap roles_map; - - // Handle fusion input TensorView objects - resolveTvToMatmulDomainsMapping( - deps_map, mma_input_candidates, m, n, k, ca_map); - findInputRolesByDomains(deps_map, roles_map); - - deps_map.clear(); - - // Handle fusion output TensorView objects - resolveTvToMatmulDomainsMapping( - deps_map, mma_output_candidates, m, n, k, ca_map); - findOutputRolesByDomains(deps_map, roles_map); + } return roles_map; } -RolesMapOpt getTensorsRoles(Fusion* fusion) { - auto mma_exprs = ir_utils::getOpsOfType(fusion); - if (mma_exprs.size() != 1) { - std::stringstream ss; - ss << "Invalid number of MmaOp instances in fusion, expected 1, got " - << mma_exprs.size(); - return ss.str(); - } - return getTensorsRoles( - fusion, - {static_cast(mma_exprs.front()->inA()), - static_cast(mma_exprs.front()->inB()), - static_cast(mma_exprs.front()->out())}); -} - namespace { -void addMMAOp(Fusion* fusion_, std::vector& props) { - for (auto prop : props) { - auto* init = - IrBuilder::create(0.0, prop.insouts.out->getDataType().value()); - IrBuilder::create( - prop.insouts.out, prop.insouts.a, prop.insouts.b, init); - } -} - // Check the val (in) is the output of broadcast. // Then check the output of the broadcast is 3D (4D for bmm). bool hasValidBroadcastOp(TensorView* bcast_out) { @@ -1437,103 +1412,215 @@ TensorView* getTensorviewPriorToCast(TensorView* in) { return in; } -// Check if the Mul-Sum pair represents a matmul. If so, add the properties -// of the mma op which can be a tentatice substitue. This checks that the output -// of sum has on reduction axis, and the inputs to mul are valid broadcasts. -std::optional getMulSumInsOutsBcasts( - BinaryOp* mop, - ReductionOp* redop) { - auto a = getTensorviewPriorToCast(static_cast(mop->lhs())); - auto b = getTensorviewPriorToCast(static_cast(mop->rhs())); - - // Get the dimension of the reduction in the output. If not present, bail. - // Also ensure there is only only reduction axis. - auto red_axis = static_cast(redop->out())->getReductionAxis(); - auto num_reduction_dims = - static_cast(redop->out())->domain()->nDims() - - static_cast(redop->out())->domain()->noReductions().size(); - if (!red_axis.has_value() || num_reduction_dims > 1) { - return std::nullopt; - } +} // namespace - if (broadcastsAreValid(a, b, *red_axis)) { - MulSumProperties props = { - {mop, redop}, - {a, b, static_cast(redop->output(0))}, - {dynamic_cast(a->definition()), - dynamic_cast(b->definition())}}; - return props; +char dtypeToChar(const DataType& dtype) { + if (dtype == DataType::Half) { + return 'H'; + } else if (dtype == DataType::BFloat16) { + return 'T'; + } else if (dtype == DataType::Float) { + return 'S'; + } else if (dtype == DataType::Double) { + return 'D'; } - return std::nullopt; + NVF_ERROR(false, "Unsupported dtype for matmul: ", dtype); + return 0; } -} // namespace -void CombineMulSum::handle(ReductionOp* stmt) { - // Check if operation is a sum. - if (stmt->getReductionOpType() == BinaryOpType::Add) { - auto* inputOfSum = stmt->in(); - if (inputOfSum != nullptr) { - auto* expr = inputOfSum->definition(); - // Then check if the prodcer of the sum is a mul. - if (auto bOp = dynamic_cast(expr)) { - // If it'a mul followed by a sum, put this in a list. - if (bOp->getBinaryOpType() == BinaryOpType::Mul) { - // If the Mul-Sum is a valid representation of a matmul, - // then get the properties of the replacement Mma op. - auto props = getMulSumInsOutsBcasts(bOp, stmt); - if (props.has_value()) { - mul_sum_props_.push_back(*props); +namespace { + +class MatmulPatternMatcher : IterVisitor { + public: + static std::vector run(Fusion* fusion) { + MatmulPatternMatcher matcher; + matcher.traverse(fusion); + return matcher.patterns_; + } + + private: + using IterVisitor::handle; + + void handle(MatmulOp* mop) override { + MatmulPattern& pattern = patterns_.emplace_back(); + pattern.A = mop->inA()->as(); + pattern.B = mop->inB()->as(); + pattern.output = mop->out()->as(); + } + + // Handle the case when no translation is needed. + void handle(MmaOp* mop) override { + MatmulPattern& pattern = patterns_.emplace_back(); + pattern.A = mop->inA()->as(); + pattern.B = mop->inB()->as(); + pattern.output = mop->out()->as(); + } + + void handle(ReductionOp* rop) override { + // Check if operation is a sum. + if (rop->getReductionOpType() != BinaryOpType::Add) { + return; + } + // Then check if the producer of the sum is a mul. + if (auto bop = dynamic_cast(rop->in()->definition())) { + if (bop->getBinaryOpType() != BinaryOpType::Mul) { + return; + } + // Remember that we are just gathering the immediate inputs to the + // matmul, so there should be no prologue between a, b and the mul/sum. + + // Check that the inputs have broadcasts that are not all in common, i.e. + // that there is at least one M and at least one N dimension. + TensorView* ltv = getTensorviewPriorToCast(bop->lhs()->as()); + TensorView* rtv = getTensorviewPriorToCast(bop->rhs()->as()); + + std::vector lrf = + TensorDomain::noReductions(ltv->getMaybeRFactorDomain()); + std::vector rrf = + TensorDomain::noReductions(rtv->getMaybeRFactorDomain()); + + // These sizes should match since ops::maybeBroadcast places BroadcastOps + // for implicit broadcasting. + NVF_ERROR(lrf.size() == rrf.size()); + const std::vector& red_root = + rop->out()->as()->getRootDomain(); + NVF_ERROR(red_root.size() == lrf.size()); + bool has_m = false, has_n = false; + for (size_t i : c10::irange(lrf.size())) { + if (lrf[i]->isBroadcast() && !rrf[i]->isBroadcast()) { + has_m = true; + } else if (!lrf[i]->isBroadcast() && rrf[i]->isBroadcast()) { + has_n = true; + } + if (red_root[i]->isReduction()) { + // matmul must be contraction of non-broadcast dimensions + if (!lrf[i]->isIteration() || !rrf[i]->isIteration()) { + return; } } } + if (!has_m || !has_n) { + // This is an ordinary reduction, not a matmul + return; + } + + MatmulPattern& pattern = patterns_.emplace_back(); + pattern.A = ltv; + pattern.B = rtv; + pattern.output = rop->out()->as(); } } + + private: + std::vector patterns_; }; -void CombineMulSum::generateMulSumCanidates() { - auto mma_exprs = ir_utils::getOpsOfType(fusion_); - if (mma_exprs.size() == 1) { - mma_utils::MulSumProperties props; - props.insouts = { - static_cast(mma_exprs.front()->inA()), - static_cast(mma_exprs.front()->inB()), - static_cast(mma_exprs.front()->out())}; - mul_sum_props_.push_back(props); - } else { - traverse(fusion_); - } - is_valid_ = (mul_sum_props_.size() == 1) ? true : false; -} +} // namespace -const std::vector& CombineMulSum::getMulSumCanidates( - const bool refresh_data) { - if (refresh_data) { - mul_sum_props_.clear(); - generateMulSumCanidates(); - } - return mul_sum_props_; +std::vector findMatmulPatterns(Fusion* fusion) { + return MatmulPatternMatcher::run(fusion); } -void CombineMulSum::replaceWithMmaOp() { - // Recreate the mul-sum pairs since someone - // may run this function more than once. - generateMulSumCanidates(); - addMMAOp(fusion_, mul_sum_props_); - return; +MmaOp* MatmulPattern::translateToMmaOp() { + if (auto mma_op = dynamic_cast(output->definition())) { + // No translation needed + return mma_op; + } else if (auto mop = dynamic_cast(output->definition())) { + // MatmulOp takes inputs whose sizes are [..., M, K] and [..., K, N], so we + // must transpose B then broadcast both operands before creating the final + // op. + // + // Also note that the output of MatmulOp is a tensor of shape [..., M, N] + // whose matches that of the inputs. We will most commonly then also need to + // cast the output of the MmaOp to produce the output TensorView. + TensorView* Btrans = transpose(B); + TensorView* Abcast = unsqueeze(A, -2); + TensorView* Bbcast = unsqueeze(Btrans, -3); + TensorView* fms = fusedMultiplySum(Abcast, Bbcast, {-1}); + auto mma_op = fms->definition()->as(); + // Update operands to keep the pattern minimal + A = Abcast; + B = Bbcast; + // TODO: skip downcasting if the only uses of `output` are casts back to + // higher precision in order avoid the round trip cast in defining an + // epilogue that starts with MatmulOp. + if (output->dtype() != fms->dtype()) { + // Redefine output as cast of MmaOp->out() + IrBuilder::create(UnaryOpType::Cast, output, fms); + // Update output so that cast is part of the epilogue + output = fms; + } else { + // No cast needed, for example the inputs might be Float + ir_utils::transferDefinitionToNewOutputs(fms->definition(), {output}); + } + return mma_op; + } else if (auto rop = dynamic_cast(output->definition())) { + Val* init = IrBuilder::create(0.0, output->dtype()); + // This replaces the mul and sum by overwriting output->definition() + return IrBuilder::create(output, A, B, init); + } + NVF_ERROR( + false, + "Could not translate matmul pattern with output ", + output->toString(), + " to MmaOp"); } -char dtypeToChar(const DataType& dtype) { - if (dtype == DataType::Half) { - return 'H'; - } else if (dtype == DataType::BFloat16) { - return 'T'; - } else if (dtype == DataType::Float) { - return 'S'; - } else if (dtype == DataType::Double) { - return 'D'; +std::unordered_map MatmulPattern::getDimRoles( + IdModel& id_model) const { + id_model.maybeBuildGraph(IdMappingMode::EXACT); + const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); + + // There are four types of ValGroup involved in a MatmulPattern: M, N, K, and + // Batch. These are enumerated in the MatmulDomain enum class. They are + // defined by their membership as follows: + // M: present in A and output, but not B + // N: present in B and output, but not A + // K: present in A and B, but not output + // Batch: present in all A, B, and Batch + // If there are other membership patterns, for example a ValGroup present in + // only A, then we should raise an exception here. + + // Indicates whether a ValGroup is present in A (bit 0), B (bit 1), or output + // (bit 2) + using ValGroupPresence = std::bitset<3>; + + std::unordered_map present_flags; + const auto recordPresence = [&exact_graph, &present_flags]( + TensorView* tv, size_t tensor_num) { + for (IterDomain* id : tv->getMaybeRFactorDomain()) { + if (id->isReduction() || id->isBroadcast()) { + // ignore reductions and broadcasts since they don't exact map to + // problem dims + continue; + } + const ValGroup& g = exact_graph.toGroup(id); + present_flags[g].set(tensor_num); + } + }; + recordPresence(A, 0); + recordPresence(B, 1); + recordPresence(output, 2); + + std::unordered_map dim_to_domain; + for (const auto& [g, flags] : present_flags) { + if (flags.all()) { + dim_to_domain[g] = MatmulDomain::Batch; + } else if (flags.test(0) && flags.test(1)) { + dim_to_domain[g] = MatmulDomain::K; + } else if (flags.test(0) && flags.test(2)) { + dim_to_domain[g] = MatmulDomain::M; + } else if (flags.test(1) && flags.test(2)) { + dim_to_domain[g] = MatmulDomain::N; + } else { + NVF_ERROR( + false, + "IterDomain ValGroup should be present in at least two of A, B, and output. flags: ", + flags); + } } - NVF_ERROR(false, "Unsupported dtype for matmul: ", dtype); - return 0; + + return dim_to_domain; } } // namespace mma_utils diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index 8b2a6861bdf..937072a7f30 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -9,8 +9,10 @@ #include #include +#include #include #include +#include #include #include #include @@ -242,34 +244,39 @@ class DataWrapperOpt { } }; -// This struct hold properties of a Mul and Sum pair -// which can possibly be replaced a Mma op. This struct -// can be be created (partially) from a Mma op. -struct MulSumProperties { - // The Mul amd Sum op which can be replaced by a Mma op. - struct MulAndSumOps { - BinaryOp* mop = nullptr; - ReductionOp* redop = nullptr; - }; - - // The inputs/ouputs to the possible Mma Op or the actual Mma op. - struct InputsOutputs { - TensorView* a = nullptr; - TensorView* b = nullptr; - TensorView* out = nullptr; - }; - - // The broadcasts which feed the Mma op/Mul-Sum pair. - struct Broadcasts { - BroadcastOp* bcast_a = nullptr; - BroadcastOp* bcast_b = nullptr; - }; - - MulAndSumOps mulsumops; - InputsOutputs insouts; - Broadcasts bcasts; +//! This represents a single matmul operation, without a prologue or epilogue. +//! Each matmul has two inputs which might not be fusion inputs: A and B. It +//! also has one output, which can be Float or reduced precision. For MatmulOp +//! and LinearOp, the output is the same dtype as the inputs; so output does not +//! necessarily correspond to the output of a translated MmaOp and it might not +//! be a fusion output. +struct MatmulPattern { + TensorView* A; + TensorView* B; + // This is not necessarily a Fusion output, but rather is the immediate output + // representing a matmul in the current Fusion. The definition of this tensor + // determines what kind of translation is needed, if any. Possible definition + // Expr types are: MmaOp, ReductionOp (for mul-sum patterns), MatmulOp, and + // LinearOp. + TensorView* output; + + //! If the pattern is not already represented by an MmaOp, for example if + //! there is a MatmulOp instead, this function modifies the fusion to insert + //! an MmaOp. TensorViews A and B are unchanged, but this->output might be + //! updated to reflect the replacement tensor. + MmaOp* translateToMmaOp(); + + //! Given an IdModel, map groups of IterDomains to dimension roles + //! (MatmulDomain). Note that ValGroup is a shared_ptr to a + //! VectorOfUniqueEntries. We copy these as keys so that the returned + //! object can safely outlive id_model. + std::unordered_map getDimRoles( + IdModel& id_model) const; }; +//! Traverse the fusion to find supported matmul patterns +std::vector findMatmulPatterns(Fusion* fusion); + using MatmulProblemLayoutOpt = DataWrapperOpt; using ProblemIterDomainsOpt = DataWrapperOpt; using RolesMapOpt = DataWrapperOpt; @@ -290,13 +297,19 @@ using DependenciesMap = std::map; //! transposition of inputs in mma instructions, while other (e.g. Turing, //! Ampere) the only supported transposition is TN which means that mma //! instruction first input is transposed, the second input is non-transposed. -NVF_API MatmulProblemLayoutOpt getMmaLayout( - Fusion* fusion, - const mma_utils::MulSumProperties::InputsOutputs& props); +NVF_API MatmulProblemLayoutOpt +getProblemLayout(Fusion* fusion, const MatmulPattern& pattern); //! This overloaded version is just a wrapper on the above function, where -//! the mma_utils::MulSumProperties::InputsOutputs is extracted from the fusion. -NVF_API MatmulProblemLayoutOpt getMmaLayout(Fusion* fusion); +//! the MatmulPattern is extracted from the fusion. +NVF_API MatmulProblemLayoutOpt getProblemLayout(Fusion* fusion); + +//! Determine the problem layout based on allocation domain of inputs. This is +//! called by the above overloads. +NVF_API MatmulProblemLayoutOpt getProblemLayout( + const IdModel& id_model, + const std::unordered_map& group_to_domain, + const RolesMap& roles_map); //! Returns wrapped collection of IterDomains that can be used to get //! problem shape with runtime info. @@ -306,16 +319,15 @@ NVF_API MatmulProblemLayoutOpt getMmaLayout(Fusion* fusion); //! be gathered. //! TODO: 4th domain must be added for batch gemm support. ProblemIterDomainsOpt getProblemIterDomains(Fusion* fusion); -ProblemIterDomainsOpt getProblemIterDomains( - const mma_utils::MulSumProperties::InputsOutputs& props); +ProblemIterDomainsOpt getProblemIterDomains(const MatmulPattern& pattern); //! Returns wrapped collection of TensorView roles in fusion. //! An error message is stored in retruned object if valid data cannot //! be gathered. RolesMapOpt getTensorsRoles( Fusion* fusion, - const mma_utils::MulSumProperties::InputsOutputs& props); -RolesMapOpt getTensorsRoles(Fusion* fusion); + const IdModel& id_model, + const std::unordered_map& group_to_domain); //! Return pair of whether use shared memory epilogue or not and whether to //! reuse shared memory for the prologue at the expense of an additional block @@ -346,50 +358,6 @@ NVF_API std::pair generateSharedMemoryEpilogueHeuristics( bool smem_b_reuse_guaranteed = false, bool ignore_occupancy_drop = false); -//! Go through the fusion IR to find combinations of mul-sum -//! which can be replaced with a mma op. This class operates -//! in two phases. It can go through the graph and find the mul-sum -//! pairs which can be replaced by a mma op. This phase returns a vector -//! of properties of the mma op (MulSumAsMmaProps) which would replace the -//! the mul-sum pair. It then exposes a function to replace with mma ops. -class CombineMulSum : public IterVisitor { - public: - CombineMulSum(Fusion* fusion) : IterVisitor(), fusion_(fusion) { - generateMulSumCanidates(); - }; - - const std::vector& getMulSumCanidates( - const bool refresh_data = false); - - //! Goes through the fusion to find mul-sum pairs. - //! If user sets the caching flags and properties have been previously - //! computed, then just return cached results. - void generateMulSumCanidates(); - - //! Replaces the candidate mul-sum pairs with mma ops. - //! Please not this will run generateMulSumCandidates again. - void replaceWithMmaOp(); - - //! Check if the fusion has a mma-op or a mul-sum pair - //! that can be replaced by a mma op. - bool isValid() { - return is_valid_; - } - - protected: - void handle(ReductionOp* stmt) override; - - private: - Fusion* fusion_; - //! This is the list of mul-sum pairs and the properties - //! of the mma op which can replace it. This is only populated - //! if the mul-sum pair is a valid replacement candidate. - std::vector mul_sum_props_ = {}; - //! This variable tracks if the fusion has a mul-sum pair - //! than can be replaced by a mma op, or has a single mma op. - bool is_valid_ = false; -}; - //! Compute the amount of shared memory we expect to need. The actual amount //! allocated will be determined by aliasing (see alias_memory.cpp). This //! function is useful for testing that we provide accurate information to our diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index d5aa7995816..638858075eb 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -67,6 +67,24 @@ class CombineMulSumAsMmaTest : public NVFuserTest { DisableOptionsGuard opt_guard_; }; +void performSubstitution(Fusion* fusion, bool should_not_find = false) { + EXPECT_TRUE(ir_utils::getOpsOfType(fusion).empty()); + + std::vector patterns = + mma_utils::findMatmulPatterns(fusion); + if (should_not_find) { + EXPECT_TRUE(patterns.empty()); + return; + } + + ASSERT_FALSE(patterns.empty()); + EXPECT_EQ(patterns.size(), 1); + + patterns.front().translateToMmaOp(); + + ASSERT_FALSE(ir_utils::getOpsOfType(fusion).empty()); +} + // Test checks to see that the combiner can correctly replace // the mul-sum pair with a mma op. TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Pass) { @@ -86,17 +104,13 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Pass) { fusion.addOutput(tv3); - ASSERT_TRUE(ir_utils::getOpsOfType(&fusion).empty()); - - nvfuser::mma_utils::CombineMulSum combiner(&fusion); - combiner.replaceWithMmaOp(); - - ASSERT_FALSE(ir_utils::getOpsOfType(&fusion).empty()); + performSubstitution(&fusion); } } -// This test checks that the combiner does not incorrectly -// replace this mul-sum pair, and the mul is not fed by broadcasts ops. +// This test checks that the pattern matcher does not incorrectly identify +// this mul-sum pair, as the mul is not fed by broadcasts ops; i.e. it is +// not a matmul. TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Fail1) { Fusion fusion; FusionGuard fg(&fusion); @@ -110,15 +124,13 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Fail1) { auto tv3 = sum(tv2, {-1}); fusion.addOutput(tv3); - nvfuser::mma_utils::CombineMulSum combiner(&fusion); - combiner.replaceWithMmaOp(); - - ASSERT_TRUE(ir_utils::getOpsOfType(&fusion).empty()); + performSubstitution(&fusion, /*should_not_find=*/true); } -// This test checks to see that the mul-sum combiner does not -// combine a mul-sum which does not have appropriate broadcasts. -TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Fail2) { +// This fusion has more than one broadcasted dimension for each operand, so it +// is currently rejected isMatmulFusionDefinitionSupported. Still, it is a valid +// MatmulPattern so we check that it is found. +TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_MultipleBroadcasts) { // Assumes layout is kAllSupportedMmaLayout::NT; Fusion fusion; FusionGuard fg(&fusion); @@ -145,12 +157,7 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Fail2) { auto tv3 = sum(tv2, {-1}); fusion.addOutput(tv3); - ASSERT_TRUE(ir_utils::getOpsOfType(&fusion).empty()); - - nvfuser::mma_utils::CombineMulSum combiner(&fusion); - combiner.replaceWithMmaOp(); - - ASSERT_TRUE(ir_utils::getOpsOfType(&fusion).empty()); + performSubstitution(&fusion, /*should_not_find=*/false); } // As a sanity check we test that after replacing a mul-sum @@ -174,11 +181,8 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Schedule) { auto tv2 = sum(mul(tv0, tv1), {-1}); fusion.addOutput(tv2); - ASSERT_TRUE(ir_utils::getOpsOfType(&fusion).empty()); - nvfuser::mma_utils::CombineMulSum combiner(&fusion); - combiner.replaceWithMmaOp(); - ASSERT_FALSE(ir_utils::getOpsOfType(&fusion).empty()); + performSubstitution(&fusion); MatMulTileOptions gemm_tile; gemm_tile.cta_tile = GemmTile(128, 128, 32); diff --git a/tests/cpp/test_gpu_tensorcore.cpp b/tests/cpp/test_gpu_tensorcore.cpp index 44907f102c6..24dff7fe504 100644 --- a/tests/cpp/test_gpu_tensorcore.cpp +++ b/tests/cpp/test_gpu_tensorcore.cpp @@ -3160,7 +3160,7 @@ TEST_F(GPUTTensorCoreTest, MisalignedVectorization) { fusion->addOutput(tv2); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index ea49711aa67..63d15e21e3f 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -188,7 +188,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueBias) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -289,7 +289,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueRelu) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -396,7 +396,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueBiasRelu) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -499,7 +499,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueReluAux) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -614,7 +614,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueBiasReluAux) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -717,7 +717,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueGelu) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -813,7 +813,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueGeluAux) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -922,7 +922,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueBiasGelu) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1039,7 +1039,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueBiasGeluAux) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1163,7 +1163,7 @@ TEST_F(MatmulSchedulerTest, BasicMatmulStrictCheckTT) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must be always TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1226,7 +1226,7 @@ TEST_F(MatmulSchedulerTest, BasicMatmulRelaxedCheck) { .value(), "the MmaOp layout of Ampere MMA must be always TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1285,7 +1285,7 @@ TEST_F(MatmulSchedulerTest, BasicMatmulInputShuffledTT) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must be always TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1342,7 +1342,7 @@ TEST_F(MatmulSchedulerTest, EpilogueOutputCast) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1405,7 +1405,7 @@ TEST_F(MatmulSchedulerTest, EpilogueAlpha) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1470,7 +1470,7 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaOutputCast) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1544,7 +1544,7 @@ TEST_F(MatmulSchedulerTest, EpilogueBeta) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1626,7 +1626,7 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaBeta) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1714,7 +1714,7 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaBetaGeluOutputCast) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1806,7 +1806,7 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaBetaBias) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1889,7 +1889,7 @@ TEST_F(MatmulSchedulerTest, StridedBatch) { .value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1975,7 +1975,7 @@ TEST_F(MatmulSchedulerTest, StridedBatchEpilogueAlphaBeta) { .value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -2074,7 +2074,7 @@ TEST_F(MatmulSchedulerTest, StridedBatchEpilogueAlphaSingleBeta) { .value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -2161,7 +2161,7 @@ TEST_F(MatmulSchedulerTest, StridedBatchEpilogueBias) { .value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -2241,7 +2241,7 @@ TEST_F(MatmulSchedulerTest, StridedBatchEpilogueSingleBias) { .value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -2310,7 +2310,7 @@ TEST_F(MatmulSchedulerTest, MisalignedVectorization) { fusion->addOutput(tv2); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -2501,7 +2501,7 @@ TEST_F(MatmulSchedulerTest, StridedInputs) { fusion->addOutput(tv2); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); From a583d40d5dc904076c04212ea75a0f77a5731af2 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 20 May 2024 17:16:12 +0000 Subject: [PATCH 02/28] Remove MatmulOp stuff. This will go in another PR --- csrc/scheduler/matmul_utils.cpp | 7 +++---- csrc/scheduler/mma_utils.cpp | 36 --------------------------------- 2 files changed, 3 insertions(+), 40 deletions(-) diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 998706673d9..2a8a96bab81 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -457,8 +457,7 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { // The plan: // 0. Check if the current CUDA device is supported // 1. Check if there is exactly one matmul pattern defined in the fusion. - // 2. Check if inputs to the mma op or mul sum pair match any of - // supported inputs layout + // 2. Check if the input layout for the matmul pattern can be determined // 3. Check if fusion represents expressions that are recognized by matmul // scheduler. @@ -467,8 +466,8 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { const auto device_prop = at::cuda::getCurrentDeviceProperties(); // Use a dummy problem shape to determine whether this is a supported // device. - const auto mma_op = - getMmaOp(device_prop->major * 10 + device_prop->minor, {128, 128, 128, 1}); + const auto mma_op = getMmaOp( + device_prop->major * 10 + device_prop->minor, {128, 128, 128, 1}); if (!mma_op.has_value()) { return "Unsupported device compute capability"; } diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 5014f30f194..332ae41cb55 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1441,13 +1441,6 @@ class MatmulPatternMatcher : IterVisitor { private: using IterVisitor::handle; - void handle(MatmulOp* mop) override { - MatmulPattern& pattern = patterns_.emplace_back(); - pattern.A = mop->inA()->as(); - pattern.B = mop->inB()->as(); - pattern.output = mop->out()->as(); - } - // Handle the case when no translation is needed. void handle(MmaOp* mop) override { MatmulPattern& pattern = patterns_.emplace_back(); @@ -1525,35 +1518,6 @@ MmaOp* MatmulPattern::translateToMmaOp() { if (auto mma_op = dynamic_cast(output->definition())) { // No translation needed return mma_op; - } else if (auto mop = dynamic_cast(output->definition())) { - // MatmulOp takes inputs whose sizes are [..., M, K] and [..., K, N], so we - // must transpose B then broadcast both operands before creating the final - // op. - // - // Also note that the output of MatmulOp is a tensor of shape [..., M, N] - // whose matches that of the inputs. We will most commonly then also need to - // cast the output of the MmaOp to produce the output TensorView. - TensorView* Btrans = transpose(B); - TensorView* Abcast = unsqueeze(A, -2); - TensorView* Bbcast = unsqueeze(Btrans, -3); - TensorView* fms = fusedMultiplySum(Abcast, Bbcast, {-1}); - auto mma_op = fms->definition()->as(); - // Update operands to keep the pattern minimal - A = Abcast; - B = Bbcast; - // TODO: skip downcasting if the only uses of `output` are casts back to - // higher precision in order avoid the round trip cast in defining an - // epilogue that starts with MatmulOp. - if (output->dtype() != fms->dtype()) { - // Redefine output as cast of MmaOp->out() - IrBuilder::create(UnaryOpType::Cast, output, fms); - // Update output so that cast is part of the epilogue - output = fms; - } else { - // No cast needed, for example the inputs might be Float - ir_utils::transferDefinitionToNewOutputs(fms->definition(), {output}); - } - return mma_op; } else if (auto rop = dynamic_cast(output->definition())) { Val* init = IrBuilder::create(0.0, output->dtype()); // This replaces the mul and sum by overwriting output->definition() From f5ec534f2bb5eeaa0f713dcd0449ece858de47d3 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 20 May 2024 19:24:01 +0000 Subject: [PATCH 03/28] Fix multiple-broadcasts test. This fixes #2273 --- csrc/ir/utils.cpp | 2 +- csrc/scheduler/mma_utils.cpp | 38 ++++++++++++++++++++++++++++-- csrc/scheduler/mma_utils.h | 2 ++ tests/cpp/test_combine_mul_sum.cpp | 32 ++++++++++++++++++------- 4 files changed, 62 insertions(+), 12 deletions(-) diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 265b2b1a202..aa47e88e29d 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -1290,7 +1290,7 @@ MmaOpDetails getMmaOpDetails( const auto validateOutputDetails = [](const TensorViewDetails& details, const std::string& desc) { // TODO: revise rules when add support for batch gemms - NVF_ERROR(details.bcasts.empty(), desc, ": has broadcast domains."); + // NVF_ERROR(details.bcasts.empty(), desc, ": has broadcast domains."); NVF_ERROR(!details.rdomains.empty(), desc, ": has no reduction domains."); NVF_ERROR( (details.cdomains.size() >= expected_gemm_cdomains), diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 332ae41cb55..ece39683446 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1259,10 +1259,14 @@ RolesMapOpt getTensorsRoles( bool has_m = false, has_n = false, has_k = false, has_unmapped = false; for (IterDomain* id : TensorDomain::noReductions(tv->getMaybeRFactorDomain())) { + if (id->isBroadcast()) { + // Ignore broadcasts in output + continue; + } const ValGroup& g = exact_graph.toGroup(id); auto it = group_to_domain.find(g); if (it == group_to_domain.end()) { - // tv has an unmapped dimension + // output tv has an unmapped non-broadcast dimension has_unmapped = true; continue; } @@ -1459,6 +1463,19 @@ class MatmulPatternMatcher : IterVisitor { if (bop->getBinaryOpType() != BinaryOpType::Mul) { return; } + // TODO: Allow multiple K dimensions + // Check that there's a single K dimension + bool has_k = false; + for (IterDomain* id : + rop->out()->as()->getMaybeRFactorDomain()) { + if (id->isReduction()) { + if (has_k) { + return; + } + has_k = true; + } + } + // Remember that we are just gathering the immediate inputs to the // matmul, so there should be no prologue between a, b and the mul/sum. @@ -1481,8 +1498,16 @@ class MatmulPatternMatcher : IterVisitor { bool has_m = false, has_n = false; for (size_t i : c10::irange(lrf.size())) { if (lrf[i]->isBroadcast() && !rrf[i]->isBroadcast()) { + if (has_m) { + // TODO: Handle multiple M dimensions + return; + } has_m = true; } else if (!lrf[i]->isBroadcast() && rrf[i]->isBroadcast()) { + if (has_n) { + // TODO: Handle multiple N dimensions + return; + } has_n = true; } if (red_root[i]->isReduction()) { @@ -1493,7 +1518,7 @@ class MatmulPatternMatcher : IterVisitor { } } if (!has_m || !has_n) { - // This is an ordinary reduction, not a matmul + // This is an ordinary reduction or mat-vec, not a matmul return; } @@ -1514,6 +1539,15 @@ std::vector findMatmulPatterns(Fusion* fusion) { return MatmulPatternMatcher::run(fusion); } +std::string MatmulPattern::toString() const { + std::stringstream ss; + ss << "MatmulPattern{"; + ss << "\n A=" << A->toString(); + ss << "\n B=" << B->toString(); + ss << "\n output=" << output->toString() << "\n}"; + return ss.str(); +} + MmaOp* MatmulPattern::translateToMmaOp() { if (auto mma_op = dynamic_cast(output->definition())) { // No translation needed diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index 937072a7f30..7e5b2f774ce 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -272,6 +272,8 @@ struct MatmulPattern { //! object can safely outlive id_model. std::unordered_map getDimRoles( IdModel& id_model) const; + + std::string toString() const; }; //! Traverse the fusion to find supported matmul patterns diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index 638858075eb..2311d1d1f2a 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -127,18 +127,17 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Fail1) { performSubstitution(&fusion, /*should_not_find=*/true); } -// This fusion has more than one broadcasted dimension for each operand, so it -// is currently rejected isMatmulFusionDefinitionSupported. Still, it is a valid -// MatmulPattern so we check that it is found. +// This fusion has Broadcast batch axes in each operand. TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_MultipleBroadcasts) { // Assumes layout is kAllSupportedMmaLayout::NT; - Fusion fusion; - FusionGuard fg(&fusion); + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion* fusion = fusion_ptr.get(); + FusionGuard fg(fusion); auto tv0 = makeContigTensor(2, DataType::Half); auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); + fusion->addInput(tv0); + fusion->addInput(tv1); auto tv0t = transpose(tv0, 0, 1); auto tv1t = transpose(tv1, 0, 1); @@ -155,9 +154,24 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_MultipleBroadcasts) { auto tv1b = broadcast(tv1t, bcast_dims); auto tv2 = mul(tv0b, tv1b); auto tv3 = sum(tv2, {-1}); - fusion.addOutput(tv3); + fusion->addOutput(tv3); + + performSubstitution(fusion, /*should_not_find=*/false); + + // We test running this fusion also to verify that the broadcast batch + // dimension does not cause unforeseen issues + + int64_t M = 256, N = 128, K = 64; + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({K, M}, options); + auto t1 = at::randn({K, N}, options); + auto tref = at::linear(t0.t(), t1.t()).unsqueeze(1); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({t0, t1}); - performSubstitution(&fusion, /*should_not_find=*/false); + testValidate( + executor_cache.fusion(), outputs, {t0, t1}, {tref}, __LINE__, __FILE__); } // As a sanity check we test that after replacing a mul-sum From 1a0fdb9bb7deaaa6fbbddcf954b9f9a7c886358f Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 20 May 2024 19:35:43 +0000 Subject: [PATCH 04/28] Remove bcast output test --- csrc/ir/utils.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index aa47e88e29d..eb89e323f98 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -1290,7 +1290,6 @@ MmaOpDetails getMmaOpDetails( const auto validateOutputDetails = [](const TensorViewDetails& details, const std::string& desc) { // TODO: revise rules when add support for batch gemms - // NVF_ERROR(details.bcasts.empty(), desc, ": has broadcast domains."); NVF_ERROR(!details.rdomains.empty(), desc, ": has no reduction domains."); NVF_ERROR( (details.cdomains.size() >= expected_gemm_cdomains), From 9e6447d69cb94ca4c90cbd62086d3ce4acbc5c57 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 20 May 2024 19:37:38 +0000 Subject: [PATCH 05/28] Remove canScheduleCompileTime check for ExprEval --- csrc/scheduler/expr_eval_sched.cpp | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/csrc/scheduler/expr_eval_sched.cpp b/csrc/scheduler/expr_eval_sched.cpp index 8a138c92026..b25a290ce99 100644 --- a/csrc/scheduler/expr_eval_sched.cpp +++ b/csrc/scheduler/expr_eval_sched.cpp @@ -16,19 +16,13 @@ namespace nvfuser { // Check if the fusion has a single MatmulOp/LinearOp node bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { auto exprs = fusion->exprs(); - if (!isOptionDisabled(DisableOption::MatmulExprEval)) { - if (exprs.size() == 1 && - (exprs.front()->isA() || exprs.front()->isA())) { - return true; - } - scheduler_debug_utils::canScheduleRejectReason( - heuristicType(), - "Fusion must contain a single expression of type MatmulOp or LinearOp"); - } else { - scheduler_debug_utils::canScheduleRejectReason( - heuristicType(), - "Matmul ATen evaluation was disabled by NVFUSER_DISABLE=matmul_expr_eval"); + if (exprs.size() == 1 && + (exprs.front()->isA() || exprs.front()->isA())) { + return true; } + scheduler_debug_utils::canScheduleRejectReason( + heuristicType(), + "Fusion must contain a single expression of type MatmulOp or LinearOp"); return false; } From 1cfac85f1d267a3f1585e0f33cf632f13a20ef2a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 21 May 2024 00:20:12 +0000 Subject: [PATCH 06/28] Fix gcc build failure due to unused variable --- csrc/scheduler/mma_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index ece39683446..e67916dad35 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1552,7 +1552,7 @@ MmaOp* MatmulPattern::translateToMmaOp() { if (auto mma_op = dynamic_cast(output->definition())) { // No translation needed return mma_op; - } else if (auto rop = dynamic_cast(output->definition())) { + } else if (output->definition()->isA()) { Val* init = IrBuilder::create(0.0, output->dtype()); // This replaces the mul and sum by overwriting output->definition() return IrBuilder::create(output, A, B, init); From 7aad3879ac9b72cd4b2ed58ec49970d97b7a22c5 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 21 May 2024 00:21:45 +0000 Subject: [PATCH 07/28] Fix signed/unsigned compare --- csrc/scheduler/matmul_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 2a8a96bab81..a19ee5f45d6 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -206,7 +206,7 @@ std::string isMatmulFusionDefinitionSupported( for (const auto& [g, dom] : id_roles) { num_axes[(size_t)dom]++; } - constexpr size_t expected_axes_numbers = 1; + constexpr int64_t expected_axes_numbers = 1; if (num_axes[(size_t)MatmulDomain::M] != expected_axes_numbers || num_axes[(size_t)MatmulDomain::N] != expected_axes_numbers || num_axes[(size_t)MatmulDomain::K] != expected_axes_numbers || From 36a72dc12f0a52800683f10e02e7722e496ce516 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 21 May 2024 13:57:48 +0000 Subject: [PATCH 08/28] Allow multiple M, N, or K dims in pattern match We can still refuse to schedule, but these are valid patterns --- csrc/scheduler/mma_utils.cpp | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index e67916dad35..b566f52e812 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1463,19 +1463,6 @@ class MatmulPatternMatcher : IterVisitor { if (bop->getBinaryOpType() != BinaryOpType::Mul) { return; } - // TODO: Allow multiple K dimensions - // Check that there's a single K dimension - bool has_k = false; - for (IterDomain* id : - rop->out()->as()->getMaybeRFactorDomain()) { - if (id->isReduction()) { - if (has_k) { - return; - } - has_k = true; - } - } - // Remember that we are just gathering the immediate inputs to the // matmul, so there should be no prologue between a, b and the mul/sum. @@ -1498,16 +1485,8 @@ class MatmulPatternMatcher : IterVisitor { bool has_m = false, has_n = false; for (size_t i : c10::irange(lrf.size())) { if (lrf[i]->isBroadcast() && !rrf[i]->isBroadcast()) { - if (has_m) { - // TODO: Handle multiple M dimensions - return; - } has_m = true; } else if (!lrf[i]->isBroadcast() && rrf[i]->isBroadcast()) { - if (has_n) { - // TODO: Handle multiple N dimensions - return; - } has_n = true; } if (red_root[i]->isReduction()) { From 1f713d38cd3d45f35ac8ac4c79e8bbca3ef9a990 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Tue, 21 May 2024 10:04:39 -0400 Subject: [PATCH 09/28] Update csrc/scheduler/mma_utils.cpp Co-authored-by: Priya Mishra <52657555+Priya2698@users.noreply.github.com> --- csrc/scheduler/mma_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index b566f52e812..4f50ebaf3a7 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1554,7 +1554,7 @@ std::unordered_map MatmulPattern::getDimRoles( // M: present in A and output, but not B // N: present in B and output, but not A // K: present in A and B, but not output - // Batch: present in all A, B, and Batch + // Batch: present in all A, B, and output // If there are other membership patterns, for example a ValGroup present in // only A, then we should raise an exception here. From 6937ff85f9aa036408de910c4caa8c7be693db87 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 21 May 2024 14:31:10 +0000 Subject: [PATCH 10/28] Add comment about why casts are often present --- csrc/scheduler/mma_utils.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 4f50ebaf3a7..234996a1b3f 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1468,6 +1468,13 @@ class MatmulPatternMatcher : IterVisitor { // Check that the inputs have broadcasts that are not all in common, i.e. // that there is at least one M and at least one N dimension. + + // Note that there might be a cast to Float just before the multiply. This + // happens when using the `mul` op with reduced precision inputs. It can + // also happen if the inputs to `mul` in the definition were Float, but + // the Fusion was segmented and casts to half precision were inserted at + // the segmentation edge (see castInputOutputToLowerPrecision in + // fusion_segmenter.cpp). TensorView* ltv = getTensorviewPriorToCast(bop->lhs()->as()); TensorView* rtv = getTensorviewPriorToCast(bop->rhs()->as()); From 345b025a4bd9c77d629e55e8d45c63e9d20864bf Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 21 May 2024 19:35:48 +0000 Subject: [PATCH 11/28] Remove getProblemIterDomains --- csrc/scheduler/mma_utils.cpp | 48 ------------------------------------ csrc/scheduler/mma_utils.h | 10 -------- 2 files changed, 58 deletions(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 234996a1b3f..32e508505bb 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1054,54 +1054,6 @@ inline void resolveTvToMatmulDomainsMapping( } // anonymous namespace -ProblemIterDomainsOpt getProblemIterDomains(const MatmulPattern& pattern) { - // NOTE: the iter domains of MMA output should be [...,M,K,N] - IterDomain* m = nullptr; - IterDomain* n = nullptr; - IterDomain* k = nullptr; - - const auto& leaf_domains = pattern.output->getLeafDomain(); - const auto concrete = TensorDomain::noDevices( - TensorDomain::noReductions(TensorDomain::noBroadcasts(leaf_domains))); - if (concrete.size() < MIN_MATMUL_INPUTS_NUMBER) { - std::stringstream ss; - ss << "Failed to find the minimum number of MMA input candidates, expected " - << MIN_MATMUL_INPUTS_NUMBER << ", got " << concrete.size(); - return ss.str(); - } - - // M,N are inner most concrete iter domains - m = concrete.rbegin()[1]; - n = concrete.rbegin()[0]; - - // K is a reduction domain, search for the inner most reduction domain - for (auto iter_domain = leaf_domains.rbegin(); - iter_domain != leaf_domains.rend(); - ++iter_domain) { - if ((*iter_domain)->isReduction()) { - k = *iter_domain; - break; - } - } - NVF_ERROR(k != nullptr, "Failed to find K domain in MMA output"); - - return ProblemIterDomains{m, n, k}; -} - -ProblemIterDomainsOpt getProblemIterDomains(Fusion* fusion) { - auto mma_exprs = ir_utils::getOpsOfType(fusion); - if (mma_exprs.size() != 1) { - std::stringstream ss; - ss << "Invalid number of MmaOp instances in fusion, expected 1, got " - << mma_exprs.size(); - return ss.str(); - } - return getProblemIterDomains( - {static_cast(mma_exprs.front()->inA()), - static_cast(mma_exprs.front()->inB()), - static_cast(mma_exprs.front()->out())}); -} - MatmulProblemLayoutOpt getProblemLayout(Fusion* fusion) { const std::vector patterns = findMatmulPatterns(fusion); if (patterns.size() != 1) { diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index 7e5b2f774ce..947f002f3a5 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -313,16 +313,6 @@ NVF_API MatmulProblemLayoutOpt getProblemLayout( const std::unordered_map& group_to_domain, const RolesMap& roles_map); -//! Returns wrapped collection of IterDomains that can be used to get -//! problem shape with runtime info. -//! Data is stored in the order in which lables are defined in MatmulDomain -//! enum class, that is in the following order: m, n, k. -//! An error message is stored in retruned object if valid data cannot -//! be gathered. -//! TODO: 4th domain must be added for batch gemm support. -ProblemIterDomainsOpt getProblemIterDomains(Fusion* fusion); -ProblemIterDomainsOpt getProblemIterDomains(const MatmulPattern& pattern); - //! Returns wrapped collection of TensorView roles in fusion. //! An error message is stored in retruned object if valid data cannot //! be gathered. From 142f3662677506f7200b74efc86ec144c96033fd Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 21 May 2024 19:37:07 +0000 Subject: [PATCH 12/28] Rename group_to_domain -> dim_roles --- csrc/scheduler/matmul_utils.cpp | 4 ++-- csrc/scheduler/mma_utils.cpp | 22 +++++++++++----------- csrc/scheduler/mma_utils.h | 4 ++-- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index a19ee5f45d6..3c2e6bea62f 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -166,10 +166,10 @@ inline bool initCoreHeuristics( //! MatmulDomain::Batch, we evaluate the extent of each, then we multiply those //! dimensions together to get the overall batch size. ProblemShape getProblemShape( - const std::unordered_map& group_to_domain, + const std::unordered_map& dim_roles, SchedulerRuntimeInfo& runtime_info) { ProblemShape shape{1, 1, 1, 1}; - for (const auto& [g, dom] : group_to_domain) { + for (const auto& [g, dom] : dim_roles) { NVF_ERROR(!g->empty()); IterDomain* id = g->front()->as(); const PolymorphicValue extent = diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 32e508505bb..9debfd1b20b 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1067,10 +1067,10 @@ MatmulProblemLayoutOpt getProblemLayout(Fusion* fusion) { MatmulProblemLayoutOpt getProblemLayout( const IdModel& id_model, - const std::unordered_map& group_to_domain, + const std::unordered_map& dim_roles, const RolesMap& roles_map) { // Assumes the exact graph has already been built, since we've been provided - // group_to_domain + // dim_roles const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); // Note: using DataWrapperOpt would be preferable here. However, @@ -1079,7 +1079,7 @@ MatmulProblemLayoutOpt getProblemLayout( // constructor for DataWrapperOpt to prevent inadvertent copying. To avoid // this complication I'm using a simple pair for the lambda's result type. using InnerDomResult = std::pair; - const auto innerDomain = [&roles_map, &group_to_domain, &exact_graph]( + const auto innerDomain = [&roles_map, &dim_roles, &exact_graph]( MatmulRole role) -> InnerDomResult { const auto role_it = roles_map.find(role); if (role_it == roles_map.end()) { @@ -1090,8 +1090,8 @@ MatmulProblemLayoutOpt getProblemLayout( IterDomain* inner_id = TensorDomain::noReductions(tv->getMaybeAllocationDomain()).back(); const ValGroup& g = exact_graph.toGroup(inner_id); - auto g_it = group_to_domain.find(g); - if (g_it == group_to_domain.end()) { + auto g_it = dim_roles.find(g); + if (g_it == dim_roles.end()) { return { MatmulDomain::M, "Inner domain of tensor was not mapped to a MatmulDomain"}; @@ -1150,7 +1150,7 @@ MatmulProblemLayoutOpt getProblemLayout( RolesMapOpt getTensorsRoles( Fusion* fusion, const IdModel& id_model, - const std::unordered_map& group_to_domain) { + const std::unordered_map& dim_roles) { const auto mma_input_candidates = ir_utils::filterByType(fusion->inputs()).vector(); if (mma_input_candidates.empty()) { @@ -1165,7 +1165,7 @@ RolesMapOpt getTensorsRoles( RolesMap roles_map; // Assumes the exact graph has already been built, since we've been provided - // group_to_domain + // dim_roles const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); for (TensorView* tv : mma_input_candidates) { @@ -1177,8 +1177,8 @@ RolesMapOpt getTensorsRoles( continue; } const ValGroup& g = exact_graph.toGroup(id); - auto it = group_to_domain.find(g); - if (it == group_to_domain.end()) { + auto it = dim_roles.find(g); + if (it == dim_roles.end()) { // tv has an unmapped non-broadcast and non-reduction dimension has_unmapped = true; continue; @@ -1216,8 +1216,8 @@ RolesMapOpt getTensorsRoles( continue; } const ValGroup& g = exact_graph.toGroup(id); - auto it = group_to_domain.find(g); - if (it == group_to_domain.end()) { + auto it = dim_roles.find(g); + if (it == dim_roles.end()) { // output tv has an unmapped non-broadcast dimension has_unmapped = true; continue; diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index 947f002f3a5..908590cdd7e 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -310,7 +310,7 @@ NVF_API MatmulProblemLayoutOpt getProblemLayout(Fusion* fusion); //! called by the above overloads. NVF_API MatmulProblemLayoutOpt getProblemLayout( const IdModel& id_model, - const std::unordered_map& group_to_domain, + const std::unordered_map& dim_roles, const RolesMap& roles_map); //! Returns wrapped collection of TensorView roles in fusion. @@ -319,7 +319,7 @@ NVF_API MatmulProblemLayoutOpt getProblemLayout( RolesMapOpt getTensorsRoles( Fusion* fusion, const IdModel& id_model, - const std::unordered_map& group_to_domain); + const std::unordered_map& dim_roles); //! Return pair of whether use shared memory epilogue or not and whether to //! reuse shared memory for the prologue at the expense of an additional block From 8bfd292c0d5c42cbf424a19c61c4e57ac0be009c Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Tue, 21 May 2024 19:05:28 -0400 Subject: [PATCH 13/28] Update csrc/scheduler/matmul.cpp Co-authored-by: Gao, Xiang --- csrc/scheduler/matmul.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index 4a636a596ff..60ed2d5c719 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -764,7 +764,7 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { IdModel id_model(fusion); std::unordered_map id_roles = patterns.front().getDimRoles(id_model); - const auto& roles_map_opt = + const auto& tensor_roles_opt = mma_utils::getTensorsRoles(fusion, id_model, id_roles); // NOTE: the contents of roles_map have been already validated during From 57974a3e0043e49dbc1a89f471a475af9d7215ad Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Tue, 21 May 2024 19:05:46 -0400 Subject: [PATCH 14/28] Update csrc/scheduler/matmul.cpp Co-authored-by: Gao, Xiang --- csrc/scheduler/matmul.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index 60ed2d5c719..96dfcc60927 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -770,7 +770,7 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { // NOTE: the contents of roles_map have been already validated during // compute-time checks NVF_ERROR(roles_map_opt.isValid(), roles_map_opt.getErrorMsg()); - const auto roles_map = roles_map_opt.getData(); + const auto tensor_roles = tensor_roles_opt.getData(); const mma_utils::MatmulProblemLayoutOpt fusion_layout = mma_utils::getProblemLayout(id_model, id_roles, roles_map); From e0619cf10bce013ccaa5d46ec47922cb35b28909 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 21 May 2024 23:32:10 +0000 Subject: [PATCH 15/28] Rename most occurences of roles_map -> tensor_roles --- csrc/scheduler/matmul.cpp | 15 ++++++++------- csrc/scheduler/mma_utils.cpp | 26 +++++++++++++------------- csrc/scheduler/mma_utils.h | 2 +- 3 files changed, 22 insertions(+), 21 deletions(-) diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index 96dfcc60927..72c0b766af1 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -767,18 +767,18 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { const auto& tensor_roles_opt = mma_utils::getTensorsRoles(fusion, id_model, id_roles); - // NOTE: the contents of roles_map have been already validated during + // NOTE: the contents of tensor_roles have been already validated during // compute-time checks - NVF_ERROR(roles_map_opt.isValid(), roles_map_opt.getErrorMsg()); + NVF_ERROR(tensor_roles_opt.isValid(), tensor_roles_opt.getErrorMsg()); const auto tensor_roles = tensor_roles_opt.getData(); const mma_utils::MatmulProblemLayoutOpt fusion_layout = - mma_utils::getProblemLayout(id_model, id_roles, roles_map); + mma_utils::getProblemLayout(id_model, id_roles, tensor_roles); NVF_ERROR(fusion_layout.isValid(), fusion_layout.getErrorMsg()); // Core roles: there can be only one... TV with assigned core role - TensorView* a = roles_map.at(MatmulRole::INPUT_A).front(); - TensorView* b = roles_map.at(MatmulRole::INPUT_B).front(); + TensorView* a = tensor_roles.at(MatmulRole::INPUT_A).front(); + TensorView* b = tensor_roles.at(MatmulRole::INPUT_B).front(); // Collect mma swizzle info auto mma = mma_ops.front(); @@ -790,7 +790,8 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { const auto& gemm_tile = params.tile_sizes; const bool has_epilogue = !mma->out()->isFusionOutput(); - const bool has_fusion_c_roles = (0 != roles_map.count(MatmulRole::INPUT_C)); + const bool has_fusion_c_roles = + (0 != tensor_roles.count(MatmulRole::INPUT_C)); const bool has_non_mma_input_tvs = has_epilogue && has_fusion_c_roles; // Including current tensor naming convention for reference, @@ -1243,7 +1244,7 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { // operations, input tvs with non-core roles // core roles: essential for matmul, for example mma inputs' producers if (has_non_mma_input_tvs) { - scheduleFusionInputsForEpilogue(roles_map, params.use_smem_epilogue); + scheduleFusionInputsForEpilogue(tensor_roles, params.use_smem_epilogue); } scheduleSplitKSum( diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 9debfd1b20b..a4b600a5c2a 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1068,7 +1068,7 @@ MatmulProblemLayoutOpt getProblemLayout(Fusion* fusion) { MatmulProblemLayoutOpt getProblemLayout( const IdModel& id_model, const std::unordered_map& dim_roles, - const RolesMap& roles_map) { + const RolesMap& tensor_roles) { // Assumes the exact graph has already been built, since we've been provided // dim_roles const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); @@ -1079,11 +1079,11 @@ MatmulProblemLayoutOpt getProblemLayout( // constructor for DataWrapperOpt to prevent inadvertent copying. To avoid // this complication I'm using a simple pair for the lambda's result type. using InnerDomResult = std::pair; - const auto innerDomain = [&roles_map, &dim_roles, &exact_graph]( + const auto innerDomain = [&tensor_roles, &dim_roles, &exact_graph]( MatmulRole role) -> InnerDomResult { - const auto role_it = roles_map.find(role); - if (role_it == roles_map.end()) { - return {MatmulDomain::M, "Could not find role in roles_map"}; + const auto role_it = tensor_roles.find(role); + if (role_it == tensor_roles.end()) { + return {MatmulDomain::M, "Could not find role in tensor_roles"}; } std::optional group_inner_dom = std::nullopt; for (TensorView* tv : role_it->second) { @@ -1162,7 +1162,7 @@ RolesMapOpt getTensorsRoles( return {"Failed to find any TV that is fusion output"}; } - RolesMap roles_map; + RolesMap tensor_roles; // Assumes the exact graph has already been built, since we've been provided // dim_roles @@ -1192,16 +1192,16 @@ RolesMapOpt getTensorsRoles( continue; } if (has_m && has_k && !has_n) { - roles_map[MatmulRole::INPUT_A].push_back(tv); + tensor_roles[MatmulRole::INPUT_A].push_back(tv); continue; } if (has_n && has_k && !has_m) { - roles_map[MatmulRole::INPUT_B].push_back(tv); + tensor_roles[MatmulRole::INPUT_B].push_back(tv); continue; } // Bias vectors are assigned to INPUT_C role if (!has_k) { - roles_map[MatmulRole::INPUT_C].push_back(tv); + tensor_roles[MatmulRole::INPUT_C].push_back(tv); continue; } } @@ -1256,13 +1256,13 @@ RolesMapOpt getTensorsRoles( // NOTE: currently, we pick as a reference tensor one with `m` and `n` // IterDomains and the most uses auto pos = storage.begin(); - roles_map[MatmulRole::OUTPUT_D].push_back(*pos); + tensor_roles[MatmulRole::OUTPUT_D].push_back(*pos); for (++pos; pos != storage.end(); ++pos) { - roles_map[MatmulRole::OUTPUT_AUX].push_back(*pos); + tensor_roles[MatmulRole::OUTPUT_AUX].push_back(*pos); } } - for (auto& [role, tvs] : roles_map) { + for (auto& [role, tvs] : tensor_roles) { // NOTE: sort input roles in descending order by uses() size, and // if equal then by name() to ensure the stable ordering of tensor // views in collections assigned to the supported roles @@ -1273,7 +1273,7 @@ RolesMapOpt getTensorsRoles( }); } - return roles_map; + return tensor_roles; } namespace { diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index 908590cdd7e..8f399bb338f 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -311,7 +311,7 @@ NVF_API MatmulProblemLayoutOpt getProblemLayout(Fusion* fusion); NVF_API MatmulProblemLayoutOpt getProblemLayout( const IdModel& id_model, const std::unordered_map& dim_roles, - const RolesMap& roles_map); + const RolesMap& tensor_roles); //! Returns wrapped collection of TensorView roles in fusion. //! An error message is stored in retruned object if valid data cannot From b0fa936c8148b7f5d4ef506334595b3e560ca29d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 21 May 2024 23:32:30 +0000 Subject: [PATCH 16/28] Remove getProblemLayout(Fusion*, const MatmulPattern&) --- csrc/scheduler/mma_utils.cpp | 21 ++++++++------------- csrc/scheduler/mma_utils.h | 13 ++++--------- 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index a4b600a5c2a..2fb68e95f77 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1062,7 +1062,14 @@ MatmulProblemLayoutOpt getProblemLayout(Fusion* fusion) { << patterns.size(); return ss.str(); } - return getProblemLayout(fusion, patterns.front()); + const MatmulPattern& pattern = patterns[0]; + IdModel id_model(fusion); + const auto id_roles = pattern.getDimRoles(id_model); + const auto tensor_roles_opt = getTensorsRoles(fusion, id_model, id_roles); + if (!tensor_roles_opt.isValid()) { + return {tensor_roles_opt.getErrorMsg()}; + } + return getProblemLayout(id_model, id_roles, tensor_roles_opt.getData()); } MatmulProblemLayoutOpt getProblemLayout( @@ -1135,18 +1142,6 @@ MatmulProblemLayoutOpt getProblemLayout( NVF_ERROR(false, "Reached unreachable section of getProblemLayout"); } -MatmulProblemLayoutOpt getProblemLayout( - Fusion* fusion, - const MatmulPattern& pattern) { - IdModel id_model(fusion); - const auto id_roles = pattern.getDimRoles(id_model); - const auto roles_map_opt = getTensorsRoles(fusion, id_model, id_roles); - if (!roles_map_opt.isValid()) { - return {roles_map_opt.getErrorMsg()}; - } - return getProblemLayout(id_model, id_roles, roles_map_opt.getData()); -} - RolesMapOpt getTensorsRoles( Fusion* fusion, const IdModel& id_model, diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index 8f399bb338f..84d295afd7a 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -299,20 +299,15 @@ using DependenciesMap = std::map; //! transposition of inputs in mma instructions, while other (e.g. Turing, //! Ampere) the only supported transposition is TN which means that mma //! instruction first input is transposed, the second input is non-transposed. -NVF_API MatmulProblemLayoutOpt -getProblemLayout(Fusion* fusion, const MatmulPattern& pattern); - -//! This overloaded version is just a wrapper on the above function, where -//! the MatmulPattern is extracted from the fusion. -NVF_API MatmulProblemLayoutOpt getProblemLayout(Fusion* fusion); - -//! Determine the problem layout based on allocation domain of inputs. This is -//! called by the above overloads. NVF_API MatmulProblemLayoutOpt getProblemLayout( const IdModel& id_model, const std::unordered_map& dim_roles, const RolesMap& tensor_roles); +//! This version assumes the Fusion contains a single MatmulPattern, then builds +//! an IdModel and infers dim roles then calls the above function. +// NVF_API MatmulProblemLayoutOpt getProblemLayout(Fusion* fusion); + //! Returns wrapped collection of TensorView roles in fusion. //! An error message is stored in retruned object if valid data cannot //! be gathered. From 76eb037c955e8d691ca425fe7b4270916b5dac87 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 21 May 2024 23:34:10 +0000 Subject: [PATCH 17/28] Uncomment getProblemLayout --- csrc/scheduler/mma_utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index 84d295afd7a..1b15ff76c22 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -306,7 +306,7 @@ NVF_API MatmulProblemLayoutOpt getProblemLayout( //! This version assumes the Fusion contains a single MatmulPattern, then builds //! an IdModel and infers dim roles then calls the above function. -// NVF_API MatmulProblemLayoutOpt getProblemLayout(Fusion* fusion); +NVF_API MatmulProblemLayoutOpt getProblemLayout(Fusion* fusion); //! Returns wrapped collection of TensorView roles in fusion. //! An error message is stored in retruned object if valid data cannot From e1044d9b2053aad20db932944b15b27101b7d860 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 21 May 2024 23:45:34 +0000 Subject: [PATCH 18/28] Replace bitwise assignment ops --- csrc/scheduler/mma_utils.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 2fb68e95f77..af688f16571 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1178,9 +1178,9 @@ RolesMapOpt getTensorsRoles( has_unmapped = true; continue; } - has_m |= it->second == MatmulDomain::M; - has_n |= it->second == MatmulDomain::N; - has_k |= it->second == MatmulDomain::K; + has_m = has_m || it->second == MatmulDomain::M; + has_n = has_n || it->second == MatmulDomain::N; + has_k = has_k || it->second == MatmulDomain::K; } if (has_unmapped) { // Don't map TVs to roles if they have unmapped dims @@ -1217,9 +1217,9 @@ RolesMapOpt getTensorsRoles( has_unmapped = true; continue; } - has_m |= it->second == MatmulDomain::M; - has_n |= it->second == MatmulDomain::N; - has_k |= it->second == MatmulDomain::K; + has_m = has_m || it->second == MatmulDomain::M; + has_n = has_n || it->second == MatmulDomain::N; + has_k = has_k || it->second == MatmulDomain::K; } // NOTE: depending on fusion definition k domain may appear in the output: // - for mma_output == fusion output k domain is present From 990f2be67c683cfa2305c07047b3b33dbc5b43f5 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 21 May 2024 23:51:29 +0000 Subject: [PATCH 19/28] Rename dim_to_domain -> dim_roles --- csrc/scheduler/mma_utils.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index af688f16571..9ff4d74b01e 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1533,16 +1533,16 @@ std::unordered_map MatmulPattern::getDimRoles( recordPresence(B, 1); recordPresence(output, 2); - std::unordered_map dim_to_domain; + std::unordered_map dim_roles; for (const auto& [g, flags] : present_flags) { if (flags.all()) { - dim_to_domain[g] = MatmulDomain::Batch; + dim_roles[g] = MatmulDomain::Batch; } else if (flags.test(0) && flags.test(1)) { - dim_to_domain[g] = MatmulDomain::K; + dim_roles[g] = MatmulDomain::K; } else if (flags.test(0) && flags.test(2)) { - dim_to_domain[g] = MatmulDomain::M; + dim_roles[g] = MatmulDomain::M; } else if (flags.test(1) && flags.test(2)) { - dim_to_domain[g] = MatmulDomain::N; + dim_roles[g] = MatmulDomain::N; } else { NVF_ERROR( false, @@ -1551,7 +1551,7 @@ std::unordered_map MatmulPattern::getDimRoles( } } - return dim_to_domain; + return dim_roles; } } // namespace mma_utils From c3753cb3176f066402d5b1ca03aec7f6cf2cdc1d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 22 May 2024 12:14:40 +0000 Subject: [PATCH 20/28] Add comment describing isMatmulFusionDefinitionSupported --- csrc/scheduler/matmul_utils.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 3c2e6bea62f..c80676acb3c 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -181,6 +181,13 @@ ProblemShape getProblemShape( return shape; } +// Checks that this pattern: +// - is a GEMM or batch GEMM +// - has at least two inputs i.e. not A @ A.T +// - has a single A and a single B operand i.e not A @ (B1 * B2) +// - has a fusion output with OUTPUT_D role i.e. that has M, N dims +// - includes all fusion inputs/outputs in its tensor roles +// - has no fusion inputs with non-trivial allocation domain std::string isMatmulFusionDefinitionSupported( Fusion* fusion, const mma_utils::MatmulPattern& pattern, From 4fac4fbb56a9ea9620dba10b7b4bce4d484f8bd1 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 22 May 2024 12:26:22 +0000 Subject: [PATCH 21/28] Use lambda to simplify getTensorsRoles --- csrc/scheduler/mma_utils.cpp | 55 ++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 30 deletions(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 9ff4d74b01e..8915aa80a33 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1163,8 +1163,15 @@ RolesMapOpt getTensorsRoles( // dim_roles const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); - for (TensorView* tv : mma_input_candidates) { - bool has_m = false, has_n = false, has_k = false, has_unmapped = false; + struct DimPresence { + bool m = false; + bool n = false; + bool k = false; + bool unmapped = false; + }; + + const auto findDims = [&dim_roles, &exact_graph](TensorView* tv) { + DimPresence has; for (IterDomain* id : TensorDomain::noReductions(tv->getMaybeRFactorDomain())) { if (id->isBroadcast()) { @@ -1175,27 +1182,32 @@ RolesMapOpt getTensorsRoles( auto it = dim_roles.find(g); if (it == dim_roles.end()) { // tv has an unmapped non-broadcast and non-reduction dimension - has_unmapped = true; + has.unmapped = true; continue; } - has_m = has_m || it->second == MatmulDomain::M; - has_n = has_n || it->second == MatmulDomain::N; - has_k = has_k || it->second == MatmulDomain::K; + has.m = has.m || it->second == MatmulDomain::M; + has.n = has.n || it->second == MatmulDomain::N; + has.k = has.k || it->second == MatmulDomain::K; } - if (has_unmapped) { + return has; + }; + + for (TensorView* tv : mma_input_candidates) { + DimPresence has = findDims(tv); + if (has.unmapped) { // Don't map TVs to roles if they have unmapped dims continue; } - if (has_m && has_k && !has_n) { + if (has.m && has.k && !has.n) { tensor_roles[MatmulRole::INPUT_A].push_back(tv); continue; } - if (has_n && has_k && !has_m) { + if (has.n && has.k && !has.m) { tensor_roles[MatmulRole::INPUT_B].push_back(tv); continue; } // Bias vectors are assigned to INPUT_C role - if (!has_k) { + if (!has.k) { tensor_roles[MatmulRole::INPUT_C].push_back(tv); continue; } @@ -1203,29 +1215,12 @@ RolesMapOpt getTensorsRoles( std::vector storage; for (TensorView* tv : mma_output_candidates) { - bool has_m = false, has_n = false, has_k = false, has_unmapped = false; - for (IterDomain* id : - TensorDomain::noReductions(tv->getMaybeRFactorDomain())) { - if (id->isBroadcast()) { - // Ignore broadcasts in output - continue; - } - const ValGroup& g = exact_graph.toGroup(id); - auto it = dim_roles.find(g); - if (it == dim_roles.end()) { - // output tv has an unmapped non-broadcast dimension - has_unmapped = true; - continue; - } - has_m = has_m || it->second == MatmulDomain::M; - has_n = has_n || it->second == MatmulDomain::N; - has_k = has_k || it->second == MatmulDomain::K; - } + DimPresence has = findDims(tv); // NOTE: depending on fusion definition k domain may appear in the output: // - for mma_output == fusion output k domain is present // - for mma_output != fusion output (fusion with epilogue) k domain // is not present - if (has_k || has_unmapped) { + if (has.k || has.unmapped) { // Don't map TVs to output roles if they have unmapped dims, or if they // have K dimension continue; @@ -1233,7 +1228,7 @@ RolesMapOpt getTensorsRoles( // NOTE: the core fusion output tensors are the ones with m and n // domains - if (has_m && has_n) { + if (has.m && has.n) { storage.push_back(tv); } } From 9d31a900f81b25c21deb12183eae691cde005fd0 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 22 May 2024 12:28:15 +0000 Subject: [PATCH 22/28] Rename getTensorsRoles -> getTensorRoles --- csrc/scheduler/matmul.cpp | 2 +- csrc/scheduler/matmul_utils.cpp | 4 ++-- csrc/scheduler/mma_utils.cpp | 4 ++-- csrc/scheduler/mma_utils.h | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index b8ced3d1bb7..d43357e5fd0 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -765,7 +765,7 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { std::unordered_map id_roles = patterns.front().getDimRoles(id_model); const auto& tensor_roles_opt = - mma_utils::getTensorsRoles(fusion, id_model, id_roles); + mma_utils::getTensorRoles(fusion, id_model, id_roles); // NOTE: the contents of tensor_roles have been already validated during // compute-time checks diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index c80676acb3c..015937046a0 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -497,7 +497,7 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { IdModel id_model(fusion); const auto id_roles = patterns.front().getDimRoles(id_model); const mma_utils::RolesMapOpt roles_map_opt = - mma_utils::getTensorsRoles(fusion, id_model, id_roles); + mma_utils::getTensorRoles(fusion, id_model, id_roles); if (!roles_map_opt.isValid()) { return {roles_map_opt.getErrorMsg()}; } @@ -576,7 +576,7 @@ std::shared_ptr getMatmulHeuristics( params->mma_macro = mma_op.value(); const auto& roles_map_opt = - mma_utils::getTensorsRoles(fusion, id_model, id_roles); + mma_utils::getTensorRoles(fusion, id_model, id_roles); NVF_ERROR(roles_map_opt.isValid(), "Tensor roles map in mma is not valid."); const auto roles_map = roles_map_opt.getData(); diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 8915aa80a33..e4d64b1d493 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1065,7 +1065,7 @@ MatmulProblemLayoutOpt getProblemLayout(Fusion* fusion) { const MatmulPattern& pattern = patterns[0]; IdModel id_model(fusion); const auto id_roles = pattern.getDimRoles(id_model); - const auto tensor_roles_opt = getTensorsRoles(fusion, id_model, id_roles); + const auto tensor_roles_opt = getTensorRoles(fusion, id_model, id_roles); if (!tensor_roles_opt.isValid()) { return {tensor_roles_opt.getErrorMsg()}; } @@ -1142,7 +1142,7 @@ MatmulProblemLayoutOpt getProblemLayout( NVF_ERROR(false, "Reached unreachable section of getProblemLayout"); } -RolesMapOpt getTensorsRoles( +RolesMapOpt getTensorRoles( Fusion* fusion, const IdModel& id_model, const std::unordered_map& dim_roles) { diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index 1b15ff76c22..68ff171eeb9 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -311,7 +311,7 @@ NVF_API MatmulProblemLayoutOpt getProblemLayout(Fusion* fusion); //! Returns wrapped collection of TensorView roles in fusion. //! An error message is stored in retruned object if valid data cannot //! be gathered. -RolesMapOpt getTensorsRoles( +RolesMapOpt getTensorRoles( Fusion* fusion, const IdModel& id_model, const std::unordered_map& dim_roles); From a99c9ae407cc537fa8af44cb33641951eab0de22 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 22 May 2024 13:42:31 +0000 Subject: [PATCH 23/28] Assume alloc=rfactor to determine M, N and A, B --- csrc/scheduler/mma_utils.cpp | 54 ++++++++++++++++++++++++------ tests/cpp/test_combine_mul_sum.cpp | 37 ++++++++++++++++++++ 2 files changed, 81 insertions(+), 10 deletions(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index e4d64b1d493..2fdd8a663d6 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1387,6 +1387,11 @@ class MatmulPatternMatcher : IterVisitor { private: using IterVisitor::handle; + // TODO: These methods currently assume the output will have allocation domain + // equal to its rfactor. However, if the rfactor domain is specified, or if + // there is a transpose operation in the epilogue, then this assumption will + // be violated. In such cases we should actually swap and transpose A and B. + // Handle the case when no translation is needed. void handle(MmaOp* mop) override { MatmulPattern& pattern = patterns_.emplace_back(); @@ -1431,18 +1436,47 @@ class MatmulPatternMatcher : IterVisitor { const std::vector& red_root = rop->out()->as()->getRootDomain(); NVF_ERROR(red_root.size() == lrf.size()); + // Find innermost M or N dimension in output + // We will assume for now that the output rfactor domain matches the + // fusion output's allocation domain; in particular that the innermost + // dimension is an N dimension. This allows us to determine which of lhs + // and rhs is A and B. + // TODO: analyze fusion outputs to determine N dimensions + bool lhs_is_A = true; bool has_m = false, has_n = false; - for (size_t i : c10::irange(lrf.size())) { - if (lrf[i]->isBroadcast() && !rrf[i]->isBroadcast()) { - has_m = true; - } else if (!lrf[i]->isBroadcast() && rrf[i]->isBroadcast()) { - has_n = true; - } - if (red_root[i]->isReduction()) { + // Loop backwards to find inner-most Iteration domain in output + for (int64_t i = red_root.size() - 1; i >= 0; --i) { + IterDomain* lhs_id = lrf[(size_t)i]; + IterDomain* rhs_id = rrf[(size_t)i]; + IterDomain* out_id = red_root[(size_t)i]; + if (out_id->isIteration()) { + if (lhs_id->isBroadcast() != rhs_id->isBroadcast()) { + // This is either an M or N dimension + + // Operand domains must be Broadcast and Iteration + NVF_ERROR(lhs_id->isIteration() || rhs_id->isIteration()); + + if (!has_n) { + // This is the inner-most output non-batch dim, so it is N + has_n = true; + // rhs is B if it has this dimension + lhs_is_A = rhs_id->isIteration(); + continue; + } + // We have found the inner-most N dim, so we can now use lhs_is_A to + // tell whether this is M or N + has_m = has_m || (lhs_is_A && lhs_id->isIteration()) || + (!lhs_is_A && (rhs_id->isIteration())); + } + // out_id could also be a batch dim + } else if (out_id->isReduction()) { // matmul must be contraction of non-broadcast dimensions - if (!lrf[i]->isIteration() || !rrf[i]->isIteration()) { + if (!lhs_id->isIteration() || !rhs_id->isIteration()) { return; } + } else if (!out_id->isBroadcast()) { + // Reduction output ID should be iteration, reduction, or broadcast + return; } } if (!has_m || !has_n) { @@ -1451,8 +1485,8 @@ class MatmulPatternMatcher : IterVisitor { } MatmulPattern& pattern = patterns_.emplace_back(); - pattern.A = ltv; - pattern.B = rtv; + pattern.A = lhs_is_A ? ltv : rtv; + pattern.B = lhs_is_A ? rtv : ltv; pattern.output = rop->out()->as(); } } diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index d7fd2c91eb6..ab88bf4c031 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -269,4 +269,41 @@ TEST_F(CombineMulSumAsMmaTest, UseMatmulScheduler) { } } +// Check that we determine A and B properly when they are swapped as inputs to +// mul +TEST_F(CombineMulSumAsMmaTest, SwapAandB) { + for (auto layout : kAllSupportedMmaLayout) { + for (bool swap : {false, true}) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A); + tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B); + // We should identify tv0 as A and tv1 as B regardless of the order here + auto tv2 = swap ? mul(tv1, tv0) : mul(tv0, tv1); + auto tv3 = sum(tv2, {-1}); + + fusion.addOutput(tv3); + + std::vector patterns = + mma_utils::findMatmulPatterns(&fusion); + + ASSERT_FALSE(patterns.empty()); + EXPECT_EQ(patterns.size(), 1); + + EXPECT_EQ(patterns.front().A, tv0); + EXPECT_EQ(patterns.front().B, tv1); + + patterns.front().translateToMmaOp(); + + ASSERT_FALSE(ir_utils::getOpsOfType(&fusion).empty()); + } + } +} + } // namespace nvfuser From bd6517ce8083bbc5389ea397aa5a78da32b1482f Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 22 May 2024 14:05:12 +0000 Subject: [PATCH 24/28] Test that dim role mapping survives swap --- tests/cpp/test_combine_mul_sum.cpp | 31 ++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index ab88bf4c031..cb85281f22a 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -296,10 +296,33 @@ TEST_F(CombineMulSumAsMmaTest, SwapAandB) { ASSERT_FALSE(patterns.empty()); EXPECT_EQ(patterns.size(), 1); - EXPECT_EQ(patterns.front().A, tv0); - EXPECT_EQ(patterns.front().B, tv1); - - patterns.front().translateToMmaOp(); + mma_utils::MatmulPattern& pattern = patterns.front(); + + EXPECT_EQ(pattern.A, tv0); + EXPECT_EQ(pattern.B, tv1); + EXPECT_EQ(pattern.output, tv3); + + pattern.translateToMmaOp(); + + // Check that we didn't modify the pattern roles + EXPECT_EQ(pattern.A, tv0); + EXPECT_EQ(pattern.B, tv1); + EXPECT_EQ(pattern.output, tv3); + + // Check that we properly map M and N to their roles even with swap + IdModel id_model(&fusion); + std::unordered_map dim_roles = + pattern.getDimRoles(id_model); + ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); + const ValGroup& m_gp = exact_graph.toGroup(tv0->axis(-3)); + auto m_it = dim_roles.find(m_gp); + ASSERT_NE(m_it, dim_roles.end()); + EXPECT_EQ(m_it->second, MatmulDomain::M); + + const ValGroup& n_gp = exact_graph.toGroup(tv1->axis(-2)); + auto n_it = dim_roles.find(n_gp); + ASSERT_NE(n_it, dim_roles.end()); + EXPECT_EQ(n_it->second, MatmulDomain::N); ASSERT_FALSE(ir_utils::getOpsOfType(&fusion).empty()); } From 8dc69487eaf1b119e805c9a14799646ec940786f Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 22 May 2024 14:05:27 +0000 Subject: [PATCH 25/28] Rename ValGroupPresence to DimPresence --- csrc/scheduler/mma_utils.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 2fdd8a663d6..745b2b7e484 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1538,14 +1538,14 @@ std::unordered_map MatmulPattern::getDimRoles( // N: present in B and output, but not A // K: present in A and B, but not output // Batch: present in all A, B, and output - // If there are other membership patterns, for example a ValGroup present in - // only A, then we should raise an exception here. + // If there are other patterns, for example a ValGroup present in only A, then + // we should raise an exception here. // Indicates whether a ValGroup is present in A (bit 0), B (bit 1), or output // (bit 2) - using ValGroupPresence = std::bitset<3>; + using DimPresence = std::bitset<3>; - std::unordered_map present_flags; + std::unordered_map present_flags; const auto recordPresence = [&exact_graph, &present_flags]( TensorView* tv, size_t tensor_num) { for (IterDomain* id : tv->getMaybeRFactorDomain()) { From 71128e49a725310b05344be552a03baf309053a4 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 22 May 2024 14:29:37 +0000 Subject: [PATCH 26/28] Use std::variant for error This also combines more code. --- csrc/scheduler/mma_utils.cpp | 48 +++++++++++++++++------------------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 745b2b7e484..c575d8ae259 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1084,13 +1084,14 @@ MatmulProblemLayoutOpt getProblemLayout( // using DataWrapperOpt(std::move(dom)) leads to a clang-tidy // warning because MatmulDomain is trivially movable. There is only a move // constructor for DataWrapperOpt to prevent inadvertent copying. To avoid - // this complication I'm using a simple pair for the lambda's result type. - using InnerDomResult = std::pair; - const auto innerDomain = [&tensor_roles, &dim_roles, &exact_graph]( - MatmulRole role) -> InnerDomResult { + // this complication I'm using an unwrapped variant for the lambda's result + // type. + using UnitDimOpt = std::variant; + const auto findUnitDim = + [&tensor_roles, &dim_roles, &exact_graph](MatmulRole role) -> UnitDimOpt { const auto role_it = tensor_roles.find(role); if (role_it == tensor_roles.end()) { - return {MatmulDomain::M, "Could not find role in tensor_roles"}; + return "Could not find role in tensor_roles"; } std::optional group_inner_dom = std::nullopt; for (TensorView* tv : role_it->second) { @@ -1099,44 +1100,41 @@ MatmulProblemLayoutOpt getProblemLayout( const ValGroup& g = exact_graph.toGroup(inner_id); auto g_it = dim_roles.find(g); if (g_it == dim_roles.end()) { - return { - MatmulDomain::M, - "Inner domain of tensor was not mapped to a MatmulDomain"}; + return "Inner domain of tensor was not mapped to a MatmulDomain"; } if (!group_inner_dom.has_value()) { group_inner_dom = g_it->second; } else if (group_inner_dom.value() != g_it->second) { - return { - MatmulDomain::M, "Group contains multiple inner dimension domains"}; + return "Group contains multiple inner dimension domains"; } } if (!group_inner_dom.has_value()) { - return {MatmulDomain::M, "No tensor found in role"}; + return "No tensor found in role"; } - return {group_inner_dom.value(), ""}; + return group_inner_dom.value() == MatmulDomain::K ? UnitDim::K + : UnitDim::M_or_N; }; - const InnerDomResult a_dom_res = innerDomain(MatmulRole::INPUT_A); - if (!a_dom_res.second.empty()) { - std::string err = a_dom_res.second; + const UnitDimOpt unitdim_a_opt = findUnitDim(MatmulRole::INPUT_A); + if (std::holds_alternative(unitdim_a_opt)) { + std::string err = std::get(unitdim_a_opt); return err; } - const bool kinner_a = a_dom_res.first == MatmulDomain::K; - - const InnerDomResult b_dom_res = innerDomain(MatmulRole::INPUT_B); - if (!b_dom_res.second.empty()) { - std::string err = b_dom_res.second; + const UnitDimOpt unitdim_b_opt = findUnitDim(MatmulRole::INPUT_B); + if (std::holds_alternative(unitdim_b_opt)) { + std::string err = std::get(unitdim_b_opt); return err; } - const bool kinner_b = b_dom_res.first == MatmulDomain::K; + const UnitDim unitdim_a = std::get(unitdim_a_opt); + const UnitDim unitdim_b = std::get(unitdim_b_opt); - if (kinner_a && kinner_b) { + if (unitdim_a == UnitDim::K && unitdim_b == UnitDim::K) { return MmaLayout::TN; - } else if (kinner_a && !kinner_b) { + } else if (unitdim_a == UnitDim::K && unitdim_b == UnitDim::M_or_N) { return MmaLayout::TT; - } else if (!kinner_a && !kinner_b) { + } else if (unitdim_a == UnitDim::M_or_N && unitdim_b == UnitDim::M_or_N) { return MmaLayout::NT; - } else if (!kinner_a && kinner_b) { + } else if (unitdim_a == UnitDim::M_or_N && unitdim_b == UnitDim::K) { return MmaLayout::NN; } NVF_ERROR(false, "Reached unreachable section of getProblemLayout"); From 71ba8a272adaca781827591054d6b2da699e555b Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 22 May 2024 14:31:32 +0000 Subject: [PATCH 27/28] clang-tidy fix --- csrc/scheduler/mma_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index c575d8ae259..1b05d80bccf 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1443,7 +1443,7 @@ class MatmulPatternMatcher : IterVisitor { bool lhs_is_A = true; bool has_m = false, has_n = false; // Loop backwards to find inner-most Iteration domain in output - for (int64_t i = red_root.size() - 1; i >= 0; --i) { + for (int64_t i = (int64_t)red_root.size() - 1; i >= 0; --i) { IterDomain* lhs_id = lrf[(size_t)i]; IterDomain* rhs_id = rrf[(size_t)i]; IterDomain* out_id = red_root[(size_t)i]; From 0deb0bfe4981ff7071113973eebb8281fdcfab8c Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 22 May 2024 23:15:29 +0000 Subject: [PATCH 28/28] Use noReductions/noBroadcasts to simplify hasTrivialAllocationDomain --- csrc/ir/utils.cpp | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index e3cda0b863d..864b7162781 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -1109,23 +1109,8 @@ bool hasTrivialAllocationDomain(const TensorView* tv) { } const std::vector& alloc = tv->getMaybeAllocationDomain(); const std::vector& rf = tv->getMaybeRFactorDomain(); - size_t i = 0, j = 0; - while (i < alloc.size() && j < rf.size()) { - if (alloc[i]->isBroadcast() || alloc[i]->isReduction()) { - i++; - continue; - } - if (rf[j]->isBroadcast() || rf[j]->isReduction()) { - j++; - continue; - } - if (!alloc[i]->sameAs(rf[j])) { - return false; - } - i++; - j++; - } - return true; + return TensorDomain::noBroadcasts(TensorDomain::noReductions(rf)) == + TensorDomain::noBroadcasts(TensorDomain::noReductions(alloc)); } } // namespace nvfuser::ir_utils