From a992c0e24e693a6398879d34cdd4e44f7769d9c4 Mon Sep 17 00:00:00 2001 From: Andrzej Bekas <118676880+drzejan2@users.noreply.github.com> Date: Wed, 5 Apr 2023 02:20:45 -0700 Subject: [PATCH] MmaOp - definition consistency checks - move some checks from compile time checks in matmul scheduler to MmaOp constructor, - move input layout check from matmul scheduler compile time checks to MmaOp class, - extend the set of attributes associated with MmaOp, - update compile time checks in matmul scheduler, --- csrc/ir_internal_nodes.h | 38 ++++- csrc/ir_nodes.cpp | 294 ++++++++++++++++++++++++++++++-- csrc/scheduler/matmul.cpp | 2 - csrc/scheduler/matmul_utils.cpp | 123 +++---------- test/test_gpu_tensorcore.cpp | 22 +++ 5 files changed, 365 insertions(+), 114 deletions(-) diff --git a/csrc/ir_internal_nodes.h b/csrc/ir_internal_nodes.h index 5933869c03e..6e751dcbf29 100644 --- a/csrc/ir_internal_nodes.h +++ b/csrc/ir_internal_nodes.h @@ -1044,6 +1044,8 @@ class TORCH_CUDA_CU_API MmaOp : public Expr { } }; + using AxesData = std::vector; + using MmaInputLayoutOpt = c10::optional; using Expr::Expr; MmaOp(IrBuilderPasskey, Val* out, Val* in_a, Val* in_b, Val* init); @@ -1082,7 +1084,7 @@ class TORCH_CUDA_CU_API MmaOp : public Expr { } const auto& options() const { - return attribute(1)->as>()->value; + return attribute(ATTR_POS_OPTS)->as>()->value; } auto accStride() const { @@ -1090,6 +1092,40 @@ class TORCH_CUDA_CU_API MmaOp : public Expr { } void configureOptions(MmaOptions options); + + auto inputLayout() const { + return attribute(ATTR_POS_INPUT_LAYOUT) + ->as>() + ->value; + } + + const auto& mAxes() const { + return attribute(ATTR_POS_M_AXES)->as>()->value; + } + + const auto& nAxes() const { + return attribute(ATTR_POS_N_AXES)->as>()->value; + } + + const auto& kAxes() const { + return attribute(ATTR_POS_K_AXES)->as>()->value; + } + + const auto& batchAxes() const { + return attribute(ATTR_POS_BATCH_AXES)->as>()->value; + } + + private: + // Predefined idexes of attributes stored for this IR node, to avoid + // magic numbers, based on order in which attributes are initialized + // in constructor + static constexpr size_t ATTR_POS_INIT = 0; + static constexpr size_t ATTR_POS_OPTS = 1; + static constexpr size_t ATTR_POS_M_AXES = 2; + static constexpr size_t ATTR_POS_N_AXES = 3; + static constexpr size_t ATTR_POS_K_AXES = 4; + static constexpr size_t ATTR_POS_BATCH_AXES = 5; + static constexpr size_t ATTR_POS_INPUT_LAYOUT = 6; }; class TORCH_CUDA_CU_API ExpandOp : public Expr { diff --git a/csrc/ir_nodes.cpp b/csrc/ir_nodes.cpp index 47e379a7d4d..523e6353e78 100644 --- a/csrc/ir_nodes.cpp +++ b/csrc/ir_nodes.cpp @@ -1326,6 +1326,251 @@ Val* GroupedWelfordOp::getInitValOfOutput(Val* output_val) const { NVFUSER_DEFINE_CLONE_AND_CREATE(GroupedWelfordOp) +//============================================================================================================================== + +// MmaOp utils +namespace MmaOpUtils { + +// The expected number of concrete domains for gemm +constexpr size_t expected_gemm_cdomains = 2; + +// A helper structure used to gather all data created during analysis +struct MmaOpDetails { + using AxesData = MmaOp::AxesData; + // Concrete axes from A that are broadcast in B and are not + // reduction in output + AxesData m_axes; + // Concrete axes from B that are broadcast in A and are not + // reduction in output + AxesData n_axes; + // Concrete axes from A that are concrete in B and are + // reduction in output + AxesData k_axes; + // Concrete or broadcast axes that are present in all inputs + // and output + AxesData batch_axes; + // A placeholder for mma input layout + c10::optional input_layout = c10::nullopt; +}; + +// A helper structure with pieces of information about TensorView +struct TensorViewDetails { + using AxesData = MmaOp::AxesData; + // Broadcast domains + AxesData bcasts; + // Reduction domains + AxesData rdomains; + // Concrete domains + AxesData cdomains; +}; + +// A helper for gathering details about TensorView object +TensorViewDetails getDetailsFor(const TensorView* tv) { + TensorViewDetails details; + using DimIdx = int; + for (DimIdx pos = 0; pos < static_cast(tv->nDims()); ++pos) { + const auto axis = tv->axis(pos); + if (axis->isReduction()) { + details.rdomains.push_back(pos); + continue; + } + if (axis->isBroadcast()) { + details.bcasts.push_back(pos); + continue; + } + details.cdomains.push_back(pos); + } + return details; +} + +MmaOptions::MmaInputLayout getInputLayout( + const TensorViewDetails& in_a, + const TensorViewDetails& in_b, + const MmaOp::AxesData& m_axes, + const MmaOp::AxesData& n_axes, + const MmaOp::AxesData& k_axes) { + // TT layout (b - broadcast, r - reduction): + // A = [M, K, b] + // B = [b, K, N] + // C = [M, r, N] + if ((m_axes.back() < in_a.bcasts.back()) && + (k_axes.back() < in_a.bcasts.back()) && + (in_b.bcasts.back() < k_axes.back()) && + (in_b.bcasts.back() < n_axes.back())) { + return MmaOptions::MmaInputLayout::TT; + } + // TN layout (b - broadcast, r - reduction): + // A = [M, b, K] + // B = [b, N, K] + // C = [M, N, r] + if ((m_axes.back() < in_a.bcasts.back()) && + (in_a.bcasts.back() < k_axes.back()) && + (in_b.bcasts.back() < n_axes.back()) && + (in_b.bcasts.back() < k_axes.back())) { + return MmaOptions::MmaInputLayout::TN; + } + // NT layout (b - broadcast, r - reduction): + // A = [K, M, b] + // B = [K, b, N] + // C = [r, M, N] + if ((k_axes.back() < in_a.bcasts.back()) && + (m_axes.back() < in_a.bcasts.back()) && + (k_axes.back() < in_b.bcasts.back()) && + (in_b.bcasts.back() < n_axes.back())) { + return MmaOptions::MmaInputLayout::NT; + } + + TORCH_INTERNAL_ASSERT(false, "Unsupported input layout"); +} + +MmaOpDetails getMmaOpDetails( + TensorView* out, + TensorView* in_a, + TensorView* in_b) { + const auto in_a_details = getDetailsFor(in_a); + const auto in_b_details = getDetailsFor(in_b); + const auto out_details = getDetailsFor(out); + + using AxesData = MmaOp::AxesData; + + const auto getMOrNaxes = [](const AxesData& cdomains, + const AxesData& bcasts, + const AxesData& rdomains) { + AxesData result; + // For all concrete domains + for (const auto& cdomain : cdomains) { + // That are in broadcast domains but are not in reduction domains + if ((std::find(bcasts.begin(), bcasts.end(), cdomain) != bcasts.end()) && + (std::find(rdomains.begin(), rdomains.end(), cdomain) == + rdomains.end())) { + result.push_back(cdomain); + } + } + return result; + }; + + const auto getKaxes = [](const AxesData& cdomains_a, + const AxesData& cdomains_b, + const AxesData& rdomains) { + AxesData result; + // For all concrete domains from in_a + for (const auto& cdomain_a : cdomains_a) { + // That are in concrete domains in in_b and are in reduction domains + if ((std::find(cdomains_b.begin(), cdomains_b.end(), cdomain_a) != + cdomains_b.end()) && + (std::find(rdomains.begin(), rdomains.end(), cdomain_a) != + rdomains.end())) { + result.push_back(cdomain_a); + } + } + return result; + }; + + const auto getBatchAxes = [](const TensorViewDetails& in_a_details, + const TensorViewDetails& in_b_details, + const TensorViewDetails& out_details) { + AxesData result; + // Batch candidates: + // concrete domains that are in all of inputs and output + for (const auto& domain : in_a_details.cdomains) { + if ((std::find( + in_b_details.cdomains.begin(), + in_b_details.cdomains.end(), + domain) != in_b_details.cdomains.end()) && + (std::find( + out_details.cdomains.begin(), + out_details.cdomains.end(), + domain) != out_details.cdomains.end())) { + result.push_back(domain); + } + } + // Batch candidates: + // broadcast domains that are in all of inputs and output + for (const auto& domain : in_a_details.bcasts) { + if ((std::find( + in_b_details.bcasts.begin(), + in_b_details.bcasts.end(), + domain) != in_b_details.bcasts.end()) && + (std::find( + out_details.bcasts.begin(), out_details.bcasts.end(), domain) != + out_details.bcasts.end())) { + result.push_back(domain); + } + } + std::sort(result.begin(), result.end()); + return result; + }; + + const auto validateInputDetails = [](const TensorViewDetails& details, + const std::string& desc) { + TORCH_INTERNAL_ASSERT( + !details.bcasts.empty(), desc, ": has no broadcast domains."); + TORCH_INTERNAL_ASSERT( + details.rdomains.empty(), desc, ": has reduction domains."); + TORCH_INTERNAL_ASSERT( + details.cdomains.size() >= expected_gemm_cdomains, + desc, + ": has unsupported number of concrete domains, expected at least ", + expected_gemm_cdomains, + ", got ", + details.cdomains.size()); + }; + + const auto validateOutputDetails = [](const TensorViewDetails& details, + const std::string& desc) { + // TODO: revise rules when add support for batch gemms + TORCH_INTERNAL_ASSERT( + details.bcasts.empty(), desc, ": has broadcast domains."); + TORCH_INTERNAL_ASSERT( + !details.rdomains.empty(), desc, ": has no reduction domains."); + TORCH_INTERNAL_ASSERT( + (details.cdomains.size() >= expected_gemm_cdomains), + desc, + ": has unsupported number of concrete domains, expected at least ", + expected_gemm_cdomains, + ", got ", + details.cdomains.size()); + }; + + validateInputDetails(in_a_details, "MmaOp input A"); + validateInputDetails(in_b_details, "MmaOp input B"); + validateOutputDetails(out_details, "MmaOp output"); + + MmaOpDetails details; + + // For details, check MmaOpDetails + details.m_axes = getMOrNaxes( + in_a_details.cdomains, in_b_details.bcasts, out_details.rdomains); + details.n_axes = getMOrNaxes( + in_b_details.cdomains, in_a_details.bcasts, out_details.rdomains); + details.k_axes = getKaxes( + in_a_details.cdomains, in_b_details.cdomains, out_details.rdomains); + details.batch_axes = getBatchAxes(in_a_details, in_b_details, out_details); + + TORCH_INTERNAL_ASSERT( + !details.m_axes.empty(), + "MmaOp inputs must define at least a single M dimension"); + TORCH_INTERNAL_ASSERT( + !details.n_axes.empty(), + "MmaOp inputs must define at least a single N dimension"); + TORCH_INTERNAL_ASSERT( + !details.k_axes.empty(), + "MmaOp inputs must define at least a single K dimension"); + + // TODO: for tensor contraction / split-k uses of MmaOp different input layout + // rules may be needed + details.input_layout = getInputLayout( + in_a_details, + in_b_details, + details.m_axes, + details.n_axes, + details.k_axes); + + return details; +} + +}; // namespace MmaOpUtils + MmaOp::MmaOp( IrBuilderPasskey passkey, Val* out, @@ -1336,7 +1581,8 @@ MmaOp::MmaOp( // Check output type TORCH_INTERNAL_ASSERT( out->getValType().value() == ValType::TensorView || - out->getValType().value() == ValType::TensorIndex); + out->getValType().value() == ValType::TensorIndex, + out->getValType().value()); TORCH_INTERNAL_ASSERT( in_a->getValType().value() == ValType::TensorView || @@ -1348,23 +1594,44 @@ MmaOp::MmaOp( in_b->getValType().value() == ValType::TensorIndex, in_b->getValType().value()); - const auto isBroadcastIn = [](const Val* val) { - if (val->getValType().value() == ValType::TensorView) { - const auto* tv = val->as(); - return tv->hasBroadcast(); - } - return true; - }; - - TORCH_INTERNAL_ASSERT(isBroadcastIn(in_a)); - TORCH_INTERNAL_ASSERT(isBroadcastIn(in_b)); + MmaOpUtils::MmaOpDetails mma_details; + // Detailed consistency checks for use case with TensorViews as inputs/output + if (in_a->isA() && in_b->isA() && + out->isA()) { + mma_details = MmaOpUtils::getMmaOpDetails( + out->as(), in_a->as(), in_b->as()); + } addOutput(out); addInput(in_a); addInput(in_b); + // ATTR_POS_INIT addAttribute(init); + // ATTR_POS_OPTS addAttribute( IrBuilder::create>(passkey.ir_container_)); + // ATTR_POS_M_AXES + addAttribute(IrBuilder::create>(passkey.ir_container_)); + // ATTR_POS_N_AXES + addAttribute(IrBuilder::create>(passkey.ir_container_)); + // ATTR_POS_K_AXES + addAttribute(IrBuilder::create>(passkey.ir_container_)); + // ATTR_POS_BATCH_AXES + addAttribute(IrBuilder::create>(passkey.ir_container_)); + // ATTR_POS_INPUT_LAYOUT + addAttribute( + IrBuilder::create>(passkey.ir_container_)); + + attribute(ATTR_POS_M_AXES)->as>()->value = + std::move(mma_details.m_axes); + attribute(ATTR_POS_N_AXES)->as>()->value = + std::move(mma_details.n_axes); + attribute(ATTR_POS_K_AXES)->as>()->value = + std::move(mma_details.k_axes); + attribute(ATTR_POS_BATCH_AXES)->as>()->value = + std::move(mma_details.batch_axes); + attribute(ATTR_POS_INPUT_LAYOUT)->as>()->value = + mma_details.input_layout; } MmaOp::MmaOp( @@ -1375,7 +1642,7 @@ MmaOp::MmaOp( Val* init, OptionsInMma options) : MmaOp(passkey, out, in_a, in_b, init) { - attribute(1)->as>()->value = options; + attribute(ATTR_POS_OPTS)->as>()->value = options; } std::string MmaOp::toString(int indent_size) const { @@ -1391,7 +1658,8 @@ std::string MmaOp::toInlineString(int indent_size) const { } void MmaOp::configureOptions(MmaOptions options) { - OptionsInMma& opt = attribute(1)->as>()->value; + OptionsInMma& opt = + attribute(ATTR_POS_OPTS)->as>()->value; TORCH_INTERNAL_ASSERT( options.macro != MmaOptions::MacroType::NoMMA, "Un-configured mma type from options."); diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index f9f56a59922..c8eebe6ae55 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -14,8 +14,6 @@ // 'SchedulerRuntimeInfo' #include -#include - namespace nvfuser { namespace { diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 0d2e2fa06b6..2a2a0322042 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -125,87 +125,6 @@ inline c10::optional getMmaOp( return c10::nullopt; } -//! A helper for checking if layout of MMA op's inputs. It will return optional -//! message if check fails. -LayoutData getInputsLayout(const MmaOp* mma_expr) { - std::stringstream ss; - const auto& mmaExprInputs = mma_expr->inputs(); - - const auto* in_A = mmaExprInputs[0]->as(); - const auto* in_B = mmaExprInputs[1]->as(); - - // The number of IterDomains of MMA inputs must be the same - if (in_A->nDims() != in_B->nDims()) { - ss << "Mma op inputs don't have the same number of IterDomains, 1st input(" - << std::to_string(in_A->nDims()) << "), 2nd input(" - << std::to_string(in_B->nDims()) + ")"; - return {c10::nullopt, ss.str()}; - } - - // The currently supported number of IterDomains per MMA op input is 3 - constexpr size_t supportedDims = 3; - if (in_A->nDims() != supportedDims) { - ss << "Mma op inputs have unsupported number of IterDomains, got: " - << std::to_string(in_A->nDims()) << ", expected " - << std::to_string(supportedDims); - return {c10::nullopt, ss.str()}; - } - - using AxisPos = decltype(std::declval().nDims()); - constexpr AxisPos unInitPos = -1; - AxisPos bcastInApos = unInitPos; - AxisPos bcastInBpos = unInitPos; - - // The first and the second input of MMA have the same number of - // IterDomains - for (AxisPos pos = 0; pos < in_A->nDims(); ++pos) { - if (in_A->axis(static_cast(pos))->isBroadcast()) { - if (bcastInApos != unInitPos) { - ss << "Mma op first input has more than one broadcast IterDomain: " - << std::to_string(bcastInApos) << " and " << std::to_string(pos); - return {c10::nullopt, ss.str()}; - } - bcastInApos = pos; - } - if (in_B->axis(static_cast(pos))->isBroadcast()) { - if (bcastInBpos != unInitPos) { - ss << "Mma op second input has more than one broadcast IterDomain: " - << std::to_string(bcastInBpos) << " and " << std::to_string(pos); - return {c10::nullopt, ss.str()}; - } - bcastInBpos = pos; - } - } - - // MMA inputs need to have broadcast IterDomains - if (bcastInApos == unInitPos || bcastInBpos == unInitPos) { - ss << "The " << (bcastInApos == unInitPos ? "first" : "second") - << " mma op has no broadcast IterDomain"; - return {c10::nullopt, ss.str()}; - } - - // MMA inputs must have supported data layout, defined in MatmulLayout - // MatmulLayout::TT - if (bcastInApos == static_cast(2) && - bcastInBpos == static_cast(0)) { - return {MatmulLayout::TT, c10::nullopt}; - } - // MatmulLayout::TN - if (bcastInApos == static_cast(1) && - bcastInBpos == static_cast(0)) { - return {MatmulLayout::TN, c10::nullopt}; - } - // MatmulLayout::NT - if (bcastInApos == static_cast(2) && - bcastInBpos == static_cast(1)) { - return {MatmulLayout::NT, c10::nullopt}; - } - - ss << "Unsupported layout, broadcasts: inputA(" << bcastInApos << "), inputB(" - << bcastInBpos << ")"; - return {c10::nullopt, ss.str()}; -} - //! A wrapper for core heuristics initialization inline bool initCoreHeuristics( std::shared_ptr params, @@ -345,7 +264,7 @@ c10::optional getProblemShape( const auto getShape = [&runtime_info](const TensorView* tv) { TensorShape tv_shape; const auto concrete_domains = TensorDomain::noReductions( - TensorDomain::noBroadcasts(tv->as()->domain()->domain())); + TensorDomain::noBroadcasts(tv->domain()->domain())); for (const auto* domain : concrete_domains) { const auto domain_extend = runtime_info.expressionEvaluator().evaluate(domain->extent()); @@ -436,7 +355,19 @@ std::string checkMatmulType(Fusion* fusion, const MmaOp* mma_expr) { constexpr size_t expected_number_of_inputs = 2; constexpr size_t expected_number_of_outputs = 1; - // Quick checks + // Quick checks - MmaOp + { + // Check if MmaOp processes single gemm + constexpr size_t expected_axes_numbers = 1; + if (mma_expr->mAxes().size() != expected_axes_numbers || + mma_expr->nAxes().size() != expected_axes_numbers || + mma_expr->kAxes().size() != expected_axes_numbers || + !mma_expr->batchAxes().empty()) { + return "MmaOp has unsupported number of one of M/N/K/Batch axes"; + } + } + + // Quick checks - Fusion { // Fusion can only have two TV inputs if (fusion_inputs.size() != fusion_inputs_tvs.size()) { @@ -446,12 +377,9 @@ std::string checkMatmulType(Fusion* fusion, const MmaOp* mma_expr) { return "Fusion inputs contain at least one non-TensorView object"; } - // Fusion can only have TVs as outputs, and there can be only one output - if (fusion_outputs_tvs.size() != fusion_outputs.size()) { - return "Fusion has output which is not a TensorView object"; - } + // Fusion has only TVs as outputs, and we expect only one object in the list if ((expected_number_of_outputs != fusion_outputs_tvs.size())) { - return "Fusion has more than a single TensorView object in outputs"; + return "Fusion has more than a single TensorView object in its outputs"; } // Each of fusion input TVs must have: @@ -548,9 +476,9 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { // #2 { for (const auto* mma_expr : mma_exprs) { - const auto layout_data = getInputsLayout(mma_expr); - if (layout_data.second) { - return layout_data.second.value(); + const auto input_layout = mma_expr->inputLayout(); + if (!input_layout) { + return "Failed to acquire inputs layout."; } } } @@ -578,16 +506,15 @@ std::shared_ptr getMatmulHeuristics( auto params = std::make_shared(); // Check initial conditions - const auto fusion_exprs = fusion->exprs(); - auto mma_exprs = ir_utils::filterByType(fusion_exprs).vector(); + auto mma_exprs = ir_utils::getMmaOps(fusion); TORCH_INTERNAL_ASSERT( mma_exprs.size() == 1, "Support only fusion with a single mma op."); - const auto layout = getInputsLayout(mma_exprs.front()); - TORCH_INTERNAL_ASSERT(!layout.second.has_value(), layout.second.value()); + const auto layout = mma_exprs.front()->inputLayout(); + TORCH_INTERNAL_ASSERT(layout.has_value(), "Failed to acquire inputs layout."); const auto problem_shape = getProblemShape( - fusion, mma_exprs[0]->as(), runtime_info, layout.first.value()); + fusion, mma_exprs[0]->as(), runtime_info, layout.value()); TORCH_INTERNAL_ASSERT( problem_shape.has_value(), "Failed to acquire problem shape."); @@ -599,7 +526,7 @@ std::shared_ptr getMatmulHeuristics( // Populate heuristic details auto status = initCoreHeuristics( - params, mma_op.value(), layout.first.value(), problem_shape.value()); + params, mma_op.value(), layout.value(), problem_shape.value()); TORCH_INTERNAL_ASSERT( status, "Core part of heuristics failed to initialize."); @@ -607,7 +534,7 @@ std::shared_ptr getMatmulHeuristics( TORCH_INTERNAL_ASSERT( status, "Additional part of heuristics failed to initialize."); - // set kernel index mode + // Set kernel index mode params->cparams.index_type = getIndexType(problem_shape.value()); if (isDebugDumpEnabled(DebugDumpOption::MatmulChecks)) { diff --git a/test/test_gpu_tensorcore.cpp b/test/test_gpu_tensorcore.cpp index 57cc8cc2ab7..9807980c312 100644 --- a/test/test_gpu_tensorcore.cpp +++ b/test/test_gpu_tensorcore.cpp @@ -3133,6 +3133,17 @@ TEST_F(NVFuserTest, FusionMatmulSegmenterBasicMatmulStrictCheckTT_CUDA) { fusion->addInput(tv1); fusion->addOutput(tv2); + TORCH_CHECK( + 1 == ir_utils::getMmaOps(fusion.get()).size(), + "matmul fusion must have at least one MmaOp"); + TORCH_CHECK( + ir_utils::getMmaOps(fusion.get()).front()->inputLayout().has_value(), + "input layout has not be set for MmaOp"); + TORCH_CHECK( + layout == + ir_utils::getMmaOps(fusion.get()).front()->inputLayout().value(), + "input layout from test and MmaOp do not match"); + at::manual_seed(0); at::Tensor t0 = matmulAtInput(M, N, K, layout, TensorMatmulPos::A, at::kHalf); @@ -3174,6 +3185,17 @@ TEST_F(NVFuserTest, FusionMatmulSegmenterBasicMatmulRelaxedCheck_CUDA) { fusion->addInput(tv1); fusion->addOutput(tv2); + TORCH_CHECK( + 1 == ir_utils::getMmaOps(fusion.get()).size(), + "matmul fusion must have at least one MmaOp"); + TORCH_CHECK( + ir_utils::getMmaOps(fusion.get()).front()->inputLayout().has_value(), + "input layout has not be set for MmaOp"); + TORCH_CHECK( + layout == + ir_utils::getMmaOps(fusion.get()).front()->inputLayout().value(), + "input layout from test and MmaOp do not match"); + at::manual_seed(0); at::Tensor t0 =