-
Notifications
You must be signed in to change notification settings - Fork 79
MmaOp - consistency checks for basic matmul #131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1326,6 +1326,251 @@ Val* GroupedWelfordOp::getInitValOfOutput(Val* output_val) const { | |
|
|
||
| NVFUSER_DEFINE_CLONE_AND_CREATE(GroupedWelfordOp) | ||
|
|
||
| //============================================================================================================================== | ||
|
|
||
| // MmaOp utils | ||
| namespace MmaOpUtils { | ||
|
|
||
| // The expected number of concrete domains for gemm | ||
| constexpr size_t expected_gemm_cdomains = 2; | ||
|
|
||
| // A helper structure used to gather all data created during analysis | ||
| struct MmaOpDetails { | ||
zasdfgbnm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| using AxesData = MmaOp::AxesData; | ||
| // Concrete axes from A that are broadcast in B and are not | ||
| // reduction in output | ||
| AxesData m_axes; | ||
| // Concrete axes from B that are broadcast in A and are not | ||
| // reduction in output | ||
| AxesData n_axes; | ||
| // Concrete axes from A that are concrete in B and are | ||
| // reduction in output | ||
| AxesData k_axes; | ||
| // Concrete or broadcast axes that are present in all inputs | ||
| // and output | ||
| AxesData batch_axes; | ||
| // A placeholder for mma input layout | ||
| c10::optional<MmaOptions::MmaInputLayout> input_layout = c10::nullopt; | ||
| }; | ||
|
|
||
| // A helper structure with pieces of information about TensorView | ||
| struct TensorViewDetails { | ||
| using AxesData = MmaOp::AxesData; | ||
| // Broadcast domains | ||
| AxesData bcasts; | ||
| // Reduction domains | ||
| AxesData rdomains; | ||
| // Concrete domains | ||
| AxesData cdomains; | ||
| }; | ||
|
|
||
| // A helper for gathering details about TensorView object | ||
| TensorViewDetails getDetailsFor(const TensorView* tv) { | ||
| TensorViewDetails details; | ||
| using DimIdx = int; | ||
| for (DimIdx pos = 0; pos < static_cast<DimIdx>(tv->nDims()); ++pos) { | ||
| const auto axis = tv->axis(pos); | ||
| if (axis->isReduction()) { | ||
| details.rdomains.push_back(pos); | ||
| continue; | ||
| } | ||
| if (axis->isBroadcast()) { | ||
| details.bcasts.push_back(pos); | ||
| continue; | ||
| } | ||
| details.cdomains.push_back(pos); | ||
| } | ||
| return details; | ||
| } | ||
|
|
||
| MmaOptions::MmaInputLayout getInputLayout( | ||
| const TensorViewDetails& in_a, | ||
| const TensorViewDetails& in_b, | ||
| const MmaOp::AxesData& m_axes, | ||
| const MmaOp::AxesData& n_axes, | ||
| const MmaOp::AxesData& k_axes) { | ||
| // TT layout (b - broadcast, r - reduction): | ||
| // A = [M, K, b] | ||
| // B = [b, K, N] | ||
| // C = [M, r, N] | ||
| if ((m_axes.back() < in_a.bcasts.back()) && | ||
| (k_axes.back() < in_a.bcasts.back()) && | ||
| (in_b.bcasts.back() < k_axes.back()) && | ||
| (in_b.bcasts.back() < n_axes.back())) { | ||
| return MmaOptions::MmaInputLayout::TT; | ||
| } | ||
| // TN layout (b - broadcast, r - reduction): | ||
| // A = [M, b, K] | ||
| // B = [b, N, K] | ||
| // C = [M, N, r] | ||
| if ((m_axes.back() < in_a.bcasts.back()) && | ||
| (in_a.bcasts.back() < k_axes.back()) && | ||
| (in_b.bcasts.back() < n_axes.back()) && | ||
| (in_b.bcasts.back() < k_axes.back())) { | ||
| return MmaOptions::MmaInputLayout::TN; | ||
| } | ||
| // NT layout (b - broadcast, r - reduction): | ||
| // A = [K, M, b] | ||
| // B = [K, b, N] | ||
| // C = [r, M, N] | ||
| if ((k_axes.back() < in_a.bcasts.back()) && | ||
| (m_axes.back() < in_a.bcasts.back()) && | ||
| (k_axes.back() < in_b.bcasts.back()) && | ||
| (in_b.bcasts.back() < n_axes.back())) { | ||
| return MmaOptions::MmaInputLayout::NT; | ||
| } | ||
|
|
||
| TORCH_INTERNAL_ASSERT(false, "Unsupported input layout"); | ||
| } | ||
|
|
||
| MmaOpDetails getMmaOpDetails( | ||
| TensorView* out, | ||
| TensorView* in_a, | ||
| TensorView* in_b) { | ||
| const auto in_a_details = getDetailsFor(in_a); | ||
| const auto in_b_details = getDetailsFor(in_b); | ||
| const auto out_details = getDetailsFor(out); | ||
|
|
||
| using AxesData = MmaOp::AxesData; | ||
|
|
||
| const auto getMOrNaxes = [](const AxesData& cdomains, | ||
| const AxesData& bcasts, | ||
| const AxesData& rdomains) { | ||
| AxesData result; | ||
| // For all concrete domains | ||
| for (const auto& cdomain : cdomains) { | ||
| // That are in broadcast domains but are not in reduction domains | ||
| if ((std::find(bcasts.begin(), bcasts.end(), cdomain) != bcasts.end()) && | ||
| (std::find(rdomains.begin(), rdomains.end(), cdomain) == | ||
| rdomains.end())) { | ||
| result.push_back(cdomain); | ||
| } | ||
| } | ||
| return result; | ||
| }; | ||
|
|
||
| const auto getKaxes = [](const AxesData& cdomains_a, | ||
| const AxesData& cdomains_b, | ||
| const AxesData& rdomains) { | ||
| AxesData result; | ||
| // For all concrete domains from in_a | ||
| for (const auto& cdomain_a : cdomains_a) { | ||
| // That are in concrete domains in in_b and are in reduction domains | ||
| if ((std::find(cdomains_b.begin(), cdomains_b.end(), cdomain_a) != | ||
| cdomains_b.end()) && | ||
| (std::find(rdomains.begin(), rdomains.end(), cdomain_a) != | ||
| rdomains.end())) { | ||
| result.push_back(cdomain_a); | ||
| } | ||
| } | ||
| return result; | ||
| }; | ||
|
|
||
| const auto getBatchAxes = [](const TensorViewDetails& in_a_details, | ||
| const TensorViewDetails& in_b_details, | ||
| const TensorViewDetails& out_details) { | ||
| AxesData result; | ||
| // Batch candidates: | ||
| // concrete domains that are in all of inputs and output | ||
| for (const auto& domain : in_a_details.cdomains) { | ||
| if ((std::find( | ||
| in_b_details.cdomains.begin(), | ||
| in_b_details.cdomains.end(), | ||
| domain) != in_b_details.cdomains.end()) && | ||
| (std::find( | ||
| out_details.cdomains.begin(), | ||
| out_details.cdomains.end(), | ||
| domain) != out_details.cdomains.end())) { | ||
| result.push_back(domain); | ||
| } | ||
| } | ||
| // Batch candidates: | ||
| // broadcast domains that are in all of inputs and output | ||
| for (const auto& domain : in_a_details.bcasts) { | ||
| if ((std::find( | ||
| in_b_details.bcasts.begin(), | ||
| in_b_details.bcasts.end(), | ||
| domain) != in_b_details.bcasts.end()) && | ||
| (std::find( | ||
| out_details.bcasts.begin(), out_details.bcasts.end(), domain) != | ||
| out_details.bcasts.end())) { | ||
| result.push_back(domain); | ||
| } | ||
| } | ||
| std::sort(result.begin(), result.end()); | ||
| return result; | ||
| }; | ||
|
|
||
| const auto validateInputDetails = [](const TensorViewDetails& details, | ||
| const std::string& desc) { | ||
| TORCH_INTERNAL_ASSERT( | ||
| !details.bcasts.empty(), desc, ": has no broadcast domains."); | ||
| TORCH_INTERNAL_ASSERT( | ||
| details.rdomains.empty(), desc, ": has reduction domains."); | ||
| TORCH_INTERNAL_ASSERT( | ||
| details.cdomains.size() >= expected_gemm_cdomains, | ||
| desc, | ||
| ": has unsupported number of concrete domains, expected at least ", | ||
| expected_gemm_cdomains, | ||
| ", got ", | ||
| details.cdomains.size()); | ||
| }; | ||
|
|
||
| const auto validateOutputDetails = [](const TensorViewDetails& details, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for the case of batched matmul, should we allowing having broadcast and more than two concrete domains?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't have (yet) support for batch, but I checked what are the domains in the MmaOp output for single test with strided batches: and it looks like this: I'm not sure if the test represents the final approach for strided batches so broadcasts could appear there. For now I will keep the current implementation but I will add comment: |
||
| const std::string& desc) { | ||
| // TODO: revise rules when add support for batch gemms | ||
| TORCH_INTERNAL_ASSERT( | ||
| details.bcasts.empty(), desc, ": has broadcast domains."); | ||
| TORCH_INTERNAL_ASSERT( | ||
| !details.rdomains.empty(), desc, ": has no reduction domains."); | ||
| TORCH_INTERNAL_ASSERT( | ||
| (details.cdomains.size() >= expected_gemm_cdomains), | ||
| desc, | ||
| ": has unsupported number of concrete domains, expected at least ", | ||
| expected_gemm_cdomains, | ||
| ", got ", | ||
| details.cdomains.size()); | ||
| }; | ||
|
|
||
| validateInputDetails(in_a_details, "MmaOp input A"); | ||
| validateInputDetails(in_b_details, "MmaOp input B"); | ||
| validateOutputDetails(out_details, "MmaOp output"); | ||
|
|
||
| MmaOpDetails details; | ||
|
|
||
| // For details, check MmaOpDetails | ||
| details.m_axes = getMOrNaxes( | ||
| in_a_details.cdomains, in_b_details.bcasts, out_details.rdomains); | ||
| details.n_axes = getMOrNaxes( | ||
| in_b_details.cdomains, in_a_details.bcasts, out_details.rdomains); | ||
| details.k_axes = getKaxes( | ||
| in_a_details.cdomains, in_b_details.cdomains, out_details.rdomains); | ||
| details.batch_axes = getBatchAxes(in_a_details, in_b_details, out_details); | ||
|
|
||
| TORCH_INTERNAL_ASSERT( | ||
| !details.m_axes.empty(), | ||
| "MmaOp inputs must define at least a single M dimension"); | ||
| TORCH_INTERNAL_ASSERT( | ||
| !details.n_axes.empty(), | ||
| "MmaOp inputs must define at least a single N dimension"); | ||
| TORCH_INTERNAL_ASSERT( | ||
| !details.k_axes.empty(), | ||
| "MmaOp inputs must define at least a single K dimension"); | ||
|
|
||
| // TODO: for tensor contraction / split-k uses of MmaOp different input layout | ||
| // rules may be needed | ||
| details.input_layout = getInputLayout( | ||
| in_a_details, | ||
| in_b_details, | ||
| details.m_axes, | ||
| details.n_axes, | ||
| details.k_axes); | ||
|
|
||
| return details; | ||
| } | ||
|
|
||
| }; // namespace MmaOpUtils | ||
|
|
||
| MmaOp::MmaOp( | ||
| IrBuilderPasskey passkey, | ||
| Val* out, | ||
|
|
@@ -1336,7 +1581,8 @@ MmaOp::MmaOp( | |
| // Check output type | ||
| TORCH_INTERNAL_ASSERT( | ||
| out->getValType().value() == ValType::TensorView || | ||
| out->getValType().value() == ValType::TensorIndex); | ||
| out->getValType().value() == ValType::TensorIndex, | ||
| out->getValType().value()); | ||
|
|
||
| TORCH_INTERNAL_ASSERT( | ||
| in_a->getValType().value() == ValType::TensorView || | ||
|
|
@@ -1348,23 +1594,44 @@ MmaOp::MmaOp( | |
| in_b->getValType().value() == ValType::TensorIndex, | ||
| in_b->getValType().value()); | ||
|
|
||
| const auto isBroadcastIn = [](const Val* val) { | ||
| if (val->getValType().value() == ValType::TensorView) { | ||
| const auto* tv = val->as<TensorView>(); | ||
| return tv->hasBroadcast(); | ||
| } | ||
| return true; | ||
| }; | ||
|
|
||
| TORCH_INTERNAL_ASSERT(isBroadcastIn(in_a)); | ||
| TORCH_INTERNAL_ASSERT(isBroadcastIn(in_b)); | ||
| MmaOpUtils::MmaOpDetails mma_details; | ||
| // Detailed consistency checks for use case with TensorViews as inputs/output | ||
| if (in_a->isA<TensorView>() && in_b->isA<TensorView>() && | ||
| out->isA<TensorView>()) { | ||
| mma_details = MmaOpUtils::getMmaOpDetails( | ||
| out->as<TensorView>(), in_a->as<TensorView>(), in_b->as<TensorView>()); | ||
| } | ||
|
|
||
| addOutput(out); | ||
| addInput(in_a); | ||
| addInput(in_b); | ||
| // ATTR_POS_INIT | ||
| addAttribute(init); | ||
| // ATTR_POS_OPTS | ||
| addAttribute( | ||
| IrBuilder::create<Attribute<OptionsInMma>>(passkey.ir_container_)); | ||
| // ATTR_POS_M_AXES | ||
| addAttribute(IrBuilder::create<Attribute<AxesData>>(passkey.ir_container_)); | ||
| // ATTR_POS_N_AXES | ||
| addAttribute(IrBuilder::create<Attribute<AxesData>>(passkey.ir_container_)); | ||
| // ATTR_POS_K_AXES | ||
| addAttribute(IrBuilder::create<Attribute<AxesData>>(passkey.ir_container_)); | ||
| // ATTR_POS_BATCH_AXES | ||
| addAttribute(IrBuilder::create<Attribute<AxesData>>(passkey.ir_container_)); | ||
| // ATTR_POS_INPUT_LAYOUT | ||
| addAttribute( | ||
| IrBuilder::create<Attribute<MmaInputLayoutOpt>>(passkey.ir_container_)); | ||
|
|
||
| attribute(ATTR_POS_M_AXES)->as<Attribute<AxesData>>()->value = | ||
| std::move(mma_details.m_axes); | ||
| attribute(ATTR_POS_N_AXES)->as<Attribute<AxesData>>()->value = | ||
| std::move(mma_details.n_axes); | ||
| attribute(ATTR_POS_K_AXES)->as<Attribute<AxesData>>()->value = | ||
| std::move(mma_details.k_axes); | ||
| attribute(ATTR_POS_BATCH_AXES)->as<Attribute<AxesData>>()->value = | ||
| std::move(mma_details.batch_axes); | ||
| attribute(ATTR_POS_INPUT_LAYOUT)->as<Attribute<MmaInputLayoutOpt>>()->value = | ||
| mma_details.input_layout; | ||
| } | ||
|
|
||
| MmaOp::MmaOp( | ||
|
|
@@ -1375,7 +1642,7 @@ MmaOp::MmaOp( | |
| Val* init, | ||
| OptionsInMma options) | ||
| : MmaOp(passkey, out, in_a, in_b, init) { | ||
| attribute(1)->as<Attribute<OptionsInMma>>()->value = options; | ||
| attribute(ATTR_POS_OPTS)->as<Attribute<OptionsInMma>>()->value = options; | ||
| } | ||
|
|
||
| std::string MmaOp::toString(int indent_size) const { | ||
|
|
@@ -1391,7 +1658,8 @@ std::string MmaOp::toInlineString(int indent_size) const { | |
| } | ||
|
|
||
| void MmaOp::configureOptions(MmaOptions options) { | ||
| OptionsInMma& opt = attribute(1)->as<Attribute<OptionsInMma>>()->value; | ||
| OptionsInMma& opt = | ||
| attribute(ATTR_POS_OPTS)->as<Attribute<OptionsInMma>>()->value; | ||
| TORCH_INTERNAL_ASSERT( | ||
| options.macro != MmaOptions::MacroType::NoMMA, | ||
| "Un-configured mma type from options."); | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.