Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions benchmarks/cpp/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,9 +400,6 @@ static void NvFuserScheduler_Matmul(
NVFUSER_BENCHMARK_ARCH_SMEM_GUARD(
8, 0, getSmemSize(cta_tile, number_of_stage), benchmark_state);

DisableOptionsGuard dog;
DisableOptionsGuard::getCurOptions().set(DisableOption::MatmulExprEval);

// Run benchmark:
if (partitionedk) {
SingleMatmulPartitionedK(benchmark_state, layout, params, splitk_factor);
Expand Down
2 changes: 1 addition & 1 deletion csrc/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ enum class AllocationType : int {
// pointer arithmetic of the input. In this case, aliased_io is a non-null
// tensor.
// 2. To evaluate output tensors which are not aliases. For example, default
// scheduling in matmul when DisableOption::MatmulExprEval is not set.
// scheduling for MatmulOp/LinearOp in ExprEval scheduler.
Evaluate,
};

Expand Down
3 changes: 1 addition & 2 deletions csrc/ir/internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,7 @@ class NVF_API UnaryOp : public Expr {

std::vector<PolymorphicValue> evaluate(
const ExpressionEvaluator& ee,
std::unordered_map<const Val*, PolymorphicValue>& known_values)
const override;
const std::vector<PolymorphicValue>& inputs) const override;

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;
Expand Down
74 changes: 2 additions & 72 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,80 +389,10 @@ UnaryOp::UnaryOp(IrBuilderPasskey passkey, UnaryOpType type, Val* out, Val* in)

std::vector<PolymorphicValue> UnaryOp::evaluate(
const ExpressionEvaluator& ee,
std::unordered_map<const Val*, PolymorphicValue>& known_values) const {
const std::vector<PolymorphicValue>& inputs) const {
using namespace PolymorphicValue_functions;

// If the UnaryOp is CastOp, check if the preceding pattern of
// operators matches with matmul (MmaOp(Broadcast (A), Broadcast(B)) -> Cast)
// or matmul + bias (BinaryOp::Add (MmaOp(Broadcast (A), Broadcast(B),
// Broadcast(bias)) -> Cast) If not, evaluate UnaryOp::CastOp along with the
// other types by evaluating the immediate input.

// Check if the unary op is a cast from fp32 to lower precision.
auto is_downcast = [this]() -> bool {
if (getUnaryOpType() != UnaryOpType::Cast) {
return false;
}
auto in_dtype = input(0)->getDataType().value();
return (
in_dtype == DataType::Float &&
isInclusiveType(*(out()->getDataType()), in_dtype));
};

if (is_downcast() && input(0)->definition() != nullptr) {
MmaOpUtils::MatmulInputs matmul_inp;

if (MmaOpUtils::matchMatmulPatterns(this, &matmul_inp)) {
// Inputs to the pattern are of the shape [M, K] x [K, N] (matmul) / [M,
// K] x [N, K] (linear). Note: alpha, beta parameters are nullptr for
// linear.
const auto a =
ee.evaluate(matmul_inp.mma_lhs, known_values).as<at::Tensor>();
const auto b =
ee.evaluate(matmul_inp.mma_rhs, known_values).as<at::Tensor>();
const c10::Scalar alpha = matmul_inp.alpha
? toScalar(ee.evaluate(matmul_inp.alpha, known_values))
: 1;

// Matmul/Addmm: n_pos=2, k_pos=1
// Linear: n_pos=1, k_pos=2
const int k_pos =
std::get<(size_t)MatmulDomain::K>(matmul_inp.mma_dims_pos);
const int n_pos =
std::get<(size_t)MatmulDomain::N>(matmul_inp.mma_dims_pos);

if (matmul_inp.bias == nullptr) {
auto out = k_pos < n_pos ? alpha * a.matmul(b) : at::linear(a, b);
return {out};
}

auto bias = ee.evaluate(matmul_inp.bias, known_values).as<at::Tensor>();

// Linear takes 1D bias. Unsqueeze for 1D bias in matmul/addmm.
if (bias.dim() != a.dim() && (k_pos < n_pos)) {
// Unsqueeze the broadcast dimensions.
// For 2D inputs to the pattern, bias is of shape [M,1]/[1,N]
for (auto dim :
c10::irange((int64_t)matmul_inp.bias_bcast_flags.size())) {
if (matmul_inp.bias_bcast_flags[dim]) {
bias = bias.unsqueeze(dim);
}
}
}

const c10::Scalar beta = matmul_inp.beta
? toScalar(ee.evaluate(matmul_inp.beta, known_values))
: 1;

auto out = k_pos < n_pos ? at::addmm(bias, a, b, beta, alpha)
: at::linear(a, b, bias);
return {out};
}
}

// If there is not a preceding MmaOp, evaluate immediate inputs and compute
// the output for unary ops.
const auto& in = ee.evaluate(inputs().at(0), known_values);
const auto& in = inputs.at(0);
if (!in.hasValue()) {
return {std::monostate{}};
}
Expand Down
208 changes: 0 additions & 208 deletions csrc/ir/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1313,212 +1313,4 @@ MmaOpDetails getMmaOpDetails(
return details;
}

namespace {
// Returns the position of M,N,K axis in mma operands.
std::tuple<int, int, int> getMmaDimsPositions(MmaOp* mma) {
auto mma_domains = mma_utils::getProblemIterDomains(mma->fusion());
NVF_ERROR(mma_domains.isValid(), mma_domains.getErrorMsg());

const auto domains_data = mma_domains.getData();
const auto m_id = domains_data[(size_t)MatmulDomain::M];
const auto n_id = domains_data[(size_t)MatmulDomain::N];
const auto k_id = domains_data[(size_t)MatmulDomain::K];

int m_pos = -1;
int n_pos = -1;
int k_pos = -1;

auto out_tv = mma->out()->as<TensorView>();
int ndims = (int)out_tv->nDims();

for (auto idx : c10::irange(ndims)) {
auto id = out_tv->axis(idx);
// Categorize each original iterdomain position
if (m_id->sameAs(id)) {
m_pos = idx;
} else if (n_id->sameAs(id)) {
n_pos = idx;
} else if (k_id->sameAs(id)) {
k_pos = idx;
}
}

NVF_ERROR(
m_pos != -1 && n_pos != -1 && k_pos != -1,
"Valid index not found for all problem iterdomains.")
return {m_pos, n_pos, k_pos};
}
} // namespace

// Verifies the assumptions made when evaluating a fusion containing MmaOp:
// 1. MmaOp is preceded by a broadcast.
// 2. The inputs to MmaOp are broadcasted as the last dim for the first operand
// and the first dim for the second operand.
// The inputs of MmaOp will be [M, K, 1] x [1, K, N].
// Additionally, the inputs to the MmaOp should be of `expected_input_dtype`.
// This is the same as the output dtype of the final castOp.
void verifyMmaOpForEvaluation(
MmaOp* mma_op,
const DataType expected_input_dtype) {
const Val* in_a = mma_op->inA();
const Val* in_b = mma_op->inB();

const auto tv_a = in_a->as<TensorView>();
const auto tv_b = in_b->as<TensorView>();

NVF_ERROR(
tv_a->nDims() == tv_b->nDims(),
"Either both or none of A and B should be batch");
// Verify that the broadcasted size is 3.
NVF_ERROR(
tv_a->nDims() == 3,
"MmaOp::evaluate is not implemented for size: ",
tv_a->nDims());

NVF_ERROR(
in_a->definition() != nullptr && in_a->definition()->isA<BroadcastOp>(),
"Currently, MmaOp::evaluate assumes the preceding op to be a broadcast.");
NVF_ERROR(
in_b->definition() != nullptr && in_b->definition()->isA<BroadcastOp>(),
"Currently, MmaOp::evaluate assumes the preceding op to be a broadcast.");

NVF_ERROR(
tv_a->getRootDomain().back()->isBroadcast() ||
tv_a->getRootDomain()[1]->isBroadcast(),
"Expected middle/last dimension to be broadcasted for first operand.");

NVF_ERROR(
tv_b->getRootDomain().front()->isBroadcast(),
"Expected first dimension to be broadcasted for second operand.");

// ATen preserves the dtype of MmaOp inputs whereas MmaOp generates float
// outputs. To preserve numerical equivalence and precision, the output of
// ATen matmul should be the same as MmaOp out `eventually`.
// See https://github.com/NVIDIA/Fuser/pull/1874#discussion_r1516991574
// Supported cases:
// 1. MmaOp->out() and MmaOp->input() are the same dtype.
// 2. MmaOp->out() is followed by a CastOp() to the MmaOp->input() dtype.
// NOTE: Currently MmaOp only accepts Half and BFloat16 so case (1) does not
// occur.

NVF_ERROR(
*(tv_a->getDataType()) == *(tv_b->getDataType()),
"MmaOp inputs should be of the same dtype.")
NVF_ERROR(
*(tv_a->getDataType()) == expected_input_dtype,
"MmaOp inputs should be the same dtype as the output dtype of the final castOp.");
}

// Possible combinations:
// 1. A x B + C
// 2. alpha * A x B + C
// 3. A x B + beta * C
// 4. alpha * A x B + beta * C
// 5. A x B
// 6. alpha * A x B
// Note: We assume the first operand to be the MmaOp output
bool matchMatmulPatterns(const UnaryOp* cast_op, MatmulInputs* matmul_inp) {
// Check if there may be a bias present.
bool has_bias = true;
auto* binary = dynamic_cast<BinaryOp*>(cast_op->input(0)->definition());
if (binary == nullptr) {
// Bias is not present
has_bias = false;
} else if (binary->getBinaryOpType() != BinaryOpType::Add) {
return false;
}

// Check for alpha in first input: alpha * (MmaOp(A, B))
MmaOp* mma = nullptr;
auto* mma_branch_root_op = has_bias ? (Expr*)binary : (Expr*)cast_op;

auto* mul_alpha =
dynamic_cast<BinaryOp*>(mma_branch_root_op->input(0)->definition());

if (mul_alpha == nullptr) { // Alpha is not present
mma = dynamic_cast<MmaOp*>(mma_branch_root_op->input(0)->definition());
} else {
NVF_ERROR(
mul_alpha->getBinaryOpType() == BinaryOpType::Mul,
"Unrecognized pattern.");
matmul_inp->alpha = mul_alpha->input(0);
mma = dynamic_cast<MmaOp*>(mul_alpha->input(1)->definition());
if (!matmul_inp->alpha->isScalar()) { // Swap alpha and mma
matmul_inp->alpha = mul_alpha->input(1);
mma = dynamic_cast<MmaOp*>(mul_alpha->input(0)->definition());
}
}

if (mma == nullptr) {
return false;
}

DataType final_out_dtype = cast_op->out()->getDataType().value();

// Verify assumptions for MmaOp hold. Assign the values to Mma operands.
MmaOpUtils::verifyMmaOpForEvaluation(mma, final_out_dtype);

// Get the non-broadcasted values to avoid inferring squeeze dimensions.
matmul_inp->mma_lhs = mma->inA()->definition()->input(0);
matmul_inp->mma_rhs = mma->inB()->definition()->input(0);
matmul_inp->mma_dims_pos = getMmaDimsPositions(mma);

NVF_ERROR(
std::get<(size_t)MatmulDomain::M>(matmul_inp->mma_dims_pos) == 0,
"Expected M to be the first dimension.");

if (!has_bias) {
return true;
}

// Based on the presence of beta parameter, the expected ops are:
// CastOp(bias, fp32) -> Broadcast (Optional) -> Mul (if beta is present)
// -> Add

// Check for beta parameter
auto* mul_beta = dynamic_cast<BinaryOp*>(binary->input(1)->definition());
if (mul_beta == nullptr) { // Case 1: bias
matmul_inp->bias = binary->input(1); // Broadcasted bias tensor in fp32
} else { // Case 2: beta * bias
NVF_ERROR(
mul_beta->getBinaryOpType() == BinaryOpType::Mul,
"Unrecognized pattern.");
matmul_inp->beta = mul_beta->input(0);
matmul_inp->bias = mul_beta->input(1);
if (!matmul_inp->beta->isScalar()) {
// bias * beta
std::swap(matmul_inp->beta, matmul_inp->bias);
}
}

auto bias_ndims = matmul_inp->bias->as<TensorView>()->nDims();
auto inp_ndims = matmul_inp->mma_lhs->as<TensorView>()->nDims();

NVF_ERROR(
(bias_ndims == inp_ndims - 1) || (bias_ndims == inp_ndims),
"Bias should be 1D / 2D tensor.");

// Check if bias was broadcasted
auto* bcast = dynamic_cast<BroadcastOp*>(matmul_inp->bias->definition());
if (bcast != nullptr) {
// Bias of shape [M, 1] / [1, N]
matmul_inp->bias_bcast_flags = bcast->getBroadcastDimFlags();
matmul_inp->bias = bcast->input(0); // Bias tensor in fp32
}

auto* bias_cast = dynamic_cast<UnaryOp*>(matmul_inp->bias->definition());

// The bias tensor and matmul inputs should be of the same dtype.
NVF_ERROR(
bias_cast == nullptr || bias_cast->getUnaryOpType() == UnaryOpType::Cast,
"Expected the bias tensor to be casted to Float.");
NVF_ERROR(
*(bias_cast->input(0)->getDataType()) == final_out_dtype,
"Bias should be originally of the same type as the final output dtype.");

matmul_inp->bias = bias_cast->input(0);

return true;
}

} // namespace nvfuser::MmaOpUtils
13 changes: 0 additions & 13 deletions csrc/scheduler/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -761,19 +761,6 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) {
"scheduleMatmul supports fusion with single mma op in definition, got ",
mma_ops.size());

// Skip scheduling if Matmul will be expression evaluated.
if (!isOptionDisabled(DisableOption::MatmulExprEval)) {
NVF_CHECK(fusion->outputs().size() == 1)
fusion->aliasOutputToInput(
fusion->outputs()[0], /*input=*/nullptr, AllocationType::Evaluate);
scheduler_debug_utils::log(
__FILE__,
":",
__LINE__,
", Matmul output to be computed through expression evaluator. Skipping codegen.");
return;
}

const auto& roles_map_opt = mma_utils::getTensorsRoles(fusion);

// NOTE: the contents of roles_map have been already validated during
Expand Down
4 changes: 0 additions & 4 deletions csrc/scheduler/matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,10 +532,6 @@ std::shared_ptr<MatmulParams> getMatmulHeuristics(
// Set kernel index mode
params->cparams.index_type = runtime_info.getIndexType();

if (!isOptionDisabled(DisableOption::MatmulExprEval)) {
return params;
}

// Check initial conditions
auto mma_exprs = ir_utils::getOpsOfType<MmaOp>(fusion);
mma_utils::CombineMulSum combiner(fusion);
Expand Down
10 changes: 0 additions & 10 deletions tests/cpp/test_combine_mul_sum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@
namespace nvfuser {

class CombineMulSumAsMmaTest : public NVFuserTest {
protected:
CombineMulSumAsMmaTest() {
DisableOptionsGuard::getCurOptions().set(DisableOption::MatmulExprEval);
}

void SetUp() override {
// These test are enable for Turing and newer. Temporarily
// we are skipping Hopper since the matmul for it is under development.
Expand All @@ -60,11 +55,6 @@ class CombineMulSumAsMmaTest : public NVFuserTest {
}
NVFuserTest::SetUp();
}

private:
// RAII style options guard. This is used to disable
// (via set) options in the constructor.
DisableOptionsGuard opt_guard_;
};

// Test checks to see that the combiner can correctly replace
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/test_double_buffering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ TEST_F(DoubleBufferingTest, SmemBlockGemmCacheDoubleBuffer) {
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({M, K}, options);
at::Tensor t1 = at::randn({K, N}, options);
at::Tensor aten_output = matmul(t0.to(at::kDouble), t1.to(at::kDouble));
at::Tensor aten_output = at::matmul(t0.to(at::kDouble), t1.to(at::kDouble));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this working currently? There's no using namespace at; anywhere that I can see, and those are definitely at::Tensors.

Copy link
Collaborator Author

@Priya2698 Priya2698 May 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It calls ATen matmul:

Thread 1 "nvfuser_tests" hit Breakpoint 2, at::matmul (self=..., other=...) at /usr/local/lib/python3.10/dist-packages/torch/include/ATen/ops/matmul.h:26
26      inline at::Tensor matmul(const at::Tensor & self, const at::Tensor & other) {
          return at::_ops::matmul::call(self, other);

I wasn't aware it could be used without the namespace either. I changed it to remove any confusion with our matmul API although the input types will distinguish the two.


std::vector<c10::IValue> aten_inputs = {t0, t1};

Expand Down
Loading