diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 9bdd05ef6f5..864b7162781 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -1103,6 +1103,16 @@ int64_t getVectorizeSize(const TensorView* tv) { return 1; } +bool hasTrivialAllocationDomain(const TensorView* tv) { + if (!tv->hasAllocation()) { + return true; + } + const std::vector& alloc = tv->getMaybeAllocationDomain(); + const std::vector& rf = tv->getMaybeRFactorDomain(); + return TensorDomain::noBroadcasts(TensorDomain::noReductions(rf)) == + TensorDomain::noBroadcasts(TensorDomain::noReductions(alloc)); +} + } // namespace nvfuser::ir_utils namespace nvfuser::MmaOpUtils { @@ -1269,7 +1279,6 @@ MmaOpDetails getMmaOpDetails( const auto validateOutputDetails = [](const TensorViewDetails& details, const std::string& desc) { // TODO: revise rules when add support for batch gemms - NVF_ERROR(details.bcasts.empty(), desc, ": has broadcast domains."); NVF_ERROR(!details.rdomains.empty(), desc, ": has no reduction domains."); NVF_ERROR( (details.cdomains.size() >= expected_gemm_cdomains), diff --git a/csrc/ir/utils.h b/csrc/ir/utils.h index 4df8bda664d..a4b5454fa2c 100644 --- a/csrc/ir/utils.h +++ b/csrc/ir/utils.h @@ -652,4 +652,6 @@ std::optional> computePermutation( return permutation; } +bool hasTrivialAllocationDomain(const TensorView* tv); + } // namespace nvfuser::ir_utils diff --git a/csrc/mma_type.h b/csrc/mma_type.h index 921996c062b..60bf608df4b 100644 --- a/csrc/mma_type.h +++ b/csrc/mma_type.h @@ -26,7 +26,7 @@ namespace nvfuser { constexpr std::string_view MATMUL_LOG_PREFIX = "[MATMUL DEBUG] "; //! Named descriptors of domains in matmul -enum class MatmulDomain { M = 0, N, K }; +enum class MatmulDomain { M = 0, N, K, Batch }; //! Named descriptors of TensorView roles in fusion //! INPUT_A - a producer of MMA input A diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index 9b4090a59be..d43357e5fd0 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -749,38 +749,45 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { // Cache and fork outputs auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, true); - mma_utils::CombineMulSum combiner(fusion); - auto mma_ops = ir_utils::getOpsOfType(fusion); - if (combiner.isValid() && mma_ops.empty()) { - combiner.replaceWithMmaOp(); - mma_ops = ir_utils::getOpsOfType(fusion); - } - + std::vector patterns = + mma_utils::findMatmulPatterns(fusion); + NVF_ERROR(!patterns.empty(), "No matmul patterns were found"); NVF_ERROR( - mma_ops.size() == 1, - "scheduleMatmul supports fusion with single mma op in definition, got ", - mma_ops.size()); + patterns.size() == 1, + "Only a single matmul pattern can currently be fused"); + std::vector mma_ops; + mma_ops.reserve(patterns.size()); + for (mma_utils::MatmulPattern& pattern : patterns) { + mma_ops.push_back(pattern.translateToMmaOp()); + } - const auto& roles_map_opt = mma_utils::getTensorsRoles(fusion); + IdModel id_model(fusion); + std::unordered_map id_roles = + patterns.front().getDimRoles(id_model); + const auto& tensor_roles_opt = + mma_utils::getTensorRoles(fusion, id_model, id_roles); - // NOTE: the contents of roles_map have been already validated during + // NOTE: the contents of tensor_roles have been already validated during // compute-time checks - NVF_ERROR(roles_map_opt.isValid(), roles_map_opt.getErrorMsg()); - const auto roles_map = roles_map_opt.getData(); + NVF_ERROR(tensor_roles_opt.isValid(), tensor_roles_opt.getErrorMsg()); + const auto tensor_roles = tensor_roles_opt.getData(); + + const mma_utils::MatmulProblemLayoutOpt fusion_layout = + mma_utils::getProblemLayout(id_model, id_roles, tensor_roles); + NVF_ERROR(fusion_layout.isValid(), fusion_layout.getErrorMsg()); // Core roles: there can be only one... TV with assigned core role - TensorView* a = roles_map.at(MatmulRole::INPUT_A).front(); - TensorView* b = roles_map.at(MatmulRole::INPUT_B).front(); + TensorView* a = tensor_roles.at(MatmulRole::INPUT_A).front(); + TensorView* b = tensor_roles.at(MatmulRole::INPUT_B).front(); + + const auto& gemm_tile = params.tile_sizes; // Collect mma swizzle info auto mma = mma_ops.front(); - const auto fusion_layout = mma_utils::getMmaLayout(fusion); - NVF_ERROR(fusion_layout.isValid(), fusion_layout.getErrorMsg()); - - const auto& gemm_tile = params.tile_sizes; const bool has_epilogue = !mma->out()->isFusionOutput(); - const bool has_fusion_c_roles = (0 != roles_map.count(MatmulRole::INPUT_C)); + const bool has_fusion_c_roles = + (0 != tensor_roles.count(MatmulRole::INPUT_C)); const bool has_non_mma_input_tvs = has_epilogue && has_fusion_c_roles; // Including current tensor naming convention for reference, @@ -1227,7 +1234,7 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { // operations, input tvs with non-core roles // core roles: essential for matmul, for example mma inputs' producers if (has_non_mma_input_tvs) { - scheduleFusionInputsForEpilogue(roles_map, params.use_smem_epilogue); + scheduleFusionInputsForEpilogue(tensor_roles, params.use_smem_epilogue); } scheduleSplitKSum( diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 00fccec35b4..015937046a0 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -15,11 +15,13 @@ // 'SchedulerRuntimeInfo' #include #include +#include #include #include #include #include #include +#include #include #include #include @@ -38,9 +40,8 @@ namespace nvfuser { namespace { -//! Access to the structure should be done with labels defined in -//! MmaOptions::MmaDomains. -using ProblemShape = std::array; +//! Access to the structure should be done with labels defined in MatmulDomain. +using ProblemShape = std::array; //! A helper for deciding the type of MMA op for given fusion and problem shape. inline std::optional getMmaOp( @@ -158,43 +159,44 @@ inline bool initCoreHeuristics( } //! A helper for getting problem shape from fusion and runtime info. +//! +//! For a given domain, try to find the size by evaluating the extent of an +//! IterDomain in each group of that domain type. For example, if there are +//! multiple Batch dimensions, we find all ValGroups that are mapped as +//! MatmulDomain::Batch, we evaluate the extent of each, then we multiply those +//! dimensions together to get the overall batch size. ProblemShape getProblemShape( - const mma_utils::MulSumProperties::InputsOutputs& props, + const std::unordered_map& dim_roles, SchedulerRuntimeInfo& runtime_info) { - const auto mma_output_domains = mma_utils::getProblemIterDomains({props}); - if (!mma_output_domains.isValid()) { - NVF_ERROR(false, mma_output_domains.getErrorMsg()); - } - - const auto [m, n, k] = mma_output_domains.getData(); - - auto m_extend = runtime_info.expressionEvaluator().evaluate(m->extent()); - auto n_extend = runtime_info.expressionEvaluator().evaluate(n->extent()); - auto k_extend = runtime_info.expressionEvaluator().evaluate(k->extent()); - - if (!(m_extend && n_extend && k_extend)) { + ProblemShape shape{1, 1, 1, 1}; + for (const auto& [g, dom] : dim_roles) { + NVF_ERROR(!g->empty()); + IterDomain* id = g->front()->as(); + const PolymorphicValue extent = + runtime_info.expressionEvaluator().evaluate(id->extent()); NVF_ERROR( - false, - "Failed to acquire one of problem dimensions, M(", - m_extend.hasValue(), - "), N(", - n_extend.hasValue(), - " K(", - k_extend.hasValue(), - ")"); + extent.hasValue(), "Could not evaluate extent of ", id->toString()); + shape[(size_t)dom] *= extent.as(); } - - return ProblemShape{ - m_extend.as(), n_extend.as(), k_extend.as()}; + return shape; } +// Checks that this pattern: +// - is a GEMM or batch GEMM +// - has at least two inputs i.e. not A @ A.T +// - has a single A and a single B operand i.e not A @ (B1 * B2) +// - has a fusion output with OUTPUT_D role i.e. that has M, N dims +// - includes all fusion inputs/outputs in its tensor roles +// - has no fusion inputs with non-trivial allocation domain std::string isMatmulFusionDefinitionSupported( Fusion* fusion, - const mma_utils::MulSumProperties::InputsOutputs& props) { + const mma_utils::MatmulPattern& pattern, + const mma_utils::RolesMap& roles_map, + const std::unordered_map& id_roles) { const auto& fusion_inputs = fusion->inputs(); const auto& fusion_outputs = fusion->outputs(); - std::vector mma_inputs = {props.a, props.b}; - const auto mma_output = props.out; + std::vector mma_inputs = {pattern.A, pattern.B}; + const auto mma_output = pattern.output; const auto fusion_inputs_tvs = ir_utils::filterByType(fusion_inputs).vector(); @@ -202,18 +204,20 @@ std::string isMatmulFusionDefinitionSupported( ir_utils::filterByType(fusion_outputs).vector(); constexpr size_t minimal_number_of_inputs = 2; - MmaOpUtils::MmaOpDetails mma_details = - MmaOpUtils::getMmaOpDetails(props.out, props.a, props.b); // Quick checks - MmaOp { // Check if MmaOp represents gemm (requires M/N/K == 1, B == 0) // or bgemm (requires M/N/K/B == 1) - constexpr size_t expected_axes_numbers = 1; - if (mma_details.m_axes.size() != expected_axes_numbers || - mma_details.n_axes.size() != expected_axes_numbers || - mma_details.k_axes.size() != expected_axes_numbers || - mma_details.batch_axes.size() > expected_axes_numbers) { + std::array num_axes{}; + for (const auto& [g, dom] : id_roles) { + num_axes[(size_t)dom]++; + } + constexpr int64_t expected_axes_numbers = 1; + if (num_axes[(size_t)MatmulDomain::M] != expected_axes_numbers || + num_axes[(size_t)MatmulDomain::N] != expected_axes_numbers || + num_axes[(size_t)MatmulDomain::K] != expected_axes_numbers || + num_axes[(size_t)MatmulDomain::Batch] > expected_axes_numbers) { return "MmaOp has unsupported number of one of M/N/K/Batch axes"; } @@ -232,12 +236,6 @@ std::string isMatmulFusionDefinitionSupported( // Fusion topology check { - const auto& roles_map_opt = mma_utils::getTensorsRoles(fusion, props); - if (!roles_map_opt.isValid()) { - return roles_map_opt.getErrorMsg(); - } - - const auto& roles_map = roles_map_opt.getData(); auto entry = roles_map.find(MatmulRole::INPUT_A); std::set tvs_with_roles; @@ -288,6 +286,21 @@ std::string isMatmulFusionDefinitionSupported( } } + // Check that no non-trivial allocation domains are set on inputs or outputs. + // TODO: Lift this requirement once we have proper allocation domain support + for (Val* inp : fusion->inputs()) { + if (auto tv = dynamic_cast(inp); + tv && !ir_utils::hasTrivialAllocationDomain(tv)) { + return "detected input TV with non-trivial allocation domain"; + } + } + for (Val* outp : fusion->outputs()) { + if (auto tv = dynamic_cast(outp); + tv && !ir_utils::hasTrivialAllocationDomain(tv)) { + return "detected output TV with non-trivial allocation domain"; + } + } + return ""; } @@ -450,10 +463,8 @@ std::string getMatmulRunTimeRejectReason( std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { // The plan: // 0. Check if the current CUDA device is supported - // 1. Check if there is exactly one MmaOp or suitable mul sum pair - // defined in the fusion. - // 2. Check if inputs to the mma op or mul sum pair match any of - // supported inputs layout + // 1. Check if there is exactly one matmul pattern defined in the fusion. + // 2. Check if the input layout for the matmul pattern can be determined // 3. Check if fusion represents expressions that are recognized by matmul // scheduler. @@ -462,8 +473,8 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { const auto device_prop = at::cuda::getCurrentDeviceProperties(); // Use a dummy problem shape to determine whether this is a supported // device. - const auto mma_op = - getMmaOp(device_prop->major * 10 + device_prop->minor, {128, 128, 128}); + const auto mma_op = getMmaOp( + device_prop->major * 10 + device_prop->minor, {128, 128, 128, 1}); if (!mma_op.has_value()) { return "Unsupported device compute capability"; } @@ -472,29 +483,37 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { // #1 // Initializing the machinery to check if there's a Mul-Sum pair // can be replaced by a Mma Op. - mma_utils::CombineMulSum combiner(fusion); - if (!combiner.isValid()) { - std::stringstream ss; - ss << "Matmul scheduler supports fusions only with a single mma op" - << "or supports a mul-sum pair which can be replaced with a mma op"; - return ss.str(); + std::vector patterns = + mma_utils::findMatmulPatterns(fusion); + if (patterns.empty()) { + return "No matmul patterns were found"; } - - const std::vector& mma_from_mul_sums = - combiner.getMulSumCanidates(); - // #2 - { - const auto input_layout_opt = - mma_utils::getMmaLayout(fusion, mma_from_mul_sums.front().insouts); - if (!input_layout_opt.isValid()) { - return input_layout_opt.getErrorMsg(); - } + if (patterns.size() > 1) { + return "Only a single matmul pattern can currently be fused"; } // #3 + // Prepare an IdModel which will be reused to check remaining conditions + IdModel id_model(fusion); + const auto id_roles = patterns.front().getDimRoles(id_model); + const mma_utils::RolesMapOpt roles_map_opt = + mma_utils::getTensorRoles(fusion, id_model, id_roles); + if (!roles_map_opt.isValid()) { + return {roles_map_opt.getErrorMsg()}; + } + mma_utils::RolesMap roles_map = roles_map_opt.getData(); + + // #4 + const auto input_layout_opt = + mma_utils::getProblemLayout(id_model, id_roles, roles_map); + if (!input_layout_opt.isValid()) { + return input_layout_opt.getErrorMsg(); + } + + // #5 { auto support_status = isMatmulFusionDefinitionSupported( - fusion, mma_from_mul_sums.front().insouts); + fusion, patterns.front(), roles_map, id_roles); if (!support_status.empty()) { return support_status; } @@ -533,16 +552,21 @@ std::shared_ptr getMatmulHeuristics( params->cparams.index_type = runtime_info.getIndexType(); // Check initial conditions - auto mma_exprs = ir_utils::getOpsOfType(fusion); - mma_utils::CombineMulSum combiner(fusion); + std::vector patterns = + mma_utils::findMatmulPatterns(fusion); + NVF_ERROR(!patterns.empty(), "No matmul patterns were found"); NVF_ERROR( - combiner.isValid(), - "There's no (single) mma op or mul-sum op which mma op can replace") + patterns.size() == 1, + "Only a single matmul pattern can currently be fused"); + mma_utils::MatmulPattern& pattern = patterns.front(); + + // IdModel is used to analyze problem shape & layout + IdModel id_model(fusion); + + const std::unordered_map id_roles = + pattern.getDimRoles(id_model); - const std::vector& mulSum = - combiner.getMulSumCanidates(); - const auto problem_shape = - getProblemShape(mulSum.front().insouts, runtime_info); + const auto problem_shape = getProblemShape(id_roles, runtime_info); const auto device_prop = at::cuda::getCurrentDeviceProperties(); const auto mma_op = @@ -552,7 +576,7 @@ std::shared_ptr getMatmulHeuristics( params->mma_macro = mma_op.value(); const auto& roles_map_opt = - mma_utils::getTensorsRoles(fusion, mulSum.front().insouts); + mma_utils::getTensorRoles(fusion, id_model, id_roles); NVF_ERROR(roles_map_opt.isValid(), "Tensor roles map in mma is not valid."); const auto roles_map = roles_map_opt.getData(); @@ -561,17 +585,17 @@ std::shared_ptr getMatmulHeuristics( if (matmul_heuristic_plugin::hasPlugin()) { const mma_utils::MatmulProblemLayoutOpt layout_opt = - mma_utils::getMmaLayout(fusion, mulSum.front().insouts); + mma_utils::getProblemLayout(id_model, id_roles, roles_map); NVF_ERROR(layout_opt.isValid(), layout_opt.getErrorMsg()); const MmaLayout layout = layout_opt.getData(); // Fill in proper values using plugin matmul_heuristic_plugin::updateMatmulParams( *params, - /*M=*/problem_shape[0], - /*N=*/problem_shape[1], - /*K=*/problem_shape[2], - /*batch_size=*/1, // TODO: extract actual batch size + problem_shape[(size_t)MatmulDomain::M], + problem_shape[(size_t)MatmulDomain::N], + problem_shape[(size_t)MatmulDomain::K], + problem_shape[(size_t)MatmulDomain::Batch], layout, roles_map); } else { diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index ee7be460059..1b05d80bccf 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -9,10 +9,13 @@ #include #include #include +#include #include +#include #include #include #include +#include #include #include "mma_type.h" namespace nvfuser { @@ -1051,154 +1054,96 @@ inline void resolveTvToMatmulDomainsMapping( } // anonymous namespace -ProblemIterDomainsOpt getProblemIterDomains( - const mma_utils::MulSumProperties::InputsOutputs& props) { - // NOTE: the iter domains of MMA output should be [...,M,K,N] - IterDomain* m = nullptr; - IterDomain* n = nullptr; - IterDomain* k = nullptr; - - const auto& leaf_domains = props.out->getLeafDomain(); - const auto concrete = TensorDomain::noDevices( - TensorDomain::noReductions(TensorDomain::noBroadcasts(leaf_domains))); - if (concrete.size() < MIN_MATMUL_INPUTS_NUMBER) { - std::stringstream ss; - ss << "Failed to find the minimum number of MMA input candidates, expected " - << MIN_MATMUL_INPUTS_NUMBER << ", got " << concrete.size(); - return ss.str(); - } - - // M,N are inner most concrete iter domains - m = concrete.rbegin()[1]; - n = concrete.rbegin()[0]; - - // K is a reduction domain, search for the inner most reduction domain - for (auto iter_domain = leaf_domains.rbegin(); - iter_domain != leaf_domains.rend(); - ++iter_domain) { - if ((*iter_domain)->isReduction()) { - k = *iter_domain; - break; - } - } - NVF_ERROR(k != nullptr, "Failed to find K domain in MMA output"); - - return ProblemIterDomains{m, n, k}; -} - -ProblemIterDomainsOpt getProblemIterDomains(Fusion* fusion) { - auto mma_exprs = ir_utils::getOpsOfType(fusion); - if (mma_exprs.size() != 1) { +MatmulProblemLayoutOpt getProblemLayout(Fusion* fusion) { + const std::vector patterns = findMatmulPatterns(fusion); + if (patterns.size() != 1) { std::stringstream ss; ss << "Invalid number of MmaOp instances in fusion, expected 1, got " - << mma_exprs.size(); + << patterns.size(); return ss.str(); } - return getProblemIterDomains( - {static_cast(mma_exprs.front()->inA()), - static_cast(mma_exprs.front()->inB()), - static_cast(mma_exprs.front()->out())}); -} - -MatmulProblemLayoutOpt getMmaLayout( - Fusion* fusion, - const mma_utils::MulSumProperties::InputsOutputs& props) { - ComputeAtMap ca_map(fusion); - const auto mma_input_candidates = - ir_utils::filterByType(fusion->inputs()).vector(); - if (mma_input_candidates.empty()) { - return {"Failed to find any TV that is fusion input"}; - } - - const auto mma_output_domains = getProblemIterDomains(props); - if (!mma_output_domains.isValid()) { - return mma_output_domains.getErrorMsg(); + const MatmulPattern& pattern = patterns[0]; + IdModel id_model(fusion); + const auto id_roles = pattern.getDimRoles(id_model); + const auto tensor_roles_opt = getTensorRoles(fusion, id_model, id_roles); + if (!tensor_roles_opt.isValid()) { + return {tensor_roles_opt.getErrorMsg()}; } + return getProblemLayout(id_model, id_roles, tensor_roles_opt.getData()); +} - const auto domains_data = mma_output_domains.getData(); - const auto m = domains_data[(size_t)MatmulDomain::M]; - const auto n = domains_data[(size_t)MatmulDomain::N]; - const auto k = domains_data[(size_t)MatmulDomain::K]; - - DependenciesMap deps_map; - resolveTvToMatmulDomainsMapping( - deps_map, mma_input_candidates, m, n, k, ca_map); - - bool mk_found = false; - bool km_found = false; - bool nk_found = false; - bool kn_found = false; - const static DomainsDesc mk_desc = {MatmulDomain::M, MatmulDomain::K}; - const static DomainsDesc km_desc = {MatmulDomain::K, MatmulDomain::M}; - const static DomainsDesc nk_desc = {MatmulDomain::N, MatmulDomain::K}; - const static DomainsDesc kn_desc = {MatmulDomain::K, MatmulDomain::N}; - - for (const auto& item : deps_map) { - if (item.second == mk_desc) { - if (mk_found) { - return { - "Failed to find MMA input, more than one fusion input has [..., M, ..., K, ...] iter domains"}; - } - mk_found = true; +MatmulProblemLayoutOpt getProblemLayout( + const IdModel& id_model, + const std::unordered_map& dim_roles, + const RolesMap& tensor_roles) { + // Assumes the exact graph has already been built, since we've been provided + // dim_roles + const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); + + // Note: using DataWrapperOpt would be preferable here. However, + // using DataWrapperOpt(std::move(dom)) leads to a clang-tidy + // warning because MatmulDomain is trivially movable. There is only a move + // constructor for DataWrapperOpt to prevent inadvertent copying. To avoid + // this complication I'm using an unwrapped variant for the lambda's result + // type. + using UnitDimOpt = std::variant; + const auto findUnitDim = + [&tensor_roles, &dim_roles, &exact_graph](MatmulRole role) -> UnitDimOpt { + const auto role_it = tensor_roles.find(role); + if (role_it == tensor_roles.end()) { + return "Could not find role in tensor_roles"; } - if (item.second == km_desc) { - if (km_found) { - return { - "Failed to find MMA input, more than one fusion input has [..., K, ..., M, ...] iter domains"}; + std::optional group_inner_dom = std::nullopt; + for (TensorView* tv : role_it->second) { + IterDomain* inner_id = + TensorDomain::noReductions(tv->getMaybeAllocationDomain()).back(); + const ValGroup& g = exact_graph.toGroup(inner_id); + auto g_it = dim_roles.find(g); + if (g_it == dim_roles.end()) { + return "Inner domain of tensor was not mapped to a MatmulDomain"; } - km_found = true; - } - if (item.second == nk_desc) { - if (nk_found) { - return { - "Failed to find MMA input, more than one fusion input has [..., N, ..., K, ...] iter domains"}; + if (!group_inner_dom.has_value()) { + group_inner_dom = g_it->second; + } else if (group_inner_dom.value() != g_it->second) { + return "Group contains multiple inner dimension domains"; } - nk_found = true; } - if (item.second == kn_desc) { - if (kn_found) { - return { - "Failed to find MMA input, more than one fusion input has [..., K, ..., N, ...] iter domains"}; - } - kn_found = true; + if (!group_inner_dom.has_value()) { + return "No tensor found in role"; } - } + return group_inner_dom.value() == MatmulDomain::K ? UnitDim::K + : UnitDim::M_or_N; + }; - if ((mk_found && kn_found) && !(km_found || nk_found)) { - return MmaLayout::TT; + const UnitDimOpt unitdim_a_opt = findUnitDim(MatmulRole::INPUT_A); + if (std::holds_alternative(unitdim_a_opt)) { + std::string err = std::get(unitdim_a_opt); + return err; } - if ((km_found && kn_found) && !(mk_found || nk_found)) { - return MmaLayout::NT; + const UnitDimOpt unitdim_b_opt = findUnitDim(MatmulRole::INPUT_B); + if (std::holds_alternative(unitdim_b_opt)) { + std::string err = std::get(unitdim_b_opt); + return err; } - if ((mk_found && nk_found) && !(km_found || kn_found)) { + const UnitDim unitdim_a = std::get(unitdim_a_opt); + const UnitDim unitdim_b = std::get(unitdim_b_opt); + + if (unitdim_a == UnitDim::K && unitdim_b == UnitDim::K) { return MmaLayout::TN; - } - if ((km_found && nk_found) && !(mk_found || kn_found)) { + } else if (unitdim_a == UnitDim::K && unitdim_b == UnitDim::M_or_N) { + return MmaLayout::TT; + } else if (unitdim_a == UnitDim::M_or_N && unitdim_b == UnitDim::M_or_N) { + return MmaLayout::NT; + } else if (unitdim_a == UnitDim::M_or_N && unitdim_b == UnitDim::K) { return MmaLayout::NN; } - - return {"Failed to decide fusion inputs' data layout."}; -} - -MatmulProblemLayoutOpt getMmaLayout(Fusion* fusion) { - auto mma_exprs = ir_utils::getOpsOfType(fusion); - if (mma_exprs.size() != 1) { - std::stringstream ss; - ss << "Invalid number of MmaOp instances in fusion, expected 1, got " - << mma_exprs.size(); - return ss.str(); - } - return getMmaLayout( - fusion, - {static_cast(mma_exprs.front()->inA()), - static_cast(mma_exprs.front()->inB()), - static_cast(mma_exprs.front()->out())}); + NVF_ERROR(false, "Reached unreachable section of getProblemLayout"); } -RolesMapOpt getTensorsRoles( +RolesMapOpt getTensorRoles( Fusion* fusion, - const mma_utils::MulSumProperties::InputsOutputs& props) { - ComputeAtMap ca_map(fusion); + const IdModel& id_model, + const std::unordered_map& dim_roles) { const auto mma_input_candidates = ir_utils::filterByType(fusion->inputs()).vector(); if (mma_input_candidates.empty()) { @@ -1210,143 +1155,117 @@ RolesMapOpt getTensorsRoles( return {"Failed to find any TV that is fusion output"}; } - const auto mma_output_domains = getProblemIterDomains(props); - if (!mma_output_domains.isValid()) { - return mma_output_domains.getErrorMsg(); - } + RolesMap tensor_roles; - const auto findInputRolesByDomains = [](const DependenciesMap& deps_map, - RolesMap& roles_map) { - for (const auto& entry : deps_map) { - const auto& domains = entry.second; - const auto begin = domains.begin(); - const auto end = domains.end(); + // Assumes the exact graph has already been built, since we've been provided + // dim_roles + const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); - bool has_m = (end != std::find(begin, end, MatmulDomain::M)); - bool has_n = (end != std::find(begin, end, MatmulDomain::N)); - bool has_k = (end != std::find(begin, end, MatmulDomain::K)); + struct DimPresence { + bool m = false; + bool n = false; + bool k = false; + bool unmapped = false; + }; - if (has_m && has_k && !has_n) { - roles_map[MatmulRole::INPUT_A].push_back(entry.first); - continue; - } - if (has_n && has_k && !has_m) { - roles_map[MatmulRole::INPUT_B].push_back(entry.first); + const auto findDims = [&dim_roles, &exact_graph](TensorView* tv) { + DimPresence has; + for (IterDomain* id : + TensorDomain::noReductions(tv->getMaybeRFactorDomain())) { + if (id->isBroadcast()) { + // Broadcast domains won't exact map to concrete domains so skip them continue; } - // Bias vectors are assigned to INPUT_C role - if (!has_k) { - roles_map[MatmulRole::INPUT_C].push_back(entry.first); + const ValGroup& g = exact_graph.toGroup(id); + auto it = dim_roles.find(g); + if (it == dim_roles.end()) { + // tv has an unmapped non-broadcast and non-reduction dimension + has.unmapped = true; continue; } + has.m = has.m || it->second == MatmulDomain::M; + has.n = has.n || it->second == MatmulDomain::N; + has.k = has.k || it->second == MatmulDomain::K; } - - for (auto& [role, tvs] : roles_map) { - // NOTE: sort input roles in descending order by uses() size, and - // if equal then by name() to ensure the stable ordering of tensor - // views in collections assigned to the supported roles - std::sort(tvs.begin(), tvs.end(), [](TensorView* a, TensorView* b) { - return (a->uses().size() == b->uses().size()) - ? (a->name() < b->name()) - : (a->uses().size() > b->uses().size()); - }); - } + return has; }; - const auto findOutputRolesByDomains = [](const DependenciesMap& deps_map, - RolesMap& roles_map) { - std::vector storage; - storage.reserve(deps_map.size()); - - for (const auto& entry : deps_map) { - const auto& domains = entry.second; - const auto begin = domains.begin(); - const auto end = domains.end(); + for (TensorView* tv : mma_input_candidates) { + DimPresence has = findDims(tv); + if (has.unmapped) { + // Don't map TVs to roles if they have unmapped dims + continue; + } + if (has.m && has.k && !has.n) { + tensor_roles[MatmulRole::INPUT_A].push_back(tv); + continue; + } + if (has.n && has.k && !has.m) { + tensor_roles[MatmulRole::INPUT_B].push_back(tv); + continue; + } + // Bias vectors are assigned to INPUT_C role + if (!has.k) { + tensor_roles[MatmulRole::INPUT_C].push_back(tv); + continue; + } + } - bool has_m = (end != std::find(begin, end, MatmulDomain::M)); - bool has_n = (end != std::find(begin, end, MatmulDomain::N)); + std::vector storage; + for (TensorView* tv : mma_output_candidates) { + DimPresence has = findDims(tv); + // NOTE: depending on fusion definition k domain may appear in the output: + // - for mma_output == fusion output k domain is present + // - for mma_output != fusion output (fusion with epilogue) k domain + // is not present + if (has.k || has.unmapped) { + // Don't map TVs to output roles if they have unmapped dims, or if they + // have K dimension + continue; + } - // NOTE: depending on fusion definition k domain may appear in the output: - // - for mma_output == fusion output k domain is present - // - for mma_output != fusion output (fusion with epilogue) k domain - // is not present + // NOTE: the core fusion output tensors are the ones with m and n + // domains + if (has.m && has.n) { + storage.push_back(tv); + } + } - // NOTE: the core fusion output tensors are the ones with m and n - // domains - if (has_m && has_n) { - storage.push_back(entry.first); - } + // NOTE: sort output roles in descending order by uses() size, and + // if equal then by name() to ensure the stable ordering of tensor + // views in collections assigned to the supported roles + std::sort(storage.begin(), storage.end(), [](TensorView* a, TensorView* b) { + return (a->uses().size() == b->uses().size()) + ? (a->name() < b->name()) + : (a->uses().size() > b->uses().size()); + }); + + if (!storage.empty()) { + // NOTE: currently, we pick as a reference tensor one with `m` and `n` + // IterDomains and the most uses + auto pos = storage.begin(); + tensor_roles[MatmulRole::OUTPUT_D].push_back(*pos); + for (++pos; pos != storage.end(); ++pos) { + tensor_roles[MatmulRole::OUTPUT_AUX].push_back(*pos); } + } - // NOTE: sort output roles in descending order by uses() size, and + for (auto& [role, tvs] : tensor_roles) { + // NOTE: sort input roles in descending order by uses() size, and // if equal then by name() to ensure the stable ordering of tensor // views in collections assigned to the supported roles - std::sort(storage.begin(), storage.end(), [](TensorView* a, TensorView* b) { + std::sort(tvs.begin(), tvs.end(), [](TensorView* a, TensorView* b) { return (a->uses().size() == b->uses().size()) ? (a->name() < b->name()) : (a->uses().size() > b->uses().size()); }); - - if (!storage.empty()) { - // NOTE: currently, we pick as a reference tensor one with `m` and `n` - // IterDomains and the most uses - auto pos = storage.begin(); - roles_map[MatmulRole::OUTPUT_D].push_back(*pos); - for (++pos; pos != storage.end(); ++pos) { - roles_map[MatmulRole::OUTPUT_AUX].push_back(*pos); - } - } - }; - - const auto domains_data = mma_output_domains.getData(); - const auto m = domains_data[(size_t)MatmulDomain::M]; - const auto n = domains_data[(size_t)MatmulDomain::N]; - const auto k = domains_data[(size_t)MatmulDomain::K]; - - DependenciesMap deps_map; - RolesMap roles_map; - - // Handle fusion input TensorView objects - resolveTvToMatmulDomainsMapping( - deps_map, mma_input_candidates, m, n, k, ca_map); - findInputRolesByDomains(deps_map, roles_map); - - deps_map.clear(); - - // Handle fusion output TensorView objects - resolveTvToMatmulDomainsMapping( - deps_map, mma_output_candidates, m, n, k, ca_map); - findOutputRolesByDomains(deps_map, roles_map); - - return roles_map; -} - -RolesMapOpt getTensorsRoles(Fusion* fusion) { - auto mma_exprs = ir_utils::getOpsOfType(fusion); - if (mma_exprs.size() != 1) { - std::stringstream ss; - ss << "Invalid number of MmaOp instances in fusion, expected 1, got " - << mma_exprs.size(); - return ss.str(); } - return getTensorsRoles( - fusion, - {static_cast(mma_exprs.front()->inA()), - static_cast(mma_exprs.front()->inB()), - static_cast(mma_exprs.front()->out())}); + + return tensor_roles; } namespace { -void addMMAOp(Fusion* fusion_, std::vector& props) { - for (auto prop : props) { - auto* init = - IrBuilder::create(0.0, prop.insouts.out->getDataType().value()); - IrBuilder::create( - prop.insouts.out, prop.insouts.a, prop.insouts.b, init); - } -} - // Check the val (in) is the output of broadcast. // Then check the output of the broadcast is 3D (4D for bmm). bool hasValidBroadcastOp(TensorView* bcast_out) { @@ -1437,103 +1356,229 @@ TensorView* getTensorviewPriorToCast(TensorView* in) { return in; } -// Check if the Mul-Sum pair represents a matmul. If so, add the properties -// of the mma op which can be a tentatice substitue. This checks that the output -// of sum has on reduction axis, and the inputs to mul are valid broadcasts. -std::optional getMulSumInsOutsBcasts( - BinaryOp* mop, - ReductionOp* redop) { - auto a = getTensorviewPriorToCast(static_cast(mop->lhs())); - auto b = getTensorviewPriorToCast(static_cast(mop->rhs())); - - // Get the dimension of the reduction in the output. If not present, bail. - // Also ensure there is only only reduction axis. - auto red_axis = static_cast(redop->out())->getReductionAxis(); - auto num_reduction_dims = - static_cast(redop->out())->domain()->nDims() - - static_cast(redop->out())->domain()->noReductions().size(); - if (!red_axis.has_value() || num_reduction_dims > 1) { - return std::nullopt; - } +} // namespace - if (broadcastsAreValid(a, b, *red_axis)) { - MulSumProperties props = { - {mop, redop}, - {a, b, static_cast(redop->output(0))}, - {dynamic_cast(a->definition()), - dynamic_cast(b->definition())}}; - return props; +char dtypeToChar(const DataType& dtype) { + if (dtype == DataType::Half) { + return 'H'; + } else if (dtype == DataType::BFloat16) { + return 'T'; + } else if (dtype == DataType::Float) { + return 'S'; + } else if (dtype == DataType::Double) { + return 'D'; } - return std::nullopt; + NVF_ERROR(false, "Unsupported dtype for matmul: ", dtype); + return 0; } -} // namespace -void CombineMulSum::handle(ReductionOp* stmt) { - // Check if operation is a sum. - if (stmt->getReductionOpType() == BinaryOpType::Add) { - auto* inputOfSum = stmt->in(); - if (inputOfSum != nullptr) { - auto* expr = inputOfSum->definition(); - // Then check if the prodcer of the sum is a mul. - if (auto bOp = dynamic_cast(expr)) { - // If it'a mul followed by a sum, put this in a list. - if (bOp->getBinaryOpType() == BinaryOpType::Mul) { - // If the Mul-Sum is a valid representation of a matmul, - // then get the properties of the replacement Mma op. - auto props = getMulSumInsOutsBcasts(bOp, stmt); - if (props.has_value()) { - mul_sum_props_.push_back(*props); +namespace { + +class MatmulPatternMatcher : IterVisitor { + public: + static std::vector run(Fusion* fusion) { + MatmulPatternMatcher matcher; + matcher.traverse(fusion); + return matcher.patterns_; + } + + private: + using IterVisitor::handle; + + // TODO: These methods currently assume the output will have allocation domain + // equal to its rfactor. However, if the rfactor domain is specified, or if + // there is a transpose operation in the epilogue, then this assumption will + // be violated. In such cases we should actually swap and transpose A and B. + + // Handle the case when no translation is needed. + void handle(MmaOp* mop) override { + MatmulPattern& pattern = patterns_.emplace_back(); + pattern.A = mop->inA()->as(); + pattern.B = mop->inB()->as(); + pattern.output = mop->out()->as(); + } + + void handle(ReductionOp* rop) override { + // Check if operation is a sum. + if (rop->getReductionOpType() != BinaryOpType::Add) { + return; + } + // Then check if the producer of the sum is a mul. + if (auto bop = dynamic_cast(rop->in()->definition())) { + if (bop->getBinaryOpType() != BinaryOpType::Mul) { + return; + } + // Remember that we are just gathering the immediate inputs to the + // matmul, so there should be no prologue between a, b and the mul/sum. + + // Check that the inputs have broadcasts that are not all in common, i.e. + // that there is at least one M and at least one N dimension. + + // Note that there might be a cast to Float just before the multiply. This + // happens when using the `mul` op with reduced precision inputs. It can + // also happen if the inputs to `mul` in the definition were Float, but + // the Fusion was segmented and casts to half precision were inserted at + // the segmentation edge (see castInputOutputToLowerPrecision in + // fusion_segmenter.cpp). + TensorView* ltv = getTensorviewPriorToCast(bop->lhs()->as()); + TensorView* rtv = getTensorviewPriorToCast(bop->rhs()->as()); + + std::vector lrf = + TensorDomain::noReductions(ltv->getMaybeRFactorDomain()); + std::vector rrf = + TensorDomain::noReductions(rtv->getMaybeRFactorDomain()); + + // These sizes should match since ops::maybeBroadcast places BroadcastOps + // for implicit broadcasting. + NVF_ERROR(lrf.size() == rrf.size()); + const std::vector& red_root = + rop->out()->as()->getRootDomain(); + NVF_ERROR(red_root.size() == lrf.size()); + // Find innermost M or N dimension in output + // We will assume for now that the output rfactor domain matches the + // fusion output's allocation domain; in particular that the innermost + // dimension is an N dimension. This allows us to determine which of lhs + // and rhs is A and B. + // TODO: analyze fusion outputs to determine N dimensions + bool lhs_is_A = true; + bool has_m = false, has_n = false; + // Loop backwards to find inner-most Iteration domain in output + for (int64_t i = (int64_t)red_root.size() - 1; i >= 0; --i) { + IterDomain* lhs_id = lrf[(size_t)i]; + IterDomain* rhs_id = rrf[(size_t)i]; + IterDomain* out_id = red_root[(size_t)i]; + if (out_id->isIteration()) { + if (lhs_id->isBroadcast() != rhs_id->isBroadcast()) { + // This is either an M or N dimension + + // Operand domains must be Broadcast and Iteration + NVF_ERROR(lhs_id->isIteration() || rhs_id->isIteration()); + + if (!has_n) { + // This is the inner-most output non-batch dim, so it is N + has_n = true; + // rhs is B if it has this dimension + lhs_is_A = rhs_id->isIteration(); + continue; + } + // We have found the inner-most N dim, so we can now use lhs_is_A to + // tell whether this is M or N + has_m = has_m || (lhs_is_A && lhs_id->isIteration()) || + (!lhs_is_A && (rhs_id->isIteration())); } + // out_id could also be a batch dim + } else if (out_id->isReduction()) { + // matmul must be contraction of non-broadcast dimensions + if (!lhs_id->isIteration() || !rhs_id->isIteration()) { + return; + } + } else if (!out_id->isBroadcast()) { + // Reduction output ID should be iteration, reduction, or broadcast + return; } } + if (!has_m || !has_n) { + // This is an ordinary reduction or mat-vec, not a matmul + return; + } + + MatmulPattern& pattern = patterns_.emplace_back(); + pattern.A = lhs_is_A ? ltv : rtv; + pattern.B = lhs_is_A ? rtv : ltv; + pattern.output = rop->out()->as(); } } + + private: + std::vector patterns_; }; -void CombineMulSum::generateMulSumCanidates() { - auto mma_exprs = ir_utils::getOpsOfType(fusion_); - if (mma_exprs.size() == 1) { - mma_utils::MulSumProperties props; - props.insouts = { - static_cast(mma_exprs.front()->inA()), - static_cast(mma_exprs.front()->inB()), - static_cast(mma_exprs.front()->out())}; - mul_sum_props_.push_back(props); - } else { - traverse(fusion_); - } - is_valid_ = (mul_sum_props_.size() == 1) ? true : false; +} // namespace + +std::vector findMatmulPatterns(Fusion* fusion) { + return MatmulPatternMatcher::run(fusion); } -const std::vector& CombineMulSum::getMulSumCanidates( - const bool refresh_data) { - if (refresh_data) { - mul_sum_props_.clear(); - generateMulSumCanidates(); - } - return mul_sum_props_; +std::string MatmulPattern::toString() const { + std::stringstream ss; + ss << "MatmulPattern{"; + ss << "\n A=" << A->toString(); + ss << "\n B=" << B->toString(); + ss << "\n output=" << output->toString() << "\n}"; + return ss.str(); } -void CombineMulSum::replaceWithMmaOp() { - // Recreate the mul-sum pairs since someone - // may run this function more than once. - generateMulSumCanidates(); - addMMAOp(fusion_, mul_sum_props_); - return; +MmaOp* MatmulPattern::translateToMmaOp() { + if (auto mma_op = dynamic_cast(output->definition())) { + // No translation needed + return mma_op; + } else if (output->definition()->isA()) { + Val* init = IrBuilder::create(0.0, output->dtype()); + // This replaces the mul and sum by overwriting output->definition() + return IrBuilder::create(output, A, B, init); + } + NVF_ERROR( + false, + "Could not translate matmul pattern with output ", + output->toString(), + " to MmaOp"); } -char dtypeToChar(const DataType& dtype) { - if (dtype == DataType::Half) { - return 'H'; - } else if (dtype == DataType::BFloat16) { - return 'T'; - } else if (dtype == DataType::Float) { - return 'S'; - } else if (dtype == DataType::Double) { - return 'D'; +std::unordered_map MatmulPattern::getDimRoles( + IdModel& id_model) const { + id_model.maybeBuildGraph(IdMappingMode::EXACT); + const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); + + // There are four types of ValGroup involved in a MatmulPattern: M, N, K, and + // Batch. These are enumerated in the MatmulDomain enum class. They are + // defined by their membership as follows: + // M: present in A and output, but not B + // N: present in B and output, but not A + // K: present in A and B, but not output + // Batch: present in all A, B, and output + // If there are other patterns, for example a ValGroup present in only A, then + // we should raise an exception here. + + // Indicates whether a ValGroup is present in A (bit 0), B (bit 1), or output + // (bit 2) + using DimPresence = std::bitset<3>; + + std::unordered_map present_flags; + const auto recordPresence = [&exact_graph, &present_flags]( + TensorView* tv, size_t tensor_num) { + for (IterDomain* id : tv->getMaybeRFactorDomain()) { + if (id->isReduction() || id->isBroadcast()) { + // ignore reductions and broadcasts since they don't exact map to + // problem dims + continue; + } + const ValGroup& g = exact_graph.toGroup(id); + present_flags[g].set(tensor_num); + } + }; + recordPresence(A, 0); + recordPresence(B, 1); + recordPresence(output, 2); + + std::unordered_map dim_roles; + for (const auto& [g, flags] : present_flags) { + if (flags.all()) { + dim_roles[g] = MatmulDomain::Batch; + } else if (flags.test(0) && flags.test(1)) { + dim_roles[g] = MatmulDomain::K; + } else if (flags.test(0) && flags.test(2)) { + dim_roles[g] = MatmulDomain::M; + } else if (flags.test(1) && flags.test(2)) { + dim_roles[g] = MatmulDomain::N; + } else { + NVF_ERROR( + false, + "IterDomain ValGroup should be present in at least two of A, B, and output. flags: ", + flags); + } } - NVF_ERROR(false, "Unsupported dtype for matmul: ", dtype); - return 0; + + return dim_roles; } } // namespace mma_utils diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index 8b2a6861bdf..68ff171eeb9 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -9,8 +9,10 @@ #include #include +#include #include #include +#include #include #include #include @@ -242,34 +244,41 @@ class DataWrapperOpt { } }; -// This struct hold properties of a Mul and Sum pair -// which can possibly be replaced a Mma op. This struct -// can be be created (partially) from a Mma op. -struct MulSumProperties { - // The Mul amd Sum op which can be replaced by a Mma op. - struct MulAndSumOps { - BinaryOp* mop = nullptr; - ReductionOp* redop = nullptr; - }; - - // The inputs/ouputs to the possible Mma Op or the actual Mma op. - struct InputsOutputs { - TensorView* a = nullptr; - TensorView* b = nullptr; - TensorView* out = nullptr; - }; - - // The broadcasts which feed the Mma op/Mul-Sum pair. - struct Broadcasts { - BroadcastOp* bcast_a = nullptr; - BroadcastOp* bcast_b = nullptr; - }; - - MulAndSumOps mulsumops; - InputsOutputs insouts; - Broadcasts bcasts; +//! This represents a single matmul operation, without a prologue or epilogue. +//! Each matmul has two inputs which might not be fusion inputs: A and B. It +//! also has one output, which can be Float or reduced precision. For MatmulOp +//! and LinearOp, the output is the same dtype as the inputs; so output does not +//! necessarily correspond to the output of a translated MmaOp and it might not +//! be a fusion output. +struct MatmulPattern { + TensorView* A; + TensorView* B; + // This is not necessarily a Fusion output, but rather is the immediate output + // representing a matmul in the current Fusion. The definition of this tensor + // determines what kind of translation is needed, if any. Possible definition + // Expr types are: MmaOp, ReductionOp (for mul-sum patterns), MatmulOp, and + // LinearOp. + TensorView* output; + + //! If the pattern is not already represented by an MmaOp, for example if + //! there is a MatmulOp instead, this function modifies the fusion to insert + //! an MmaOp. TensorViews A and B are unchanged, but this->output might be + //! updated to reflect the replacement tensor. + MmaOp* translateToMmaOp(); + + //! Given an IdModel, map groups of IterDomains to dimension roles + //! (MatmulDomain). Note that ValGroup is a shared_ptr to a + //! VectorOfUniqueEntries. We copy these as keys so that the returned + //! object can safely outlive id_model. + std::unordered_map getDimRoles( + IdModel& id_model) const; + + std::string toString() const; }; +//! Traverse the fusion to find supported matmul patterns +std::vector findMatmulPatterns(Fusion* fusion); + using MatmulProblemLayoutOpt = DataWrapperOpt; using ProblemIterDomainsOpt = DataWrapperOpt; using RolesMapOpt = DataWrapperOpt; @@ -290,32 +299,22 @@ using DependenciesMap = std::map; //! transposition of inputs in mma instructions, while other (e.g. Turing, //! Ampere) the only supported transposition is TN which means that mma //! instruction first input is transposed, the second input is non-transposed. -NVF_API MatmulProblemLayoutOpt getMmaLayout( - Fusion* fusion, - const mma_utils::MulSumProperties::InputsOutputs& props); - -//! This overloaded version is just a wrapper on the above function, where -//! the mma_utils::MulSumProperties::InputsOutputs is extracted from the fusion. -NVF_API MatmulProblemLayoutOpt getMmaLayout(Fusion* fusion); +NVF_API MatmulProblemLayoutOpt getProblemLayout( + const IdModel& id_model, + const std::unordered_map& dim_roles, + const RolesMap& tensor_roles); -//! Returns wrapped collection of IterDomains that can be used to get -//! problem shape with runtime info. -//! Data is stored in the order in which lables are defined in MatmulDomain -//! enum class, that is in the following order: m, n, k. -//! An error message is stored in retruned object if valid data cannot -//! be gathered. -//! TODO: 4th domain must be added for batch gemm support. -ProblemIterDomainsOpt getProblemIterDomains(Fusion* fusion); -ProblemIterDomainsOpt getProblemIterDomains( - const mma_utils::MulSumProperties::InputsOutputs& props); +//! This version assumes the Fusion contains a single MatmulPattern, then builds +//! an IdModel and infers dim roles then calls the above function. +NVF_API MatmulProblemLayoutOpt getProblemLayout(Fusion* fusion); //! Returns wrapped collection of TensorView roles in fusion. //! An error message is stored in retruned object if valid data cannot //! be gathered. -RolesMapOpt getTensorsRoles( +RolesMapOpt getTensorRoles( Fusion* fusion, - const mma_utils::MulSumProperties::InputsOutputs& props); -RolesMapOpt getTensorsRoles(Fusion* fusion); + const IdModel& id_model, + const std::unordered_map& dim_roles); //! Return pair of whether use shared memory epilogue or not and whether to //! reuse shared memory for the prologue at the expense of an additional block @@ -346,50 +345,6 @@ NVF_API std::pair generateSharedMemoryEpilogueHeuristics( bool smem_b_reuse_guaranteed = false, bool ignore_occupancy_drop = false); -//! Go through the fusion IR to find combinations of mul-sum -//! which can be replaced with a mma op. This class operates -//! in two phases. It can go through the graph and find the mul-sum -//! pairs which can be replaced by a mma op. This phase returns a vector -//! of properties of the mma op (MulSumAsMmaProps) which would replace the -//! the mul-sum pair. It then exposes a function to replace with mma ops. -class CombineMulSum : public IterVisitor { - public: - CombineMulSum(Fusion* fusion) : IterVisitor(), fusion_(fusion) { - generateMulSumCanidates(); - }; - - const std::vector& getMulSumCanidates( - const bool refresh_data = false); - - //! Goes through the fusion to find mul-sum pairs. - //! If user sets the caching flags and properties have been previously - //! computed, then just return cached results. - void generateMulSumCanidates(); - - //! Replaces the candidate mul-sum pairs with mma ops. - //! Please not this will run generateMulSumCandidates again. - void replaceWithMmaOp(); - - //! Check if the fusion has a mma-op or a mul-sum pair - //! that can be replaced by a mma op. - bool isValid() { - return is_valid_; - } - - protected: - void handle(ReductionOp* stmt) override; - - private: - Fusion* fusion_; - //! This is the list of mul-sum pairs and the properties - //! of the mma op which can replace it. This is only populated - //! if the mul-sum pair is a valid replacement candidate. - std::vector mul_sum_props_ = {}; - //! This variable tracks if the fusion has a mul-sum pair - //! than can be replaced by a mma op, or has a single mma op. - bool is_valid_ = false; -}; - //! Compute the amount of shared memory we expect to need. The actual amount //! allocated will be determined by aliasing (see alias_memory.cpp). This //! function is useful for testing that we provide accurate information to our diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index d3fdd9bf106..cb85281f22a 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -57,6 +57,24 @@ class CombineMulSumAsMmaTest : public NVFuserTest { } }; +void performSubstitution(Fusion* fusion, bool should_not_find = false) { + EXPECT_TRUE(ir_utils::getOpsOfType(fusion).empty()); + + std::vector patterns = + mma_utils::findMatmulPatterns(fusion); + if (should_not_find) { + EXPECT_TRUE(patterns.empty()); + return; + } + + ASSERT_FALSE(patterns.empty()); + EXPECT_EQ(patterns.size(), 1); + + patterns.front().translateToMmaOp(); + + ASSERT_FALSE(ir_utils::getOpsOfType(fusion).empty()); +} + // Test checks to see that the combiner can correctly replace // the mul-sum pair with a mma op. TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Pass) { @@ -76,17 +94,13 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Pass) { fusion.addOutput(tv3); - ASSERT_TRUE(ir_utils::getOpsOfType(&fusion).empty()); - - nvfuser::mma_utils::CombineMulSum combiner(&fusion); - combiner.replaceWithMmaOp(); - - ASSERT_FALSE(ir_utils::getOpsOfType(&fusion).empty()); + performSubstitution(&fusion); } } -// This test checks that the combiner does not incorrectly -// replace this mul-sum pair, and the mul is not fed by broadcasts ops. +// This test checks that the pattern matcher does not incorrectly identify +// this mul-sum pair, as the mul is not fed by broadcasts ops; i.e. it is +// not a matmul. TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Fail1) { Fusion fusion; FusionGuard fg(&fusion); @@ -100,23 +114,20 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Fail1) { auto tv3 = sum(tv2, {-1}); fusion.addOutput(tv3); - nvfuser::mma_utils::CombineMulSum combiner(&fusion); - combiner.replaceWithMmaOp(); - - ASSERT_TRUE(ir_utils::getOpsOfType(&fusion).empty()); + performSubstitution(&fusion, /*should_not_find=*/true); } -// This test checks to see that the mul-sum combiner does not -// combine a mul-sum which does not have appropriate broadcasts. -TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Fail2) { +// This fusion has Broadcast batch axes in each operand. +TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_MultipleBroadcasts) { // Assumes layout is kAllSupportedMmaLayout::NT; - Fusion fusion; - FusionGuard fg(&fusion); + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion* fusion = fusion_ptr.get(); + FusionGuard fg(fusion); auto tv0 = makeContigTensor(2, DataType::Half); auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); + fusion->addInput(tv0); + fusion->addInput(tv1); auto tv0t = transpose(tv0, 0, 1); auto tv1t = transpose(tv1, 0, 1); @@ -133,14 +144,24 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Fail2) { auto tv1b = broadcast(tv1t, bcast_dims); auto tv2 = mul(tv0b, tv1b); auto tv3 = sum(tv2, {-1}); - fusion.addOutput(tv3); + fusion->addOutput(tv3); + + performSubstitution(fusion, /*should_not_find=*/false); + + // We test running this fusion also to verify that the broadcast batch + // dimension does not cause unforeseen issues - ASSERT_TRUE(ir_utils::getOpsOfType(&fusion).empty()); + int64_t M = 256, N = 128, K = 64; + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({K, M}, options); + auto t1 = at::randn({K, N}, options); + auto tref = at::linear(t0.t(), t1.t()).unsqueeze(1); - nvfuser::mma_utils::CombineMulSum combiner(&fusion); - combiner.replaceWithMmaOp(); + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({t0, t1}); - ASSERT_TRUE(ir_utils::getOpsOfType(&fusion).empty()); + testValidate( + executor_cache.fusion(), outputs, {t0, t1}, {tref}, __LINE__, __FILE__); } // As a sanity check we test that after replacing a mul-sum @@ -164,11 +185,8 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Schedule) { auto tv2 = sum(mul(tv0, tv1), {-1}); fusion.addOutput(tv2); - ASSERT_TRUE(ir_utils::getOpsOfType(&fusion).empty()); - nvfuser::mma_utils::CombineMulSum combiner(&fusion); - combiner.replaceWithMmaOp(); - ASSERT_FALSE(ir_utils::getOpsOfType(&fusion).empty()); + performSubstitution(&fusion); MatMulTileOptions gemm_tile; gemm_tile.cta_tile = GemmTile(128, 128, 32); @@ -251,4 +269,64 @@ TEST_F(CombineMulSumAsMmaTest, UseMatmulScheduler) { } } +// Check that we determine A and B properly when they are swapped as inputs to +// mul +TEST_F(CombineMulSumAsMmaTest, SwapAandB) { + for (auto layout : kAllSupportedMmaLayout) { + for (bool swap : {false, true}) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A); + tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B); + // We should identify tv0 as A and tv1 as B regardless of the order here + auto tv2 = swap ? mul(tv1, tv0) : mul(tv0, tv1); + auto tv3 = sum(tv2, {-1}); + + fusion.addOutput(tv3); + + std::vector patterns = + mma_utils::findMatmulPatterns(&fusion); + + ASSERT_FALSE(patterns.empty()); + EXPECT_EQ(patterns.size(), 1); + + mma_utils::MatmulPattern& pattern = patterns.front(); + + EXPECT_EQ(pattern.A, tv0); + EXPECT_EQ(pattern.B, tv1); + EXPECT_EQ(pattern.output, tv3); + + pattern.translateToMmaOp(); + + // Check that we didn't modify the pattern roles + EXPECT_EQ(pattern.A, tv0); + EXPECT_EQ(pattern.B, tv1); + EXPECT_EQ(pattern.output, tv3); + + // Check that we properly map M and N to their roles even with swap + IdModel id_model(&fusion); + std::unordered_map dim_roles = + pattern.getDimRoles(id_model); + ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); + const ValGroup& m_gp = exact_graph.toGroup(tv0->axis(-3)); + auto m_it = dim_roles.find(m_gp); + ASSERT_NE(m_it, dim_roles.end()); + EXPECT_EQ(m_it->second, MatmulDomain::M); + + const ValGroup& n_gp = exact_graph.toGroup(tv1->axis(-2)); + auto n_it = dim_roles.find(n_gp); + ASSERT_NE(n_it, dim_roles.end()); + EXPECT_EQ(n_it->second, MatmulDomain::N); + + ASSERT_FALSE(ir_utils::getOpsOfType(&fusion).empty()); + } + } +} + } // namespace nvfuser diff --git a/tests/cpp/test_gpu_tensorcore.cpp b/tests/cpp/test_gpu_tensorcore.cpp index 7c125ba590d..4d4523e1cc1 100644 --- a/tests/cpp/test_gpu_tensorcore.cpp +++ b/tests/cpp/test_gpu_tensorcore.cpp @@ -3149,7 +3149,7 @@ TEST_F(GPUTTensorCoreTest, MisalignedVectorization) { fusion->addOutput(tv2); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index d199b958a55..5ffb26d40e9 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -172,7 +172,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueBias) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -266,7 +266,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueRelu) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -366,7 +366,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueBiasRelu) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -462,7 +462,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueReluAux) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -570,7 +570,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueBiasReluAux) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -666,7 +666,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueGelu) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -755,7 +755,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueGeluAux) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -857,7 +857,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueBiasGelu) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -967,7 +967,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueBiasGeluAux) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1084,7 +1084,7 @@ TEST_F(MatmulSchedulerTest, BasicMatmulStrictCheckTT) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1134,7 +1134,7 @@ TEST_F(MatmulSchedulerTest, BasicMatmulRelaxedCheck) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1186,7 +1186,7 @@ TEST_F(MatmulSchedulerTest, BasicMatmulInputShuffledTT) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1236,7 +1236,7 @@ TEST_F(MatmulSchedulerTest, EpilogueOutputCast) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1292,7 +1292,7 @@ TEST_F(MatmulSchedulerTest, EpilogueAlpha) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1350,7 +1350,7 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaOutputCast) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1417,7 +1417,7 @@ TEST_F(MatmulSchedulerTest, EpilogueBeta) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1492,7 +1492,7 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaBeta) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1573,7 +1573,7 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaBetaGeluOutputCast) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1658,7 +1658,7 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaBetaBias) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1728,7 +1728,7 @@ TEST_F(MatmulSchedulerTest, StridedBatch) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1801,7 +1801,7 @@ TEST_F(MatmulSchedulerTest, StridedBatchEpilogueAlphaBeta) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1887,7 +1887,7 @@ TEST_F(MatmulSchedulerTest, StridedBatchEpilogueAlphaSingleBeta) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1961,7 +1961,7 @@ TEST_F(MatmulSchedulerTest, StridedBatchEpilogueBias) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -2028,7 +2028,7 @@ TEST_F(MatmulSchedulerTest, StridedBatchEpilogueSingleBias) { 1 == ir_utils::getOpsOfType(fusion.get()).size(), "matmul fusion must have at least one MmaOp"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -2097,7 +2097,7 @@ TEST_F(MatmulSchedulerTest, MisalignedVectorization) { fusion->addOutput(tv2); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -2288,7 +2288,7 @@ TEST_F(MatmulSchedulerTest, StridedInputs) { fusion->addOutput(tv2); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition");