Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion csrc/ir_internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,8 @@ class TORCH_CUDA_CU_API MmaOp : public Expr {
}
};

using AxesData = std::vector<int>;
using MmaInputLayoutOpt = c10::optional<MmaOptions::MmaInputLayout>;
using Expr::Expr;

MmaOp(IrBuilderPasskey, Val* out, Val* in_a, Val* in_b, Val* init);
Expand Down Expand Up @@ -1082,14 +1084,48 @@ class TORCH_CUDA_CU_API MmaOp : public Expr {
}

const auto& options() const {
return attribute(1)->as<Attribute<OptionsInMma>>()->value;
return attribute(ATTR_POS_OPTS)->as<Attribute<OptionsInMma>>()->value;
}

auto accStride() const {
return options().accumulator_stride;
}

void configureOptions(MmaOptions options);

auto inputLayout() const {
return attribute(ATTR_POS_INPUT_LAYOUT)
->as<Attribute<MmaInputLayoutOpt>>()
->value;
}

const auto& mAxes() const {
return attribute(ATTR_POS_M_AXES)->as<Attribute<AxesData>>()->value;
}

const auto& nAxes() const {
return attribute(ATTR_POS_N_AXES)->as<Attribute<AxesData>>()->value;
}

const auto& kAxes() const {
return attribute(ATTR_POS_K_AXES)->as<Attribute<AxesData>>()->value;
}

const auto& batchAxes() const {
return attribute(ATTR_POS_BATCH_AXES)->as<Attribute<AxesData>>()->value;
}

private:
// Predefined idexes of attributes stored for this IR node, to avoid
// magic numbers, based on order in which attributes are initialized
// in constructor
static constexpr size_t ATTR_POS_INIT = 0;
static constexpr size_t ATTR_POS_OPTS = 1;
static constexpr size_t ATTR_POS_M_AXES = 2;
static constexpr size_t ATTR_POS_N_AXES = 3;
static constexpr size_t ATTR_POS_K_AXES = 4;
static constexpr size_t ATTR_POS_BATCH_AXES = 5;
static constexpr size_t ATTR_POS_INPUT_LAYOUT = 6;
};

class TORCH_CUDA_CU_API ExpandOp : public Expr {
Expand Down
294 changes: 281 additions & 13 deletions csrc/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,251 @@ Val* GroupedWelfordOp::getInitValOfOutput(Val* output_val) const {

NVFUSER_DEFINE_CLONE_AND_CREATE(GroupedWelfordOp)

//==============================================================================================================================

// MmaOp utils
namespace MmaOpUtils {

// The expected number of concrete domains for gemm
constexpr size_t expected_gemm_cdomains = 2;

// A helper structure used to gather all data created during analysis
struct MmaOpDetails {
using AxesData = MmaOp::AxesData;
// Concrete axes from A that are broadcast in B and are not
// reduction in output
AxesData m_axes;
// Concrete axes from B that are broadcast in A and are not
// reduction in output
AxesData n_axes;
// Concrete axes from A that are concrete in B and are
// reduction in output
AxesData k_axes;
// Concrete or broadcast axes that are present in all inputs
// and output
AxesData batch_axes;
// A placeholder for mma input layout
c10::optional<MmaOptions::MmaInputLayout> input_layout = c10::nullopt;
};

// A helper structure with pieces of information about TensorView
struct TensorViewDetails {
using AxesData = MmaOp::AxesData;
// Broadcast domains
AxesData bcasts;
// Reduction domains
AxesData rdomains;
// Concrete domains
AxesData cdomains;
};

// A helper for gathering details about TensorView object
TensorViewDetails getDetailsFor(const TensorView* tv) {
TensorViewDetails details;
using DimIdx = int;
for (DimIdx pos = 0; pos < static_cast<DimIdx>(tv->nDims()); ++pos) {
const auto axis = tv->axis(pos);
if (axis->isReduction()) {
details.rdomains.push_back(pos);
continue;
}
if (axis->isBroadcast()) {
details.bcasts.push_back(pos);
continue;
}
details.cdomains.push_back(pos);
}
return details;
}

MmaOptions::MmaInputLayout getInputLayout(
const TensorViewDetails& in_a,
const TensorViewDetails& in_b,
const MmaOp::AxesData& m_axes,
const MmaOp::AxesData& n_axes,
const MmaOp::AxesData& k_axes) {
// TT layout (b - broadcast, r - reduction):
// A = [M, K, b]
// B = [b, K, N]
// C = [M, r, N]
if ((m_axes.back() < in_a.bcasts.back()) &&
(k_axes.back() < in_a.bcasts.back()) &&
(in_b.bcasts.back() < k_axes.back()) &&
(in_b.bcasts.back() < n_axes.back())) {
return MmaOptions::MmaInputLayout::TT;
}
// TN layout (b - broadcast, r - reduction):
// A = [M, b, K]
// B = [b, N, K]
// C = [M, N, r]
if ((m_axes.back() < in_a.bcasts.back()) &&
(in_a.bcasts.back() < k_axes.back()) &&
(in_b.bcasts.back() < n_axes.back()) &&
(in_b.bcasts.back() < k_axes.back())) {
return MmaOptions::MmaInputLayout::TN;
}
// NT layout (b - broadcast, r - reduction):
// A = [K, M, b]
// B = [K, b, N]
// C = [r, M, N]
if ((k_axes.back() < in_a.bcasts.back()) &&
(m_axes.back() < in_a.bcasts.back()) &&
(k_axes.back() < in_b.bcasts.back()) &&
(in_b.bcasts.back() < n_axes.back())) {
return MmaOptions::MmaInputLayout::NT;
}

TORCH_INTERNAL_ASSERT(false, "Unsupported input layout");
}

MmaOpDetails getMmaOpDetails(
TensorView* out,
TensorView* in_a,
TensorView* in_b) {
const auto in_a_details = getDetailsFor(in_a);
const auto in_b_details = getDetailsFor(in_b);
const auto out_details = getDetailsFor(out);

using AxesData = MmaOp::AxesData;

const auto getMOrNaxes = [](const AxesData& cdomains,
const AxesData& bcasts,
const AxesData& rdomains) {
AxesData result;
// For all concrete domains
for (const auto& cdomain : cdomains) {
// That are in broadcast domains but are not in reduction domains
if ((std::find(bcasts.begin(), bcasts.end(), cdomain) != bcasts.end()) &&
(std::find(rdomains.begin(), rdomains.end(), cdomain) ==
rdomains.end())) {
result.push_back(cdomain);
}
}
return result;
};

const auto getKaxes = [](const AxesData& cdomains_a,
const AxesData& cdomains_b,
const AxesData& rdomains) {
AxesData result;
// For all concrete domains from in_a
for (const auto& cdomain_a : cdomains_a) {
// That are in concrete domains in in_b and are in reduction domains
if ((std::find(cdomains_b.begin(), cdomains_b.end(), cdomain_a) !=
cdomains_b.end()) &&
(std::find(rdomains.begin(), rdomains.end(), cdomain_a) !=
rdomains.end())) {
result.push_back(cdomain_a);
}
}
return result;
};

const auto getBatchAxes = [](const TensorViewDetails& in_a_details,
const TensorViewDetails& in_b_details,
const TensorViewDetails& out_details) {
AxesData result;
// Batch candidates:
// concrete domains that are in all of inputs and output
for (const auto& domain : in_a_details.cdomains) {
if ((std::find(
in_b_details.cdomains.begin(),
in_b_details.cdomains.end(),
domain) != in_b_details.cdomains.end()) &&
(std::find(
out_details.cdomains.begin(),
out_details.cdomains.end(),
domain) != out_details.cdomains.end())) {
result.push_back(domain);
}
}
// Batch candidates:
// broadcast domains that are in all of inputs and output
for (const auto& domain : in_a_details.bcasts) {
if ((std::find(
in_b_details.bcasts.begin(),
in_b_details.bcasts.end(),
domain) != in_b_details.bcasts.end()) &&
(std::find(
out_details.bcasts.begin(), out_details.bcasts.end(), domain) !=
out_details.bcasts.end())) {
result.push_back(domain);
}
}
std::sort(result.begin(), result.end());
return result;
};

const auto validateInputDetails = [](const TensorViewDetails& details,
const std::string& desc) {
TORCH_INTERNAL_ASSERT(
!details.bcasts.empty(), desc, ": has no broadcast domains.");
TORCH_INTERNAL_ASSERT(
details.rdomains.empty(), desc, ": has reduction domains.");
TORCH_INTERNAL_ASSERT(
details.cdomains.size() >= expected_gemm_cdomains,
desc,
": has unsupported number of concrete domains, expected at least ",
expected_gemm_cdomains,
", got ",
details.cdomains.size());
};

const auto validateOutputDetails = [](const TensorViewDetails& details,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the case of batched matmul, should we allowing having broadcast and more than two concrete domains?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have (yet) support for batch, but I checked what are the domains in the MmaOp output for single test with strided batches:

NVFuserTest.FusionAmpereStridedBatchedMatmulTN_CUDA

and it looks like this:

T4_l [ iS18{i0}, iS19{i2}, iS20{i6}, iS21{i3}, rS22{i4} ]

I'm not sure if the test represents the final approach for strided batches so broadcasts could appear there.

For now I will keep the current implementation but I will add comment:

// TODO: revise rules when add support for batch gemms

const std::string& desc) {
// TODO: revise rules when add support for batch gemms
TORCH_INTERNAL_ASSERT(
details.bcasts.empty(), desc, ": has broadcast domains.");
TORCH_INTERNAL_ASSERT(
!details.rdomains.empty(), desc, ": has no reduction domains.");
TORCH_INTERNAL_ASSERT(
(details.cdomains.size() >= expected_gemm_cdomains),
desc,
": has unsupported number of concrete domains, expected at least ",
expected_gemm_cdomains,
", got ",
details.cdomains.size());
};

validateInputDetails(in_a_details, "MmaOp input A");
validateInputDetails(in_b_details, "MmaOp input B");
validateOutputDetails(out_details, "MmaOp output");

MmaOpDetails details;

// For details, check MmaOpDetails
details.m_axes = getMOrNaxes(
in_a_details.cdomains, in_b_details.bcasts, out_details.rdomains);
details.n_axes = getMOrNaxes(
in_b_details.cdomains, in_a_details.bcasts, out_details.rdomains);
details.k_axes = getKaxes(
in_a_details.cdomains, in_b_details.cdomains, out_details.rdomains);
details.batch_axes = getBatchAxes(in_a_details, in_b_details, out_details);

TORCH_INTERNAL_ASSERT(
!details.m_axes.empty(),
"MmaOp inputs must define at least a single M dimension");
TORCH_INTERNAL_ASSERT(
!details.n_axes.empty(),
"MmaOp inputs must define at least a single N dimension");
TORCH_INTERNAL_ASSERT(
!details.k_axes.empty(),
"MmaOp inputs must define at least a single K dimension");

// TODO: for tensor contraction / split-k uses of MmaOp different input layout
// rules may be needed
details.input_layout = getInputLayout(
in_a_details,
in_b_details,
details.m_axes,
details.n_axes,
details.k_axes);

return details;
}

}; // namespace MmaOpUtils

MmaOp::MmaOp(
IrBuilderPasskey passkey,
Val* out,
Expand All @@ -1336,7 +1581,8 @@ MmaOp::MmaOp(
// Check output type
TORCH_INTERNAL_ASSERT(
out->getValType().value() == ValType::TensorView ||
out->getValType().value() == ValType::TensorIndex);
out->getValType().value() == ValType::TensorIndex,
out->getValType().value());

TORCH_INTERNAL_ASSERT(
in_a->getValType().value() == ValType::TensorView ||
Expand All @@ -1348,23 +1594,44 @@ MmaOp::MmaOp(
in_b->getValType().value() == ValType::TensorIndex,
in_b->getValType().value());

const auto isBroadcastIn = [](const Val* val) {
if (val->getValType().value() == ValType::TensorView) {
const auto* tv = val->as<TensorView>();
return tv->hasBroadcast();
}
return true;
};

TORCH_INTERNAL_ASSERT(isBroadcastIn(in_a));
TORCH_INTERNAL_ASSERT(isBroadcastIn(in_b));
MmaOpUtils::MmaOpDetails mma_details;
// Detailed consistency checks for use case with TensorViews as inputs/output
if (in_a->isA<TensorView>() && in_b->isA<TensorView>() &&
out->isA<TensorView>()) {
mma_details = MmaOpUtils::getMmaOpDetails(
out->as<TensorView>(), in_a->as<TensorView>(), in_b->as<TensorView>());
}

addOutput(out);
addInput(in_a);
addInput(in_b);
// ATTR_POS_INIT
addAttribute(init);
// ATTR_POS_OPTS
addAttribute(
IrBuilder::create<Attribute<OptionsInMma>>(passkey.ir_container_));
// ATTR_POS_M_AXES
addAttribute(IrBuilder::create<Attribute<AxesData>>(passkey.ir_container_));
// ATTR_POS_N_AXES
addAttribute(IrBuilder::create<Attribute<AxesData>>(passkey.ir_container_));
// ATTR_POS_K_AXES
addAttribute(IrBuilder::create<Attribute<AxesData>>(passkey.ir_container_));
// ATTR_POS_BATCH_AXES
addAttribute(IrBuilder::create<Attribute<AxesData>>(passkey.ir_container_));
// ATTR_POS_INPUT_LAYOUT
addAttribute(
IrBuilder::create<Attribute<MmaInputLayoutOpt>>(passkey.ir_container_));

attribute(ATTR_POS_M_AXES)->as<Attribute<AxesData>>()->value =
std::move(mma_details.m_axes);
attribute(ATTR_POS_N_AXES)->as<Attribute<AxesData>>()->value =
std::move(mma_details.n_axes);
attribute(ATTR_POS_K_AXES)->as<Attribute<AxesData>>()->value =
std::move(mma_details.k_axes);
attribute(ATTR_POS_BATCH_AXES)->as<Attribute<AxesData>>()->value =
std::move(mma_details.batch_axes);
attribute(ATTR_POS_INPUT_LAYOUT)->as<Attribute<MmaInputLayoutOpt>>()->value =
mma_details.input_layout;
}

MmaOp::MmaOp(
Expand All @@ -1375,7 +1642,7 @@ MmaOp::MmaOp(
Val* init,
OptionsInMma options)
: MmaOp(passkey, out, in_a, in_b, init) {
attribute(1)->as<Attribute<OptionsInMma>>()->value = options;
attribute(ATTR_POS_OPTS)->as<Attribute<OptionsInMma>>()->value = options;
}

std::string MmaOp::toString(int indent_size) const {
Expand All @@ -1391,7 +1658,8 @@ std::string MmaOp::toInlineString(int indent_size) const {
}

void MmaOp::configureOptions(MmaOptions options) {
OptionsInMma& opt = attribute(1)->as<Attribute<OptionsInMma>>()->value;
OptionsInMma& opt =
attribute(ATTR_POS_OPTS)->as<Attribute<OptionsInMma>>()->value;
TORCH_INTERNAL_ASSERT(
options.macro != MmaOptions::MacroType::NoMMA,
"Un-configured mma type from options.");
Expand Down
Loading