Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
58687ad
Generalize CombineMulSum as MatmulPatterns
jacobhinkle May 20, 2024
a583d40
Remove MatmulOp stuff.
jacobhinkle May 20, 2024
f5ec534
Fix multiple-broadcasts test.
jacobhinkle May 20, 2024
d74c7c0
Merge branch 'main' into matmul_patterns
jacobhinkle May 20, 2024
1a0fdb9
Remove bcast output test
jacobhinkle May 20, 2024
9e6447d
Remove canScheduleCompileTime check for ExprEval
jacobhinkle May 20, 2024
1cfac85
Fix gcc build failure due to unused variable
jacobhinkle May 21, 2024
7aad387
Fix signed/unsigned compare
jacobhinkle May 21, 2024
e4ef03e
Merge branch 'main' into matmul_patterns
jacobhinkle May 21, 2024
36a72dc
Allow multiple M, N, or K dims in pattern match
jacobhinkle May 21, 2024
1f713d3
Update csrc/scheduler/mma_utils.cpp
jacobhinkle May 21, 2024
361389a
Merge branch 'main' into matmul_patterns
jacobhinkle May 21, 2024
6937ff8
Add comment about why casts are often present
jacobhinkle May 21, 2024
f54c51b
Merge remote-tracking branch 'origin/main' into matmul_patterns
jacobhinkle May 21, 2024
345b025
Remove getProblemIterDomains
jacobhinkle May 21, 2024
142f366
Rename group_to_domain -> dim_roles
jacobhinkle May 21, 2024
8bfd292
Update csrc/scheduler/matmul.cpp
jacobhinkle May 21, 2024
57974a3
Update csrc/scheduler/matmul.cpp
jacobhinkle May 21, 2024
e0619cf
Rename most occurences of roles_map -> tensor_roles
jacobhinkle May 21, 2024
b0fa936
Remove getProblemLayout(Fusion*, const MatmulPattern&)
jacobhinkle May 21, 2024
76eb037
Uncomment getProblemLayout
jacobhinkle May 21, 2024
e1044d9
Replace bitwise assignment ops
jacobhinkle May 21, 2024
990f2be
Rename dim_to_domain -> dim_roles
jacobhinkle May 21, 2024
4df6b11
Merge remote-tracking branch 'origin/main' into matmul_patterns
jacobhinkle May 22, 2024
c3753cb
Add comment describing isMatmulFusionDefinitionSupported
jacobhinkle May 22, 2024
4fac4fb
Use lambda to simplify getTensorsRoles
jacobhinkle May 22, 2024
9d31a90
Rename getTensorsRoles -> getTensorRoles
jacobhinkle May 22, 2024
a99c9ae
Assume alloc=rfactor to determine M, N and A, B
jacobhinkle May 22, 2024
bd6517c
Test that dim role mapping survives swap
jacobhinkle May 22, 2024
8dc6948
Rename ValGroupPresence to DimPresence
jacobhinkle May 22, 2024
71128e4
Use std::variant<std::string, UnitDim> for error
jacobhinkle May 22, 2024
71ba8a2
clang-tidy fix
jacobhinkle May 22, 2024
0deb0bf
Use noReductions/noBroadcasts to simplify hasTrivialAllocationDomain
jacobhinkle May 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion csrc/ir/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,16 @@ int64_t getVectorizeSize(const TensorView* tv) {
return 1;
}

bool hasTrivialAllocationDomain(const TensorView* tv) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intention of this utility is to generalize !tv->hasAllocation() to cases where an allocation domain is provided, but it actually corresponds to the no-reductions rfactor domain (ignoring broadcasts).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Would it be possible to add this as a comment in the header file.

if (!tv->hasAllocation()) {
return true;
}
const std::vector<IterDomain*>& alloc = tv->getMaybeAllocationDomain();
const std::vector<IterDomain*>& rf = tv->getMaybeRFactorDomain();
return TensorDomain::noBroadcasts(TensorDomain::noReductions(rf)) ==
TensorDomain::noBroadcasts(TensorDomain::noReductions(alloc));
}

} // namespace nvfuser::ir_utils

namespace nvfuser::MmaOpUtils {
Expand Down Expand Up @@ -1269,7 +1279,6 @@ MmaOpDetails getMmaOpDetails(
const auto validateOutputDetails = [](const TensorViewDetails& details,
const std::string& desc) {
// TODO: revise rules when add support for batch gemms
NVF_ERROR(details.bcasts.empty(), desc, ": has broadcast domains.");
NVF_ERROR(!details.rdomains.empty(), desc, ": has no reduction domains.");
NVF_ERROR(
(details.cdomains.size() >= expected_gemm_cdomains),
Expand Down
2 changes: 2 additions & 0 deletions csrc/ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -652,4 +652,6 @@ std::optional<std::vector<int64_t>> computePermutation(
return permutation;
}

bool hasTrivialAllocationDomain(const TensorView* tv);

} // namespace nvfuser::ir_utils
2 changes: 1 addition & 1 deletion csrc/mma_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace nvfuser {
constexpr std::string_view MATMUL_LOG_PREFIX = "[MATMUL DEBUG] ";

//! Named descriptors of domains in matmul
enum class MatmulDomain { M = 0, N, K };
enum class MatmulDomain { M = 0, N, K, Batch };

//! Named descriptors of TensorView roles in fusion
//! INPUT_A - a producer of MMA input A
Expand Down
51 changes: 29 additions & 22 deletions csrc/scheduler/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -749,38 +749,45 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) {
// Cache and fork outputs
auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, true);

mma_utils::CombineMulSum combiner(fusion);
auto mma_ops = ir_utils::getOpsOfType<MmaOp>(fusion);
if (combiner.isValid() && mma_ops.empty()) {
combiner.replaceWithMmaOp();
mma_ops = ir_utils::getOpsOfType<MmaOp>(fusion);
}

std::vector<mma_utils::MatmulPattern> patterns =
mma_utils::findMatmulPatterns(fusion);
NVF_ERROR(!patterns.empty(), "No matmul patterns were found");
NVF_ERROR(
mma_ops.size() == 1,
"scheduleMatmul supports fusion with single mma op in definition, got ",
mma_ops.size());
patterns.size() == 1,
"Only a single matmul pattern can currently be fused");
std::vector<MmaOp*> mma_ops;
mma_ops.reserve(patterns.size());
for (mma_utils::MatmulPattern& pattern : patterns) {
mma_ops.push_back(pattern.translateToMmaOp());
}

const auto& roles_map_opt = mma_utils::getTensorsRoles(fusion);
IdModel id_model(fusion);
std::unordered_map<ValGroup, MatmulDomain> id_roles =
patterns.front().getDimRoles(id_model);
const auto& tensor_roles_opt =
mma_utils::getTensorRoles(fusion, id_model, id_roles);

// NOTE: the contents of roles_map have been already validated during
// NOTE: the contents of tensor_roles have been already validated during
// compute-time checks
NVF_ERROR(roles_map_opt.isValid(), roles_map_opt.getErrorMsg());
const auto roles_map = roles_map_opt.getData();
NVF_ERROR(tensor_roles_opt.isValid(), tensor_roles_opt.getErrorMsg());
const auto tensor_roles = tensor_roles_opt.getData();

const mma_utils::MatmulProblemLayoutOpt fusion_layout =
mma_utils::getProblemLayout(id_model, id_roles, tensor_roles);
NVF_ERROR(fusion_layout.isValid(), fusion_layout.getErrorMsg());

// Core roles: there can be only one... TV with assigned core role
TensorView* a = roles_map.at(MatmulRole::INPUT_A).front();
TensorView* b = roles_map.at(MatmulRole::INPUT_B).front();
TensorView* a = tensor_roles.at(MatmulRole::INPUT_A).front();
TensorView* b = tensor_roles.at(MatmulRole::INPUT_B).front();

const auto& gemm_tile = params.tile_sizes;

// Collect mma swizzle info
auto mma = mma_ops.front();
const auto fusion_layout = mma_utils::getMmaLayout(fusion);
NVF_ERROR(fusion_layout.isValid(), fusion_layout.getErrorMsg());

const auto& gemm_tile = params.tile_sizes;
const bool has_epilogue = !mma->out()->isFusionOutput();

const bool has_fusion_c_roles = (0 != roles_map.count(MatmulRole::INPUT_C));
const bool has_fusion_c_roles =
(0 != tensor_roles.count(MatmulRole::INPUT_C));
const bool has_non_mma_input_tvs = has_epilogue && has_fusion_c_roles;

// Including current tensor naming convention for reference,
Expand Down Expand Up @@ -1227,7 +1234,7 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) {
// operations, input tvs with non-core roles
// core roles: essential for matmul, for example mma inputs' producers
if (has_non_mma_input_tvs) {
scheduleFusionInputsForEpilogue(roles_map, params.use_smem_epilogue);
scheduleFusionInputsForEpilogue(tensor_roles, params.use_smem_epilogue);
}

scheduleSplitKSum(
Expand Down
Loading