From 85504a08d2006b51353df87045f62aee36a62f5e Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 13 May 2024 18:02:24 +0000 Subject: [PATCH 01/52] Introduce MatmulPattern and enable it in scheduler --- csrc/scheduler/matmul.cpp | 14 ++-- csrc/scheduler/mma_utils.cpp | 121 +++++++++++++++++++++++++++++++++++ csrc/scheduler/mma_utils.h | 26 ++++++++ 3 files changed, 156 insertions(+), 5 deletions(-) diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index f29b2fd4617..77ad15ff6dd 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -749,13 +749,17 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { // Cache and fork outputs auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, true); - mma_utils::CombineMulSum combiner(fusion); - auto mma_ops = ir_utils::getOpsOfType(fusion); - if (combiner.isValid() && mma_ops.empty()) { - combiner.replaceWithMmaOp(); - mma_ops = ir_utils::getOpsOfType(fusion); + std::vector patterns = + mma_utils::findMatmulPatterns(fusion); + NVF_ERROR( + patterns.size() == 1, + "Only a single matmul pattern can currently be fused"); + for (mma_utils::MatmulPattern& pattern : patterns) { + pattern.translateToMmaOp(); } + auto mma_ops = ir_utils::getOpsOfType(fusion); + NVF_ERROR( mma_ops.size() == 1, "scheduleMatmul supports fusion with single mma op in definition, got ", diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index ee7be460059..c3e143a42eb 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1536,6 +1536,127 @@ char dtypeToChar(const DataType& dtype) { return 0; } +namespace { + +class MatmulPatternMatcher : IterVisitor { + public: + static std::vector run(Fusion* fusion) { + MatmulPatternMatcher matcher; + matcher.traverse(fusion); + return matcher.patterns_; + } + + private: + using IterVisitor::handle; + + void handle(MatmulOp* mop) { + patterns_.emplace_back( + {mop->inA()->as(), + mop->inB()->as(), + mop->out()->as()}); + } + + // Handle the case when no translation is needed. + void handle(MmaOp* mop) { + patterns_.emplace_back( + {mop->inA()->as(), + mop->inB()->as(), + mop->out()->as()}); + } + + void handle(ReductionOp* rop) { + // Check if operation is a sum. + if (rop->getReductionOpType() != BinaryOpType::Add) { + return; + } + // Then check if the producer of the sum is a mul. + if (auto bop = dynamic_cast(rop->in()->definition())) { + if (bop->getBinaryOpType() != BinaryOpType::Mul) { + return; + } + // Remember that we are just gathering the immediate inputs to the + // matmul, so there should be no prologue between a, b and the mul/sum. + + // Check that the inputs have broadcasts that are not all in common, i.e. + // that there is at least one M and at least one N dimension. + TensorView* ltv = bop->lhs()->as(); + TensorView* rtv = bop->rhs()->as(); + std::vector lrf = + TensorDomain::noReductions(ltv->getMaybeRFactorDomain()); + std::vector rrf = + TensorDomain::noReductions(rtv->getMaybeRFactorDomain()); + + // These sizes should match since ops::maybeBroadcast places BroadcastOps + // for implicit broadcasting. + NVF_ERROR(lrf.size() == rrf.size()); + bool has_m=false, has_n=false; + for (size_t i : c10::irange(lrf.size())) { + if (lrf[i]->isBroadcast() && !rrf[i]->isBroadcast()) { + has_m = true; + } else if (!lrf[i]->isBroadcast() && rrf[i]->isBroadcast()) { + has_n = true; + } + } + if (!has_m || !has_n) { + // This is an ordinary reduction, not a matmul + return; + } + + MatmulPattern& pattern = + patterns_.emplace_back({ltv, rtv, bop->out()->as()}); + } + } + + private: + std::vector patterns_; +}; + +} + +std::vector findMatmulPatterns(Fusion* fusion) { + return MatmulPatternMatcher::run(fusion); +} + +void MatmulPattern::translateToMmaOp() { + if (dynamic_cast(output->definition())) { + // No translation needed + return; + } else if (auto mop = dynamic_cast(output->definition())) { + // MatmulOp takes inputs whose sizes are [..., M, K] and [..., K, N], so we + // must transpose B then broadcast both operands before creating the final + // op. + // + // Also note that the output of MatmulOp is a tensor of shape [..., M, N] + // whose matches that of the inputs. We will most commonly then also need to + // cast the output of the MmaOp to produce the output TensorView. + TensorView* Btrans = transpose(B); + std::vector bcast_dims(A.size() + 1, false); + bcast_dims[A.size() - 2] = true; + TensorView* Abcast = broadcast(A, bcast_dims); + bcast_dims[A.size() - 2] = false; + bcast_dims[A.size() - 3] = true; + TensorView* Bbcast = broadcast(Btrans, bcast_dims); + TensorView* fms = fusedMultiplySum(Abcast, Bbcast, {-1}); + // Update operands to keep the pattern minimal + A = Abcast; + B = Bbcast; + if (output->dtype() != fms->dtype()) { + // Redefine output as cast of MmaOp->out() + IrBuilder::create(UnaryOpType::Cast, output, fms->out()); + // Update output so that cast is part of the epilogue + output = fms->out(); + } else { + // No cast needed, for example the inputs might be Float + ir_utils::transferDefinitionToNewOutputs(fms->definition(), {output}); + } + } else if (auto rop = dynamic_cast(output->definition())) { + Val* init = + IrBuilder::create(0.0, output->dtype()); + // This replaces the mul and sum by overwriting output->definition() + IrBuilder::create(output, A, B, init); + } +} + } // namespace mma_utils } // namespace nvfuser diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index 8b2a6861bdf..d365bb63440 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -390,6 +390,32 @@ class CombineMulSum : public IterVisitor { bool is_valid_ = false; }; +//! This represents a single matmul operation, without a prologue or epilogue. +//! Each matmul has two inputs which might not be fusion inputs: A and B. It +//! also has one output, which can be Float or reduced precision. For MatmulOp +//! and LinearOp, the output is the same dtype as the inputs; so output does not +//! necessarily correspond to the output of a translated MmaOp and it might not +//! be a fusion output. +struct MatmulPattern { + TensorView* A; + TensorView* B; + // This is not necessarily a Fusion output, but rather is the immediate output + // representing a matmul in the current Fusion. The definition of this tensor + // determines what kind of translation is needed, if any. Possible definition + // Expr types are: MmaOp, ReductionOp (for mul-sum patterns), MatmulOp, and + // LinearOp. + TensorView* output; + + //! If the pattern is not already represented by an MmaOp, for example if + //! there is a MatmulOp instead, this function modifies the fusion to insert + //! an MmaOp. TensorViews A and B are unchanged, but this->output might be + //! updated to reflect the replacement tensor. + void translateToMmaOp(); +}; + +//! Traverse the fusion to find supported matmul patterns +std::vector findMatmulPatterns(Fusion* fusion); + //! Compute the amount of shared memory we expect to need. The actual amount //! allocated will be determined by aliasing (see alias_memory.cpp). This //! function is useful for testing that we provide accurate information to our From d174a8c1570d42426938f17c3bc578632d3d9950 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 13 May 2024 18:51:56 +0000 Subject: [PATCH 02/52] Fixes --- csrc/scheduler/matmul.cpp | 3 ++ csrc/scheduler/mma_utils.cpp | 54 +++++++++++++++++------------------- 2 files changed, 29 insertions(+), 28 deletions(-) diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index 77ad15ff6dd..9bc897b0c1f 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -751,12 +751,15 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { std::vector patterns = mma_utils::findMatmulPatterns(fusion); + NVF_ERROR(!patterns.empty(), "No matmul patterns were found"); NVF_ERROR( patterns.size() == 1, "Only a single matmul pattern can currently be fused"); + fusion->printMath(); for (mma_utils::MatmulPattern& pattern : patterns) { pattern.translateToMmaOp(); } + fusion->printMath(); auto mma_ops = ir_utils::getOpsOfType(fusion); diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index c3e143a42eb..2f689e80090 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -1549,22 +1550,22 @@ class MatmulPatternMatcher : IterVisitor { private: using IterVisitor::handle; - void handle(MatmulOp* mop) { - patterns_.emplace_back( - {mop->inA()->as(), - mop->inB()->as(), - mop->out()->as()}); + void handle(MatmulOp* mop) override { + MatmulPattern& pattern = patterns_.emplace_back(); + pattern.A = mop->inA()->as(); + pattern.B = mop->inB()->as(); + pattern.output = mop->out()->as(); } // Handle the case when no translation is needed. - void handle(MmaOp* mop) { - patterns_.emplace_back( - {mop->inA()->as(), - mop->inB()->as(), - mop->out()->as()}); + void handle(MmaOp* mop) override { + MatmulPattern& pattern = patterns_.emplace_back(); + pattern.A = mop->inA()->as(); + pattern.B = mop->inB()->as(); + pattern.output = mop->out()->as(); } - void handle(ReductionOp* rop) { + void handle(ReductionOp* rop) override { // Check if operation is a sum. if (rop->getReductionOpType() != BinaryOpType::Add) { return; @@ -1589,7 +1590,7 @@ class MatmulPatternMatcher : IterVisitor { // These sizes should match since ops::maybeBroadcast places BroadcastOps // for implicit broadcasting. NVF_ERROR(lrf.size() == rrf.size()); - bool has_m=false, has_n=false; + bool has_m = false, has_n = false; for (size_t i : c10::irange(lrf.size())) { if (lrf[i]->isBroadcast() && !rrf[i]->isBroadcast()) { has_m = true; @@ -1602,8 +1603,10 @@ class MatmulPatternMatcher : IterVisitor { return; } - MatmulPattern& pattern = - patterns_.emplace_back({ltv, rtv, bop->out()->as()}); + MatmulPattern& pattern = patterns_.emplace_back(); + pattern.A = ltv; + pattern.B = rtv; + pattern.output = rop->out()->as(); } } @@ -1611,17 +1614,17 @@ class MatmulPatternMatcher : IterVisitor { std::vector patterns_; }; -} +} // namespace std::vector findMatmulPatterns(Fusion* fusion) { return MatmulPatternMatcher::run(fusion); } void MatmulPattern::translateToMmaOp() { - if (dynamic_cast(output->definition())) { + if (dynamic_cast(output->definition())) { // No translation needed return; - } else if (auto mop = dynamic_cast(output->definition())) { + } else if (auto mop = dynamic_cast(output->definition())) { // MatmulOp takes inputs whose sizes are [..., M, K] and [..., K, N], so we // must transpose B then broadcast both operands before creating the final // op. @@ -1630,28 +1633,23 @@ void MatmulPattern::translateToMmaOp() { // whose matches that of the inputs. We will most commonly then also need to // cast the output of the MmaOp to produce the output TensorView. TensorView* Btrans = transpose(B); - std::vector bcast_dims(A.size() + 1, false); - bcast_dims[A.size() - 2] = true; - TensorView* Abcast = broadcast(A, bcast_dims); - bcast_dims[A.size() - 2] = false; - bcast_dims[A.size() - 3] = true; - TensorView* Bbcast = broadcast(Btrans, bcast_dims); + TensorView* Abcast = unsqueeze(A, -2); + TensorView* Bbcast = unsqueeze(Btrans, -3); TensorView* fms = fusedMultiplySum(Abcast, Bbcast, {-1}); // Update operands to keep the pattern minimal A = Abcast; B = Bbcast; if (output->dtype() != fms->dtype()) { // Redefine output as cast of MmaOp->out() - IrBuilder::create(UnaryOpType::Cast, output, fms->out()); + IrBuilder::create(UnaryOpType::Cast, output, fms); // Update output so that cast is part of the epilogue - output = fms->out(); + output = fms; } else { // No cast needed, for example the inputs might be Float ir_utils::transferDefinitionToNewOutputs(fms->definition(), {output}); } - } else if (auto rop = dynamic_cast(output->definition())) { - Val* init = - IrBuilder::create(0.0, output->dtype()); + } else if (auto rop = dynamic_cast(output->definition())) { + Val* init = IrBuilder::create(0.0, output->dtype()); // This replaces the mul and sum by overwriting output->definition() IrBuilder::create(output, A, B, init); } From 6c7f71afc59b1717a123d47d82eea03becbaa3eb Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 13 May 2024 19:10:39 +0000 Subject: [PATCH 03/52] Strip casts from input of mul-sum patterns --- csrc/scheduler/matmul.cpp | 2 -- csrc/scheduler/mma_utils.cpp | 5 +++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index 9bc897b0c1f..81c289e0c49 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -755,11 +755,9 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { NVF_ERROR( patterns.size() == 1, "Only a single matmul pattern can currently be fused"); - fusion->printMath(); for (mma_utils::MatmulPattern& pattern : patterns) { pattern.translateToMmaOp(); } - fusion->printMath(); auto mma_ops = ir_utils::getOpsOfType(fusion); diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 2f689e80090..759ca4651b8 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1580,8 +1580,9 @@ class MatmulPatternMatcher : IterVisitor { // Check that the inputs have broadcasts that are not all in common, i.e. // that there is at least one M and at least one N dimension. - TensorView* ltv = bop->lhs()->as(); - TensorView* rtv = bop->rhs()->as(); + TensorView* ltv = getTensorviewPriorToCast(bop->lhs()->as()); + TensorView* rtv = getTensorviewPriorToCast(bop->rhs()->as()); + std::vector lrf = TensorDomain::noReductions(ltv->getMaybeRFactorDomain()); std::vector rrf = From bb399952567fb4d186f42afaf07583c5fb1b5a62 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 13 May 2024 19:58:43 +0000 Subject: [PATCH 04/52] Remove CombineMulSum --- csrc/scheduler/matmul_utils.cpp | 55 ++++++------ csrc/scheduler/mma_utils.cpp | 116 ++++--------------------- csrc/scheduler/mma_utils.h | 134 +++++++---------------------- tests/cpp/test_combine_mul_sum.cpp | 42 ++++----- 4 files changed, 92 insertions(+), 255 deletions(-) diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index e2fa409f4a6..087e1589d79 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -138,9 +138,9 @@ inline bool initCoreHeuristics( //! A helper for getting problem shape from fusion and runtime info. ProblemShape getProblemShape( - const mma_utils::MulSumProperties::InputsOutputs& props, + const mma_utils::MatmulPattern& pattern, SchedulerRuntimeInfo& runtime_info) { - const auto mma_output_domains = mma_utils::getProblemIterDomains({props}); + const auto mma_output_domains = mma_utils::getProblemIterDomains(pattern); if (!mma_output_domains.isValid()) { NVF_ERROR(false, mma_output_domains.getErrorMsg()); } @@ -169,11 +169,11 @@ ProblemShape getProblemShape( std::string isMatmulFusionDefinitionSupported( Fusion* fusion, - const mma_utils::MulSumProperties::InputsOutputs& props) { + const mma_utils::MatmulPattern& pattern) { const auto& fusion_inputs = fusion->inputs(); const auto& fusion_outputs = fusion->outputs(); - std::vector mma_inputs = {props.a, props.b}; - const auto mma_output = props.out; + std::vector mma_inputs = {pattern.A, pattern.B}; + const auto mma_output = pattern.output; const auto fusion_inputs_tvs = ir_utils::filterByType(fusion_inputs).vector(); @@ -182,7 +182,7 @@ std::string isMatmulFusionDefinitionSupported( constexpr size_t minimal_number_of_inputs = 2; MmaOpUtils::MmaOpDetails mma_details = - MmaOpUtils::getMmaOpDetails(props.out, props.a, props.b); + MmaOpUtils::getMmaOpDetails(pattern.output, pattern.A, pattern.B); // Quick checks - MmaOp { @@ -211,7 +211,7 @@ std::string isMatmulFusionDefinitionSupported( // Fusion topology check { - const auto& roles_map_opt = mma_utils::getTensorsRoles(fusion, props); + const auto& roles_map_opt = mma_utils::getTensorsRoles(fusion, pattern); if (!roles_map_opt.isValid()) { return roles_map_opt.getErrorMsg(); } @@ -298,20 +298,19 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { // #1 // Initializing the machinery to check if there's a Mul-Sum pair // can be replaced by a Mma Op. - mma_utils::CombineMulSum combiner(fusion); - if (!combiner.isValid()) { - std::stringstream ss; - ss << "Matmul scheduler supports fusions only with a single mma op" - << "or supports a mul-sum pair which can be replaced with a mma op"; - return ss.str(); + std::vector patterns = + mma_utils::findMatmulPatterns(fusion); + if (patterns.empty()) { + return "No matmul patterns were found"; + } + if (patterns.size() > 1) { + return "Only a single matmul pattern can currently be fused"; } - const std::vector& mma_from_mul_sums = - combiner.getMulSumCanidates(); // #2 { const auto input_layout_opt = - mma_utils::getMmaLayout(fusion, mma_from_mul_sums.front().insouts); + mma_utils::getMmaLayout(fusion, patterns.front()); if (!input_layout_opt.isValid()) { return input_layout_opt.getErrorMsg(); } @@ -319,8 +318,8 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { // #3 { - auto support_status = isMatmulFusionDefinitionSupported( - fusion, mma_from_mul_sums.front().insouts); + auto support_status = + isMatmulFusionDefinitionSupported(fusion, patterns.front()); if (!support_status.empty()) { return support_status; } @@ -345,16 +344,15 @@ std::shared_ptr getMatmulHeuristics( } // Check initial conditions - auto mma_exprs = ir_utils::getOpsOfType(fusion); - mma_utils::CombineMulSum combiner(fusion); + std::vector patterns = + mma_utils::findMatmulPatterns(fusion); + NVF_ERROR(!patterns.empty(), "No matmul patterns were found"); NVF_ERROR( - combiner.isValid(), - "There's no (single) mma op or mul-sum op which mma op can replace") + patterns.size() == 1, + "Only a single matmul pattern can currently be fused"); + mma_utils::MatmulPattern& pattern = patterns.front(); - const std::vector& mulSum = - combiner.getMulSumCanidates(); - const auto problem_shape = - getProblemShape(mulSum.front().insouts, runtime_info); + const auto problem_shape = getProblemShape(pattern, runtime_info); const auto device_prop = at::cuda::getCurrentDeviceProperties(); const auto mma_op = @@ -363,14 +361,13 @@ std::shared_ptr getMatmulHeuristics( mma_op.has_value(), "Failed to determine a MMA op for given problem."); params->mma_macro = mma_op.value(); - const auto& roles_map_opt = - mma_utils::getTensorsRoles(fusion, mulSum.front().insouts); + const auto& roles_map_opt = mma_utils::getTensorsRoles(fusion, pattern); NVF_ERROR(roles_map_opt.isValid(), "Tensor roles map in mma is not valid."); const auto roles_map = roles_map_opt.getData(); if (matmul_heuristic_plugin::hasPlugin()) { const mma_utils::MatmulProblemLayoutOpt layout_opt = - mma_utils::getMmaLayout(fusion, mulSum.front().insouts); + mma_utils::getMmaLayout(fusion, pattern); NVF_ERROR(layout_opt.isValid(), layout_opt.getErrorMsg()); const MmaLayout layout = layout_opt.getData(); diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 759ca4651b8..6f27e406db4 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1052,14 +1052,13 @@ inline void resolveTvToMatmulDomainsMapping( } // anonymous namespace -ProblemIterDomainsOpt getProblemIterDomains( - const mma_utils::MulSumProperties::InputsOutputs& props) { +ProblemIterDomainsOpt getProblemIterDomains(const MatmulPattern& pattern) { // NOTE: the iter domains of MMA output should be [...,M,K,N] IterDomain* m = nullptr; IterDomain* n = nullptr; IterDomain* k = nullptr; - const auto& leaf_domains = props.out->getLeafDomain(); + const auto& leaf_domains = pattern.output->getLeafDomain(); const auto concrete = TensorDomain::noDevices( TensorDomain::noReductions(TensorDomain::noBroadcasts(leaf_domains))); if (concrete.size() < MIN_MATMUL_INPUTS_NUMBER) { @@ -1103,7 +1102,7 @@ ProblemIterDomainsOpt getProblemIterDomains(Fusion* fusion) { MatmulProblemLayoutOpt getMmaLayout( Fusion* fusion, - const mma_utils::MulSumProperties::InputsOutputs& props) { + const MatmulPattern& pattern) { ComputeAtMap ca_map(fusion); const auto mma_input_candidates = ir_utils::filterByType(fusion->inputs()).vector(); @@ -1111,7 +1110,7 @@ MatmulProblemLayoutOpt getMmaLayout( return {"Failed to find any TV that is fusion input"}; } - const auto mma_output_domains = getProblemIterDomains(props); + const auto mma_output_domains = getProblemIterDomains(pattern); if (!mma_output_domains.isValid()) { return mma_output_domains.getErrorMsg(); } @@ -1196,9 +1195,7 @@ MatmulProblemLayoutOpt getMmaLayout(Fusion* fusion) { static_cast(mma_exprs.front()->out())}); } -RolesMapOpt getTensorsRoles( - Fusion* fusion, - const mma_utils::MulSumProperties::InputsOutputs& props) { +RolesMapOpt getTensorsRoles(Fusion* fusion, const MatmulPattern& pattern) { ComputeAtMap ca_map(fusion); const auto mma_input_candidates = ir_utils::filterByType(fusion->inputs()).vector(); @@ -1211,7 +1208,7 @@ RolesMapOpt getTensorsRoles( return {"Failed to find any TV that is fusion output"}; } - const auto mma_output_domains = getProblemIterDomains(props); + const auto mma_output_domains = getProblemIterDomains(pattern); if (!mma_output_domains.isValid()) { return mma_output_domains.getErrorMsg(); } @@ -1339,15 +1336,6 @@ RolesMapOpt getTensorsRoles(Fusion* fusion) { namespace { -void addMMAOp(Fusion* fusion_, std::vector& props) { - for (auto prop : props) { - auto* init = - IrBuilder::create(0.0, prop.insouts.out->getDataType().value()); - IrBuilder::create( - prop.insouts.out, prop.insouts.a, prop.insouts.b, init); - } -} - // Check the val (in) is the output of broadcast. // Then check the output of the broadcast is 3D (4D for bmm). bool hasValidBroadcastOp(TensorView* bcast_out) { @@ -1438,91 +1426,8 @@ TensorView* getTensorviewPriorToCast(TensorView* in) { return in; } -// Check if the Mul-Sum pair represents a matmul. If so, add the properties -// of the mma op which can be a tentatice substitue. This checks that the output -// of sum has on reduction axis, and the inputs to mul are valid broadcasts. -std::optional getMulSumInsOutsBcasts( - BinaryOp* mop, - ReductionOp* redop) { - auto a = getTensorviewPriorToCast(static_cast(mop->lhs())); - auto b = getTensorviewPriorToCast(static_cast(mop->rhs())); - - // Get the dimension of the reduction in the output. If not present, bail. - // Also ensure there is only only reduction axis. - auto red_axis = static_cast(redop->out())->getReductionAxis(); - auto num_reduction_dims = - static_cast(redop->out())->domain()->nDims() - - static_cast(redop->out())->domain()->noReductions().size(); - if (!red_axis.has_value() || num_reduction_dims > 1) { - return std::nullopt; - } - - if (broadcastsAreValid(a, b, *red_axis)) { - MulSumProperties props = { - {mop, redop}, - {a, b, static_cast(redop->output(0))}, - {dynamic_cast(a->definition()), - dynamic_cast(b->definition())}}; - return props; - } - return std::nullopt; -} } // namespace -void CombineMulSum::handle(ReductionOp* stmt) { - // Check if operation is a sum. - if (stmt->getReductionOpType() == BinaryOpType::Add) { - auto* inputOfSum = stmt->in(); - if (inputOfSum != nullptr) { - auto* expr = inputOfSum->definition(); - // Then check if the prodcer of the sum is a mul. - if (auto bOp = dynamic_cast(expr)) { - // If it'a mul followed by a sum, put this in a list. - if (bOp->getBinaryOpType() == BinaryOpType::Mul) { - // If the Mul-Sum is a valid representation of a matmul, - // then get the properties of the replacement Mma op. - auto props = getMulSumInsOutsBcasts(bOp, stmt); - if (props.has_value()) { - mul_sum_props_.push_back(*props); - } - } - } - } - } -}; - -void CombineMulSum::generateMulSumCanidates() { - auto mma_exprs = ir_utils::getOpsOfType(fusion_); - if (mma_exprs.size() == 1) { - mma_utils::MulSumProperties props; - props.insouts = { - static_cast(mma_exprs.front()->inA()), - static_cast(mma_exprs.front()->inB()), - static_cast(mma_exprs.front()->out())}; - mul_sum_props_.push_back(props); - } else { - traverse(fusion_); - } - is_valid_ = (mul_sum_props_.size() == 1) ? true : false; -} - -const std::vector& CombineMulSum::getMulSumCanidates( - const bool refresh_data) { - if (refresh_data) { - mul_sum_props_.clear(); - generateMulSumCanidates(); - } - return mul_sum_props_; -} - -void CombineMulSum::replaceWithMmaOp() { - // Recreate the mul-sum pairs since someone - // may run this function more than once. - generateMulSumCanidates(); - addMMAOp(fusion_, mul_sum_props_); - return; -} - char dtypeToChar(const DataType& dtype) { if (dtype == DataType::Half) { return 'H'; @@ -1591,6 +1496,9 @@ class MatmulPatternMatcher : IterVisitor { // These sizes should match since ops::maybeBroadcast places BroadcastOps // for implicit broadcasting. NVF_ERROR(lrf.size() == rrf.size()); + const std::vector& red_root = + rop->out()->as()->getRootDomain(); + NVF_ERROR(red_root.size() == lrf.size()); bool has_m = false, has_n = false; for (size_t i : c10::irange(lrf.size())) { if (lrf[i]->isBroadcast() && !rrf[i]->isBroadcast()) { @@ -1598,6 +1506,12 @@ class MatmulPatternMatcher : IterVisitor { } else if (!lrf[i]->isBroadcast() && rrf[i]->isBroadcast()) { has_n = true; } + if (red_root[i]->isReduction()) { + // matmul must be contraction of non-broadcast dimensions + if (!lrf[i]->isIteration() || !rrf[i]->isIteration()) { + return; + } + } } if (!has_m || !has_n) { // This is an ordinary reduction, not a matmul diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index d365bb63440..2c566494822 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -242,34 +242,32 @@ class DataWrapperOpt { } }; -// This struct hold properties of a Mul and Sum pair -// which can possibly be replaced a Mma op. This struct -// can be be created (partially) from a Mma op. -struct MulSumProperties { - // The Mul amd Sum op which can be replaced by a Mma op. - struct MulAndSumOps { - BinaryOp* mop = nullptr; - ReductionOp* redop = nullptr; - }; - - // The inputs/ouputs to the possible Mma Op or the actual Mma op. - struct InputsOutputs { - TensorView* a = nullptr; - TensorView* b = nullptr; - TensorView* out = nullptr; - }; - - // The broadcasts which feed the Mma op/Mul-Sum pair. - struct Broadcasts { - BroadcastOp* bcast_a = nullptr; - BroadcastOp* bcast_b = nullptr; - }; - - MulAndSumOps mulsumops; - InputsOutputs insouts; - Broadcasts bcasts; +//! This represents a single matmul operation, without a prologue or epilogue. +//! Each matmul has two inputs which might not be fusion inputs: A and B. It +//! also has one output, which can be Float or reduced precision. For MatmulOp +//! and LinearOp, the output is the same dtype as the inputs; so output does not +//! necessarily correspond to the output of a translated MmaOp and it might not +//! be a fusion output. +struct MatmulPattern { + TensorView* A; + TensorView* B; + // This is not necessarily a Fusion output, but rather is the immediate output + // representing a matmul in the current Fusion. The definition of this tensor + // determines what kind of translation is needed, if any. Possible definition + // Expr types are: MmaOp, ReductionOp (for mul-sum patterns), MatmulOp, and + // LinearOp. + TensorView* output; + + //! If the pattern is not already represented by an MmaOp, for example if + //! there is a MatmulOp instead, this function modifies the fusion to insert + //! an MmaOp. TensorViews A and B are unchanged, but this->output might be + //! updated to reflect the replacement tensor. + void translateToMmaOp(); }; +//! Traverse the fusion to find supported matmul patterns +std::vector findMatmulPatterns(Fusion* fusion); + using MatmulProblemLayoutOpt = DataWrapperOpt; using ProblemIterDomainsOpt = DataWrapperOpt; using RolesMapOpt = DataWrapperOpt; @@ -290,12 +288,11 @@ using DependenciesMap = std::map; //! transposition of inputs in mma instructions, while other (e.g. Turing, //! Ampere) the only supported transposition is TN which means that mma //! instruction first input is transposed, the second input is non-transposed. -NVF_API MatmulProblemLayoutOpt getMmaLayout( - Fusion* fusion, - const mma_utils::MulSumProperties::InputsOutputs& props); +NVF_API MatmulProblemLayoutOpt +getMmaLayout(Fusion* fusion, const MatmulPattern& pattern); //! This overloaded version is just a wrapper on the above function, where -//! the mma_utils::MulSumProperties::InputsOutputs is extracted from the fusion. +//! the MatmulPattern is extracted from the fusion. NVF_API MatmulProblemLayoutOpt getMmaLayout(Fusion* fusion); //! Returns wrapped collection of IterDomains that can be used to get @@ -306,15 +303,12 @@ NVF_API MatmulProblemLayoutOpt getMmaLayout(Fusion* fusion); //! be gathered. //! TODO: 4th domain must be added for batch gemm support. ProblemIterDomainsOpt getProblemIterDomains(Fusion* fusion); -ProblemIterDomainsOpt getProblemIterDomains( - const mma_utils::MulSumProperties::InputsOutputs& props); +ProblemIterDomainsOpt getProblemIterDomains(const MatmulPattern& pattern); //! Returns wrapped collection of TensorView roles in fusion. //! An error message is stored in retruned object if valid data cannot //! be gathered. -RolesMapOpt getTensorsRoles( - Fusion* fusion, - const mma_utils::MulSumProperties::InputsOutputs& props); +RolesMapOpt getTensorsRoles(Fusion* fusion, const MatmulPattern& pattern); RolesMapOpt getTensorsRoles(Fusion* fusion); //! Return pair of whether use shared memory epilogue or not and whether to @@ -346,76 +340,6 @@ NVF_API std::pair generateSharedMemoryEpilogueHeuristics( bool smem_b_reuse_guaranteed = false, bool ignore_occupancy_drop = false); -//! Go through the fusion IR to find combinations of mul-sum -//! which can be replaced with a mma op. This class operates -//! in two phases. It can go through the graph and find the mul-sum -//! pairs which can be replaced by a mma op. This phase returns a vector -//! of properties of the mma op (MulSumAsMmaProps) which would replace the -//! the mul-sum pair. It then exposes a function to replace with mma ops. -class CombineMulSum : public IterVisitor { - public: - CombineMulSum(Fusion* fusion) : IterVisitor(), fusion_(fusion) { - generateMulSumCanidates(); - }; - - const std::vector& getMulSumCanidates( - const bool refresh_data = false); - - //! Goes through the fusion to find mul-sum pairs. - //! If user sets the caching flags and properties have been previously - //! computed, then just return cached results. - void generateMulSumCanidates(); - - //! Replaces the candidate mul-sum pairs with mma ops. - //! Please not this will run generateMulSumCandidates again. - void replaceWithMmaOp(); - - //! Check if the fusion has a mma-op or a mul-sum pair - //! that can be replaced by a mma op. - bool isValid() { - return is_valid_; - } - - protected: - void handle(ReductionOp* stmt) override; - - private: - Fusion* fusion_; - //! This is the list of mul-sum pairs and the properties - //! of the mma op which can replace it. This is only populated - //! if the mul-sum pair is a valid replacement candidate. - std::vector mul_sum_props_ = {}; - //! This variable tracks if the fusion has a mul-sum pair - //! than can be replaced by a mma op, or has a single mma op. - bool is_valid_ = false; -}; - -//! This represents a single matmul operation, without a prologue or epilogue. -//! Each matmul has two inputs which might not be fusion inputs: A and B. It -//! also has one output, which can be Float or reduced precision. For MatmulOp -//! and LinearOp, the output is the same dtype as the inputs; so output does not -//! necessarily correspond to the output of a translated MmaOp and it might not -//! be a fusion output. -struct MatmulPattern { - TensorView* A; - TensorView* B; - // This is not necessarily a Fusion output, but rather is the immediate output - // representing a matmul in the current Fusion. The definition of this tensor - // determines what kind of translation is needed, if any. Possible definition - // Expr types are: MmaOp, ReductionOp (for mul-sum patterns), MatmulOp, and - // LinearOp. - TensorView* output; - - //! If the pattern is not already represented by an MmaOp, for example if - //! there is a MatmulOp instead, this function modifies the fusion to insert - //! an MmaOp. TensorViews A and B are unchanged, but this->output might be - //! updated to reflect the replacement tensor. - void translateToMmaOp(); -}; - -//! Traverse the fusion to find supported matmul patterns -std::vector findMatmulPatterns(Fusion* fusion); - //! Compute the amount of shared memory we expect to need. The actual amount //! allocated will be determined by aliasing (see alias_memory.cpp). This //! function is useful for testing that we provide accurate information to our diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index 7752af52914..6b4e99ce5bc 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -67,6 +67,24 @@ class CombineMulSumAsMmaTest : public NVFuserTest { DisableOptionsGuard opt_guard_; }; +void performSubstitution(Fusion* fusion, bool should_not_replace = false) { + EXPECT_TRUE(ir_utils::getOpsOfType(fusion).empty()); + + std::vector patterns = + mma_utils::findMatmulPatterns(fusion); + if (should_not_replace) { + EXPECT_TRUE(patterns.empty()); + return; + } + + ASSERT_FALSE(patterns.empty()); + EXPECT_EQ(patterns.size(), 1); + + patterns.front().translateToMmaOp(); + + ASSERT_FALSE(ir_utils::getOpsOfType(fusion).empty()); +} + // Test checks to see that the combiner can correctly replace // the mul-sum pair with a mma op. TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Pass) { @@ -86,12 +104,7 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Pass) { fusion.addOutput(tv3); - ASSERT_TRUE(ir_utils::getOpsOfType(&fusion).empty()); - - nvfuser::mma_utils::CombineMulSum combiner(&fusion); - combiner.replaceWithMmaOp(); - - ASSERT_FALSE(ir_utils::getOpsOfType(&fusion).empty()); + performSubstitution(&fusion); } } @@ -110,10 +123,7 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Fail1) { auto tv3 = sum(tv2, {-1}); fusion.addOutput(tv3); - nvfuser::mma_utils::CombineMulSum combiner(&fusion); - combiner.replaceWithMmaOp(); - - ASSERT_TRUE(ir_utils::getOpsOfType(&fusion).empty()); + performSubstitution(&fusion, /*should_not_replace=*/true); } // This test checks to see that the mul-sum combiner does not @@ -145,12 +155,7 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Fail2) { auto tv3 = sum(tv2, {-1}); fusion.addOutput(tv3); - ASSERT_TRUE(ir_utils::getOpsOfType(&fusion).empty()); - - nvfuser::mma_utils::CombineMulSum combiner(&fusion); - combiner.replaceWithMmaOp(); - - ASSERT_TRUE(ir_utils::getOpsOfType(&fusion).empty()); + performSubstitution(&fusion, /*should_not_replace=*/true); } // As a sanity check we test that after replacing a mul-sum @@ -174,11 +179,8 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Schedule) { auto tv2 = sum(mul(tv0, tv1), {-1}); fusion.addOutput(tv2); - ASSERT_TRUE(ir_utils::getOpsOfType(&fusion).empty()); - nvfuser::mma_utils::CombineMulSum combiner(&fusion); - combiner.replaceWithMmaOp(); - ASSERT_FALSE(ir_utils::getOpsOfType(&fusion).empty()); + performSubstitution(&fusion); MatMulTileOptions gemm_tile; gemm_tile.cta_tile = GemmTile(128, 128, 32); From 03e9bdccbfb67652b9771eda6e15ead2ab9a38d3 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 14 May 2024 12:40:59 +0000 Subject: [PATCH 05/52] Add a test --- tests/cpp/test_combine_mul_sum.cpp | 50 ++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index f87bbf15423..c3f7e22b670 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -263,4 +263,54 @@ TEST_F(CombineMulSumAsMmaTest, UseMatmulScheduler) { } } +TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerMatmulNode) { + // Keep multiples of 8 to keep vectorizable. + int M = 504, N = 136, K = 248; + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion->addInput(tv0); + fusion->addInput(tv1); + auto tv2 = matmul(tv0, tv1); + + fusion->addOutput(tv2); + + ASSERT_TRUE(ir_utils::getOpsOfType(fusion.get()).empty()); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({K, N}, options); + auto tref = at::matmul(t0, t1); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto outputs = executor_cache.runFusionWithInputs({t0, t1}); + + if (!isOptionDisabled(DisableOption::MatmulExprEval)) { + // Ensure there's a mma op. + // If there's no mma op present, then stop the test. + ASSERT_FALSE(ir_utils::getOpsOfType( + executor_cache.getMostRecentKernelRuntime() + ->executors() + .at(0) + .kernel()) + .empty()); + // Ensure that the matmul scheduler ran. + EXPECT_TRUE( + dynamic_cast( + executor_cache.getMostRecentKernelRuntime() + ->schedulerHeuristics() + ->heuristicsList() + .at(0) + .get()) != nullptr); + + EXPECT_FALSE(executor_cache.getMostRecentKernelRuntime()->isSegmented()); + } + + testValidate( + executor_cache.fusion(), outputs, {t0, t1}, {tref}, __LINE__, __FILE__); +} + } // namespace nvfuser From c0d33f5e3f0ec671b4d109b1dad127a4644dba5d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 14 May 2024 13:08:59 +0000 Subject: [PATCH 06/52] Re-enable NVFUSER_DISABLE=matmul_expr_eval --- csrc/scheduler/expr_eval_sched.cpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/csrc/scheduler/expr_eval_sched.cpp b/csrc/scheduler/expr_eval_sched.cpp index c600b2f0bea..7482c71892d 100644 --- a/csrc/scheduler/expr_eval_sched.cpp +++ b/csrc/scheduler/expr_eval_sched.cpp @@ -16,12 +16,18 @@ namespace nvfuser { // Check if the fusion has a single MatmulOp node bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { auto exprs = fusion->exprs(); - if (exprs.size() == 1 && exprs.front()->isA()) { - return true; + if (!isOptionDisabled(DisableOption::MatmulExprEval)) { + if (exprs.size() == 1 && exprs.front()->isA()) { + return true; + } + scheduler_debug_utils::canScheduleRejectReason( + heuristicType(), + "Fusion must contain a single expression of type MatmulOp"); + } else { + scheduler_debug_utils::canScheduleRejectReason( + heuristicType(), + "Matmul ATen evaluation was disabled by NVFUSER_DISABLE=matmul_expr_eval"); } - scheduler_debug_utils::canScheduleRejectReason( - heuristicType(), - "Fusion must contain a single expression of type MatmulOp"); return false; } From a697b5544f5b8666648034e4a9a1d1bc9b90d997 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 14 May 2024 20:20:29 +0000 Subject: [PATCH 07/52] Add MatmulOp to ir_utils::isTvOp This also adds testing of exact mapping to the node tests (WIP). --- csrc/device_lower/utils.cpp | 1 + tests/cpp/test_matmul_aten_evaluation.cpp | 40 +++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index b82496ed623..99b22087589 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -149,6 +149,7 @@ bool isTvOp(const Expr* expr) { WelfordOp, GroupedWelfordOp, LoadStoreOp, + MatmulOp, MmaOp, BroadcastOp, SqueezeOp, diff --git a/tests/cpp/test_matmul_aten_evaluation.cpp b/tests/cpp/test_matmul_aten_evaluation.cpp index e9c0b190f2c..633d45f0d01 100644 --- a/tests/cpp/test_matmul_aten_evaluation.cpp +++ b/tests/cpp/test_matmul_aten_evaluation.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -409,6 +410,41 @@ TEST_F(MatmulATenEvaluationTest, LinearWithBias) { EXPECT_TRUE(at::allclose(out[0], out_ref)); } +// Check that ID exact mapping works as expected +void checkMatmulOpIdMapping( + Fusion* fusion, + TensorView* A, + TensorView* B, + TensorView* output) { + IdModel id_model(fusion); + const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); + vg.validateConsistency(); + + const auto checkMapped = [&vg](IterDomain* x, IterDomain* y) -> bool { + if (!vg.hasGroup(x) || !vg.hasGroup(y)) { + return false; + } + const ValGroup& gx = vg.toGroup(x); + const ValGroup& gy = vg.toGroup(y); + return gx.get() == gy.get(); + }; + + if (A->nDims() == 2 && B->nDims() == 2) { + // [M, K] @ [K, N] = [M, N] + ASSERT_EQ(output->nDims(), 2); + EXPECT_TRUE(checkMapped(A->axis(0), output->axis(0))); // M + EXPECT_TRUE(checkMapped(B->axis(1), output->axis(1))); // N + // EXPECT_TRUE(checkMapped(A->axis(1), B->axis(0))); // K + } else if (A->nDims() == 1 && B->nDims() == 1) { + // [M, K] @ [K, N] = [M, N] + EXPECT_EQ(output->nDims(), 0); + // EXPECT_TRUE(checkMapped(A->axis(0), B->axis(0))); // K + } else { + std::cout << "Unhandled set of input dimensions" << std::endl; + // EXPECT_TRUE(false); + } +} + TEST_P(ATenNodesParametrizedTest, MatmulNodeConcrete) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -423,6 +459,8 @@ TEST_P(ATenNodesParametrizedTest, MatmulNodeConcrete) { fusion->addInput(tv1); fusion->addOutput(tv2); + checkMatmulOpIdMapping(fusion.get(), tv0, tv1, tv2); + at::Tensor t0 = at::randn(a_shape, at::kHalf).cuda(); at::Tensor t1 = at::randn(b_shape, at::kHalf).cuda(); at::Tensor out_ref = at::matmul(t0, t1); @@ -447,6 +485,8 @@ TEST_P(ATenNodesParametrizedTest, MatmulNodeSymbolic) { fusion->addInput(tv1); fusion->addOutput(tv2); + checkMatmulOpIdMapping(fusion.get(), tv0, tv1, tv2); + at::Tensor t0 = at::randn(a_shape, at::kHalf).cuda(); at::Tensor t1 = at::randn(b_shape, at::kHalf).cuda(); at::Tensor out_ref = at::matmul(t0, t1); From d50e19b58bc5cdac10ed81a93c8141b3a758b54b Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 14 May 2024 20:24:32 +0000 Subject: [PATCH 08/52] Big refactor to use IdModel and allocation domain Use IdModel to define ID roles (and hence TV roles). Also use allocation domains to determine whether each operand has K as its inner dimension. --- csrc/device_lower/utils.cpp | 1 + csrc/mma_type.h | 2 +- csrc/root_domain_map.cpp | 2 + csrc/scheduler/matmul.cpp | 36 +-- csrc/scheduler/matmul_utils.cpp | 70 +++--- csrc/scheduler/mma_utils.cpp | 343 +++++++++++++++++++++------- csrc/scheduler/mma_utils.h | 26 ++- tests/cpp/test_combine_mul_sum.cpp | 97 +++++--- tests/cpp/test_gpu_tensorcore.cpp | 2 +- tests/cpp/test_matmul_scheduler.cpp | 52 ++--- 10 files changed, 422 insertions(+), 209 deletions(-) diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index b82496ed623..99b22087589 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -149,6 +149,7 @@ bool isTvOp(const Expr* expr) { WelfordOp, GroupedWelfordOp, LoadStoreOp, + MatmulOp, MmaOp, BroadcastOp, SqueezeOp, diff --git a/csrc/mma_type.h b/csrc/mma_type.h index e38619576e2..52c27a63e7e 100644 --- a/csrc/mma_type.h +++ b/csrc/mma_type.h @@ -26,7 +26,7 @@ namespace nvfuser { constexpr std::string_view MATMUL_LOG_PREFIX = "[MATMUL DEBUG] "; //! Named descriptors of domains in matmul -enum class MatmulDomain { M = 0, N, K }; +enum class MatmulDomain { M = 0, N, K, Batch }; //! Named descriptors of TensorView roles in fusion //! INPUT_A - a producer of MMA input A diff --git a/csrc/root_domain_map.cpp b/csrc/root_domain_map.cpp index 4e76e255a55..371eda1dec0 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/root_domain_map.cpp @@ -161,6 +161,8 @@ std::unordered_map PairwiseRootDomainMap::map( } if (root_dims_to_map.find(map_key_id) != root_dims_to_map.end()) { + std::cout << "MAP " << map_key_id->toString() << " -> " + << map_value_id->toString() << std::endl; dom_map.insert(std::make_pair(map_key_id, map_value_id)); } }; diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index 81c289e0c49..4a636a596ff 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -755,37 +755,27 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { NVF_ERROR( patterns.size() == 1, "Only a single matmul pattern can currently be fused"); + std::vector mma_ops; + mma_ops.reserve(patterns.size()); for (mma_utils::MatmulPattern& pattern : patterns) { - pattern.translateToMmaOp(); + mma_ops.push_back(pattern.translateToMmaOp()); } - auto mma_ops = ir_utils::getOpsOfType(fusion); - - NVF_ERROR( - mma_ops.size() == 1, - "scheduleMatmul supports fusion with single mma op in definition, got ", - mma_ops.size()); - - // Skip scheduling if Matmul will be expression evaluated. - if (!isOptionDisabled(DisableOption::MatmulExprEval)) { - NVF_CHECK(fusion->outputs().size() == 1) - fusion->aliasOutputToInput( - fusion->outputs()[0], /*input=*/nullptr, AllocationType::Evaluate); - scheduler_debug_utils::log( - __FILE__, - ":", - __LINE__, - ", Matmul output to be computed through expression evaluator. Skipping codegen."); - return; - } - - const auto& roles_map_opt = mma_utils::getTensorsRoles(fusion); + IdModel id_model(fusion); + std::unordered_map id_roles = + patterns.front().getDimRoles(id_model); + const auto& roles_map_opt = + mma_utils::getTensorsRoles(fusion, id_model, id_roles); // NOTE: the contents of roles_map have been already validated during // compute-time checks NVF_ERROR(roles_map_opt.isValid(), roles_map_opt.getErrorMsg()); const auto roles_map = roles_map_opt.getData(); + const mma_utils::MatmulProblemLayoutOpt fusion_layout = + mma_utils::getProblemLayout(id_model, id_roles, roles_map); + NVF_ERROR(fusion_layout.isValid(), fusion_layout.getErrorMsg()); + // Core roles: there can be only one... TV with assigned core role TensorView* a = roles_map.at(MatmulRole::INPUT_A).front(); TensorView* b = roles_map.at(MatmulRole::INPUT_B).front(); @@ -796,8 +786,6 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { NVF_ERROR( mma_layout_opt.has_value(), "fusion mma op has undefined input layout"); const auto mma_layout = mma_layout_opt.value(); - const auto fusion_layout = mma_utils::getMmaLayout(fusion); - NVF_ERROR(fusion_layout.isValid(), fusion_layout.getErrorMsg()); const auto& gemm_tile = params.tile_sizes; const bool has_epilogue = !mma->out()->isFusionOutput(); diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index d71055dd149..814ce2187e2 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -15,11 +15,13 @@ // 'SchedulerRuntimeInfo' #include #include +#include #include #include #include #include #include +#include #include #include #include @@ -38,9 +40,8 @@ namespace nvfuser { namespace { -//! Access to the structure should be done with labels defined in -//! MmaOptions::MmaDomains. -using ProblemShape = std::array; +//! Access to the structure should be done with labels defined in MatmulDomain. +using ProblemShape = std::array; //! A helper for deciding the type of MMA op for given fusion and problem shape. inline std::optional getMmaOp( @@ -158,34 +159,26 @@ inline bool initCoreHeuristics( } //! A helper for getting problem shape from fusion and runtime info. +//! +//! For a given domain, try to find the size by evaluating the extent of an +//! IterDomain in each group of that domain type. For example, if there are +//! multiple Batch dimensions, we find all ValGroups that are mapped as +//! MatmulDomain::Batch, we evaluate the extent of each, then we multiply those +//! dimensions together to get the overall batch size. ProblemShape getProblemShape( - const mma_utils::MatmulPattern& pattern, + const std::unordered_map& group_to_domain, SchedulerRuntimeInfo& runtime_info) { - const auto mma_output_domains = mma_utils::getProblemIterDomains(pattern); - if (!mma_output_domains.isValid()) { - NVF_ERROR(false, mma_output_domains.getErrorMsg()); - } - - const auto [m, n, k] = mma_output_domains.getData(); - - auto m_extend = runtime_info.expressionEvaluator().evaluate(m->extent()); - auto n_extend = runtime_info.expressionEvaluator().evaluate(n->extent()); - auto k_extend = runtime_info.expressionEvaluator().evaluate(k->extent()); - - if (!(m_extend && n_extend && k_extend)) { + ProblemShape shape{1, 1, 1, 1}; + for (const auto& [g, dom] : group_to_domain) { + NVF_ERROR(!g->empty()); + IterDomain* id = g->front()->as(); + const PolymorphicValue extent = + runtime_info.expressionEvaluator().evaluate(id->extent()); NVF_ERROR( - false, - "Failed to acquire one of problem dimensions, M(", - m_extend.hasValue(), - "), N(", - n_extend.hasValue(), - " K(", - k_extend.hasValue(), - ")"); + extent.hasValue(), "Could not evaluate extent of ", id->toString()); + shape[(size_t)dom] *= extent.as(); } - - return ProblemShape{ - m_extend.as(), n_extend.as(), k_extend.as()}; + return shape; } std::string isMatmulFusionDefinitionSupported( @@ -471,7 +464,7 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { // #2 { const auto input_layout_opt = - mma_utils::getMmaLayout(fusion, patterns.front()); + mma_utils::getProblemLayout(fusion, patterns.front()); if (!input_layout_opt.isValid()) { return input_layout_opt.getErrorMsg(); } @@ -531,7 +524,13 @@ std::shared_ptr getMatmulHeuristics( "Only a single matmul pattern can currently be fused"); mma_utils::MatmulPattern& pattern = patterns.front(); - const auto problem_shape = getProblemShape(pattern, runtime_info); + // IdModel is used to analyze problem shape & layout + IdModel id_model(fusion); + + const std::unordered_map id_roles = + pattern.getDimRoles(id_model); + + const auto problem_shape = getProblemShape(id_roles, runtime_info); const auto device_prop = at::cuda::getCurrentDeviceProperties(); const auto mma_op = @@ -540,7 +539,8 @@ std::shared_ptr getMatmulHeuristics( mma_op.has_value(), "Failed to determine a MMA op for given problem."); params->mma_macro = mma_op.value(); - const auto& roles_map_opt = mma_utils::getTensorsRoles(fusion, pattern); + const auto& roles_map_opt = + mma_utils::getTensorsRoles(fusion, id_model, id_roles); NVF_ERROR(roles_map_opt.isValid(), "Tensor roles map in mma is not valid."); const auto roles_map = roles_map_opt.getData(); @@ -549,17 +549,17 @@ std::shared_ptr getMatmulHeuristics( if (matmul_heuristic_plugin::hasPlugin()) { const mma_utils::MatmulProblemLayoutOpt layout_opt = - mma_utils::getMmaLayout(fusion, pattern); + mma_utils::getProblemLayout(id_model, id_roles, roles_map); NVF_ERROR(layout_opt.isValid(), layout_opt.getErrorMsg()); const MmaLayout layout = layout_opt.getData(); // Fill in proper values using plugin matmul_heuristic_plugin::updateMatmulParams( *params, - /*M=*/problem_shape[0], - /*N=*/problem_shape[1], - /*K=*/problem_shape[2], - /*batch_size=*/1, // TODO: extract actual batch size + problem_shape[(size_t)MatmulDomain::M], + problem_shape[(size_t)MatmulDomain::N], + problem_shape[(size_t)MatmulDomain::K], + problem_shape[(size_t)MatmulDomain::Batch], layout, roles_map); } else { diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 6f27e406db4..2fa292ba7c6 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -9,11 +9,13 @@ #include #include #include +#include #include #include #include #include #include +#include #include #include "mma_type.h" namespace nvfuser { @@ -1100,99 +1102,87 @@ ProblemIterDomainsOpt getProblemIterDomains(Fusion* fusion) { static_cast(mma_exprs.front()->out())}); } -MatmulProblemLayoutOpt getMmaLayout( - Fusion* fusion, - const MatmulPattern& pattern) { - ComputeAtMap ca_map(fusion); - const auto mma_input_candidates = - ir_utils::filterByType(fusion->inputs()).vector(); - if (mma_input_candidates.empty()) { - return {"Failed to find any TV that is fusion input"}; - } - - const auto mma_output_domains = getProblemIterDomains(pattern); - if (!mma_output_domains.isValid()) { - return mma_output_domains.getErrorMsg(); +MatmulProblemLayoutOpt getProblemLayout(Fusion* fusion) { + const std::vector patterns = findMatmulPatterns(fusion); + if (patterns.size() != 1) { + std::stringstream ss; + ss << "Invalid number of MmaOp instances in fusion, expected 1, got " + << patterns.size(); + return ss.str(); } + return getProblemLayout(fusion, patterns.front()); +} - const auto domains_data = mma_output_domains.getData(); - const auto m = domains_data[(size_t)MatmulDomain::M]; - const auto n = domains_data[(size_t)MatmulDomain::N]; - const auto k = domains_data[(size_t)MatmulDomain::K]; - - DependenciesMap deps_map; - resolveTvToMatmulDomainsMapping( - deps_map, mma_input_candidates, m, n, k, ca_map); - - bool mk_found = false; - bool km_found = false; - bool nk_found = false; - bool kn_found = false; - const static DomainsDesc mk_desc = {MatmulDomain::M, MatmulDomain::K}; - const static DomainsDesc km_desc = {MatmulDomain::K, MatmulDomain::M}; - const static DomainsDesc nk_desc = {MatmulDomain::N, MatmulDomain::K}; - const static DomainsDesc kn_desc = {MatmulDomain::K, MatmulDomain::N}; - - for (const auto& item : deps_map) { - if (item.second == mk_desc) { - if (mk_found) { - return { - "Failed to find MMA input, more than one fusion input has [..., M, ..., K, ...] iter domains"}; - } - mk_found = true; +MatmulProblemLayoutOpt getProblemLayout( + const IdModel& id_model, + const std::unordered_map& group_to_domain, + const RolesMap& roles_map) { + // Assumes the exact graph has already been built, since we've been provided + // group_to_domain + const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); + + using MatmulDomainOpt = DataWrapperOpt; + const auto innerDomain = [&roles_map, &group_to_domain, &exact_graph]( + MatmulRole role) -> MatmulDomainOpt { + const auto role_it = roles_map.find(role); + if (role_it == roles_map.end()) { + return {"Could not find role in roles_map"}; } - if (item.second == km_desc) { - if (km_found) { - return { - "Failed to find MMA input, more than one fusion input has [..., K, ..., M, ...] iter domains"}; + std::optional group_inner_dom = std::nullopt; + for (TensorView* tv : role_it->second) { + IterDomain* inner_id = + TensorDomain::noReductions(tv->getMaybeAllocationDomain()).back(); + const ValGroup& g = exact_graph.toGroup(inner_id); + auto g_it = group_to_domain.find(g); + if (g_it == group_to_domain.end()) { + return {"Inner domain of tensor was not mapped to a MatmulDomain"}; } - km_found = true; - } - if (item.second == nk_desc) { - if (nk_found) { - return { - "Failed to find MMA input, more than one fusion input has [..., N, ..., K, ...] iter domains"}; + if (!group_inner_dom.has_value()) { + group_inner_dom = g_it->second; + } else if (group_inner_dom.value() != g_it->second) { + return {"Group contains multiple inner dimension domains"}; } - nk_found = true; } - if (item.second == kn_desc) { - if (kn_found) { - return { - "Failed to find MMA input, more than one fusion input has [..., K, ..., N, ...] iter domains"}; - } - kn_found = true; + if (!group_inner_dom.has_value()) { + return {"No tensor found in role"}; } + MatmulDomain dom = group_inner_dom.value(); + return MatmulDomainOpt(std::move(dom)); + }; + MatmulDomainOpt a_inner_dom = innerDomain(MatmulRole::INPUT_A); + if (!a_inner_dom.isValid()) { + return a_inner_dom.getErrorMsg(); + } + MatmulDomainOpt b_inner_dom = innerDomain(MatmulRole::INPUT_B); + if (!b_inner_dom.isValid()) { + return b_inner_dom.getErrorMsg(); } - if ((mk_found && kn_found) && !(km_found || nk_found)) { + bool kinner_a = a_inner_dom.getData() == MatmulDomain::K; + bool kinner_b = b_inner_dom.getData() == MatmulDomain::K; + + if (kinner_a && kinner_b) { + return MmaLayout::TN; + } else if (kinner_a && !kinner_b) { return MmaLayout::TT; - } - if ((km_found && kn_found) && !(mk_found || nk_found)) { + } else if (!kinner_a && !kinner_b) { return MmaLayout::NT; - } - if ((mk_found && nk_found) && !(km_found || kn_found)) { - return MmaLayout::TN; - } - if ((km_found && nk_found) && !(mk_found || kn_found)) { + } else if (!kinner_a && kinner_b) { return MmaLayout::NN; } - - return {"Failed to decide fusion inputs' data layout."}; + NVF_ERROR(false, "Reached unreachable section of getProblemLayout"); } -MatmulProblemLayoutOpt getMmaLayout(Fusion* fusion) { - auto mma_exprs = ir_utils::getOpsOfType(fusion); - if (mma_exprs.size() != 1) { - std::stringstream ss; - ss << "Invalid number of MmaOp instances in fusion, expected 1, got " - << mma_exprs.size(); - return ss.str(); +MatmulProblemLayoutOpt getProblemLayout( + Fusion* fusion, + const MatmulPattern& pattern) { + IdModel id_model(fusion); + const auto id_roles = pattern.getDimRoles(id_model); + const auto roles_map_opt = getTensorsRoles(fusion, id_model, id_roles); + if (!roles_map_opt.isValid()) { + return {roles_map_opt.getErrorMsg()}; } - return getMmaLayout( - fusion, - {static_cast(mma_exprs.front()->inA()), - static_cast(mma_exprs.front()->inB()), - static_cast(mma_exprs.front()->out())}); + return getProblemLayout(id_model, id_roles, roles_map_opt.getData()); } RolesMapOpt getTensorsRoles(Fusion* fusion, const MatmulPattern& pattern) { @@ -1319,6 +1309,127 @@ RolesMapOpt getTensorsRoles(Fusion* fusion, const MatmulPattern& pattern) { return roles_map; } +RolesMapOpt getTensorsRoles( + Fusion* fusion, + const IdModel& id_model, + const std::unordered_map& group_to_domain) { + const auto mma_input_candidates = + ir_utils::filterByType(fusion->inputs()).vector(); + if (mma_input_candidates.empty()) { + return {"Failed to find any TV that is fusion input"}; + } + const auto mma_output_candidates = + ir_utils::filterByType(fusion->outputs()).vector(); + if (mma_output_candidates.empty()) { + return {"Failed to find any TV that is fusion output"}; + } + + RolesMap roles_map; + + // Assumes the exact graph has already been built, since we've been provided + // group_to_domain + const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); + + for (TensorView* tv : mma_input_candidates) { + bool has_m = false, has_n = false, has_k = false, has_unmapped = false; + for (IterDomain* id : + TensorDomain::noReductions(tv->getMaybeRFactorDomain())) { + const ValGroup& g = exact_graph.toGroup(id); + auto it = group_to_domain.find(g); + if (it == group_to_domain.end()) { + // tv has an unmapped dimension + has_unmapped = true; + continue; + } + has_m |= it->second == MatmulDomain::M; + has_n |= it->second == MatmulDomain::N; + has_k |= it->second == MatmulDomain::K; + } + if (has_unmapped) { + // Don't map TVs to roles if they have unmapped dims + continue; + } + if (has_m && has_k && !has_n) { + roles_map[MatmulRole::INPUT_A].push_back(tv); + continue; + } + if (has_n && has_k && !has_m) { + roles_map[MatmulRole::INPUT_B].push_back(tv); + continue; + } + // Bias vectors are assigned to INPUT_C role + if (!has_k) { + roles_map[MatmulRole::INPUT_C].push_back(tv); + continue; + } + } + + std::vector storage; + for (TensorView* tv : mma_input_candidates) { + bool has_m = false, has_n = false, has_k = false, has_unmapped = false; + for (IterDomain* id : + TensorDomain::noReductions(tv->getMaybeRFactorDomain())) { + const ValGroup& g = exact_graph.toGroup(id); + auto it = group_to_domain.find(g); + if (it == group_to_domain.end()) { + // tv has an unmapped dimension + has_unmapped = true; + continue; + } + has_m |= it->second == MatmulDomain::M; + has_n |= it->second == MatmulDomain::N; + has_k |= it->second == MatmulDomain::K; + } + // NOTE: depending on fusion definition k domain may appear in the output: + // - for mma_output == fusion output k domain is present + // - for mma_output != fusion output (fusion with epilogue) k domain + // is not present + if (has_k || has_unmapped) { + // Don't map TVs to output roles if they have unmapped dims, or if they + // have K dimension + continue; + } + + // NOTE: the core fusion output tensors are the ones with m and n + // domains + if (has_m && has_n) { + storage.push_back(tv); + } + } + + // NOTE: sort output roles in descending order by uses() size, and + // if equal then by name() to ensure the stable ordering of tensor + // views in collections assigned to the supported roles + std::sort(storage.begin(), storage.end(), [](TensorView* a, TensorView* b) { + return (a->uses().size() == b->uses().size()) + ? (a->name() < b->name()) + : (a->uses().size() > b->uses().size()); + }); + + if (!storage.empty()) { + // NOTE: currently, we pick as a reference tensor one with `m` and `n` + // IterDomains and the most uses + auto pos = storage.begin(); + roles_map[MatmulRole::OUTPUT_D].push_back(*pos); + for (++pos; pos != storage.end(); ++pos) { + roles_map[MatmulRole::OUTPUT_AUX].push_back(*pos); + } + } + + for (auto& [role, tvs] : roles_map) { + // NOTE: sort input roles in descending order by uses() size, and + // if equal then by name() to ensure the stable ordering of tensor + // views in collections assigned to the supported roles + std::sort(tvs.begin(), tvs.end(), [](TensorView* a, TensorView* b) { + return (a->uses().size() == b->uses().size()) + ? (a->name() < b->name()) + : (a->uses().size() > b->uses().size()); + }); + } + + return roles_map; +} + RolesMapOpt getTensorsRoles(Fusion* fusion) { auto mma_exprs = ir_utils::getOpsOfType(fusion); if (mma_exprs.size() != 1) { @@ -1535,10 +1646,10 @@ std::vector findMatmulPatterns(Fusion* fusion) { return MatmulPatternMatcher::run(fusion); } -void MatmulPattern::translateToMmaOp() { - if (dynamic_cast(output->definition())) { +MmaOp* MatmulPattern::translateToMmaOp() { + if (auto mma_op = dynamic_cast(output->definition())) { // No translation needed - return; + return mma_op; } else if (auto mop = dynamic_cast(output->definition())) { // MatmulOp takes inputs whose sizes are [..., M, K] and [..., K, N], so we // must transpose B then broadcast both operands before creating the final @@ -1551,6 +1662,7 @@ void MatmulPattern::translateToMmaOp() { TensorView* Abcast = unsqueeze(A, -2); TensorView* Bbcast = unsqueeze(Btrans, -3); TensorView* fms = fusedMultiplySum(Abcast, Bbcast, {-1}); + auto mma_op = fms->definition()->as(); // Update operands to keep the pattern minimal A = Abcast; B = Bbcast; @@ -1563,11 +1675,80 @@ void MatmulPattern::translateToMmaOp() { // No cast needed, for example the inputs might be Float ir_utils::transferDefinitionToNewOutputs(fms->definition(), {output}); } + return mma_op; } else if (auto rop = dynamic_cast(output->definition())) { Val* init = IrBuilder::create(0.0, output->dtype()); // This replaces the mul and sum by overwriting output->definition() - IrBuilder::create(output, A, B, init); + return IrBuilder::create(output, A, B, init); } + NVF_ERROR( + false, + "Could not translate matmul pattern with output ", + output->toString(), + " to MmaOp"); +} + +std::unordered_map MatmulPattern::getDimRoles( + IdModel& id_model) const { + id_model.maybeBuildGraph(IdMappingMode::EXACT); + const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); + + // There are four types of ValGroup involved in a MatmulPattern: M, N, K, and + // Batch. These are enumerated in the MatmulDomain enum class. They are + // defined by their membership as follows: + // M: present in A and output, but not B + // N: present in B and output, but not A + // K: present in A and B, but not output + // Batch: present in all A, B, and Batch + // If there are other membership patterns, for example a ValGroup present in + // only A, then we should raise an exception here. + + // Indicates whether a ValGroup is present in A (bit 0), B (bit 1), or output + // (bit 2) + using ValGroupPresence = std::bitset<3>; + + std::unordered_map present_flags; + const auto recordPresence = [&exact_graph, &present_flags]( + TensorView* tv, size_t tensor_num) { + std::cout << "recordPresence " << tv->toString() << std::endl; + for (IterDomain* id : tv->getMaybeRFactorDomain()) { + if (id->isReduction() || id->isBroadcast()) { + // ignore reductions and broadcasts since they don't exact map to + // problem dims + continue; + } + std::cout << " id=" << id->toString() << std::endl; + const ValGroup& g = exact_graph.toGroup(id); + std::cout << " updating bitset " << present_flags[g] << " to "; + present_flags[g].set(tensor_num); + std::cout << present_flags[g] << std::endl; + } + }; + A->fusion()->printMath(); + std::cout << id_model.toString() << std::endl; + recordPresence(A, 0); + recordPresence(B, 1); + recordPresence(output, 2); + + std::unordered_map dim_to_domain; + for (const auto& [g, flags] : present_flags) { + if (flags.all()) { + dim_to_domain[g] = MatmulDomain::Batch; + } else if (flags.test(0) && flags.test(1)) { + dim_to_domain[g] = MatmulDomain::K; + } else if (flags.test(0) && flags.test(2)) { + dim_to_domain[g] = MatmulDomain::M; + } else if (flags.test(1) && flags.test(2)) { + dim_to_domain[g] = MatmulDomain::N; + } else { + NVF_ERROR( + false, + "IterDomain ValGroup should be present in at least two of A, B, and output. flags: ", + flags); + } + } + + return dim_to_domain; } } // namespace mma_utils diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index 2c566494822..841cd7549d0 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -9,8 +9,10 @@ #include #include +#include #include #include +#include #include #include #include @@ -262,7 +264,14 @@ struct MatmulPattern { //! there is a MatmulOp instead, this function modifies the fusion to insert //! an MmaOp. TensorViews A and B are unchanged, but this->output might be //! updated to reflect the replacement tensor. - void translateToMmaOp(); + MmaOp* translateToMmaOp(); + + //! Given an IdModel, map groups of IterDomains to dimension roles + //! (MatmulDomain). Note that ValGroup is a shared_ptr to a + //! VectorOfUniqueEntries. We copy these as keys so that the returned + //! object can safely outlive id_model. + std::unordered_map getDimRoles( + IdModel& id_model) const; }; //! Traverse the fusion to find supported matmul patterns @@ -289,11 +298,18 @@ using DependenciesMap = std::map; //! Ampere) the only supported transposition is TN which means that mma //! instruction first input is transposed, the second input is non-transposed. NVF_API MatmulProblemLayoutOpt -getMmaLayout(Fusion* fusion, const MatmulPattern& pattern); +getProblemLayout(Fusion* fusion, const MatmulPattern& pattern); //! This overloaded version is just a wrapper on the above function, where //! the MatmulPattern is extracted from the fusion. -NVF_API MatmulProblemLayoutOpt getMmaLayout(Fusion* fusion); +NVF_API MatmulProblemLayoutOpt getProblemLayout(Fusion* fusion); + +//! Determine the problem layout based on allocation domain of inputs. This is +//! called by the above overloads. +NVF_API MatmulProblemLayoutOpt getProblemLayout( + const IdModel& id_model, + const std::unordered_map& group_to_domain, + const RolesMap& roles_map); //! Returns wrapped collection of IterDomains that can be used to get //! problem shape with runtime info. @@ -309,6 +325,10 @@ ProblemIterDomainsOpt getProblemIterDomains(const MatmulPattern& pattern); //! An error message is stored in retruned object if valid data cannot //! be gathered. RolesMapOpt getTensorsRoles(Fusion* fusion, const MatmulPattern& pattern); +RolesMapOpt getTensorsRoles( + Fusion* fusion, + const IdModel& id_model, + const std::unordered_map& group_to_domain); RolesMapOpt getTensorsRoles(Fusion* fusion); //! Return pair of whether use shared memory epilogue or not and whether to diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index c3f7e22b670..d24a3dee59a 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -264,53 +264,74 @@ TEST_F(CombineMulSumAsMmaTest, UseMatmulScheduler) { } TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerMatmulNode) { - // Keep multiples of 8 to keep vectorizable. - int M = 504, N = 136, K = 248; - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeContigTensor(2, DataType::Half); - auto tv1 = makeContigTensor(2, DataType::Half); + const auto run = [&](bool expect_aten_eval) { + int M = 504, N = 136, K = 248; + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); - fusion->addInput(tv0); - fusion->addInput(tv1); - auto tv2 = matmul(tv0, tv1); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); - fusion->addOutput(tv2); + fusion->addInput(tv0); + fusion->addInput(tv1); + auto tv2 = matmul(tv0, tv1); - ASSERT_TRUE(ir_utils::getOpsOfType(fusion.get()).empty()); + fusion->addOutput(tv2); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({K, N}, options); - auto tref = at::matmul(t0, t1); + ASSERT_TRUE(ir_utils::getOpsOfType(fusion.get()).empty()); - FusionExecutorCache executor_cache(std::move(fusion)); - auto outputs = executor_cache.runFusionWithInputs({t0, t1}); + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({K, N}, options); + auto tref = at::matmul(t0, t1); - if (!isOptionDisabled(DisableOption::MatmulExprEval)) { - // Ensure there's a mma op. - // If there's no mma op present, then stop the test. - ASSERT_FALSE(ir_utils::getOpsOfType( - executor_cache.getMostRecentKernelRuntime() - ->executors() - .at(0) - .kernel()) - .empty()); - // Ensure that the matmul scheduler ran. - EXPECT_TRUE( - dynamic_cast( - executor_cache.getMostRecentKernelRuntime() - ->schedulerHeuristics() - ->heuristicsList() - .at(0) - .get()) != nullptr); + FusionExecutorCache executor_cache(std::move(fusion)); + auto outputs = executor_cache.runFusionWithInputs({t0, t1}); EXPECT_FALSE(executor_cache.getMostRecentKernelRuntime()->isSegmented()); - } - testValidate( - executor_cache.fusion(), outputs, {t0, t1}, {tref}, __LINE__, __FILE__); + if (expect_aten_eval) { + // Ensure that the matmul scheduler ran. + EXPECT_EQ( + executor_cache.getMostRecentKernelRuntime() + ->schedulerHeuristics() + ->heuristicsList() + .front() + ->heuristic(), + ScheduleHeuristic::ExprEval); + } else { + // Ensure that the matmul scheduler ran. + EXPECT_EQ( + executor_cache.getMostRecentKernelRuntime() + ->schedulerHeuristics() + ->heuristicsList() + .front() + ->heuristic(), + ScheduleHeuristic::Matmul); + // Ensure there's a mma op. + // If there's no mma op present, then stop the test. + ASSERT_FALSE(ir_utils::getOpsOfType( + executor_cache.getMostRecentKernelRuntime() + ->executors() + .at(0) + .kernel()) + .empty()); + } + + testValidate( + executor_cache.fusion(), outputs, {t0, t1}, {tref}, __LINE__, __FILE__); + }; + // Run the test with and without matmul_expr_eval + { + DisableOptionsGuard dog; + DisableOptionsGuard::getCurOptions().unset(DisableOption::MatmulExprEval); + run(/*expect_aten_eval=*/true); + } + { + DisableOptionsGuard dog; + DisableOptionsGuard::getCurOptions().set(DisableOption::MatmulExprEval); + run(/*expect_aten_eval=*/false); + } } } // namespace nvfuser diff --git a/tests/cpp/test_gpu_tensorcore.cpp b/tests/cpp/test_gpu_tensorcore.cpp index 44907f102c6..24dff7fe504 100644 --- a/tests/cpp/test_gpu_tensorcore.cpp +++ b/tests/cpp/test_gpu_tensorcore.cpp @@ -3160,7 +3160,7 @@ TEST_F(GPUTTensorCoreTest, MisalignedVectorization) { fusion->addOutput(tv2); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index ea49711aa67..63d15e21e3f 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -188,7 +188,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueBias) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -289,7 +289,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueRelu) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -396,7 +396,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueBiasRelu) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -499,7 +499,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueReluAux) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -614,7 +614,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueBiasReluAux) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -717,7 +717,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueGelu) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -813,7 +813,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueGeluAux) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -922,7 +922,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueBiasGelu) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1039,7 +1039,7 @@ TEST_P(PrecisionParametrizedTest, EpilogueBiasGeluAux) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1163,7 +1163,7 @@ TEST_F(MatmulSchedulerTest, BasicMatmulStrictCheckTT) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must be always TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1226,7 +1226,7 @@ TEST_F(MatmulSchedulerTest, BasicMatmulRelaxedCheck) { .value(), "the MmaOp layout of Ampere MMA must be always TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1285,7 +1285,7 @@ TEST_F(MatmulSchedulerTest, BasicMatmulInputShuffledTT) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must be always TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1342,7 +1342,7 @@ TEST_F(MatmulSchedulerTest, EpilogueOutputCast) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1405,7 +1405,7 @@ TEST_F(MatmulSchedulerTest, EpilogueAlpha) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1470,7 +1470,7 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaOutputCast) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1544,7 +1544,7 @@ TEST_F(MatmulSchedulerTest, EpilogueBeta) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1626,7 +1626,7 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaBeta) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1714,7 +1714,7 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaBetaGeluOutputCast) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1806,7 +1806,7 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaBetaBias) { ir_utils::getOpsOfType(fusion.get()).front()->layout().value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1889,7 +1889,7 @@ TEST_F(MatmulSchedulerTest, StridedBatch) { .value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -1975,7 +1975,7 @@ TEST_F(MatmulSchedulerTest, StridedBatchEpilogueAlphaBeta) { .value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -2074,7 +2074,7 @@ TEST_F(MatmulSchedulerTest, StridedBatchEpilogueAlphaSingleBeta) { .value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -2161,7 +2161,7 @@ TEST_F(MatmulSchedulerTest, StridedBatchEpilogueBias) { .value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -2241,7 +2241,7 @@ TEST_F(MatmulSchedulerTest, StridedBatchEpilogueSingleBias) { .value(), "the MmaOp layout of Ampere MMA must always be TN"); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -2310,7 +2310,7 @@ TEST_F(MatmulSchedulerTest, MisalignedVectorization) { fusion->addOutput(tv2); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); @@ -2501,7 +2501,7 @@ TEST_F(MatmulSchedulerTest, StridedInputs) { fusion->addOutput(tv2); - const auto fusion_layout = mma_utils::getMmaLayout(fusion.get()); + const auto fusion_layout = mma_utils::getProblemLayout(fusion.get()); NVF_CHECK( fusion_layout.isValid(), "failed to get decide matmul layout through fusion definition"); From 04522b82a8e0daa88b49608ade08e411f8c53653 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 15 May 2024 00:33:01 +0000 Subject: [PATCH 09/52] Add IterType::Reduction domain for K dim in output --- csrc/ops/composite.cpp | 18 +++++++++++------- csrc/ops/utils.cpp | 19 ++++++++++++------- csrc/root_domain_map.cpp | 5 +++-- tests/cpp/test_matmul_aten_evaluation.cpp | 20 +++++++++++++++----- 4 files changed, 41 insertions(+), 21 deletions(-) diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index a151dbbb566..6180b7818fb 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -276,13 +276,13 @@ static TensorView* newForMatmul(TensorView* tv_a, TensorView* tv_b) { auto ndims_b = orig_domain_b.size(); // Matmul output size is same as the higher dimensional input size if both A/B - // > 1D. - auto ndims_out = std::max(ndims_a, ndims_b); + // > 1D, but with 1 additional IterType::Reduction axis rK. + auto ndims_out = std::max(ndims_a, ndims_b) + 1; if (std::min(ndims_a, ndims_b) == 1) { - // If one of the inputs is 1D, the output size is 1 less than the higher - // dimensional input size, since either M/N axis will be missing in the - // output. For eg: [M, K] x [K] -> [M] - ndims_out = std::max(ndims_a, ndims_b) - 1; + // If one of the inputs is 1D, the output size is the same as the higher + // dimensional input size, since we will include a Reduction axis for K in + // the output. For example: [iM, iK] x [iK] -> [iM, rK] + ndims_out = std::max(ndims_a, ndims_b); } std::vector out_domain(ndims_out, nullptr); @@ -292,7 +292,7 @@ static TensorView* newForMatmul(TensorView* tv_a, TensorView* tv_b) { const std::vector& mapping_b = ops::mapMatmulOpIterDomains( orig_domain_b, MatmulRole::INPUT_B, ndims_out); - for (auto idx : c10::irange(ndims_out)) { + for (auto idx : c10::irange(ndims_out - 1)) { std::vector input_ids; input_ids.reserve(2); if (mapping_a[idx] != nullptr) { @@ -304,6 +304,10 @@ static TensorView* newForMatmul(TensorView* tv_a, TensorView* tv_b) { out_domain[idx] = ops::newOutputIterDomain(input_ids); } + out_domain[ndims_out - 1] = IterDomainBuilder(mapping_a.back()) + .iter_type(IterType::Reduction) + .build(); + TensorDomain* td = IrBuilder::create( out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)); diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index 50d57c9179d..d26c33062a2 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -190,29 +190,34 @@ std::vector mapMatmulOpIterDomains( std::vector mapping(out_size, nullptr); auto inp_size = (int64_t)input_domain.size(); + // Input A to matmul: {*, M, K} + // Input B to matmul: {*, K, N} + auto kpos = input_role == MatmulRole::INPUT_A ? inp_size - 1 : inp_size - 2; + if (inp_size == 1) { // Only reduction axis {K} + mapping[out_size - 1] = input_domain[0]; return mapping; } - // Input A to matmul: {*, M, K} - // Input B to matmul: {*, K, N} - auto kpos = input_role == MatmulRole::INPUT_A ? inp_size - 1 : inp_size - 2; - // If A/B is 1D, out_size < inp_size. - for (auto out_idx = (int64_t)out_size - 1, inp_idx = inp_size - 1; + // Last position is a reduction dimension mapping to K + mapping[out_size - 1] = input_domain.at(kpos); + + for (auto out_idx = (int64_t)out_size - 2, inp_idx = inp_size - 1; inp_idx >= 0; inp_idx--) { if (inp_idx != kpos) { mapping[out_idx] = input_domain[inp_idx]; out_idx--; } - // Consider [M, K] x [K]: [M]. Since out_size < inp_size, + // Consider [iM, iK] x [iK]: [iM, rK]. Since out_size < inp_size, // input A and output are not right-aligned. In this case, the output index // pointer should not be moved when the reduction axis is encountered. - else if (inp_size <= (int64_t)out_size) { + else if (inp_size <= (int64_t)out_size - 1) { out_idx--; } } + return mapping; } diff --git a/csrc/root_domain_map.cpp b/csrc/root_domain_map.cpp index 4e76e255a55..0f6e01ded31 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/root_domain_map.cpp @@ -168,8 +168,7 @@ std::unordered_map PairwiseRootDomainMap::map( // For MatmulOp, use the corresponding mapped input iterdomains. if (MatmulOp* op = dynamic_cast(consumer_tv_->definition())) { // Check if the producer is lhs/rhs input - MatmulRole input_role = - producer->sameAs(op->inA()->as()->domain()) + MatmulRole input_role = producer_tv_->sameAs(op->inA()) ? MatmulRole::INPUT_A : MatmulRole::INPUT_B; auto out_size = consumer_root.size(); @@ -185,6 +184,8 @@ std::unordered_map PairwiseRootDomainMap::map( const std::vector& aligned_producer_ids = ops::mapMatmulOpIterDomains(producer_root, input_role, out_size); + NVF_ERROR(aligned_producer_ids.size() == consumer_root.size()); + for (auto inx : c10::irange(out_size)) { IterDomain* producer_id = aligned_producer_ids.at(inx); IterDomain* consumer_id = consumer_root.at(inx); diff --git a/tests/cpp/test_matmul_aten_evaluation.cpp b/tests/cpp/test_matmul_aten_evaluation.cpp index 633d45f0d01..24da5df2e8a 100644 --- a/tests/cpp/test_matmul_aten_evaluation.cpp +++ b/tests/cpp/test_matmul_aten_evaluation.cpp @@ -430,15 +430,25 @@ void checkMatmulOpIdMapping( }; if (A->nDims() == 2 && B->nDims() == 2) { - // [M, K] @ [K, N] = [M, N] - ASSERT_EQ(output->nDims(), 2); + // [iM, iK] @ [iK, iN] = [iM, iN, rK] + ASSERT_EQ(output->nDims(), 3); EXPECT_TRUE(checkMapped(A->axis(0), output->axis(0))); // M EXPECT_TRUE(checkMapped(B->axis(1), output->axis(1))); // N - // EXPECT_TRUE(checkMapped(A->axis(1), B->axis(0))); // K + EXPECT_TRUE(checkMapped(A->axis(1), B->axis(0))); // K + EXPECT_TRUE(checkMapped(A->axis(1), output->axis(2))); // K + } else if (A->nDims() == 2 && B->nDims() == 1) { + // [iM, iK] @ [iK] = [iM, rK] + ASSERT_EQ(output->nDims(), 2); + EXPECT_TRUE(checkMapped(A->axis(0), output->axis(0))); // M + EXPECT_TRUE(checkMapped(B->axis(0), output->axis(1))); // N + EXPECT_TRUE(checkMapped(A->axis(1), B->axis(0))); // K + EXPECT_TRUE(checkMapped(A->axis(1), output->axis(1))); // K } else if (A->nDims() == 1 && B->nDims() == 1) { - // [M, K] @ [K, N] = [M, N] + // [K] @ [K] = [] + // Note there is no IterType::Reduction dim in this case because we + // translate to a mul+sum+cast EXPECT_EQ(output->nDims(), 0); - // EXPECT_TRUE(checkMapped(A->axis(0), B->axis(0))); // K + EXPECT_TRUE(checkMapped(A->axis(0), B->axis(0))); // K } else { std::cout << "Unhandled set of input dimensions" << std::endl; // EXPECT_TRUE(false); From 7b042865ee095bf6fcd7e819f6336ef5d61161b2 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 15 May 2024 12:29:55 +0000 Subject: [PATCH 10/52] Translate bcast K as simple product --- csrc/ops/composite.cpp | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index 6180b7818fb..3c14cca87eb 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -335,12 +335,44 @@ TensorView* matmul(TensorView* tv_a, TensorView* tv_b) { " and ", tv_b->dtype()); + // Check for K=1 i.e. reduction of broadcast. In these cases we don't need a + // matmul so we translate it to a multiplication+cast + auto b_k_axis = tv_b->nDims() == 1 ? -1 : -2; + NVF_CHECK( + tv_a->axis(-1)->isBroadcast() == tv_b->axis(b_k_axis)->isBroadcast(), + "K dimension must be broadcast in both operands or none"); + if (tv_a->axis(-1)->isBroadcast()) { + TensorView* float_result = nullptr; + if (tv_a->nDims() == 1 && tv_b->nDims() == 1) { + // [1] @ [1] = [] + float_result = + mul(squeeze(tv_a, std::vector{0}), + squeeze(tv_b, std::vector{0})); + } else if (tv_a->nDims() == 1) { + // [1] @ [..., 1, N] = [..., N] + float_result = mul(tv_a, squeeze(tv_b, std::vector{-2})); + } else if (tv_b->nDims() == 1) { + // [..., M, 1] @ [1] = [..., M] + float_result = mul(squeeze(tv_a, std::vector{-1}), tv_b); + } else { + float_result = mul(tv_a, tv_b); + } + return maybeCastOp(tv_a->dtype(), float_result); + } + if (tv_a->nDims() == 1 && tv_b->nDims() == 1) { // Return the dot product instead of creating the MatmulOp. // Cast back the output if needed since torch.matmul maintains input dtype. return maybeCastOp(tv_a->dtype(), sum(mul(tv_a, tv_b), {0})); } + if (tv_b->nDims() > 1 && tv_b->axis(-2)->isBroadcast()) { + NVF_ERROR( + tv_a->axis(-1)->isBroadcast(), + "Mismatched Broadcast in K dimension of operands"); + // K dimension is broadcast so this is an outer product + } + // For all other cases, create a new MatmulOp TensorView* out = newForMatmul(tv_a, tv_b); IrBuilder::create(out, tv_a, tv_b); From dcaf349a89a321ee01a15e1e6cf42e2f09937ff0 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 15 May 2024 12:30:15 +0000 Subject: [PATCH 11/52] Finish testing all combinations of mappings Tests pass! --- tests/cpp/test_matmul_aten_evaluation.cpp | 96 ++++++++++++++++++----- 1 file changed, 76 insertions(+), 20 deletions(-) diff --git a/tests/cpp/test_matmul_aten_evaluation.cpp b/tests/cpp/test_matmul_aten_evaluation.cpp index 24da5df2e8a..94da409b504 100644 --- a/tests/cpp/test_matmul_aten_evaluation.cpp +++ b/tests/cpp/test_matmul_aten_evaluation.cpp @@ -429,29 +429,74 @@ void checkMatmulOpIdMapping( return gx.get() == gy.get(); }; - if (A->nDims() == 2 && B->nDims() == 2) { - // [iM, iK] @ [iK, iN] = [iM, iN, rK] - ASSERT_EQ(output->nDims(), 3); - EXPECT_TRUE(checkMapped(A->axis(0), output->axis(0))); // M - EXPECT_TRUE(checkMapped(B->axis(1), output->axis(1))); // N - EXPECT_TRUE(checkMapped(A->axis(1), B->axis(0))); // K - EXPECT_TRUE(checkMapped(A->axis(1), output->axis(2))); // K - } else if (A->nDims() == 2 && B->nDims() == 1) { - // [iM, iK] @ [iK] = [iM, rK] - ASSERT_EQ(output->nDims(), 2); - EXPECT_TRUE(checkMapped(A->axis(0), output->axis(0))); // M - EXPECT_TRUE(checkMapped(B->axis(0), output->axis(1))); // N - EXPECT_TRUE(checkMapped(A->axis(1), B->axis(0))); // K - EXPECT_TRUE(checkMapped(A->axis(1), output->axis(1))); // K - } else if (A->nDims() == 1 && B->nDims() == 1) { + // If K is Broadcast then we will not have a reduction dim + bool k_bcast = A->axis(-1)->isBroadcast(); + int64_t red_dims = k_bcast ? 0 : 1; + + if (A->nDims() == 1 && B->nDims() == 1) { // [K] @ [K] = [] - // Note there is no IterType::Reduction dim in this case because we + // Note there is no IterType::Reduction dim ever in this case because we // translate to a mul+sum+cast EXPECT_EQ(output->nDims(), 0); - EXPECT_TRUE(checkMapped(A->axis(0), B->axis(0))); // K + // When K is Broadcast, we squeeze then multiply then cast instead + if (!k_bcast) { + EXPECT_TRUE(checkMapped(A->axis(0), B->axis(0))); // K + } + } else if (A->nDims() > 1 && B->nDims() == 1) { + // [..., iM, iK] @ [iK] = [..., iM, rK] + ASSERT_EQ(output->nDims(), A->nDims() + red_dims - 1); + EXPECT_TRUE(checkMapped(A->axis(-2), output->axis(-1 - red_dims))); // M + if (!k_bcast) { + EXPECT_TRUE(checkMapped(A->axis(-1), B->axis(0))); // K + EXPECT_TRUE(checkMapped(A->axis(-1), output->axis(-1))); // K + } + // Check that batch dims are mapped + for (int64_t i : c10::irange(output->nDims() - red_dims - 1)) { + if (!A->axis(i)->isBroadcast()) { + EXPECT_TRUE(checkMapped(A->axis(i), output->axis(i))); + } + } + } else if (A->nDims() == 1 && B->nDims() > 1) { + // [iK] @ [..., iK, iN] = [..., iN, rK] + ASSERT_EQ(output->nDims(), B->nDims() + red_dims - 1); + EXPECT_TRUE(checkMapped(B->axis(-1), output->axis(-1 - red_dims))); // N + if (!k_bcast) { + EXPECT_TRUE(checkMapped(A->axis(0), B->axis(-2))); // K + EXPECT_TRUE(checkMapped(A->axis(0), output->axis(-1))); // K + } + // Check that batch dims are mapped + for (int64_t i : c10::irange(output->nDims() - red_dims - 1)) { + if (!B->axis(i)->isBroadcast()) { + EXPECT_TRUE(checkMapped(B->axis(i), output->axis(i))); + } + } + } else if (A->nDims() > 1 && B->nDims() > 1) { + // [..., iM, iK] @ [..., iK, iN] = [..., iM, iN, rK] + ASSERT_EQ(output->nDims(), std::max(A->nDims(), B->nDims()) + red_dims); + EXPECT_TRUE(checkMapped(A->axis(-2), output->axis(-2 - red_dims))); // M + EXPECT_TRUE(checkMapped(B->axis(-1), output->axis(-1 - red_dims))); // N + if (!k_bcast) { + EXPECT_TRUE(checkMapped(A->axis(-1), B->axis(-2))); // K + EXPECT_TRUE(checkMapped(A->axis(-1), output->axis(-1))); // K + } + // Check that batch dims are mapped + // Note that A and B can have different dimensions, so here we count + // backwards from the innermost batch dimension. Then we check that the axis + // exists (is not negative) and is not Broadcast before checking mapping. + for (int64_t i : c10::irange(output->nDims() - red_dims - 2)) { + int64_t i_a = A->nDims() - 3 - i; + int64_t i_b = B->nDims() - 3 - i; + int64_t i_out = output->nDims() - red_dims - 3 - i; + if (i_a >= 0 && !A->axis(i_a)->isBroadcast()) { + EXPECT_TRUE(checkMapped(A->axis(i_a), output->axis(i_out))); + } + if (i_b >= 0 && !B->axis(i_b)->isBroadcast()) { + EXPECT_TRUE(checkMapped(B->axis(i_b), output->axis(i_out))); + } + } } else { - std::cout << "Unhandled set of input dimensions" << std::endl; - // EXPECT_TRUE(false); + std::cerr << "Unhandled set of input dimensions" << std::endl; + EXPECT_TRUE(false); } } @@ -530,6 +575,17 @@ INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P( ReductionAxisIsOne, ATenNodesParametrizedTest, - testing::Values(std::make_tuple(Sizes({m, 1}), Sizes({1, n})))); + testing::Combine( + testing::Values( + Sizes({1}), + Sizes({m, 1}), + Sizes({1, 1}), + Sizes({b, m, 1}), + Sizes({b, 1, m, 1})), + testing::Values( + Sizes({1}), + Sizes({1, n}), + Sizes({1, 1}), + Sizes({b, 1, n})))); } // namespace nvfuser From a7a9b56efb5655436b50e61e147e8a0d82d6a755 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 15 May 2024 13:52:23 +0000 Subject: [PATCH 12/52] Remove prints and fix up isMatmulFusionDefinitionSupported --- csrc/root_domain_map.cpp | 2 -- csrc/scheduler/matmul_utils.cpp | 48 ++++++++++++++++++--------------- csrc/scheduler/mma_utils.cpp | 8 +----- 3 files changed, 28 insertions(+), 30 deletions(-) diff --git a/csrc/root_domain_map.cpp b/csrc/root_domain_map.cpp index 0e77f92c208..0f6e01ded31 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/root_domain_map.cpp @@ -161,8 +161,6 @@ std::unordered_map PairwiseRootDomainMap::map( } if (root_dims_to_map.find(map_key_id) != root_dims_to_map.end()) { - std::cout << "MAP " << map_key_id->toString() << " -> " - << map_value_id->toString() << std::endl; dom_map.insert(std::make_pair(map_key_id, map_value_id)); } }; diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 814ce2187e2..9cb60f8d1ff 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -183,7 +183,9 @@ ProblemShape getProblemShape( std::string isMatmulFusionDefinitionSupported( Fusion* fusion, - const mma_utils::MatmulPattern& pattern) { + const mma_utils::MatmulPattern& pattern, + const mma_utils::RolesMap& roles_map, + const std::unordered_map& id_roles) { const auto& fusion_inputs = fusion->inputs(); const auto& fusion_outputs = fusion->outputs(); std::vector mma_inputs = {pattern.A, pattern.B}; @@ -195,18 +197,20 @@ std::string isMatmulFusionDefinitionSupported( ir_utils::filterByType(fusion_outputs).vector(); constexpr size_t minimal_number_of_inputs = 2; - MmaOpUtils::MmaOpDetails mma_details = - MmaOpUtils::getMmaOpDetails(pattern.output, pattern.A, pattern.B); // Quick checks - MmaOp { // Check if MmaOp represents gemm (requires M/N/K == 1, B == 0) // or bgemm (requires M/N/K/B == 1) + std::array num_axes{}; + for (const auto& [g, dom] : id_roles) { + num_axes[(size_t)dom]++; + } constexpr size_t expected_axes_numbers = 1; - if (mma_details.m_axes.size() != expected_axes_numbers || - mma_details.n_axes.size() != expected_axes_numbers || - mma_details.k_axes.size() != expected_axes_numbers || - mma_details.batch_axes.size() > expected_axes_numbers) { + if (num_axes[(size_t)MatmulDomain::M] != expected_axes_numbers || + num_axes[(size_t)MatmulDomain::N] != expected_axes_numbers || + num_axes[(size_t)MatmulDomain::K] != expected_axes_numbers || + num_axes[(size_t)MatmulDomain::Batch] > expected_axes_numbers) { return "MmaOp has unsupported number of one of M/N/K/Batch axes"; } @@ -225,12 +229,6 @@ std::string isMatmulFusionDefinitionSupported( // Fusion topology check { - const auto& roles_map_opt = mma_utils::getTensorsRoles(fusion, pattern); - if (!roles_map_opt.isValid()) { - return roles_map_opt.getErrorMsg(); - } - - const auto& roles_map = roles_map_opt.getData(); auto entry = roles_map.find(MatmulRole::INPUT_A); std::set tvs_with_roles; @@ -461,19 +459,27 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { return "Only a single matmul pattern can currently be fused"; } + // Prepare an IdModel which will be reused to check remaining conditions + IdModel id_model(fusion); + const auto id_roles = patterns.front().getDimRoles(id_model); + const mma_utils::RolesMapOpt roles_map_opt = + mma_utils::getTensorsRoles(fusion, id_model, id_roles); + if (!roles_map_opt.isValid()) { + return {roles_map_opt.getErrorMsg()}; + } + mma_utils::RolesMap roles_map = roles_map_opt.getData(); + // #2 - { - const auto input_layout_opt = - mma_utils::getProblemLayout(fusion, patterns.front()); - if (!input_layout_opt.isValid()) { - return input_layout_opt.getErrorMsg(); - } + const auto input_layout_opt = + mma_utils::getProblemLayout(id_model, id_roles, roles_map); + if (!input_layout_opt.isValid()) { + return input_layout_opt.getErrorMsg(); } // #3 { - auto support_status = - isMatmulFusionDefinitionSupported(fusion, patterns.front()); + auto support_status = isMatmulFusionDefinitionSupported( + fusion, patterns.front(), roles_map, id_roles); if (!support_status.empty()) { return support_status; } diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 2fa292ba7c6..9365ad70117 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1365,7 +1365,7 @@ RolesMapOpt getTensorsRoles( } std::vector storage; - for (TensorView* tv : mma_input_candidates) { + for (TensorView* tv : mma_output_candidates) { bool has_m = false, has_n = false, has_k = false, has_unmapped = false; for (IterDomain* id : TensorDomain::noReductions(tv->getMaybeRFactorDomain())) { @@ -1710,22 +1710,16 @@ std::unordered_map MatmulPattern::getDimRoles( std::unordered_map present_flags; const auto recordPresence = [&exact_graph, &present_flags]( TensorView* tv, size_t tensor_num) { - std::cout << "recordPresence " << tv->toString() << std::endl; for (IterDomain* id : tv->getMaybeRFactorDomain()) { if (id->isReduction() || id->isBroadcast()) { // ignore reductions and broadcasts since they don't exact map to // problem dims continue; } - std::cout << " id=" << id->toString() << std::endl; const ValGroup& g = exact_graph.toGroup(id); - std::cout << " updating bitset " << present_flags[g] << " to "; present_flags[g].set(tensor_num); - std::cout << present_flags[g] << std::endl; } }; - A->fusion()->printMath(); - std::cout << id_model.toString() << std::endl; recordPresence(A, 0); recordPresence(B, 1); recordPresence(output, 2); From d6487b1b47add7f833fad441858adc6645067a82 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 15 May 2024 14:42:15 +0000 Subject: [PATCH 13/52] Fix getProblemLayout --- csrc/scheduler/mma_utils.cpp | 8 ++++++-- tests/cpp/test_combine_mul_sum.cpp | 20 +++++++++++--------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 9365ad70117..8918f5f0d0d 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1147,7 +1147,7 @@ MatmulProblemLayoutOpt getProblemLayout( return {"No tensor found in role"}; } MatmulDomain dom = group_inner_dom.value(); - return MatmulDomainOpt(std::move(dom)); + return MatmulDomainOpt(dom); }; MatmulDomainOpt a_inner_dom = innerDomain(MatmulRole::INPUT_A); if (!a_inner_dom.isValid()) { @@ -1334,10 +1334,14 @@ RolesMapOpt getTensorsRoles( bool has_m = false, has_n = false, has_k = false, has_unmapped = false; for (IterDomain* id : TensorDomain::noReductions(tv->getMaybeRFactorDomain())) { + if (id->isBroadcast()) { + // Broadcast domains won't exact map to concrete domains so skip them + continue; + } const ValGroup& g = exact_graph.toGroup(id); auto it = group_to_domain.find(g); if (it == group_to_domain.end()) { - // tv has an unmapped dimension + // tv has an unmapped non-broadcast and non-reduction dimension has_unmapped = true; continue; } diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index d24a3dee59a..0172524c6fd 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -67,12 +67,12 @@ class CombineMulSumAsMmaTest : public NVFuserTest { DisableOptionsGuard opt_guard_; }; -void performSubstitution(Fusion* fusion, bool should_not_replace = false) { +void performSubstitution(Fusion* fusion, bool should_not_find = false) { EXPECT_TRUE(ir_utils::getOpsOfType(fusion).empty()); std::vector patterns = mma_utils::findMatmulPatterns(fusion); - if (should_not_replace) { + if (should_not_find) { EXPECT_TRUE(patterns.empty()); return; } @@ -108,8 +108,9 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Pass) { } } -// This test checks that the combiner does not incorrectly -// replace this mul-sum pair, and the mul is not fed by broadcasts ops. +// This test checks that the pattern matcher does not incorrectly identify +// this mul-sum pair, as the mul is not fed by broadcasts ops; i.e. it is +// not a matmul. TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Fail1) { Fusion fusion; FusionGuard fg(&fusion); @@ -123,12 +124,13 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Fail1) { auto tv3 = sum(tv2, {-1}); fusion.addOutput(tv3); - performSubstitution(&fusion, /*should_not_replace=*/true); + performSubstitution(&fusion, /*should_not_find=*/true); } -// This test checks to see that the mul-sum combiner does not -// combine a mul-sum which does not have appropriate broadcasts. -TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Fail2) { +// This fusion has more than one broadcasted dimension for each operand, so it +// is currently rejected isMatmulFusionDefinitionSupported. Still, it is a valid +// MatmulPattern so we check that it is found. +TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_MultipleBroadcasts) { // Assumes layout is kAllSupportedMmaLayout::NT; Fusion fusion; FusionGuard fg(&fusion); @@ -155,7 +157,7 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Fail2) { auto tv3 = sum(tv2, {-1}); fusion.addOutput(tv3); - performSubstitution(&fusion, /*should_not_replace=*/true); + performSubstitution(&fusion, /*should_not_find=*/false); } // As a sanity check we test that after replacing a mul-sum From fd0ceb999a7919b28dfaf485d584ffa4196e385d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 15 May 2024 15:06:34 +0000 Subject: [PATCH 14/52] Add EnableOption::FuseMatmul and check sm arch --- csrc/options.cpp | 1 + csrc/options.h | 1 + csrc/scheduler/matmul_utils.cpp | 35 +++++++++++++++++++++++++----- csrc/scheduler/mma_utils.cpp | 2 +- tests/cpp/test_combine_mul_sum.cpp | 6 ++++- 5 files changed, 38 insertions(+), 7 deletions(-) diff --git a/csrc/options.cpp b/csrc/options.cpp index 66f4224e493..e66ada47775 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -154,6 +154,7 @@ template <> std::unordered_map> Options< EnableOption>::getOptionsFromEnv() { const std::unordered_map available_options = { + {"fuse_matmul", EnableOption::FuseMatmul}, {"id_model", EnableOption::IdModel}, {"kernel_db", EnableOption::KernelDb}, {"kernel_profile", EnableOption::KernelProfile}, diff --git a/csrc/options.h b/csrc/options.h index 71178f83e54..8b7a72b52e0 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -91,6 +91,7 @@ enum class DebugDumpOption { //! These can be set through the `NVFUSER_ENABLE` environment variable //! enum class EnableOption { + FuseMatmul, //! Enable automatic fusion of matmul and linear ops IdModel, //! Enable IdModel KernelDb, //! Enable Kernel Database KernelProfile, //! Enable intra-kernel performance profiling diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 9cb60f8d1ff..846ccbbab4a 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -440,14 +440,27 @@ std::string getMatmulRunTimeRejectReason( // by the analysis. std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { // The plan: - // 1. Check if there is exactly one MmaOp or suitable mul sum pair + // 1. Check supported device + // 2. Check if there is exactly one MmaOp or suitable mul sum pair // defined in the fusion. - // 2. Check if inputs to the mma op or mul sum pair match any of + // 3. Check if matmul scheduler is enabled + // 4. Check if inputs to the mma op or mul sum pair match any of // supported inputs layout - // 3. Check if fusion represents expressions that are recognized by matmul + // 5. Check if fusion represents expressions that are recognized by matmul // scheduler. // #1 + // Use a dummy problem shape to determine whether this is a supported device. + { + const auto device_prop = at::cuda::getCurrentDeviceProperties(); + const auto mma_op = getMmaOp( + device_prop->major * 10 + device_prop->minor, {128, 128, 128, 1}); + if (!mma_op.has_value()) { + return "Unsupported device compute capability"; + } + } + + // #2 // Initializing the machinery to check if there's a Mul-Sum pair // can be replaced by a Mma Op. std::vector patterns = @@ -459,6 +472,18 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { return "Only a single matmul pattern can currently be fused"; } + // #1 + { + if (!isOptionEnabled(EnableOption::FuseMatmul)) { + // Check for MatmulOp or LinearOp. If found, then only fuse if option is + // specified + Expr* op = patterns.front().output->definition(); + if (op->isA() /* || op->isA()*/) { + return "Matmul fusion is disabled by default. Enable it using NVFUSER_ENABLE=fuse_matmul"; + } + } + } + // Prepare an IdModel which will be reused to check remaining conditions IdModel id_model(fusion); const auto id_roles = patterns.front().getDimRoles(id_model); @@ -469,14 +494,14 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { } mma_utils::RolesMap roles_map = roles_map_opt.getData(); - // #2 + // #4 const auto input_layout_opt = mma_utils::getProblemLayout(id_model, id_roles, roles_map); if (!input_layout_opt.isValid()) { return input_layout_opt.getErrorMsg(); } - // #3 + // #5 { auto support_status = isMatmulFusionDefinitionSupported( fusion, patterns.front(), roles_map, id_roles); diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 8918f5f0d0d..1d2b1af0a68 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1147,7 +1147,7 @@ MatmulProblemLayoutOpt getProblemLayout( return {"No tensor found in role"}; } MatmulDomain dom = group_inner_dom.value(); - return MatmulDomainOpt(dom); + return MatmulDomainOpt(std::move(dom)); }; MatmulDomainOpt a_inner_dom = innerDomain(MatmulRole::INPUT_A); if (!a_inner_dom.isValid()) { diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index 0172524c6fd..10e2438ba85 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -323,15 +323,19 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerMatmulNode) { testValidate( executor_cache.fusion(), outputs, {t0, t1}, {tref}, __LINE__, __FILE__); }; - // Run the test with and without matmul_expr_eval + // Run the test with and without the matmul scheduler { DisableOptionsGuard dog; DisableOptionsGuard::getCurOptions().unset(DisableOption::MatmulExprEval); run(/*expect_aten_eval=*/true); } { + // Disable ExprEval scheduler, which takes precedence of Matmul scheduler DisableOptionsGuard dog; DisableOptionsGuard::getCurOptions().set(DisableOption::MatmulExprEval); + // Allow Matmul Scheduler to accept MatmulOp and LinearOp + EnableOptionsGuard eog; + EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMatmul); run(/*expect_aten_eval=*/false); } } From 6e3f0525f9f0c1ad1c14eecbdb119835ec18cc8d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 15 May 2024 15:24:10 +0000 Subject: [PATCH 15/52] Add TODO about skipping downcast roundtrip --- csrc/scheduler/mma_utils.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 1d2b1af0a68..e317417d6ec 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1670,6 +1670,9 @@ MmaOp* MatmulPattern::translateToMmaOp() { // Update operands to keep the pattern minimal A = Abcast; B = Bbcast; + // TODO: skip downcasting if the only uses of `output` are casts back to + // higher precision in order avoid the round trip cast in defining an + // epilogue that starts with MatmulOp. if (output->dtype() != fms->dtype()) { // Redefine output as cast of MmaOp->out() IrBuilder::create(UnaryOpType::Cast, output, fms); From 812f9d243f0f9ddeeecbd9d6bf1fbe588b9fa83e Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 15 May 2024 15:52:25 +0000 Subject: [PATCH 16/52] Remove some unused code --- csrc/scheduler/matmul_utils.cpp | 4 - csrc/scheduler/mma_utils.cpp | 139 -------------------------------- csrc/scheduler/mma_utils.h | 2 - 3 files changed, 145 deletions(-) diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 846ccbbab4a..7ffa92b83bc 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -542,10 +542,6 @@ std::shared_ptr getMatmulHeuristics( // Set kernel index mode params->cparams.index_type = runtime_info.getIndexType(); - if (!isOptionDisabled(DisableOption::MatmulExprEval)) { - return params; - } - // Check initial conditions std::vector patterns = mma_utils::findMatmulPatterns(fusion); diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index e317417d6ec..ffd7cf854d0 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1185,130 +1185,6 @@ MatmulProblemLayoutOpt getProblemLayout( return getProblemLayout(id_model, id_roles, roles_map_opt.getData()); } -RolesMapOpt getTensorsRoles(Fusion* fusion, const MatmulPattern& pattern) { - ComputeAtMap ca_map(fusion); - const auto mma_input_candidates = - ir_utils::filterByType(fusion->inputs()).vector(); - if (mma_input_candidates.empty()) { - return {"Failed to find any TV that is fusion input"}; - } - const auto mma_output_candidates = - ir_utils::filterByType(fusion->outputs()).vector(); - if (mma_output_candidates.empty()) { - return {"Failed to find any TV that is fusion output"}; - } - - const auto mma_output_domains = getProblemIterDomains(pattern); - if (!mma_output_domains.isValid()) { - return mma_output_domains.getErrorMsg(); - } - - const auto findInputRolesByDomains = [](const DependenciesMap& deps_map, - RolesMap& roles_map) { - for (const auto& entry : deps_map) { - const auto& domains = entry.second; - const auto begin = domains.begin(); - const auto end = domains.end(); - - bool has_m = (end != std::find(begin, end, MatmulDomain::M)); - bool has_n = (end != std::find(begin, end, MatmulDomain::N)); - bool has_k = (end != std::find(begin, end, MatmulDomain::K)); - - if (has_m && has_k && !has_n) { - roles_map[MatmulRole::INPUT_A].push_back(entry.first); - continue; - } - if (has_n && has_k && !has_m) { - roles_map[MatmulRole::INPUT_B].push_back(entry.first); - continue; - } - // Bias vectors are assigned to INPUT_C role - if (!has_k) { - roles_map[MatmulRole::INPUT_C].push_back(entry.first); - continue; - } - } - - for (auto& [role, tvs] : roles_map) { - // NOTE: sort input roles in descending order by uses() size, and - // if equal then by name() to ensure the stable ordering of tensor - // views in collections assigned to the supported roles - std::sort(tvs.begin(), tvs.end(), [](TensorView* a, TensorView* b) { - return (a->uses().size() == b->uses().size()) - ? (a->name() < b->name()) - : (a->uses().size() > b->uses().size()); - }); - } - }; - - const auto findOutputRolesByDomains = [](const DependenciesMap& deps_map, - RolesMap& roles_map) { - std::vector storage; - storage.reserve(deps_map.size()); - - for (const auto& entry : deps_map) { - const auto& domains = entry.second; - const auto begin = domains.begin(); - const auto end = domains.end(); - - bool has_m = (end != std::find(begin, end, MatmulDomain::M)); - bool has_n = (end != std::find(begin, end, MatmulDomain::N)); - - // NOTE: depending on fusion definition k domain may appear in the output: - // - for mma_output == fusion output k domain is present - // - for mma_output != fusion output (fusion with epilogue) k domain - // is not present - - // NOTE: the core fusion output tensors are the ones with m and n - // domains - if (has_m && has_n) { - storage.push_back(entry.first); - } - } - - // NOTE: sort output roles in descending order by uses() size, and - // if equal then by name() to ensure the stable ordering of tensor - // views in collections assigned to the supported roles - std::sort(storage.begin(), storage.end(), [](TensorView* a, TensorView* b) { - return (a->uses().size() == b->uses().size()) - ? (a->name() < b->name()) - : (a->uses().size() > b->uses().size()); - }); - - if (!storage.empty()) { - // NOTE: currently, we pick as a reference tensor one with `m` and `n` - // IterDomains and the most uses - auto pos = storage.begin(); - roles_map[MatmulRole::OUTPUT_D].push_back(*pos); - for (++pos; pos != storage.end(); ++pos) { - roles_map[MatmulRole::OUTPUT_AUX].push_back(*pos); - } - } - }; - - const auto domains_data = mma_output_domains.getData(); - const auto m = domains_data[(size_t)MatmulDomain::M]; - const auto n = domains_data[(size_t)MatmulDomain::N]; - const auto k = domains_data[(size_t)MatmulDomain::K]; - - DependenciesMap deps_map; - RolesMap roles_map; - - // Handle fusion input TensorView objects - resolveTvToMatmulDomainsMapping( - deps_map, mma_input_candidates, m, n, k, ca_map); - findInputRolesByDomains(deps_map, roles_map); - - deps_map.clear(); - - // Handle fusion output TensorView objects - resolveTvToMatmulDomainsMapping( - deps_map, mma_output_candidates, m, n, k, ca_map); - findOutputRolesByDomains(deps_map, roles_map); - - return roles_map; -} - RolesMapOpt getTensorsRoles( Fusion* fusion, const IdModel& id_model, @@ -1434,21 +1310,6 @@ RolesMapOpt getTensorsRoles( return roles_map; } -RolesMapOpt getTensorsRoles(Fusion* fusion) { - auto mma_exprs = ir_utils::getOpsOfType(fusion); - if (mma_exprs.size() != 1) { - std::stringstream ss; - ss << "Invalid number of MmaOp instances in fusion, expected 1, got " - << mma_exprs.size(); - return ss.str(); - } - return getTensorsRoles( - fusion, - {static_cast(mma_exprs.front()->inA()), - static_cast(mma_exprs.front()->inB()), - static_cast(mma_exprs.front()->out())}); -} - namespace { // Check the val (in) is the output of broadcast. diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index 841cd7549d0..937072a7f30 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -324,12 +324,10 @@ ProblemIterDomainsOpt getProblemIterDomains(const MatmulPattern& pattern); //! Returns wrapped collection of TensorView roles in fusion. //! An error message is stored in retruned object if valid data cannot //! be gathered. -RolesMapOpt getTensorsRoles(Fusion* fusion, const MatmulPattern& pattern); RolesMapOpt getTensorsRoles( Fusion* fusion, const IdModel& id_model, const std::unordered_map& group_to_domain); -RolesMapOpt getTensorsRoles(Fusion* fusion); //! Return pair of whether use shared memory epilogue or not and whether to //! reuse shared memory for the prologue at the expense of an additional block From 6d48f35953d906c67603241d69bfd8422d1e6a28 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 15 May 2024 16:16:24 +0000 Subject: [PATCH 17/52] clang-tidy --- csrc/scheduler/mma_utils.cpp | 44 ++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index ffd7cf854d0..5014f30f194 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1121,12 +1121,17 @@ MatmulProblemLayoutOpt getProblemLayout( // group_to_domain const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); - using MatmulDomainOpt = DataWrapperOpt; + // Note: using DataWrapperOpt would be preferable here. However, + // using DataWrapperOpt(std::move(dom)) leads to a clang-tidy + // warning because MatmulDomain is trivially movable. There is only a move + // constructor for DataWrapperOpt to prevent inadvertent copying. To avoid + // this complication I'm using a simple pair for the lambda's result type. + using InnerDomResult = std::pair; const auto innerDomain = [&roles_map, &group_to_domain, &exact_graph]( - MatmulRole role) -> MatmulDomainOpt { + MatmulRole role) -> InnerDomResult { const auto role_it = roles_map.find(role); if (role_it == roles_map.end()) { - return {"Could not find role in roles_map"}; + return {MatmulDomain::M, "Could not find role in roles_map"}; } std::optional group_inner_dom = std::nullopt; for (TensorView* tv : role_it->second) { @@ -1135,31 +1140,36 @@ MatmulProblemLayoutOpt getProblemLayout( const ValGroup& g = exact_graph.toGroup(inner_id); auto g_it = group_to_domain.find(g); if (g_it == group_to_domain.end()) { - return {"Inner domain of tensor was not mapped to a MatmulDomain"}; + return { + MatmulDomain::M, + "Inner domain of tensor was not mapped to a MatmulDomain"}; } if (!group_inner_dom.has_value()) { group_inner_dom = g_it->second; } else if (group_inner_dom.value() != g_it->second) { - return {"Group contains multiple inner dimension domains"}; + return { + MatmulDomain::M, "Group contains multiple inner dimension domains"}; } } if (!group_inner_dom.has_value()) { - return {"No tensor found in role"}; + return {MatmulDomain::M, "No tensor found in role"}; } - MatmulDomain dom = group_inner_dom.value(); - return MatmulDomainOpt(std::move(dom)); + return {group_inner_dom.value(), ""}; }; - MatmulDomainOpt a_inner_dom = innerDomain(MatmulRole::INPUT_A); - if (!a_inner_dom.isValid()) { - return a_inner_dom.getErrorMsg(); - } - MatmulDomainOpt b_inner_dom = innerDomain(MatmulRole::INPUT_B); - if (!b_inner_dom.isValid()) { - return b_inner_dom.getErrorMsg(); + + const InnerDomResult a_dom_res = innerDomain(MatmulRole::INPUT_A); + if (!a_dom_res.second.empty()) { + std::string err = a_dom_res.second; + return err; } + const bool kinner_a = a_dom_res.first == MatmulDomain::K; - bool kinner_a = a_inner_dom.getData() == MatmulDomain::K; - bool kinner_b = b_inner_dom.getData() == MatmulDomain::K; + const InnerDomResult b_dom_res = innerDomain(MatmulRole::INPUT_B); + if (!b_dom_res.second.empty()) { + std::string err = b_dom_res.second; + return err; + } + const bool kinner_b = b_dom_res.first == MatmulDomain::K; if (kinner_a && kinner_b) { return MmaLayout::TN; From ffc62dd1df4060f2825e4b5dde80e4669c56e09b Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 16 May 2024 16:23:09 +0000 Subject: [PATCH 18/52] Undo change to matmul --- csrc/ops/composite.cpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index d683e5f6f14..cdeeaecb624 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -366,13 +366,6 @@ TensorView* matmul(TensorView* tv_a, TensorView* tv_b) { return maybeCastOp(tv_a->dtype(), sum(mul(tv_a, tv_b), {0})); } - if (tv_b->nDims() > 1 && tv_b->axis(-2)->isBroadcast()) { - NVF_ERROR( - tv_a->axis(-1)->isBroadcast(), - "Mismatched Broadcast in K dimension of operands"); - // K dimension is broadcast so this is an outer product - } - // For all other cases, create a new MatmulOp TensorView* out = newForMatmul(tv_a, tv_b); IrBuilder::create(out, tv_a, tv_b); From 4839125c3dad6a7126b763e52f894d4b8d9e4b0a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 17 May 2024 13:16:30 +0000 Subject: [PATCH 19/52] Clean up test --- tests/cpp/test_combine_mul_sum.cpp | 42 ++++++++++++------------------ 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index 10e2438ba85..03083ea47fa 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -265,6 +265,8 @@ TEST_F(CombineMulSumAsMmaTest, UseMatmulScheduler) { } } +// Test that a simple matmul fusion is picked up by the appropriate scheduler +// and the translation to MmaOp is performed properly. TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerMatmulNode) { const auto run = [&](bool expect_aten_eval) { int M = 504, N = 136, K = 248; @@ -276,6 +278,7 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerMatmulNode) { fusion->addInput(tv0); fusion->addInput(tv1); + auto tv2 = matmul(tv0, tv1); fusion->addOutput(tv2); @@ -290,34 +293,23 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerMatmulNode) { FusionExecutorCache executor_cache(std::move(fusion)); auto outputs = executor_cache.runFusionWithInputs({t0, t1}); - EXPECT_FALSE(executor_cache.getMostRecentKernelRuntime()->isSegmented()); + const FusionKernelRuntime* runtime = + executor_cache.getMostRecentKernelRuntime(); + ASSERT_NE(runtime, nullptr); + + EXPECT_FALSE(runtime->isSegmented()); + ScheduleHeuristic heuristic = + runtime->schedulerHeuristics()->heuristicsList().front()->heuristic(); if (expect_aten_eval) { - // Ensure that the matmul scheduler ran. - EXPECT_EQ( - executor_cache.getMostRecentKernelRuntime() - ->schedulerHeuristics() - ->heuristicsList() - .front() - ->heuristic(), - ScheduleHeuristic::ExprEval); + EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval); } else { - // Ensure that the matmul scheduler ran. - EXPECT_EQ( - executor_cache.getMostRecentKernelRuntime() - ->schedulerHeuristics() - ->heuristicsList() - .front() - ->heuristic(), - ScheduleHeuristic::Matmul); - // Ensure there's a mma op. - // If there's no mma op present, then stop the test. - ASSERT_FALSE(ir_utils::getOpsOfType( - executor_cache.getMostRecentKernelRuntime() - ->executors() - .at(0) - .kernel()) - .empty()); + // Ensure that the Matmul scheduler ran. + EXPECT_EQ(heuristic, ScheduleHeuristic::Matmul); + // Ensure there's an MmaOp. + EXPECT_FALSE( + ir_utils::getOpsOfType(runtime->executors().at(0).kernel()) + .empty()); } testValidate( From 254ba508d0ee8d9ff0b80c93970d3924c7fc1f1a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 17 May 2024 14:13:06 +0000 Subject: [PATCH 20/52] Test that alloc domain causes ExprEval use --- csrc/ir/utils.cpp | 25 +++++++++++++ csrc/ir/utils.h | 2 ++ csrc/scheduler/matmul_utils.cpp | 15 ++++++++ tests/cpp/test_combine_mul_sum.cpp | 58 +++++++++++++++++++----------- 4 files changed, 80 insertions(+), 20 deletions(-) diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 296cff951b8..265b2b1a202 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -1099,6 +1099,31 @@ int64_t getVectorizeSize(const TensorView* tv) { return 1; } +bool hasTrivialAllocationDomain(const TensorView* tv) { + if (!tv->hasAllocation()) { + return true; + } + const std::vector& alloc = tv->getMaybeAllocationDomain(); + const std::vector& rf = tv->getMaybeRFactorDomain(); + size_t i = 0, j = 0; + while (i < alloc.size() && j < rf.size()) { + if (alloc[i]->isBroadcast() || alloc[i]->isReduction()) { + i++; + continue; + } + if (rf[j]->isBroadcast() || rf[j]->isReduction()) { + j++; + continue; + } + if (!alloc[i]->sameAs(rf[j])) { + return false; + } + i++; + j++; + } + return true; +} + } // namespace nvfuser::ir_utils namespace nvfuser::MmaOpUtils { diff --git a/csrc/ir/utils.h b/csrc/ir/utils.h index a9b5c95c5e0..498a22b57c4 100644 --- a/csrc/ir/utils.h +++ b/csrc/ir/utils.h @@ -649,4 +649,6 @@ std::optional> computePermutation( return permutation; } +bool hasTrivialAllocationDomain(const TensorView* tv); + } // namespace nvfuser::ir_utils diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 7ffa92b83bc..506c339ae33 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -279,6 +279,21 @@ std::string isMatmulFusionDefinitionSupported( } } + // Check that no non-trivial allocation domains are set on inputs or outputs. + // TODO: Lift this requirement once we have proper allocation domain support + for (Val* inp : fusion->inputs()) { + if (auto tv = dynamic_cast(inp); + tv && !ir_utils::hasTrivialAllocationDomain(tv)) { + return "detected input TV with non-trivial allocation domain"; + } + } + for (Val* outp : fusion->outputs()) { + if (auto tv = dynamic_cast(outp); + tv && !ir_utils::hasTrivialAllocationDomain(tv)) { + return "detected output TV with non-trivial allocation domain"; + } + } + return ""; } diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index 03083ea47fa..30a7bb5ca94 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -268,7 +268,7 @@ TEST_F(CombineMulSumAsMmaTest, UseMatmulScheduler) { // Test that a simple matmul fusion is picked up by the appropriate scheduler // and the translation to MmaOp is performed properly. TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerMatmulNode) { - const auto run = [&](bool expect_aten_eval) { + const auto run = [&](bool transpose_a_alloc, bool expect_aten_eval) { int M = 504, N = 136, K = 248; auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -276,19 +276,31 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerMatmulNode) { auto tv0 = makeContigTensor(2, DataType::Half); auto tv1 = makeContigTensor(2, DataType::Half); + if (transpose_a_alloc) { + tv0->setAllocationDomain({tv0->axis(1), tv0->axis(0)}, true); + } + fusion->addInput(tv0); fusion->addInput(tv1); auto tv2 = matmul(tv0, tv1); - fusion->addOutput(tv2); + // add an epilogue + auto tv3 = sin(tv2); + auto tv4 = castOp(DataType::Half, tv3); + fusion->addOutput(tv4); + + // Verify that we no longer set up MmaOp in matmul() ASSERT_TRUE(ir_utils::getOpsOfType(fusion.get()).empty()); auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); auto t0 = at::randn({M, K}, options); auto t1 = at::randn({K, N}, options); - auto tref = at::matmul(t0, t1); + if (transpose_a_alloc) { + t0 = t0.as_strided({M, K}, {1, M}); + } + auto tref = at::matmul(t0, t1).sin().to(at::kHalf); FusionExecutorCache executor_cache(std::move(fusion)); auto outputs = executor_cache.runFusionWithInputs({t0, t1}); @@ -297,7 +309,11 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerMatmulNode) { executor_cache.getMostRecentKernelRuntime(); ASSERT_NE(runtime, nullptr); - EXPECT_FALSE(runtime->isSegmented()); + if (expect_aten_eval) { + EXPECT_TRUE(runtime->isSegmented()); + } else { + EXPECT_FALSE(runtime->isSegmented()); + } ScheduleHeuristic heuristic = runtime->schedulerHeuristics()->heuristicsList().front()->heuristic(); @@ -305,7 +321,9 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerMatmulNode) { EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval); } else { // Ensure that the Matmul scheduler ran. - EXPECT_EQ(heuristic, ScheduleHeuristic::Matmul); + // Assert here since we will inspect the kernel next, which we can't do if + // ExprEval accepts the segment. + ASSERT_EQ(heuristic, ScheduleHeuristic::Matmul); // Ensure there's an MmaOp. EXPECT_FALSE( ir_utils::getOpsOfType(runtime->executors().at(0).kernel()) @@ -315,21 +333,21 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerMatmulNode) { testValidate( executor_cache.fusion(), outputs, {t0, t1}, {tref}, __LINE__, __FILE__); }; - // Run the test with and without the matmul scheduler - { - DisableOptionsGuard dog; - DisableOptionsGuard::getCurOptions().unset(DisableOption::MatmulExprEval); - run(/*expect_aten_eval=*/true); - } - { - // Disable ExprEval scheduler, which takes precedence of Matmul scheduler - DisableOptionsGuard dog; - DisableOptionsGuard::getCurOptions().set(DisableOption::MatmulExprEval); - // Allow Matmul Scheduler to accept MatmulOp and LinearOp - EnableOptionsGuard eog; - EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMatmul); - run(/*expect_aten_eval=*/false); - } + // CombineMulSumAsMmaTest disabled MatmulExprEval, but we need it enabled + DisableOptionsGuard dog; + DisableOptionsGuard::getCurOptions().unset(DisableOption::MatmulExprEval); + EnableOptionsGuard eog; + + // Run the test with and without matmul fusion enabled + EnableOptionsGuard::getCurOptions().unset(EnableOption::FuseMatmul); + run(/*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); + run(/*transpose_a_alloc=*/true, /*expect_aten_eval=*/true); + + // Allow Matmul Scheduler to fuse MatmulOp and LinearOp + EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMatmul); + run(/*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); + // We cannot yet handle allocation domain in matmul scheduler + run(/*transpose_a_alloc=*/true, /*expect_aten_eval=*/true); } } // namespace nvfuser From 72da560594ba64a336aa76db6b67ec0b99a79def Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 23 May 2024 12:08:58 +0000 Subject: [PATCH 21/52] WIP adding LinearOp support --- csrc/scheduler/matmul_utils.cpp | 21 +++-- csrc/scheduler/mma_utils.cpp | 11 +++ tests/cpp/test_combine_mul_sum.cpp | 119 +++++++++++++++++++++++++++++ 3 files changed, 145 insertions(+), 6 deletions(-) diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index ab073a7d24d..d0ed825ed1b 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -493,12 +493,21 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { // #2 { - if (!isOptionEnabled(EnableOption::FuseMatmul)) { - // Check for MatmulOp or LinearOp. If found, then only fuse if option is - // specified - Expr* op = patterns.front().output->definition(); - if (op->isA() /* || op->isA()*/) { - return "Matmul fusion is disabled by default. Enable it using NVFUSER_ENABLE=fuse_matmul"; + for (const mma_utils::MatmulPattern& pattern : patterns) { + Expr* op = pattern.output->definition(); + if (op->isA() || op->isA()) { + if (!isOptionEnabled(EnableOption::FuseMatmul)) { + // Check for MatmulOp or LinearOp. If found, then only fuse if option + // is specified + return "MatmulOp and LinearOp fusion is disabled by default. " + "Enable it using NVFUSER_ENABLE=fuse_matmul"; + } + // Refuse patterns containing 1D inputs since these are mat-vec as + // opposed to mat-mat products. + if (op->input(0)->as()->nDims() < 2 || + op->input(1)->as()->nDims() < 2) { + return "Cannot fuse matrix-vector products"; + } } } } diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index a50eb7c9bf9..32a32c1eb01 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1390,6 +1390,17 @@ class MatmulPatternMatcher : IterVisitor { // there is a transpose operation in the epilogue, then this assumption will // be violated. In such cases we should actually swap and transpose A and B. + // Match all LinearOps and MatmulOps as MatmulPatterns. This includes ops + // whose inputs are not 2D, i.e. matrix-vector products. The matmul scheduler + // will decide whether or not it can fuse a given pattern based on the + // dimensionality of its inputs. + void handle(LinearOp* lop) override { + MatmulPattern& pattern = patterns_.emplace_back(); + pattern.A = lop->inA()->as(); + pattern.B = lop->inB()->as(); + pattern.output = lop->out()->as(); + } + void handle(MatmulOp* mop) override { MatmulPattern& pattern = patterns_.emplace_back(); pattern.A = mop->inA()->as(); diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index 10280ba6315..ab69b90f1d1 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -354,6 +354,125 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerMatmulNode) { run(/*transpose_a_alloc=*/true, /*expect_aten_eval=*/true); } +// Test that a simple linear op fusion is picked up by the appropriate scheduler +// and the translation to MmaOp is performed properly. +TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerLinearNode) { + const auto run = [&](int64_t A_dim, + int64_t B_dim, + int64_t bias_dim, + bool transpose_a_alloc, + bool expect_aten_eval) { + int M = 504, N = 136, K = 248; + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeContigTensor(A_dim, DataType::Half); + auto tv1 = makeContigTensor(B_dim, DataType::Half); + + if (transpose_a_alloc && A_dim > 1) { + std::vector alloc = tv0->getMaybeAllocationDomain(); + alloc[alloc.size() - 1] = tv0->axis(-2); + alloc[alloc.size() - 2] = tv0->axis(-1); + tv0->setAllocationDomain(alloc, true); + } + + fusion->addInput(tv0); + fusion->addInput(tv1); + + TensorView* tv2 = nullptr; + if (bias_dim >= 0) { + // bias_dim = -1 indicates we should not use any bias argument + auto bias = makeContigTensor(B_dim, DataType::Half); + fusion->addInput(bias); + tv2 = linear(tv0, tv1, bias); + } else { + tv2 = linear(tv0, tv1); + } + + // add an epilogue + auto tv3 = sin(tv2); + auto tv4 = castOp(DataType::Half, tv3); + + fusion->addOutput(tv4); + + // Verify that we no longer set up MmaOp in matmul() + ASSERT_TRUE(ir_utils::getOpsOfType(fusion.get()).empty()); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({N, K}, options); + if (transpose_a_alloc) { + t0 = t0.as_strided({M, K}, {1, M}); + } + std::vector inputs{t0, t1}; + at::Tensor tref; + if (bias_dim >= 0) { + at::Tensor bias; + if (bias_dim == 0) { + bias = at::randn({}, options); + } else if (bias_dim == 1) { + bias = at::randn({N}, options); + } else if (bias_dim == 2) { + bias = at::randn({M, N}, options); + } else { + NVF_ERROR(false, "Invalid bias dimension given:", bias_dim); + } + inputs.emplace_back(bias); + tref = at::linear(t0, t1, bias); + } else { + tref = at::linear(t0, t1); + } + tref = tref.sin().to(at::kHalf); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + + const FusionKernelRuntime* runtime = + executor_cache.getMostRecentKernelRuntime(); + ASSERT_NE(runtime, nullptr); + + if (expect_aten_eval) { + EXPECT_TRUE(runtime->isSegmented()); + } else { + EXPECT_FALSE(runtime->isSegmented()); + } + + ScheduleHeuristic heuristic = + runtime->schedulerHeuristics()->heuristicsList().front()->heuristic(); + if (expect_aten_eval) { + EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval); + } else { + // Ensure that the Matmul scheduler ran. + // Assert here since we will inspect the kernel next, which we can't + // do if ExprEval accepts the segment. + ASSERT_EQ(heuristic, ScheduleHeuristic::Matmul); + // Ensure there's an MmaOp. + EXPECT_FALSE( + ir_utils::getOpsOfType(runtime->executors().at(0).kernel()) + .empty()); + } + + testValidate( + executor_cache.fusion(), outputs, {t0, t1}, {tref}, __LINE__, __FILE__); + }; + // CombineMulSumAsMmaTest disabled MatmulExprEval, but we need it + // enabled + DisableOptionsGuard dog; + DisableOptionsGuard::getCurOptions().unset(DisableOption::MatmulExprEval); + EnableOptionsGuard eog; + + // Run the test with and without matmul fusion enabled + EnableOptionsGuard::getCurOptions().unset(EnableOption::FuseMatmul); + run(2, 2, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); + run(2, 2, -1, /*transpose_a_alloc=*/true, /*expect_aten_eval=*/true); + + // Allow Matmul Scheduler to fuse MatmulOp and LinearOp + EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMatmul); + run(2, 2, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); + // We cannot yet handle allocation domain in matmul scheduler + run(2, 2, -1, /*transpose_a_alloc=*/true, /*expect_aten_eval=*/true); +} + // Check that we determine A and B properly when they are swapped as inputs to // mul TEST_F(CombineMulSumAsMmaTest, SwapAandB) { From 514da40182f93042ab5569087caf14126802855d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 23 May 2024 12:49:51 +0000 Subject: [PATCH 22/52] Translate LinearOp --- csrc/scheduler/mma_utils.cpp | 101 +++++++++++++++++++++++++---------- 1 file changed, 72 insertions(+), 29 deletions(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 32a32c1eb01..49f4f89d8d5 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1530,6 +1530,47 @@ MmaOp* MatmulPattern::translateToMmaOp() { if (auto mma_op = dynamic_cast(output->definition())) { // No translation needed return mma_op; + } else if (output->definition()->isA()) { + Val* init = IrBuilder::create(0.0, output->dtype()); + // This replaces the mul and sum by overwriting output->definition() + return IrBuilder::create(output, A, B, init); + } + + // This will hold the translated output from MatmulOp or LinearOp + TensorView* fms = nullptr; + MmaOp* mma_op = nullptr; + if (auto lop = dynamic_cast(output->definition())) { + // Linear takes inputs input, weight(, bias) + // - input can be any dimension > 0. We assert that it must be at least 2 + // and refuse to translate if dimension is 1. + // - weight can be one or two dimensional. We refuse to translate if + // dimension is 1. + // - bias, if present, can be zero or two dimensional. Bias can only be + // present if weight is 2D + // + // We translate by broadcasting input, weight, and bias such that the + // contracted dimension K is in the last position (this is true of the + // rfactor domains in input and weight already). Then we form an MmaOp and + // optionally add the bias tensor followed by a cast back to the input + // dtype. + NVF_ERROR( + A->nDims() > 1 && B->nDims() > 1, + "Cannot translate LinearOp with 1D input"); + std::vector bcast_dim((size_t)A->nDims() + 1, false); + bcast_dim[bcast_dim.size() - 2] = true; // N + A = broadcast(A, bcast_dim); + + bcast_dim[bcast_dim.size() - 2] = false; // reset N + std::fill(bcast_dim.begin(), bcast_dim.end() - 2, true); + B = broadcast(B, bcast_dim); + + fms = fusedMultiplySum(A, B, {-1}); + mma_op = fms->definition()->as(); + + auto* bias = dynamic_cast(lop->bias()); + if (bias != nullptr) { + fms = add(fms, bias); + } } else if (auto mop = dynamic_cast(output->definition())) { // MatmulOp takes inputs whose sizes are [..., M, K] and [..., K, N], so we // must transpose B then broadcast both operands before creating the final @@ -1538,37 +1579,39 @@ MmaOp* MatmulPattern::translateToMmaOp() { // Also note that the output of MatmulOp is a tensor of shape [..., M, N] // whose matches that of the inputs. We will most commonly then also need to // cast the output of the MmaOp to produce the output TensorView. + NVF_ERROR( + A->nDims() > 1 && B->nDims() > 1, + "Cannot translate MatmulOp with 1D input"); TensorView* Btrans = transpose(B); - TensorView* Abcast = unsqueeze(A, -2); - TensorView* Bbcast = unsqueeze(Btrans, -3); - TensorView* fms = fusedMultiplySum(Abcast, Bbcast, {-1}); - auto mma_op = fms->definition()->as(); - // Update operands to keep the pattern minimal - A = Abcast; - B = Bbcast; - // TODO: skip downcasting if the only uses of `output` are casts back to - // higher precision in order avoid the round trip cast in defining an - // epilogue that starts with MatmulOp. - if (output->dtype() != fms->dtype()) { - // Redefine output as cast of MmaOp->out() - IrBuilder::create(UnaryOpType::Cast, output, fms); - // Update output so that cast is part of the epilogue - output = fms; - } else { - // No cast needed, for example the inputs might be Float - ir_utils::transferDefinitionToNewOutputs(fms->definition(), {output}); - } - return mma_op; - } else if (output->definition()->isA()) { - Val* init = IrBuilder::create(0.0, output->dtype()); - // This replaces the mul and sum by overwriting output->definition() - return IrBuilder::create(output, A, B, init); + A = unsqueeze(A, -2); + B = unsqueeze(Btrans, -3); + fms = fusedMultiplySum(A, B, {-1}); + mma_op = fms->definition()->as(); + } else { + NVF_ERROR( + false, + "Could not translate matmul pattern with output ", + output->toString(), + " to MmaOp"); } - NVF_ERROR( - false, - "Could not translate matmul pattern with output ", - output->toString(), - " to MmaOp"); + NVF_ERROR(fms != nullptr); + NVF_ERROR(mma_op != nullptr); + + // The following is common to both MatmulOp and LinearOp translation + + // TODO: skip downcasting if the only uses of `output` are casts back to + // higher precision in order avoid the round trip cast in defining an + // epilogue that starts with MatmulOp. + if (output->dtype() != fms->dtype()) { + // Redefine output as cast of MmaOp->out() + IrBuilder::create(UnaryOpType::Cast, output, fms); + // Update output so that cast is part of the epilogue + output = fms; + } else { + // No cast needed, for example the inputs might be Float + ir_utils::transferDefinitionToNewOutputs(fms->definition(), {output}); + } + return mma_op; } std::unordered_map MatmulPattern::getDimRoles( From 4f64c5db5ecab0f01f3886e9f1cf354c5050103a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 23 May 2024 13:27:51 +0000 Subject: [PATCH 23/52] Fixes plus add more test cases for LinearOp Some are still failing --- csrc/scheduler/matmul_utils.cpp | 3 +-- tests/cpp/test_combine_mul_sum.cpp | 38 ++++++++++++++++++++++++++---- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index d0ed825ed1b..1e92f28c99f 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -504,8 +504,7 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { } // Refuse patterns containing 1D inputs since these are mat-vec as // opposed to mat-mat products. - if (op->input(0)->as()->nDims() < 2 || - op->input(1)->as()->nDims() < 2) { + if (pattern.A->nDims() < 2 || pattern.B->nDims() < 2) { return "Cannot fuse matrix-vector products"; } } diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index ab69b90f1d1..f5c8aba8c73 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -362,7 +362,7 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerLinearNode) { int64_t bias_dim, bool transpose_a_alloc, bool expect_aten_eval) { - int M = 504, N = 136, K = 248; + int batch_size = 3, M = 504, N = 136, K = 248; auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -382,7 +382,7 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerLinearNode) { TensorView* tv2 = nullptr; if (bias_dim >= 0) { // bias_dim = -1 indicates we should not use any bias argument - auto bias = makeContigTensor(B_dim, DataType::Half); + auto bias = makeContigTensor(bias_dim, DataType::Half); fusion->addInput(bias); tv2 = linear(tv0, tv1, bias); } else { @@ -399,8 +399,18 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerLinearNode) { ASSERT_TRUE(ir_utils::getOpsOfType(fusion.get()).empty()); auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({N, K}, options); + std::vector A_shape(A_dim, batch_size); + A_shape[A_dim - 1] = K; + if (A_dim > 1) { + A_shape[A_dim - 2] = M; + } + at::Tensor t0 = at::randn(A_shape, options); + std::vector B_shape(B_dim, batch_size); + B_shape[B_dim - 1] = K; + if (B_dim > 1) { + B_shape[B_dim - 2] = N; + } + auto t1 = at::randn(B_shape, options); if (transpose_a_alloc) { t0 = t0.as_strided({M, K}, {1, M}); } @@ -471,6 +481,26 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerLinearNode) { run(2, 2, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); // We cannot yet handle allocation domain in matmul scheduler run(2, 2, -1, /*transpose_a_alloc=*/true, /*expect_aten_eval=*/true); + + // Don't fuse 1D inputs + run(1, 2, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); + run(2, 1, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); + // The following currently fails but it should not be translated to LinearOp + // to begin with + // run(1, 1, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); + + // Multiple batch dims in input + run(3, 2, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); + run(4, 2, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); + + // Bias cases + run(2, 2, 0, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); + run(2, 2, 1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); + // Undocumented 2D bias support + run(2, 2, 2, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); + + run(3, 2, 1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); + run(4, 2, 1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); } // Check that we determine A and B properly when they are swapped as inputs to From 8328207247cf810a0d65083173c6fd35e504e672 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 23 May 2024 13:50:07 +0000 Subject: [PATCH 24/52] Fix up more tests. Not all test cases are passing yet --- tests/cpp/test_combine_mul_sum.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index f5c8aba8c73..b25148c87eb 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -463,7 +463,7 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerLinearNode) { } testValidate( - executor_cache.fusion(), outputs, {t0, t1}, {tref}, __LINE__, __FILE__); + executor_cache.fusion(), outputs, inputs, {tref}, __LINE__, __FILE__); }; // CombineMulSumAsMmaTest disabled MatmulExprEval, but we need it // enabled @@ -485,9 +485,9 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerLinearNode) { // Don't fuse 1D inputs run(1, 2, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); run(2, 1, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); - // The following currently fails but it should not be translated to LinearOp - // to begin with - // run(1, 1, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); + // TODO: The following currently fails but it should not be translated to + // LinearOp to begin with run(1, 1, -1, /*transpose_a_alloc=*/false, + // /*expect_aten_eval=*/true); // Multiple batch dims in input run(3, 2, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); @@ -497,7 +497,9 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerLinearNode) { run(2, 2, 0, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); run(2, 2, 1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); // Undocumented 2D bias support - run(2, 2, 2, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); + // TODO: Currently failing in propagateBoundValuesThroughExactMaps, indicating + // possible PairwiseRootDomainMap issue? + // run(2, 2, 2, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); run(3, 2, 1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); run(4, 2, 1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); From db32d2d8200b618f3e04ce26a187eb836392b3f2 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 23 May 2024 15:02:21 +0000 Subject: [PATCH 25/52] Special cases in getDimRoles for new nodes --- csrc/scheduler/matmul_utils.cpp | 6 +-- csrc/scheduler/mma_utils.cpp | 71 +++++++++++++++++++++++++++++++-- 2 files changed, 71 insertions(+), 6 deletions(-) diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 1e92f28c99f..179df3f0300 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -205,9 +205,9 @@ std::string isMatmulFusionDefinitionSupported( constexpr size_t minimal_number_of_inputs = 2; - // Quick checks - MmaOp + // Quick checks { - // Check if MmaOp represents gemm (requires M/N/K == 1, B == 0) + // Check if matmul pattern represents gemm (requires M/N/K == 1, B == 0) // or bgemm (requires M/N/K/B == 1) std::array num_axes{}; for (const auto& [g, dom] : id_roles) { @@ -218,7 +218,7 @@ std::string isMatmulFusionDefinitionSupported( num_axes[(size_t)MatmulDomain::N] != expected_axes_numbers || num_axes[(size_t)MatmulDomain::K] != expected_axes_numbers || num_axes[(size_t)MatmulDomain::Batch] > expected_axes_numbers) { - return "MmaOp has unsupported number of one of M/N/K/Batch axes"; + return "Matmul pattern has unsupported number of one of M/N/K/Batch axes"; } if (!mma_output->hasReduction()) { diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 49f4f89d8d5..d38bc7ca0e2 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1629,6 +1629,72 @@ std::unordered_map MatmulPattern::getDimRoles( // If there are other patterns, for example a ValGroup present in only A, then // we should raise an exception here. + std::unordered_map dim_roles; + + if (output->definition()->isA()) { + // Special case for MatmulOp + // torch.matmul has a single M, N, and K dimension and 0 or more batch + // dimensions. + dim_roles[exact_graph.toGroup(A->axis(-1))] = MatmulDomain::K; + NVF_ERROR(A->nDims() > 0 && B->nDims() > 0); + size_t m_and_k_dims = 0; + if (A->nDims() == 1 && B->nDims() == 1) { + NVF_ERROR( + false, "MatmulOp node should not be created when both inputs are 1D"); + } else if (A->nDims() == 1) { + // Missing M dimension + dim_roles[exact_graph.toGroup(B->axis(-1))] = MatmulDomain::N; + m_and_k_dims = 1; + } else if (B->nDims() == 1) { + // Missing N dimension + dim_roles[exact_graph.toGroup(A->axis(-2))] = MatmulDomain::M; + m_and_k_dims = 1; + } else { + // Both A and B are at least 2D + dim_roles[exact_graph.toGroup(A->axis(-2))] = MatmulDomain::M; + dim_roles[exact_graph.toGroup(B->axis(-1))] = MatmulDomain::N; + m_and_k_dims = 2; + } + // Skip one dimension for the reduction axis in the output + for (size_t i : c10::irange(output->nDims() - 1 - m_and_k_dims)) { + dim_roles[exact_graph.toGroup(output->axis((int64_t)i))] = + MatmulDomain::Batch; + } + return dim_roles; + } else if (output->definition()->isA()) { + // Special case for LinearOp + // torch.matmul has a single M, N, and K dimension and 0 or more batch + // dimensions. The batch dimensions are only present in A + dim_roles[exact_graph.toGroup(A->axis(-1))] = MatmulDomain::K; + NVF_ERROR(A->nDims() > 0 && B->nDims() > 0); + size_t m_and_k_dims = 0; + if (A->nDims() == 1 && B->nDims() == 1) { + NVF_ERROR( + false, "MatmulOp node should not be created when both inputs are 1D"); + } else if (A->nDims() == 1) { + // Missing M dimension + dim_roles[exact_graph.toGroup(B->axis(-2))] = MatmulDomain::N; + m_and_k_dims = 1; + } else if (B->nDims() == 1) { + // Missing N dimension + dim_roles[exact_graph.toGroup(A->axis(-2))] = MatmulDomain::M; + m_and_k_dims = 1; + } else { + // Both A and B are at least 2D + dim_roles[exact_graph.toGroup(A->axis(-2))] = MatmulDomain::M; + dim_roles[exact_graph.toGroup(B->axis(-2))] = MatmulDomain::N; + m_and_k_dims = 2; + } + // Skip one dimension for the reduction axis in the output + for (size_t i : c10::irange(output->nDims() - 1 - m_and_k_dims)) { + dim_roles[exact_graph.toGroup(output->axis((int64_t)i))] = + MatmulDomain::Batch; + } + return dim_roles; + } + + // The code below handles MmaOp or mul-sum patterns + // Indicates whether a ValGroup is present in A (bit 0), B (bit 1), or output // (bit 2) using DimPresence = std::bitset<3>; @@ -1650,15 +1716,14 @@ std::unordered_map MatmulPattern::getDimRoles( recordPresence(B, 1); recordPresence(output, 2); - std::unordered_map dim_roles; for (const auto& [g, flags] : present_flags) { if (flags.all()) { dim_roles[g] = MatmulDomain::Batch; } else if (flags.test(0) && flags.test(1)) { dim_roles[g] = MatmulDomain::K; - } else if (flags.test(0) && flags.test(2)) { + } else if (flags.test(0) && !flags.test(1) && flags.test(2)) { dim_roles[g] = MatmulDomain::M; - } else if (flags.test(1) && flags.test(2)) { + } else if (!flags.test(0) && flags.test(1) && flags.test(2)) { dim_roles[g] = MatmulDomain::N; } else { NVF_ERROR( From c5a613524346d0ae4eaee9d46af00ad5460b978f Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 23 May 2024 15:25:42 +0000 Subject: [PATCH 26/52] Disable allocation domain inference for test --- tests/cpp/test_combine_mul_sum.cpp | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index b25148c87eb..1355df30fca 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -27,6 +27,8 @@ #include #include #include +#include +#include #include #include #include @@ -357,6 +359,12 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerMatmulNode) { // Test that a simple linear op fusion is picked up by the appropriate scheduler // and the translation to MmaOp is performed properly. TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerLinearNode) { + // The allocation domain propagation pass sets the output allocation domain, + // which sometimes causes the matmul scheduler to decline the whole fusion + // when it could compile it otherwise. + preseg_passes::OptimizationPassGuard + alloc_pass_guard(false); + const auto run = [&](int64_t A_dim, int64_t B_dim, int64_t bias_dim, @@ -486,23 +494,26 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerLinearNode) { run(1, 2, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); run(2, 1, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); // TODO: The following currently fails but it should not be translated to - // LinearOp to begin with run(1, 1, -1, /*transpose_a_alloc=*/false, - // /*expect_aten_eval=*/true); + // LinearOp to begin with + // run(1, 1, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); - // Multiple batch dims in input + // Batch dims in input + // TODO: This is a single batch dim in the input. This fails currently run(3, 2, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); - run(4, 2, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); + // TODO: We don't yet support multiple batch dims in matmul scheduler + run(4, 2, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); // Bias cases run(2, 2, 0, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); run(2, 2, 1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); // Undocumented 2D bias support // TODO: Currently failing in propagateBoundValuesThroughExactMaps, indicating - // possible PairwiseRootDomainMap issue? + // possible PairwiseRootDomainMap issue for 2D bias? // run(2, 2, 2, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); run(3, 2, 1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); - run(4, 2, 1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); + // TODO: We don't yet support multiple batch dims in matmul scheduler + run(4, 2, 1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); } // Check that we determine A and B properly when they are swapped as inputs to From 9e4f70cbe6db61de0459113b8005ebbb34d377d8 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 23 May 2024 15:47:28 +0000 Subject: [PATCH 27/52] Cover more cases in matmul node test --- tests/cpp/test_combine_mul_sum.cpp | 101 +++++++++++++++++++++++------ 1 file changed, 81 insertions(+), 20 deletions(-) diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index 1355df30fca..105fbd5fa84 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -274,16 +274,23 @@ TEST_F(CombineMulSumAsMmaTest, UseMatmulScheduler) { // Test that a simple matmul fusion is picked up by the appropriate scheduler // and the translation to MmaOp is performed properly. TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerMatmulNode) { - const auto run = [&](bool transpose_a_alloc, bool expect_aten_eval) { - int M = 504, N = 136, K = 248; + const auto run = [&](int64_t A_dim, + int64_t B_dim, + bool transpose_a_alloc, + bool expect_segmented, + ScheduleHeuristic expected_heuristic) { + int batch_size = 3, M = 504, N = 136, K = 248; auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); - auto tv0 = makeContigTensor(2, DataType::Half); - auto tv1 = makeContigTensor(2, DataType::Half); + auto tv0 = makeContigTensor(A_dim, DataType::Half); + auto tv1 = makeContigTensor(B_dim, DataType::Half); - if (transpose_a_alloc) { - tv0->setAllocationDomain({tv0->axis(1), tv0->axis(0)}, true); + if (transpose_a_alloc && A_dim > 1) { + std::vector alloc = tv0->getMaybeAllocationDomain(); + alloc[alloc.size() - 1] = tv0->axis(-2); + alloc[alloc.size() - 2] = tv0->axis(-1); + tv0->setAllocationDomain(alloc, true); } fusion->addInput(tv0); @@ -301,8 +308,20 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerMatmulNode) { ASSERT_TRUE(ir_utils::getOpsOfType(fusion.get()).empty()); auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({K, N}, options); + std::vector A_shape(A_dim, batch_size); + A_shape[A_dim - 1] = K; + if (A_dim > 1) { + A_shape[A_dim - 2] = M; + } + at::Tensor t0 = at::randn(A_shape, options); + std::vector B_shape(B_dim, batch_size); + if (B_dim > 1) { + B_shape[B_dim - 2] = K; + B_shape[B_dim - 1] = N; + } else { + B_shape[B_dim - 1] = K; + } + auto t1 = at::randn(B_shape, options); if (transpose_a_alloc) { t0 = t0.as_strided({M, K}, {1, M}); } @@ -315,7 +334,7 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerMatmulNode) { executor_cache.getMostRecentKernelRuntime(); ASSERT_NE(runtime, nullptr); - if (expect_aten_eval) { + if (expect_segmented) { EXPECT_TRUE(runtime->isSegmented()); } else { EXPECT_FALSE(runtime->isSegmented()); @@ -323,13 +342,9 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerMatmulNode) { ScheduleHeuristic heuristic = runtime->schedulerHeuristics()->heuristicsList().front()->heuristic(); - if (expect_aten_eval) { - EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval); - } else { - // Ensure that the Matmul scheduler ran. - // Assert here since we will inspect the kernel next, which we can't do if - // ExprEval accepts the segment. - ASSERT_EQ(heuristic, ScheduleHeuristic::Matmul); + EXPECT_EQ(heuristic, expected_heuristic); + + if (heuristic == ScheduleHeuristic::Matmul) { // Ensure there's an MmaOp. EXPECT_FALSE( ir_utils::getOpsOfType(runtime->executors().at(0).kernel()) @@ -346,14 +361,60 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerMatmulNode) { // Run the test with and without matmul fusion enabled EnableOptionsGuard::getCurOptions().unset(EnableOption::FuseMatmul); - run(/*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); - run(/*transpose_a_alloc=*/true, /*expect_aten_eval=*/true); + run(2, + 2, + /*transpose_a_alloc=*/false, + /*expect_segmented=*/true, + ScheduleHeuristic::ExprEval); + run(2, + 2, + /*transpose_a_alloc=*/true, + /*expect_segmented=*/true, + ScheduleHeuristic::ExprEval); // Allow Matmul Scheduler to fuse MatmulOp and LinearOp EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMatmul); - run(/*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); + run(2, + 2, + /*transpose_a_alloc=*/false, + /*expect_segmented=*/false, + ScheduleHeuristic::Matmul); // We cannot yet handle allocation domain in matmul scheduler - run(/*transpose_a_alloc=*/true, /*expect_aten_eval=*/true); + run(2, + 2, + /*transpose_a_alloc=*/true, + /*expect_segmented=*/true, + ScheduleHeuristic::ExprEval); + + // Size-1 input combinations + run(1, + 2, + /*transpose_a_alloc=*/false, + /*expect_segmented=*/true, + ScheduleHeuristic::ExprEval); + run(2, + 1, + /*transpose_a_alloc=*/false, + /*expect_segmented=*/true, + ScheduleHeuristic::ExprEval); + // We fuse this case using the Reduction scheduler + run(1, + 1, + /*transpose_a_alloc=*/false, + /*expect_segmented=*/false, + ScheduleHeuristic::Reduction); + + // Batch dims + run(3, + 1, + /*transpose_a_alloc=*/false, + /*expect_segmented=*/true, + ScheduleHeuristic::ExprEval); + run(3, + 2, + /*transpose_a_alloc=*/false, + /*expect_segmented=*/false, + ScheduleHeuristic::Matmul); } // Test that a simple linear op fusion is picked up by the appropriate scheduler From 09d3945ab586e4b2f195884682063c949888d20e Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 23 May 2024 16:03:47 +0000 Subject: [PATCH 28/52] Fix up batch cases by adding outer broadcast dims --- csrc/scheduler/mma_utils.cpp | 7 +++++++ tests/cpp/test_combine_mul_sum.cpp | 23 +++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index d38bc7ca0e2..db50239cebe 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -1585,6 +1586,12 @@ MmaOp* MatmulPattern::translateToMmaOp() { TensorView* Btrans = transpose(B); A = unsqueeze(A, -2); B = unsqueeze(Btrans, -3); + // A and B might have different dimensions. If so, broadcast the smaller one + // up to the size of the larger. + int64_t out_dims = std::max(A->nDims(), B->nDims()); + // Add new outer broadcast dimensions if necessary + A = ops::maybe_broadcast_inner_to_rank(A, out_dims); + B = ops::maybe_broadcast_inner_to_rank(B, out_dims); fms = fusedMultiplySum(A, B, {-1}); mma_op = fms->definition()->as(); } else { diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index 105fbd5fa84..06b66b85d56 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -274,6 +274,12 @@ TEST_F(CombineMulSumAsMmaTest, UseMatmulScheduler) { // Test that a simple matmul fusion is picked up by the appropriate scheduler // and the translation to MmaOp is performed properly. TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerMatmulNode) { + // The allocation domain propagation pass sets the output allocation domain, + // which sometimes causes the matmul scheduler to decline the whole fusion + // when it could compile it otherwise. + preseg_passes::OptimizationPassGuard + alloc_pass_guard(false); + const auto run = [&](int64_t A_dim, int64_t B_dim, bool transpose_a_alloc, @@ -410,11 +416,28 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerMatmulNode) { /*transpose_a_alloc=*/false, /*expect_segmented=*/true, ScheduleHeuristic::ExprEval); + run(3, + 3, + /*transpose_a_alloc=*/false, + /*expect_segmented=*/false, + ScheduleHeuristic::Matmul); run(3, 2, /*transpose_a_alloc=*/false, /*expect_segmented=*/false, ScheduleHeuristic::Matmul); + run(2, + 3, + /*transpose_a_alloc=*/false, + /*expect_segmented=*/false, + ScheduleHeuristic::Matmul); + // TODO: More than one batch dimension is not yet supported in Matmul + // scheduler + run(4, + 3, + /*transpose_a_alloc=*/false, + /*expect_segmented=*/true, + ScheduleHeuristic::ExprEval); } // Test that a simple linear op fusion is picked up by the appropriate scheduler From f5c4f36fd9d7a266e4e7b6da5d55f56ef66585c9 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 23 May 2024 16:42:17 +0000 Subject: [PATCH 29/52] Fix some cases and add comments --- csrc/scheduler/matmul_utils.cpp | 43 +++++++++++++++++++++++------- csrc/scheduler/mma_utils.cpp | 2 +- tests/cpp/test_combine_mul_sum.cpp | 22 +++++++++------ 3 files changed, 49 insertions(+), 18 deletions(-) diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 179df3f0300..1b1c71e6afa 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -236,12 +236,22 @@ std::string isMatmulFusionDefinitionSupported( // Fusion topology check { + // We will check that all operands have same dimension + int64_t operand_dim = -1; + auto entry = roles_map.find(MatmulRole::INPUT_A); std::set tvs_with_roles; if (entry != roles_map.end()) { if (MATMUL_CORE_ROLES_EXPECTED_COUNT == entry->second.size()) { tvs_with_roles.insert(entry->second.begin(), entry->second.end()); + for (TensorView* tv : entry->second) { + if (operand_dim == -1) { + operand_dim = tv->nDims(); + } else if (tv->nDims() != operand_dim) { + return "All A operands must have the same dimension."; + } + } } else { return "There is more than a single fusion input that can be MMA first input"; } @@ -253,6 +263,21 @@ std::string isMatmulFusionDefinitionSupported( if (entry != roles_map.end()) { if (MATMUL_CORE_ROLES_EXPECTED_COUNT == entry->second.size()) { tvs_with_roles.insert(entry->second.begin(), entry->second.end()); + for (TensorView* tv : entry->second) { + if (operand_dim == -1) { + operand_dim = tv->nDims(); + } else if (tv->nDims() != operand_dim) { + // We cannot always handle differently sized inputs, such as those + // we encounter when translating MatmulOp and LinearOp. This is + // because in those cases one of the operands will have new + // Broadcast dimensions where the other operand has Iteration + // batch dimensions, meaning these new dims are actually M or N + // dimensions. Multiple M and N dimension support is planned but for + // now we must reject these patterns before attempting to translate + // them. + return "All A and B operands must have the same dimension."; + } + } } else { return "There is more than a single fusion input that can be MMA second input"; } @@ -467,8 +492,8 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { // 2. Check if matmul scheduler is enabled // 3. Check if inputs to the mma op or mul sum pair match any of // supported inputs layout - // 4. Check if the input layout for the matmul pattern can be determined - // 5. Check if fusion represents expressions that are recognized by matmul + // 4. Check if fusion represents expressions that are recognized by matmul + // 5. Check if the input layout for the matmul pattern can be determined // scheduler. // #0 @@ -527,13 +552,6 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { mma_utils::RolesMap roles_map = roles_map_opt.getData(); // #4 - const auto input_layout_opt = - mma_utils::getProblemLayout(id_model, id_roles, roles_map); - if (!input_layout_opt.isValid()) { - return input_layout_opt.getErrorMsg(); - } - - // #5 { auto support_status = isMatmulFusionDefinitionSupported( fusion, patterns.front(), roles_map, id_roles); @@ -542,6 +560,13 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { } } + // #5 + const auto input_layout_opt = + mma_utils::getProblemLayout(id_model, id_roles, roles_map); + if (!input_layout_opt.isValid()) { + return input_layout_opt.getErrorMsg(); + } + return ""; } diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index db50239cebe..ed61404c0dc 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1583,7 +1583,7 @@ MmaOp* MatmulPattern::translateToMmaOp() { NVF_ERROR( A->nDims() > 1 && B->nDims() > 1, "Cannot translate MatmulOp with 1D input"); - TensorView* Btrans = transpose(B); + TensorView* Btrans = transpose(B, -2, -1); A = unsqueeze(A, -2); B = unsqueeze(Btrans, -3); // A and B might have different dimensions. If so, broadcast the smaller one diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index 06b66b85d56..5c385ac1ed7 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -421,20 +421,23 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerMatmulNode) { /*transpose_a_alloc=*/false, /*expect_segmented=*/false, ScheduleHeuristic::Matmul); + // TODO: mixed length inputs via broadcasted batch dims + // We currently reject differently-sized inputs since these translate to + // multiple M or N dims run(3, 2, /*transpose_a_alloc=*/false, - /*expect_segmented=*/false, - ScheduleHeuristic::Matmul); + /*expect_segmented=*/true, + ScheduleHeuristic::ExprEval); run(2, 3, /*transpose_a_alloc=*/false, - /*expect_segmented=*/false, - ScheduleHeuristic::Matmul); + /*expect_segmented=*/true, + ScheduleHeuristic::ExprEval); // TODO: More than one batch dimension is not yet supported in Matmul // scheduler run(4, - 3, + 4, /*transpose_a_alloc=*/false, /*expect_segmented=*/true, ScheduleHeuristic::ExprEval); @@ -582,8 +585,10 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerLinearNode) { // run(1, 1, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); // Batch dims in input - // TODO: This is a single batch dim in the input. This fails currently - run(3, 2, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); + // TODO: mixed length inputs via broadcasted batch dims + // We currently reject differently-sized inputs since these translate to + // multiple M or N dims + run(3, 2, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); // TODO: We don't yet support multiple batch dims in matmul scheduler run(4, 2, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); @@ -595,7 +600,8 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerLinearNode) { // possible PairwiseRootDomainMap issue for 2D bias? // run(2, 2, 2, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); - run(3, 2, 1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); + // TODO: Mixed-length inputs are rejected + run(3, 2, 1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); // TODO: We don't yet support multiple batch dims in matmul scheduler run(4, 2, 1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); } From 5c0d31fd8c508c4b7e312c5fde3b7434296f08ec Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 23 May 2024 16:51:36 +0000 Subject: [PATCH 30/52] Fix comment --- csrc/scheduler/mma_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index ed61404c0dc..8043344bcc9 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1546,7 +1546,7 @@ MmaOp* MatmulPattern::translateToMmaOp() { // and refuse to translate if dimension is 1. // - weight can be one or two dimensional. We refuse to translate if // dimension is 1. - // - bias, if present, can be zero or two dimensional. Bias can only be + // - bias, if present, can be zero or one dimensional. Bias can only be // present if weight is 2D // // We translate by broadcasting input, weight, and bias such that the From 18c0f4634483973b23af1cb89dc6821006646756 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 23 May 2024 16:53:26 +0000 Subject: [PATCH 31/52] Fix faulty merge --- csrc/scheduler/mma_utils.h | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index b41f5b3bc0e..68ff171eeb9 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -299,30 +299,6 @@ using DependenciesMap = std::map; //! transposition of inputs in mma instructions, while other (e.g. Turing, //! Ampere) the only supported transposition is TN which means that mma //! instruction first input is transposed, the second input is non-transposed. -NVF_API MatmulProblemLayoutOpt -getProblemLayout(Fusion* fusion, const MatmulPattern& pattern); - -//! This overloaded version is just a wrapper on the above function, where -//! the MatmulPattern is extracted from the fusion. -NVF_API MatmulProblemLayoutOpt getProblemLayout(Fusion* fusion); - -//! Determine the problem layout based on allocation domain of inputs. This is -//! called by the above overloads. -NVF_API MatmulProblemLayoutOpt getProblemLayout( - const IdModel& id_model, - const std::unordered_map& group_to_domain, - const RolesMap& roles_map); - -//! Returns wrapped collection of IterDomains that can be used to get -//! problem shape with runtime info. -//! Data is stored in the order in which lables are defined in MatmulDomain -//! enum class, that is in the following order: m, n, k. -//! An error message is stored in retruned object if valid data cannot -//! be gathered. -//! TODO: 4th domain must be added for batch gemm support. -ProblemIterDomainsOpt getProblemIterDomains(Fusion* fusion); -ProblemIterDomainsOpt getProblemIterDomains(const MatmulPattern& pattern); - NVF_API MatmulProblemLayoutOpt getProblemLayout( const IdModel& id_model, const std::unordered_map& dim_roles, From b52cb1edb524f2ddf522cf5f9f8de79f18d9728f Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 23 May 2024 19:07:44 +0000 Subject: [PATCH 32/52] Remove undocumented 2D bias test case --- tests/cpp/test_combine_mul_sum.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index 5c385ac1ed7..dbb52e3bbb2 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -595,10 +595,6 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerLinearNode) { // Bias cases run(2, 2, 0, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); run(2, 2, 1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); - // Undocumented 2D bias support - // TODO: Currently failing in propagateBoundValuesThroughExactMaps, indicating - // possible PairwiseRootDomainMap issue for 2D bias? - // run(2, 2, 2, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); // TODO: Mixed-length inputs are rejected run(3, 2, 1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); From bb0d619805ab8ff7e23cca5993e9be69faf1ece0 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 23 May 2024 19:11:42 +0000 Subject: [PATCH 33/52] Fix gcc build error --- csrc/scheduler/mma_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 8043344bcc9..d2b6586089b 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1572,7 +1572,7 @@ MmaOp* MatmulPattern::translateToMmaOp() { if (bias != nullptr) { fms = add(fms, bias); } - } else if (auto mop = dynamic_cast(output->definition())) { + } else if (output->definition()->isA()) { // MatmulOp takes inputs whose sizes are [..., M, K] and [..., K, N], so we // must transpose B then broadcast both operands before creating the final // op. From d0c5bed45ea674f2301b7893004d3c7674b491b6 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 29 May 2024 14:55:08 +0000 Subject: [PATCH 34/52] Loop over operand roles to reduce code duplication --- csrc/scheduler/matmul_utils.cpp | 64 +++++++++++++-------------------- 1 file changed, 25 insertions(+), 39 deletions(-) diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 6245004e0ea..96e0095d26f 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -239,53 +239,39 @@ std::string isMatmulFusionDefinitionSupported( // We will check that all operands have same dimension int64_t operand_dim = -1; - auto entry = roles_map.find(MatmulRole::INPUT_A); + // Track TensorViews with assigned roles so we can check that all inputs and + // outputs have recognized roles std::set tvs_with_roles; - if (entry != roles_map.end()) { - if (MATMUL_CORE_ROLES_EXPECTED_COUNT == entry->second.size()) { - tvs_with_roles.insert(entry->second.begin(), entry->second.end()); - for (TensorView* tv : entry->second) { - if (operand_dim == -1) { - operand_dim = tv->nDims(); - } else if (tv->nDims() != operand_dim) { - return "All A operands must have the same dimension."; - } - } - } else { - return "There is more than a single fusion input that can be MMA first input"; - } - } else { - return "No candidate in fusion inputs for MMA first input"; - } - - entry = roles_map.find(MatmulRole::INPUT_B); - if (entry != roles_map.end()) { - if (MATMUL_CORE_ROLES_EXPECTED_COUNT == entry->second.size()) { - tvs_with_roles.insert(entry->second.begin(), entry->second.end()); - for (TensorView* tv : entry->second) { - if (operand_dim == -1) { - operand_dim = tv->nDims(); - } else if (tv->nDims() != operand_dim) { - // We cannot always handle differently sized inputs, such as those - // we encounter when translating MatmulOp and LinearOp. This is - // because in those cases one of the operands will have new - // Broadcast dimensions where the other operand has Iteration - // batch dimensions, meaning these new dims are actually M or N - // dimensions. Multiple M and N dimension support is planned but for - // now we must reject these patterns before attempting to translate - // them. - return "All A and B operands must have the same dimension."; + for (MatmulRole role : {MatmulRole::INPUT_A, MatmulRole::INPUT_B}) { + auto entry = roles_map.find(role); + if (entry != roles_map.end()) { + if (MATMUL_CORE_ROLES_EXPECTED_COUNT == entry->second.size()) { + tvs_with_roles.insert(entry->second.begin(), entry->second.end()); + for (TensorView* tv : entry->second) { + if (operand_dim == -1) { + operand_dim = tv->nDims(); + } else if (tv->nDims() != operand_dim) { + // We cannot always handle differently sized inputs, such as those + // we encounter when translating MatmulOp and LinearOp. This is + // because in those cases one of the operands will have new + // Broadcast dimensions where the other operand has Iteration + // batch dimensions, meaning these new dims are actually M or N + // dimensions. Multiple M and N dimension support is planned but + // for now we must reject these patterns before attempting to + // translate them. + return "All operands must have the same dimension."; + } } + } else { + return "There is more than a single fusion input that can be MMA operand "; } } else { - return "There is more than a single fusion input that can be MMA second input"; + return "No candidate in fusion inputs for MMA operand"; } - } else { - return "No candidate in fusion inputs for MMA second input"; } - entry = roles_map.find(MatmulRole::OUTPUT_D); + auto entry = roles_map.find(MatmulRole::OUTPUT_D); if (entry != roles_map.end()) { tvs_with_roles.insert(entry->second.begin(), entry->second.end()); } else { From 384f0364f61ff18321a23d4512dca231ca485dad Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 29 May 2024 14:58:54 +0000 Subject: [PATCH 35/52] Fix comment about enabling matmul op --- csrc/scheduler/matmul_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 96e0095d26f..7574c1ed72f 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -469,7 +469,7 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { // The plan: // 0. Check if the current CUDA device is supported // 1. Check if there is exactly one matmul pattern defined in the fusion. - // 2. Check if matmul scheduler is enabled + // 2. Check if fusion of MatmulOp and LinearOp is enabled, if applicable // 3. Check if inputs to the mma op or mul sum pair match any of // supported inputs layout // 4. Check if fusion represents expressions that are recognized by matmul From 47cc50e8d6d69834a71c541957301994fe25b849 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Wed, 29 May 2024 10:59:25 -0400 Subject: [PATCH 36/52] Update csrc/scheduler/matmul_utils.cpp Co-authored-by: Priya Mishra <52657555+Priya2698@users.noreply.github.com> --- csrc/scheduler/matmul_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 7574c1ed72f..9ec42a16795 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -470,7 +470,7 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { // 0. Check if the current CUDA device is supported // 1. Check if there is exactly one matmul pattern defined in the fusion. // 2. Check if fusion of MatmulOp and LinearOp is enabled, if applicable - // 3. Check if inputs to the mma op or mul sum pair match any of + // 3. Check if inputs to the matmul pattern match any of // supported inputs layout // 4. Check if fusion represents expressions that are recognized by matmul // 5. Check if the input layout for the matmul pattern can be determined From 1e55592661a495e40d696c8c99971c600562d628 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 29 May 2024 15:15:32 +0000 Subject: [PATCH 37/52] Reject dtypes other than Half or BFloat16 --- csrc/scheduler/matmul_utils.cpp | 6 ++++ tests/cpp/test_matmul_scheduler.cpp | 49 +++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 9ec42a16795..ad728f1f3fb 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -512,6 +512,12 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { if (pattern.A->nDims() < 2 || pattern.B->nDims() < 2) { return "Cannot fuse matrix-vector products"; } + for (TensorView* operand : {pattern.A, pattern.B}) { + if (operand->dtype() != DataType::Half && + operand->dtype() != DataType::BFloat16) { + return "Unsupported operand type. Operands must be fp16 or bf16"; + } + } } } } diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index 5ffb26d40e9..cbc5c1fe818 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -2685,6 +2685,55 @@ TEST_F(NVFuserTest, SegmentLinearOpPrologue) { testValidate(executor_cache.fusion(), outputs, {t0, t1}, __LINE__, __FILE__); } +// Test that the matmul scheduler refuses to translate a matmul that is not +// Half or BFloat16 +TEST_F(NVFuserTest, SegmentMatmulOpUnsupportedDtype) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + // A - tv0, B - tv1, C - tv2 + auto tv0 = makeContigTensor(2, DataType::Float); + auto tv1 = makeContigTensor(2, DataType::Float); + fusion->addInput(tv0); + fusion->addInput(tv1); + + // Prologue prevents ExprEval scheduler from accepting. If Matmul scheduler + // rejects, then Pointwise must not accept this unsegmented fusion. + tv1 = castOp(DataType::Float, sin(tv1)); + + auto tv2 = matmul(tv0, tv1); + + fusion->addOutput(tv2); + + NVF_CHECK( + ir_utils::getOpsOfType(fusion.get()).size() == 1, + "matmul fusion must have at least one MmaOp"); + + FusionExecutorCache executor_cache(std::move(fusion)); + + const int M = 504, N = 136, K = 248; + + at::manual_seed(0); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({K, N}, options); + + // Enable MatmulOp fusion, which should reject because float operands are not + // supported. + EnableOptionsGuard eog; + EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMatmul); + + auto outputs = executor_cache.runFusionWithInputs({t0, t1}); + + const FusionKernelRuntime* runtime = + executor_cache.getMostRecentKernelRuntime(); + + EXPECT_TRUE(runtime->isSegmented()); + + testValidate(executor_cache.fusion(), outputs, {t0, t1}, __LINE__, __FILE__); +} + // This test can be used to check that an external plugin has been loaded. It // is DISABLED_ so that the test suite will pass even if the user has not // provided a plugin via NVFUSER_MATMUL_HEURISTIC_PLUGIN. To check that a From 06f675e7fe0dfa647e3b31be51c1315795fdf090 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Wed, 29 May 2024 11:17:10 -0400 Subject: [PATCH 38/52] Update csrc/scheduler/mma_utils.cpp Co-authored-by: Priya Mishra <52657555+Priya2698@users.noreply.github.com> --- csrc/scheduler/mma_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 11bd7c0fd0d..30f3d20c3fc 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1569,7 +1569,7 @@ MmaOp* MatmulPattern::translateToMmaOp() { // op. // // Also note that the output of MatmulOp is a tensor of shape [..., M, N] - // whose matches that of the inputs. We will most commonly then also need to + // whose dtype matches that of the inputs. We will most commonly then also need to // cast the output of the MmaOp to produce the output TensorView. NVF_ERROR( A->nDims() > 1 && B->nDims() > 1, From 8fa30eccc573acf4233908fd4d118901e3e4c281 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 29 May 2024 15:52:26 +0000 Subject: [PATCH 39/52] Use mapMatmulOpIterDomains --- csrc/scheduler/mma_utils.cpp | 55 +++++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 30f3d20c3fc..ac5caf1dbb8 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1634,29 +1634,38 @@ std::unordered_map MatmulPattern::getDimRoles( // torch.matmul has a single M, N, and K dimension and 0 or more batch // dimensions. dim_roles[exact_graph.toGroup(A->axis(-1))] = MatmulDomain::K; - NVF_ERROR(A->nDims() > 0 && B->nDims() > 0); - size_t m_and_k_dims = 0; - if (A->nDims() == 1 && B->nDims() == 1) { - NVF_ERROR( - false, "MatmulOp node should not be created when both inputs are 1D"); - } else if (A->nDims() == 1) { - // Missing M dimension - dim_roles[exact_graph.toGroup(B->axis(-1))] = MatmulDomain::N; - m_and_k_dims = 1; - } else if (B->nDims() == 1) { - // Missing N dimension - dim_roles[exact_graph.toGroup(A->axis(-2))] = MatmulDomain::M; - m_and_k_dims = 1; - } else { - // Both A and B are at least 2D - dim_roles[exact_graph.toGroup(A->axis(-2))] = MatmulDomain::M; - dim_roles[exact_graph.toGroup(B->axis(-1))] = MatmulDomain::N; - m_and_k_dims = 2; - } - // Skip one dimension for the reduction axis in the output - for (size_t i : c10::irange(output->nDims() - 1 - m_and_k_dims)) { - dim_roles[exact_graph.toGroup(output->axis((int64_t)i))] = - MatmulDomain::Batch; + + // Map output dims to inputs + const std::vector& out_logical = output->getRFactorDomain(); + const std::vector& mapping_a = ops::mapMatmulOpIterDomains( + A->getRFactorDomain(), MatmulRole::INPUT_A, out_logical.size()); + const std::vector& mapping_b = ops::mapMatmulOpIterDomains( + B->getRFactorDomain(), MatmulRole::INPUT_B, out_logical.size()); + + NVF_ERROR(mapping_a.size() == out_logical.size()); + NVF_ERROR(mapping_a.size() == mapping_b.size()); + for (size_t i : c10::irange(out_logical.size())) { + IterDomain* id_out = out_logical[i]; + ValGroup g = exact_graph.toGroup(id_out); + + if (id_out->isReduction()) { + dim_roles[g] = MatmulDomain::K; + continue; + } + + bool has_a = mapping_a[i] != nullptr && mapping_a[i]->isIteration(); + bool has_b = mapping_b[i] != nullptr && mapping_b[i]->isIteration(); + + NVF_ERROR(has_a || has_b); + // If both operand IterDomains are Broadcast, treat as Batch dimension + // If they mismatch, then one must be broadcast which determines M or N + if (has_a == has_b) { + dim_roles[g] = MatmulDomain::Batch; + } else if (has_a) { + dim_roles[g] = MatmulDomain::M; + } else if (has_b) { + dim_roles[g] = MatmulDomain::N; + } } return dim_roles; } else if (output->definition()->isA()) { From e19c8b9dc30db53e4e13dffaed4c0cd674fe55e6 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 29 May 2024 16:03:32 +0000 Subject: [PATCH 40/52] Use map*OpIterDomains to simplify getDimRoles --- csrc/scheduler/mma_utils.cpp | 122 ++++++++++++++++------------------- 1 file changed, 55 insertions(+), 67 deletions(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index ac5caf1dbb8..b9caf5a2159 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1612,6 +1612,44 @@ MmaOp* MatmulPattern::translateToMmaOp() { return mma_op; } +namespace { +// Determine dim roles for either a MatmulOp or a LinearOp, given IterDomain +// mappings +std::unordered_map matmulOrLinearOpDimRoles( + const ValGraph& exact_graph, + const std::vector& out_logical, + const std::vector& mapping_a, + const std::vector& mapping_b) { + std::unordered_map dim_roles; + NVF_ERROR(mapping_a.size() == out_logical.size()); + NVF_ERROR(mapping_a.size() == mapping_b.size()); + for (size_t i : c10::irange(out_logical.size())) { + IterDomain* id_out = out_logical[i]; + const ValGroup& g = exact_graph.toGroup(id_out); + + if (id_out->isReduction()) { + dim_roles[g] = MatmulDomain::K; + continue; + } + + bool has_a = mapping_a[i] != nullptr && mapping_a[i]->isIteration(); + bool has_b = mapping_b[i] != nullptr && mapping_b[i]->isIteration(); + + NVF_ERROR(has_a || has_b); + // If both operand IterDomains are Broadcast, treat as Batch dimension + // If they mismatch, then one must be broadcast which determines M or N + if (has_a == has_b) { + dim_roles[g] = MatmulDomain::Batch; + } else if (has_a) { + dim_roles[g] = MatmulDomain::M; + } else if (has_b) { + dim_roles[g] = MatmulDomain::N; + } + } + return dim_roles; +} +} // namespace + std::unordered_map MatmulPattern::getDimRoles( IdModel& id_model) const { id_model.maybeBuildGraph(IdMappingMode::EXACT); @@ -1627,81 +1665,31 @@ std::unordered_map MatmulPattern::getDimRoles( // If there are other patterns, for example a ValGroup present in only A, then // we should raise an exception here. - std::unordered_map dim_roles; - if (output->definition()->isA()) { - // Special case for MatmulOp - // torch.matmul has a single M, N, and K dimension and 0 or more batch - // dimensions. - dim_roles[exact_graph.toGroup(A->axis(-1))] = MatmulDomain::K; - - // Map output dims to inputs const std::vector& out_logical = output->getRFactorDomain(); - const std::vector& mapping_a = ops::mapMatmulOpIterDomains( - A->getRFactorDomain(), MatmulRole::INPUT_A, out_logical.size()); - const std::vector& mapping_b = ops::mapMatmulOpIterDomains( - B->getRFactorDomain(), MatmulRole::INPUT_B, out_logical.size()); - - NVF_ERROR(mapping_a.size() == out_logical.size()); - NVF_ERROR(mapping_a.size() == mapping_b.size()); - for (size_t i : c10::irange(out_logical.size())) { - IterDomain* id_out = out_logical[i]; - ValGroup g = exact_graph.toGroup(id_out); - - if (id_out->isReduction()) { - dim_roles[g] = MatmulDomain::K; - continue; - } + return matmulOrLinearOpDimRoles( + exact_graph, + out_logical, + ops::mapMatmulOpIterDomains( + A->getRFactorDomain(), MatmulRole::INPUT_A, out_logical.size()), + ops::mapMatmulOpIterDomains( + B->getRFactorDomain(), MatmulRole::INPUT_B, out_logical.size())); - bool has_a = mapping_a[i] != nullptr && mapping_a[i]->isIteration(); - bool has_b = mapping_b[i] != nullptr && mapping_b[i]->isIteration(); - - NVF_ERROR(has_a || has_b); - // If both operand IterDomains are Broadcast, treat as Batch dimension - // If they mismatch, then one must be broadcast which determines M or N - if (has_a == has_b) { - dim_roles[g] = MatmulDomain::Batch; - } else if (has_a) { - dim_roles[g] = MatmulDomain::M; - } else if (has_b) { - dim_roles[g] = MatmulDomain::N; - } - } - return dim_roles; } else if (output->definition()->isA()) { - // Special case for LinearOp - // torch.matmul has a single M, N, and K dimension and 0 or more batch - // dimensions. The batch dimensions are only present in A - dim_roles[exact_graph.toGroup(A->axis(-1))] = MatmulDomain::K; - NVF_ERROR(A->nDims() > 0 && B->nDims() > 0); - size_t m_and_k_dims = 0; - if (A->nDims() == 1 && B->nDims() == 1) { - NVF_ERROR( - false, "MatmulOp node should not be created when both inputs are 1D"); - } else if (A->nDims() == 1) { - // Missing M dimension - dim_roles[exact_graph.toGroup(B->axis(-2))] = MatmulDomain::N; - m_and_k_dims = 1; - } else if (B->nDims() == 1) { - // Missing N dimension - dim_roles[exact_graph.toGroup(A->axis(-2))] = MatmulDomain::M; - m_and_k_dims = 1; - } else { - // Both A and B are at least 2D - dim_roles[exact_graph.toGroup(A->axis(-2))] = MatmulDomain::M; - dim_roles[exact_graph.toGroup(B->axis(-2))] = MatmulDomain::N; - m_and_k_dims = 2; - } - // Skip one dimension for the reduction axis in the output - for (size_t i : c10::irange(output->nDims() - 1 - m_and_k_dims)) { - dim_roles[exact_graph.toGroup(output->axis((int64_t)i))] = - MatmulDomain::Batch; - } - return dim_roles; + const std::vector& out_logical = output->getRFactorDomain(); + return matmulOrLinearOpDimRoles( + exact_graph, + out_logical, + ops::mapLinearOpIterDomains( + A->getRFactorDomain(), MatmulRole::INPUT_A, out_logical.size()), + ops::mapLinearOpIterDomains( + B->getRFactorDomain(), MatmulRole::INPUT_B, out_logical.size())); } // The code below handles MmaOp or mul-sum patterns + std::unordered_map dim_roles; + // Indicates whether a ValGroup is present in A (bit 0), B (bit 1), or output // (bit 2) using DimPresence = std::bitset<3>; From 520b709e65731caf71d92f2aa1526de47ebebb61 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 29 May 2024 16:27:39 +0000 Subject: [PATCH 41/52] Add Llama2FFN test --- tests/cpp/test_matmul_scheduler.cpp | 67 ++++++++++++++++++++++++++--- 1 file changed, 61 insertions(+), 6 deletions(-) diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index cbc5c1fe818..671e6ec17c4 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -2625,7 +2625,7 @@ TEST_F(NVFuserTest, SegmentMatmulOpPrologue) { NVF_CHECK( ir_utils::getOpsOfType(fusion.get()).size() == 1, - "matmul fusion must have at least one MmaOp"); + "matmul fusion must have at least one MatmulOp"); FusionExecutorCache executor_cache(std::move(fusion)); @@ -2692,7 +2692,6 @@ TEST_F(NVFuserTest, SegmentMatmulOpUnsupportedDtype) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); - // A - tv0, B - tv1, C - tv2 auto tv0 = makeContigTensor(2, DataType::Float); auto tv1 = makeContigTensor(2, DataType::Float); fusion->addInput(tv0); @@ -2706,10 +2705,6 @@ TEST_F(NVFuserTest, SegmentMatmulOpUnsupportedDtype) { fusion->addOutput(tv2); - NVF_CHECK( - ir_utils::getOpsOfType(fusion.get()).size() == 1, - "matmul fusion must have at least one MmaOp"); - FusionExecutorCache executor_cache(std::move(fusion)); const int M = 504, N = 136, K = 248; @@ -2748,6 +2743,66 @@ TEST_F(MatmulSchedulerTest, DISABLED_RequireExternalPlugin) { MatmulParams params; } +// Test that we can segment a Fusion containing two matmuls +TEST_F(MatmulSchedulerTest, Llama2FFN) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); + + for (bool enable_fusion : {false, true}) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); + auto tv2 = makeContigTensor(2, DataType::Half); + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addInput(tv2); + + auto tv3 = matmul(tv0, tv1); + auto tv4 = matmul(tv0, tv2); + + // silu + auto tv5 = mul(sigmoid(tv3), tv3); + + auto tv6 = mul(tv5, tv4); + + fusion->addOutput(tv6); + + FusionExecutorCache executor_cache(std::move(fusion)); + + const int M = 504, N = 136, K = 248; + + at::manual_seed(0); + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({K, N}, options); + auto t2 = at::randn({K, N}, options); + std::vector inputs{t0, t1, t2}; + + EnableOptionsGuard eog; + if (enable_fusion) { + EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMatmul); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::FuseMatmul); + } + + auto outputs = executor_cache.runFusionWithInputs(inputs); + + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); + + const FusionKernelRuntime* runtime = + executor_cache.getMostRecentKernelRuntime(); + + EXPECT_TRUE(runtime->isSegmented()); + + if (enable_fusion) { + EXPECT_EQ(runtime->fusionSegments()->groups().size(), 2); + } else { + EXPECT_EQ(runtime->fusionSegments()->groups().size(), 3); + } + } +} + #undef NVFUSER_TEST_CUDA_ARCH_GUARD } // namespace nvfuser From 9ae0769b1f22110b831ad45b40fbae18a4b9b2ab Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 29 May 2024 16:27:47 +0000 Subject: [PATCH 42/52] clang-format --- csrc/scheduler/mma_utils.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index b9caf5a2159..d3555adb08b 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1569,8 +1569,8 @@ MmaOp* MatmulPattern::translateToMmaOp() { // op. // // Also note that the output of MatmulOp is a tensor of shape [..., M, N] - // whose dtype matches that of the inputs. We will most commonly then also need to - // cast the output of the MmaOp to produce the output TensorView. + // whose dtype matches that of the inputs. We will most commonly then also + // need to cast the output of the MmaOp to produce the output TensorView. NVF_ERROR( A->nDims() > 1 && B->nDims() > 1, "Cannot translate MatmulOp with 1D input"); From 0fb44e4267c7981e6d8976c06828713b50aa8ac3 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 29 May 2024 16:31:10 +0000 Subject: [PATCH 43/52] Remove 2D bias check in linear test --- tests/cpp/test_combine_mul_sum.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index 670ebebbe94..72fdcdb0773 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -517,8 +517,6 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerLinearNode) { bias = at::randn({}, options); } else if (bias_dim == 1) { bias = at::randn({N}, options); - } else if (bias_dim == 2) { - bias = at::randn({M, N}, options); } else { NVF_ERROR(false, "Invalid bias dimension given:", bias_dim); } From fe8898ed6af11a53f6fec0ef56aaf58be7ce44dd Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 29 May 2024 17:24:01 +0000 Subject: [PATCH 44/52] Parametrize LinearOp node translation test --- tests/cpp/test_combine_mul_sum.cpp | 285 ++++++++++++++++------------- 1 file changed, 155 insertions(+), 130 deletions(-) diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index 72fdcdb0773..fa69f32b844 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -443,163 +443,188 @@ TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerMatmulNode) { ScheduleHeuristic::ExprEval); } +using LinearNodeTranslationTestParams = + std::tuple; +using LinearNodeTranslationTest = + NVFuserFixtureParamTest; + // Test that a simple linear op fusion is picked up by the appropriate scheduler // and the translation to MmaOp is performed properly. -TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerLinearNode) { +TEST_P(LinearNodeTranslationTest, AutomaticSchedulerLinearNode) { // The allocation domain propagation pass sets the output allocation domain, // which sometimes causes the matmul scheduler to decline the whole fusion // when it could compile it otherwise. preseg_passes::OptimizationPassGuard alloc_pass_guard(false); + int64_t A_dim = std::get<0>(GetParam()); + int64_t B_dim = std::get<1>(GetParam()); + int64_t bias_dim = std::get<2>(GetParam()); + bool enable_fusion = std::get<3>(GetParam()); + bool transpose_a_alloc = std::get<4>(GetParam()); + bool expect_aten_eval = std::get<5>(GetParam()); - const auto run = [&](int64_t A_dim, - int64_t B_dim, - int64_t bias_dim, - bool transpose_a_alloc, - bool expect_aten_eval) { - int batch_size = 3, M = 504, N = 136, K = 248; - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeContigTensor(A_dim, DataType::Half); - auto tv1 = makeContigTensor(B_dim, DataType::Half); - - if (transpose_a_alloc && A_dim > 1) { - std::vector alloc = tv0->getMaybeAllocationDomain(); - alloc[alloc.size() - 1] = tv0->axis(-2); - alloc[alloc.size() - 2] = tv0->axis(-1); - tv0->setAllocationDomain(alloc, true); - } + // CombineMulSumAsMmaTest disabled MatmulExprEval, but we need it + // enabled + DisableOptionsGuard dog; + DisableOptionsGuard::getCurOptions().unset(DisableOption::MatmulExprEval); - fusion->addInput(tv0); - fusion->addInput(tv1); + EnableOptionsGuard eog; + if (enable_fusion) { + EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMatmul); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::FuseMatmul); + } - TensorView* tv2 = nullptr; - if (bias_dim >= 0) { - // bias_dim = -1 indicates we should not use any bias argument - auto bias = makeContigTensor(bias_dim, DataType::Half); - fusion->addInput(bias); - tv2 = linear(tv0, tv1, bias); - } else { - tv2 = linear(tv0, tv1); - } + int batch_size = 3, M = 504, N = 136, K = 248; + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); - // add an epilogue - auto tv3 = sin(tv2); - auto tv4 = castOp(DataType::Half, tv3); + auto tv0 = makeContigTensor(A_dim, DataType::Half); + auto tv1 = makeContigTensor(B_dim, DataType::Half); - fusion->addOutput(tv4); + if (transpose_a_alloc && A_dim > 1) { + std::vector alloc = tv0->getMaybeAllocationDomain(); + alloc[alloc.size() - 1] = tv0->axis(-2); + alloc[alloc.size() - 2] = tv0->axis(-1); + tv0->setAllocationDomain(alloc, true); + } - // Verify that we no longer set up MmaOp in matmul() - ASSERT_TRUE(ir_utils::getOpsOfType(fusion.get()).empty()); + fusion->addInput(tv0); + fusion->addInput(tv1); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - std::vector A_shape(A_dim, batch_size); - A_shape[A_dim - 1] = K; - if (A_dim > 1) { - A_shape[A_dim - 2] = M; - } - at::Tensor t0 = at::randn(A_shape, options); - std::vector B_shape(B_dim, batch_size); - B_shape[B_dim - 1] = K; - if (B_dim > 1) { - B_shape[B_dim - 2] = N; - } - auto t1 = at::randn(B_shape, options); - if (transpose_a_alloc) { - t0 = t0.as_strided({M, K}, {1, M}); - } - std::vector inputs{t0, t1}; - at::Tensor tref; - if (bias_dim >= 0) { - at::Tensor bias; - if (bias_dim == 0) { - bias = at::randn({}, options); - } else if (bias_dim == 1) { - bias = at::randn({N}, options); - } else { - NVF_ERROR(false, "Invalid bias dimension given:", bias_dim); - } - inputs.emplace_back(bias); - tref = at::linear(t0, t1, bias); - } else { - tref = at::linear(t0, t1); - } - tref = tref.sin().to(at::kHalf); + TensorView* tv2 = nullptr; + if (bias_dim >= 0) { + // bias_dim = -1 indicates we should not use any bias argument + auto bias = makeContigTensor(bias_dim, DataType::Half); + fusion->addInput(bias); + tv2 = linear(tv0, tv1, bias); + } else { + tv2 = linear(tv0, tv1); + } - FusionExecutorCache executor_cache(std::move(fusion)); - auto outputs = executor_cache.runFusionWithInputs(inputs); + // add an epilogue + auto tv3 = sin(tv2); + auto tv4 = castOp(DataType::Half, tv3); - const FusionKernelRuntime* runtime = - executor_cache.getMostRecentKernelRuntime(); - ASSERT_NE(runtime, nullptr); + fusion->addOutput(tv4); - if (expect_aten_eval) { - EXPECT_TRUE(runtime->isSegmented()); - } else { - EXPECT_FALSE(runtime->isSegmented()); - } + // Verify that we no longer set up MmaOp in matmul() + ASSERT_TRUE(ir_utils::getOpsOfType(fusion.get()).empty()); - ScheduleHeuristic heuristic = - runtime->schedulerHeuristics()->heuristicsList().front()->heuristic(); - if (expect_aten_eval) { - EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval); + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + std::vector A_shape(A_dim, batch_size); + A_shape[A_dim - 1] = K; + if (A_dim > 1) { + A_shape[A_dim - 2] = M; + } + at::Tensor t0 = at::randn(A_shape, options); + std::vector B_shape(B_dim, batch_size); + B_shape[B_dim - 1] = K; + if (B_dim > 1) { + B_shape[B_dim - 2] = N; + } + auto t1 = at::randn(B_shape, options); + if (transpose_a_alloc) { + t0 = t0.as_strided({M, K}, {1, M}); + } + std::vector inputs{t0, t1}; + at::Tensor tref; + if (bias_dim >= 0) { + at::Tensor bias; + if (bias_dim == 0) { + bias = at::randn({}, options); + } else if (bias_dim == 1) { + bias = at::randn({N}, options); } else { - // Ensure that the Matmul scheduler ran. - // Assert here since we will inspect the kernel next, which we can't - // do if ExprEval accepts the segment. - ASSERT_EQ(heuristic, ScheduleHeuristic::Matmul); - // Ensure there's an MmaOp. - EXPECT_FALSE( - ir_utils::getOpsOfType(runtime->executors().at(0).kernel()) - .empty()); + NVF_ERROR(false, "Invalid bias dimension given:", bias_dim); } + inputs.emplace_back(bias); + tref = at::linear(t0, t1, bias); + } else { + tref = at::linear(t0, t1); + } + tref = tref.sin().to(at::kHalf); - testValidate( - executor_cache.fusion(), outputs, inputs, {tref}, __LINE__, __FILE__); - }; - // CombineMulSumAsMmaTest disabled MatmulExprEval, but we need it - // enabled - DisableOptionsGuard dog; - DisableOptionsGuard::getCurOptions().unset(DisableOption::MatmulExprEval); - EnableOptionsGuard eog; + FusionExecutorCache executor_cache(std::move(fusion)); + auto outputs = executor_cache.runFusionWithInputs(inputs); - // Run the test with and without matmul fusion enabled - EnableOptionsGuard::getCurOptions().unset(EnableOption::FuseMatmul); - run(2, 2, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); - run(2, 2, -1, /*transpose_a_alloc=*/true, /*expect_aten_eval=*/true); + const FusionKernelRuntime* runtime = + executor_cache.getMostRecentKernelRuntime(); + ASSERT_NE(runtime, nullptr); - // Allow Matmul Scheduler to fuse MatmulOp and LinearOp - EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMatmul); - run(2, 2, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); - // We cannot yet handle allocation domain in matmul scheduler - run(2, 2, -1, /*transpose_a_alloc=*/true, /*expect_aten_eval=*/true); + if (expect_aten_eval) { + EXPECT_TRUE(runtime->isSegmented()); + } else { + EXPECT_FALSE(runtime->isSegmented()); + } - // Don't fuse 1D inputs - run(1, 2, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); - run(2, 1, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); - // TODO: The following currently fails but it should not be translated to - // LinearOp to begin with - // run(1, 1, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); + ScheduleHeuristic heuristic = + runtime->schedulerHeuristics()->heuristicsList().front()->heuristic(); + if (expect_aten_eval) { + EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval); + } else { + // Ensure that the Matmul scheduler ran. + // Assert here since we will inspect the kernel next, which we can't + // do if ExprEval accepts the segment. + ASSERT_EQ(heuristic, ScheduleHeuristic::Matmul); + // Ensure there's an MmaOp. + EXPECT_FALSE( + ir_utils::getOpsOfType(runtime->executors().at(0).kernel()) + .empty()); + } - // Batch dims in input - // TODO: mixed length inputs via broadcasted batch dims - // We currently reject differently-sized inputs since these translate to - // multiple M or N dims - run(3, 2, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); - // TODO: We don't yet support multiple batch dims in matmul scheduler - run(4, 2, -1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); - - // Bias cases - run(2, 2, 0, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); - run(2, 2, 1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/false); - - // TODO: Mixed-length inputs are rejected - run(3, 2, 1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); - // TODO: We don't yet support multiple batch dims in matmul scheduler - run(4, 2, 1, /*transpose_a_alloc=*/false, /*expect_aten_eval=*/true); + testValidate( + executor_cache.fusion(), outputs, inputs, {tref}, __LINE__, __FILE__); } +INSTANTIATE_TEST_SUITE_P( + , + LinearNodeTranslationTest, + ::testing::Values( + // Tests without fusion enabled + std::make_tuple(2l, 2l, -1l, false, false, true), + std::make_tuple(2l, 2l, -1l, false, true, true), + // Enable fusion + std::make_tuple(2l, 2l, -1l, true, false, false), + // We cannot yet handle allocation domain in matmul scheduler + std::make_tuple(2l, 2l, -1l, true, true, true), + // We don't fuse 1D inputs + std::make_tuple(1l, 2l, -1l, true, false, true), + std::make_tuple(2l, 1l, -1l, true, false, true), + // TODO: The following currently fails but it should not be translated + // to LinearOp to begin with + // std::make_tuple(1l, 1l, -1l, true, false, false), + // Batch dims in input + // TODO: mixed length inputs via broadcasted batch dims + // We currently reject differently-sized inputs since these translate to + // multiple M or N dims + std::make_tuple(3l, 2l, -1l, true, false, true), + // TODO: We don't yet support multiple batch dims in matmul scheduler + std::make_tuple(4l, 2l, -1l, true, false, true), + // Bias cases + std::make_tuple(2l, 2l, 0l, true, false, false), + std::make_tuple(2l, 2l, 1l, true, false, false), + // TODO: Mixed-length inputs are rejected with bias also + std::make_tuple(3l, 2l, 1l, true, false, true), + // TODO: We don't yet support multiple batch dims in matmul scheduler + std::make_tuple(4l, 2l, 1l, true, false, true)), + [](const testing::TestParamInfo& info) { + std::ostringstream os; + os << std::get<0>(info.param) << "dA"; + os << "_" << std::get<1>(info.param) << "dB"; + int64_t bias_dim = std::get<2>(info.param); + if (bias_dim >= 0) { + os << "_" << bias_dim << "dBias"; + } + if (!std::get<3>(info.param)) { + os << "_nofuse"; + } + if (std::get<4>(info.param)) { + os << "_transposeA"; + } + return os.str(); + }); + // Check that we determine A and B properly when they are swapped as inputs to // mul TEST_F(CombineMulSumAsMmaTest, SwapAandB) { From 2d8d45d38dc1cb72ceb26b038a91da138814c664 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 29 May 2024 23:31:42 +0000 Subject: [PATCH 45/52] Parametrize matmul test --- tests/cpp/test_combine_mul_sum.cpp | 294 ++++++++++++++--------------- 1 file changed, 139 insertions(+), 155 deletions(-) diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index fa69f32b844..95500d545fc 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -271,178 +271,162 @@ TEST_F(CombineMulSumAsMmaTest, UseMatmulScheduler) { } } -// Test that a simple matmul fusion is picked up by the appropriate scheduler +using MatmulNodeTranslationTestParams = + std::tuple; +using MatmulNodeTranslationTest = + NVFuserFixtureParamTest; + +// Test that a simple matmul op fusion is picked up by the appropriate scheduler // and the translation to MmaOp is performed properly. -TEST_F(CombineMulSumAsMmaTest, AutomaticSchedulerMatmulNode) { +TEST_P(MatmulNodeTranslationTest, AutomaticSchedulerMatmulNode) { + const int64_t A_dim = std::get<0>(GetParam()); + const int64_t B_dim = std::get<1>(GetParam()); + const bool enable_fusion = std::get<2>(GetParam()); + const bool transpose_a_alloc = std::get<3>(GetParam()); + const bool expect_segmented = std::get<4>(GetParam()); + const ScheduleHeuristic expected_heuristic = std::get<5>(GetParam()); + + // CombineMulSumAsMmaTest disabled MatmulExprEval, but we need it enabled + DisableOptionsGuard dog; + DisableOptionsGuard::getCurOptions().unset(DisableOption::MatmulExprEval); + + EnableOptionsGuard eog; + if (enable_fusion) { + EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMatmul); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::FuseMatmul); + } + // The allocation domain propagation pass sets the output allocation domain, // which sometimes causes the matmul scheduler to decline the whole fusion // when it could compile it otherwise. preseg_passes::OptimizationPassGuard alloc_pass_guard(false); - const auto run = [&](int64_t A_dim, - int64_t B_dim, - bool transpose_a_alloc, - bool expect_segmented, - ScheduleHeuristic expected_heuristic) { - int batch_size = 3, M = 504, N = 136, K = 248; - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeContigTensor(A_dim, DataType::Half); - auto tv1 = makeContigTensor(B_dim, DataType::Half); + int batch_size = 3, M = 504, N = 136, K = 248; + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); - if (transpose_a_alloc && A_dim > 1) { - std::vector alloc = tv0->getMaybeAllocationDomain(); - alloc[alloc.size() - 1] = tv0->axis(-2); - alloc[alloc.size() - 2] = tv0->axis(-1); - tv0->setAllocationDomain(alloc, true); - } + auto tv0 = makeContigTensor(A_dim, DataType::Half); + auto tv1 = makeContigTensor(B_dim, DataType::Half); - fusion->addInput(tv0); - fusion->addInput(tv1); + if (transpose_a_alloc && A_dim > 1) { + std::vector alloc = tv0->getMaybeAllocationDomain(); + alloc[alloc.size() - 1] = tv0->axis(-2); + alloc[alloc.size() - 2] = tv0->axis(-1); + tv0->setAllocationDomain(alloc, true); + } - auto tv2 = matmul(tv0, tv1); + fusion->addInput(tv0); + fusion->addInput(tv1); - // add an epilogue - auto tv3 = sin(tv2); - auto tv4 = castOp(DataType::Half, tv3); + auto tv2 = matmul(tv0, tv1); - fusion->addOutput(tv4); + // add an epilogue + auto tv3 = sin(tv2); + auto tv4 = castOp(DataType::Half, tv3); - // Verify that we no longer set up MmaOp in matmul() - ASSERT_TRUE(ir_utils::getOpsOfType(fusion.get()).empty()); + fusion->addOutput(tv4); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - std::vector A_shape(A_dim, batch_size); - A_shape[A_dim - 1] = K; - if (A_dim > 1) { - A_shape[A_dim - 2] = M; - } - at::Tensor t0 = at::randn(A_shape, options); - std::vector B_shape(B_dim, batch_size); - if (B_dim > 1) { - B_shape[B_dim - 2] = K; - B_shape[B_dim - 1] = N; - } else { - B_shape[B_dim - 1] = K; - } - auto t1 = at::randn(B_shape, options); - if (transpose_a_alloc) { - t0 = t0.as_strided({M, K}, {1, M}); - } - auto tref = at::matmul(t0, t1).sin().to(at::kHalf); + // Verify that we no longer set up MmaOp in matmul() + ASSERT_TRUE(ir_utils::getOpsOfType(fusion.get()).empty()); - FusionExecutorCache executor_cache(std::move(fusion)); - auto outputs = executor_cache.runFusionWithInputs({t0, t1}); + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + std::vector A_shape(A_dim, batch_size); + A_shape[A_dim - 1] = K; + if (A_dim > 1) { + A_shape[A_dim - 2] = M; + } + at::Tensor t0 = at::randn(A_shape, options); + std::vector B_shape(B_dim, batch_size); + if (B_dim > 1) { + B_shape[B_dim - 2] = K; + B_shape[B_dim - 1] = N; + } else { + B_shape[B_dim - 1] = K; + } + auto t1 = at::randn(B_shape, options); + if (transpose_a_alloc) { + t0 = t0.as_strided({M, K}, {1, M}); + } + auto tref = at::matmul(t0, t1).sin().to(at::kHalf); - const FusionKernelRuntime* runtime = - executor_cache.getMostRecentKernelRuntime(); - ASSERT_NE(runtime, nullptr); + FusionExecutorCache executor_cache(std::move(fusion)); + auto outputs = executor_cache.runFusionWithInputs({t0, t1}); - if (expect_segmented) { - EXPECT_TRUE(runtime->isSegmented()); - } else { - EXPECT_FALSE(runtime->isSegmented()); - } + const FusionKernelRuntime* runtime = + executor_cache.getMostRecentKernelRuntime(); + ASSERT_NE(runtime, nullptr); - ScheduleHeuristic heuristic = - runtime->schedulerHeuristics()->heuristicsList().front()->heuristic(); - EXPECT_EQ(heuristic, expected_heuristic); + if (expect_segmented) { + EXPECT_TRUE(runtime->isSegmented()); + } else { + EXPECT_FALSE(runtime->isSegmented()); + } - if (heuristic == ScheduleHeuristic::Matmul) { - // Ensure there's an MmaOp. - EXPECT_FALSE( - ir_utils::getOpsOfType(runtime->executors().at(0).kernel()) - .empty()); - } + ScheduleHeuristic heuristic = + runtime->schedulerHeuristics()->heuristicsList().front()->heuristic(); + EXPECT_EQ(heuristic, expected_heuristic); - testValidate( - executor_cache.fusion(), outputs, {t0, t1}, {tref}, __LINE__, __FILE__); - }; - // CombineMulSumAsMmaTest disabled MatmulExprEval, but we need it enabled - DisableOptionsGuard dog; - DisableOptionsGuard::getCurOptions().unset(DisableOption::MatmulExprEval); - EnableOptionsGuard eog; + if (heuristic == ScheduleHeuristic::Matmul) { + // Ensure there's an MmaOp. + EXPECT_FALSE( + ir_utils::getOpsOfType(runtime->executors().at(0).kernel()) + .empty()); + } - // Run the test with and without matmul fusion enabled - EnableOptionsGuard::getCurOptions().unset(EnableOption::FuseMatmul); - run(2, - 2, - /*transpose_a_alloc=*/false, - /*expect_segmented=*/true, - ScheduleHeuristic::ExprEval); - run(2, - 2, - /*transpose_a_alloc=*/true, - /*expect_segmented=*/true, - ScheduleHeuristic::ExprEval); - - // Allow Matmul Scheduler to fuse MatmulOp and LinearOp - EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMatmul); - run(2, - 2, - /*transpose_a_alloc=*/false, - /*expect_segmented=*/false, - ScheduleHeuristic::Matmul); - // We cannot yet handle allocation domain in matmul scheduler - run(2, - 2, - /*transpose_a_alloc=*/true, - /*expect_segmented=*/true, - ScheduleHeuristic::ExprEval); - - // Size-1 input combinations - run(1, - 2, - /*transpose_a_alloc=*/false, - /*expect_segmented=*/true, - ScheduleHeuristic::ExprEval); - run(2, - 1, - /*transpose_a_alloc=*/false, - /*expect_segmented=*/true, - ScheduleHeuristic::ExprEval); - // We fuse this case using the Reduction scheduler - run(1, - 1, - /*transpose_a_alloc=*/false, - /*expect_segmented=*/false, - ScheduleHeuristic::Reduction); - - // Batch dims - run(3, - 1, - /*transpose_a_alloc=*/false, - /*expect_segmented=*/true, - ScheduleHeuristic::ExprEval); - run(3, - 3, - /*transpose_a_alloc=*/false, - /*expect_segmented=*/false, - ScheduleHeuristic::Matmul); - // TODO: mixed length inputs via broadcasted batch dims - // We currently reject differently-sized inputs since these translate to - // multiple M or N dims - run(3, - 2, - /*transpose_a_alloc=*/false, - /*expect_segmented=*/true, - ScheduleHeuristic::ExprEval); - run(2, - 3, - /*transpose_a_alloc=*/false, - /*expect_segmented=*/true, - ScheduleHeuristic::ExprEval); - // TODO: More than one batch dimension is not yet supported in Matmul - // scheduler - run(4, - 4, - /*transpose_a_alloc=*/false, - /*expect_segmented=*/true, - ScheduleHeuristic::ExprEval); + testValidate( + executor_cache.fusion(), outputs, {t0, t1}, {tref}, __LINE__, __FILE__); } +INSTANTIATE_TEST_SUITE_P( + , + MatmulNodeTranslationTest, + ::testing::Values( + // Tests without fusion enabled + std:: + make_tuple(2l, 2l, false, false, true, ScheduleHeuristic::ExprEval), + std::make_tuple(2l, 2l, false, true, true, ScheduleHeuristic::ExprEval), + // Tests with fusion enabled + std::make_tuple(2l, 2l, true, false, false, ScheduleHeuristic::Matmul), + // We cannot yet handle allocation domain in matmul scheduler + std::make_tuple(2l, 2l, true, true, true, ScheduleHeuristic::ExprEval), + // Size-1 input combinations + std::make_tuple(1l, 2l, true, false, true, ScheduleHeuristic::ExprEval), + std::make_tuple(2l, 1l, true, false, true, ScheduleHeuristic::ExprEval), + // We fuse this case using the Reduction scheduler + std::make_tuple( + 1l, + 1l, + true, + false, + false, + ScheduleHeuristic::Reduction), + // Batch dims + std::make_tuple(3l, 1l, true, false, true, ScheduleHeuristic::ExprEval), + std::make_tuple(3l, 3l, true, false, false, ScheduleHeuristic::Matmul), + // TODO: mixed length inputs via broadcasted batch dims + // We currently reject differently-sized inputs since these translate to + // multiple M or N dims + std::make_tuple(3l, 2l, true, false, true, ScheduleHeuristic::ExprEval), + std::make_tuple(2l, 3l, true, false, true, ScheduleHeuristic::ExprEval), + // TODO: More than one batch dimension is not yet supported in Matmul + // scheduler + std:: + make_tuple(4l, 4l, true, false, true, ScheduleHeuristic::ExprEval)), + [](const testing::TestParamInfo& info) { + std::ostringstream os; + os << std::get<0>(info.param) << "dA"; + os << "_" << std::get<1>(info.param) << "dB"; + if (!std::get<2>(info.param)) { + os << "_nofuse"; + } + if (std::get<3>(info.param)) { + os << "_transposeA"; + } + return os.str(); + }); + using LinearNodeTranslationTestParams = std::tuple; using LinearNodeTranslationTest = @@ -456,12 +440,12 @@ TEST_P(LinearNodeTranslationTest, AutomaticSchedulerLinearNode) { // when it could compile it otherwise. preseg_passes::OptimizationPassGuard alloc_pass_guard(false); - int64_t A_dim = std::get<0>(GetParam()); - int64_t B_dim = std::get<1>(GetParam()); - int64_t bias_dim = std::get<2>(GetParam()); - bool enable_fusion = std::get<3>(GetParam()); - bool transpose_a_alloc = std::get<4>(GetParam()); - bool expect_aten_eval = std::get<5>(GetParam()); + const int64_t A_dim = std::get<0>(GetParam()); + const int64_t B_dim = std::get<1>(GetParam()); + const int64_t bias_dim = std::get<2>(GetParam()); + const bool enable_fusion = std::get<3>(GetParam()); + const bool transpose_a_alloc = std::get<4>(GetParam()); + const bool expect_aten_eval = std::get<5>(GetParam()); // CombineMulSumAsMmaTest disabled MatmulExprEval, but we need it // enabled From 59d77cf6ababc3894c4ac1ad67d7c6709497e5d6 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 30 May 2024 12:19:53 +0000 Subject: [PATCH 46/52] Fix multidevice examples by filtering out device dims --- csrc/scheduler/matmul_utils.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index ad728f1f3fb..4f116b32db5 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -249,9 +249,14 @@ std::string isMatmulFusionDefinitionSupported( if (MATMUL_CORE_ROLES_EXPECTED_COUNT == entry->second.size()) { tvs_with_roles.insert(entry->second.begin(), entry->second.end()); for (TensorView* tv : entry->second) { + const std::vector& leaf = tv->getLeafDomain(); + int64_t ndims = (int64_t)std::count_if( + leaf.begin(), leaf.end(), [](IterDomain* id) { + return !id->isReduction() && !id->isDeviceDim(); + }); if (operand_dim == -1) { - operand_dim = tv->nDims(); - } else if (tv->nDims() != operand_dim) { + operand_dim = ndims; + } else if (ndims != operand_dim) { // We cannot always handle differently sized inputs, such as those // we encounter when translating MatmulOp and LinearOp. This is // because in those cases one of the operands will have new @@ -260,7 +265,7 @@ std::string isMatmulFusionDefinitionSupported( // dimensions. Multiple M and N dimension support is planned but // for now we must reject these patterns before attempting to // translate them. - return "All operands must have the same dimension."; + return "All operands must have the same no-devices dimension."; } } } else { From 764823fa9c8a9a45a26a3d665a7bb50fa02da837 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 30 May 2024 12:23:58 +0000 Subject: [PATCH 47/52] Clean up canScheduleCompileTime --- csrc/scheduler/expr_eval_sched.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/csrc/scheduler/expr_eval_sched.cpp b/csrc/scheduler/expr_eval_sched.cpp index 8a138c92026..777452079cb 100644 --- a/csrc/scheduler/expr_eval_sched.cpp +++ b/csrc/scheduler/expr_eval_sched.cpp @@ -17,8 +17,7 @@ namespace nvfuser { bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { auto exprs = fusion->exprs(); if (!isOptionDisabled(DisableOption::MatmulExprEval)) { - if (exprs.size() == 1 && - (exprs.front()->isA() || exprs.front()->isA())) { + if (exprs.size() == 1 && (exprs.front()->isOneOf())) { return true; } scheduler_debug_utils::canScheduleRejectReason( From 5cf7d729083c01047ca33c7fe425e692027411e4 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 30 May 2024 12:33:44 +0000 Subject: [PATCH 48/52] Add link to #2241 in failing test case --- tests/cpp/test_combine_mul_sum.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index 95500d545fc..45f8f1556d0 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -576,7 +576,8 @@ INSTANTIATE_TEST_SUITE_P( std::make_tuple(1l, 2l, -1l, true, false, true), std::make_tuple(2l, 1l, -1l, true, false, true), // TODO: The following currently fails but it should not be translated - // to LinearOp to begin with + // to LinearOp to begin with. + // See https://github.com/NVIDIA/Fuser/issues/2241 // std::make_tuple(1l, 1l, -1l, true, false, false), // Batch dims in input // TODO: mixed length inputs via broadcasted batch dims From 5ffe5903c39499173b9e85d3f9c0910951634384 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 30 May 2024 12:50:34 +0000 Subject: [PATCH 49/52] Make NoOp scheduler avoid matmul ops And enable 1d/1d linear tests --- csrc/scheduler/no_op.cpp | 6 ++++++ tests/cpp/test_combine_mul_sum.cpp | 12 ++++++++---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/csrc/scheduler/no_op.cpp b/csrc/scheduler/no_op.cpp index 01eda65e043..de48d1adb01 100644 --- a/csrc/scheduler/no_op.cpp +++ b/csrc/scheduler/no_op.cpp @@ -50,6 +50,12 @@ bool NoOpScheduler::canScheduleCompileTime(Fusion* fusion) { return true; } + if (ir_utils::hasAnyMatmulOps(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + heuristicType(), "matmul ops are not supported"); + return false; + } + // Check there're no non-trivial reduction ops. for (auto reduction : ir_utils::getAllTypesOfReductionOps(fusion)) { for (auto output : diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index 45f8f1556d0..a41667793f2 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -568,6 +568,12 @@ INSTANTIATE_TEST_SUITE_P( // Tests without fusion enabled std::make_tuple(2l, 2l, -1l, false, false, true), std::make_tuple(2l, 2l, -1l, false, true, true), + std::make_tuple(1l, 2l, -1l, false, false, true), + std::make_tuple(2l, 1l, -1l, false, false, true), + std::make_tuple(2l, 2l, 1l, false, false, true), + std::make_tuple(1l, 1l, -1l, false, false, true), + std::make_tuple(3l, 2l, 1l, false, false, true), + std::make_tuple(4l, 2l, 1l, false, false, true), // Enable fusion std::make_tuple(2l, 2l, -1l, true, false, false), // We cannot yet handle allocation domain in matmul scheduler @@ -575,10 +581,8 @@ INSTANTIATE_TEST_SUITE_P( // We don't fuse 1D inputs std::make_tuple(1l, 2l, -1l, true, false, true), std::make_tuple(2l, 1l, -1l, true, false, true), - // TODO: The following currently fails but it should not be translated - // to LinearOp to begin with. - // See https://github.com/NVIDIA/Fuser/issues/2241 - // std::make_tuple(1l, 1l, -1l, true, false, false), + // Check that zero-dim output fusion is not claimed by NoOp scheduler + std::make_tuple(1l, 1l, -1l, true, false, true), // Batch dims in input // TODO: mixed length inputs via broadcasted batch dims // We currently reject differently-sized inputs since these translate to From 6108522085cc848c425cb2598eb3979b214aacfc Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 30 May 2024 12:52:23 +0000 Subject: [PATCH 50/52] NVFuserTest -> MatmulSchedulerTest --- tests/cpp/test_matmul_scheduler.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index 671e6ec17c4..3c84c0fc230 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -2604,7 +2604,7 @@ TEST_F(MatmulSchedulerPluginTest, BasicMatmul) { // TODO: Once we can control the ExprEval and Matmul schedulers via options, run // this test with all three combinations (with and without each scheduler, but // at least one enabled). -TEST_F(NVFuserTest, SegmentMatmulOpPrologue) { +TEST_F(MatmulSchedulerTest, SegmentMatmulOpPrologue) { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -2645,7 +2645,7 @@ TEST_F(NVFuserTest, SegmentMatmulOpPrologue) { } // This is just like the above test but with LinearOp instead of MatmulOp -TEST_F(NVFuserTest, SegmentLinearOpPrologue) { +TEST_F(MatmulSchedulerTest, SegmentLinearOpPrologue) { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -2687,7 +2687,7 @@ TEST_F(NVFuserTest, SegmentLinearOpPrologue) { // Test that the matmul scheduler refuses to translate a matmul that is not // Half or BFloat16 -TEST_F(NVFuserTest, SegmentMatmulOpUnsupportedDtype) { +TEST_F(MatmulSchedulerTest, SegmentMatmulOpUnsupportedDtype) { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); From f34f4ad32608625f92973e4cc1ac748468d126c9 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 30 May 2024 13:35:28 +0000 Subject: [PATCH 51/52] Parametrize Llama2FFN test --- tests/cpp/test_matmul_scheduler.cpp | 121 +++++++++++++++------------- 1 file changed, 65 insertions(+), 56 deletions(-) diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index 3c84c0fc230..2b2eeccfade 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -2609,7 +2609,6 @@ TEST_F(MatmulSchedulerTest, SegmentMatmulOpPrologue) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); - // A - tv0, B - tv1, C - tv2 auto tv0 = makeContigTensor(2, DataType::Half); auto tv1 = makeContigTensor(2, DataType::Half); fusion->addInput(tv0); @@ -2650,7 +2649,6 @@ TEST_F(MatmulSchedulerTest, SegmentLinearOpPrologue) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); - // A - tv0, B - tv1, C - tv2 auto tv0 = makeContigTensor(2, DataType::Half); auto tv1 = makeContigTensor(2, DataType::Half); fusion->addInput(tv0); @@ -2729,80 +2727,91 @@ TEST_F(MatmulSchedulerTest, SegmentMatmulOpUnsupportedDtype) { testValidate(executor_cache.fusion(), outputs, {t0, t1}, __LINE__, __FILE__); } -// This test can be used to check that an external plugin has been loaded. It -// is DISABLED_ so that the test suite will pass even if the user has not -// provided a plugin via NVFUSER_MATMUL_HEURISTIC_PLUGIN. To check that a -// plugin can be loaded properly, invoke the test suite like so: -// -// export NVFUSER_MATMUL_HEURISTIC_PLUGIN=/path/to/plugin.so -// build/test_matmul --gtest_also_run_disabled_tests -// -TEST_F(MatmulSchedulerTest, DISABLED_RequireExternalPlugin) { - EXPECT_TRUE(matmul_heuristic_plugin::hasPlugin()); +class MatmulFusionTest : public MatmulSchedulerTest, + public ::testing::WithParamInterface { + protected: + void SetUp() override { + if (fusion_enabled) { + EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMatmul); + } + } - MatmulParams params; -} + EnableOptionsGuard eog_; + bool fusion_enabled = GetParam(); +}; // Test that we can segment a Fusion containing two matmuls -TEST_F(MatmulSchedulerTest, Llama2FFN) { +TEST_P(MatmulFusionTest, Llama2FFN) { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); - for (bool enable_fusion : {false, true}) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeContigTensor(2, DataType::Half); - auto tv1 = makeContigTensor(2, DataType::Half); - auto tv2 = makeContigTensor(2, DataType::Half); - fusion->addInput(tv0); - fusion->addInput(tv1); - fusion->addInput(tv2); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); + auto tv2 = makeContigTensor(2, DataType::Half); + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addInput(tv2); - auto tv3 = matmul(tv0, tv1); - auto tv4 = matmul(tv0, tv2); + auto tv3 = matmul(tv0, tv1); + auto tv4 = matmul(tv0, tv2); - // silu - auto tv5 = mul(sigmoid(tv3), tv3); + // silu + auto tv5 = mul(sigmoid(tv3), tv3); - auto tv6 = mul(tv5, tv4); + auto tv6 = mul(tv5, tv4); - fusion->addOutput(tv6); + fusion->addOutput(tv6); - FusionExecutorCache executor_cache(std::move(fusion)); + FusionExecutorCache executor_cache(std::move(fusion)); - const int M = 504, N = 136, K = 248; + const int M = 504, N = 136, K = 248; - at::manual_seed(0); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({K, N}, options); - auto t2 = at::randn({K, N}, options); - std::vector inputs{t0, t1, t2}; - - EnableOptionsGuard eog; - if (enable_fusion) { - EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMatmul); - } else { - EnableOptionsGuard::getCurOptions().unset(EnableOption::FuseMatmul); - } + at::manual_seed(0); + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({K, N}, options); + auto t2 = at::randn({K, N}, options); + std::vector inputs{t0, t1, t2}; - auto outputs = executor_cache.runFusionWithInputs(inputs); + auto outputs = executor_cache.runFusionWithInputs(inputs); - testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); - const FusionKernelRuntime* runtime = - executor_cache.getMostRecentKernelRuntime(); + const FusionKernelRuntime* runtime = + executor_cache.getMostRecentKernelRuntime(); - EXPECT_TRUE(runtime->isSegmented()); + EXPECT_TRUE(runtime->isSegmented()); - if (enable_fusion) { - EXPECT_EQ(runtime->fusionSegments()->groups().size(), 2); - } else { - EXPECT_EQ(runtime->fusionSegments()->groups().size(), 3); - } + if (fusion_enabled) { + EXPECT_EQ(runtime->fusionSegments()->groups().size(), 2); + } else { + EXPECT_EQ(runtime->fusionSegments()->groups().size(), 3); } } +INSTANTIATE_TEST_SUITE_P( + , + MatmulFusionTest, + ::testing::Bool(), + [](const testing::TestParamInfo& info) { + return info.param ? "fuse" : "dontfuse"; + }); + +// This test can be used to check that an external plugin has been loaded. It +// is DISABLED_ so that the test suite will pass even if the user has not +// provided a plugin via NVFUSER_MATMUL_HEURISTIC_PLUGIN. To check that a +// plugin can be loaded properly, invoke the test suite like so: +// +// export NVFUSER_MATMUL_HEURISTIC_PLUGIN=/path/to/plugin.so +// build/test_matmul --gtest_also_run_disabled_tests +// +TEST_F(MatmulSchedulerTest, DISABLED_RequireExternalPlugin) { + EXPECT_TRUE(matmul_heuristic_plugin::hasPlugin()); + + MatmulParams params; +} + #undef NVFUSER_TEST_CUDA_ARCH_GUARD } // namespace nvfuser From 7342ecd098da7eb2125e3982eb55072609c321f9 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 30 May 2024 12:02:37 +0000 Subject: [PATCH 52/52] Guard translation tests for cc < 7.5 --- tests/cpp/test_combine_mul_sum.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index a41667793f2..91d956f952b 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -279,6 +279,7 @@ using MatmulNodeTranslationTest = // Test that a simple matmul op fusion is picked up by the appropriate scheduler // and the translation to MmaOp is performed properly. TEST_P(MatmulNodeTranslationTest, AutomaticSchedulerMatmulNode) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); const int64_t A_dim = std::get<0>(GetParam()); const int64_t B_dim = std::get<1>(GetParam()); const bool enable_fusion = std::get<2>(GetParam()); @@ -435,6 +436,7 @@ using LinearNodeTranslationTest = // Test that a simple linear op fusion is picked up by the appropriate scheduler // and the translation to MmaOp is performed properly. TEST_P(LinearNodeTranslationTest, AutomaticSchedulerLinearNode) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); // The allocation domain propagation pass sets the output allocation domain, // which sometimes causes the matmul scheduler to decline the whole fusion // when it could compile it otherwise.