Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
85504a0
Introduce MatmulPattern and enable it in scheduler
jacobhinkle May 13, 2024
d174a8c
Fixes
jacobhinkle May 13, 2024
6c7f71a
Strip casts from input of mul-sum patterns
jacobhinkle May 13, 2024
bb39995
Remove CombineMulSum
jacobhinkle May 13, 2024
cb99bf2
Merge remote-tracking branch 'origin/main' into translate_matmul_op
jacobhinkle May 14, 2024
03e9bdc
Add a test
jacobhinkle May 14, 2024
c0d33f5
Re-enable NVFUSER_DISABLE=matmul_expr_eval
jacobhinkle May 14, 2024
a697b55
Add MatmulOp to ir_utils::isTvOp
jacobhinkle May 14, 2024
d50e19b
Big refactor to use IdModel and allocation domain
jacobhinkle May 14, 2024
04522b8
Add IterType::Reduction domain for K dim in output
jacobhinkle May 15, 2024
7b04286
Translate bcast K as simple product
jacobhinkle May 15, 2024
dcaf349
Finish testing all combinations of mappings
jacobhinkle May 15, 2024
dd7f09c
Merge remote-tracking branch 'origin/main' into translate_matmul_op
jacobhinkle May 15, 2024
5ece10e
Merge remote-tracking branch 'origin/matmul_op_id_mapping' into trans…
jacobhinkle May 15, 2024
a7a9b56
Remove prints and fix up isMatmulFusionDefinitionSupported
jacobhinkle May 15, 2024
d6487b1
Fix getProblemLayout
jacobhinkle May 15, 2024
fd0ceb9
Add EnableOption::FuseMatmul and check sm arch
jacobhinkle May 15, 2024
6e3f052
Add TODO about skipping downcast roundtrip
jacobhinkle May 15, 2024
812f9d2
Remove some unused code
jacobhinkle May 15, 2024
6d48f35
clang-tidy
jacobhinkle May 15, 2024
68bf90e
Merge remote-tracking branch 'origin/main' into translate_matmul_op
jacobhinkle May 16, 2024
ffc62dd
Undo change to matmul
jacobhinkle May 16, 2024
7ef6dda
Merge branch 'main' into translate_matmul_op
jacobhinkle May 17, 2024
4839125
Clean up test
jacobhinkle May 17, 2024
254ba50
Test that alloc domain causes ExprEval use
jacobhinkle May 17, 2024
6d8313d
Merge remote-tracking branch 'origin/main' into translate_matmul_op
jacobhinkle May 20, 2024
d432785
Merge remote-tracking branch 'origin/main' into translate_matmul_op
jacobhinkle May 23, 2024
72da560
WIP adding LinearOp support
jacobhinkle May 23, 2024
514da40
Translate LinearOp
jacobhinkle May 23, 2024
4f64c5d
Fixes plus add more test cases for LinearOp
jacobhinkle May 23, 2024
8328207
Fix up more tests. Not all test cases are passing yet
jacobhinkle May 23, 2024
db32d2d
Special cases in getDimRoles for new nodes
jacobhinkle May 23, 2024
c5a6135
Disable allocation domain inference for test
jacobhinkle May 23, 2024
9e4f70c
Cover more cases in matmul node test
jacobhinkle May 23, 2024
09d3945
Fix up batch cases by adding outer broadcast dims
jacobhinkle May 23, 2024
f5c4f36
Fix some cases and add comments
jacobhinkle May 23, 2024
5c0d31f
Fix comment
jacobhinkle May 23, 2024
18c0f46
Fix faulty merge
jacobhinkle May 23, 2024
b52cb1e
Remove undocumented 2D bias test case
jacobhinkle May 23, 2024
bb0d619
Fix gcc build error
jacobhinkle May 23, 2024
de5eb9b
Merge branch 'main' into translate_matmul_op
jacobhinkle May 24, 2024
6326033
Merge remote-tracking branch 'origin/main' into translate_matmul_op
jacobhinkle May 29, 2024
d0c5bed
Loop over operand roles to reduce code duplication
jacobhinkle May 29, 2024
384f036
Fix comment about enabling matmul op
jacobhinkle May 29, 2024
47cc50e
Update csrc/scheduler/matmul_utils.cpp
jacobhinkle May 29, 2024
1e55592
Reject dtypes other than Half or BFloat16
jacobhinkle May 29, 2024
06f675e
Update csrc/scheduler/mma_utils.cpp
jacobhinkle May 29, 2024
8fa30ec
Use mapMatmulOpIterDomains
jacobhinkle May 29, 2024
e19c8b9
Use map*OpIterDomains to simplify getDimRoles
jacobhinkle May 29, 2024
520b709
Add Llama2FFN test
jacobhinkle May 29, 2024
9ae0769
clang-format
jacobhinkle May 29, 2024
0fb44e4
Remove 2D bias check in linear test
jacobhinkle May 29, 2024
fe8898e
Parametrize LinearOp node translation test
jacobhinkle May 29, 2024
2d8d45d
Parametrize matmul test
jacobhinkle May 29, 2024
59d77cf
Fix multidevice examples by filtering out device dims
jacobhinkle May 30, 2024
764823f
Clean up canScheduleCompileTime
jacobhinkle May 30, 2024
5cf7d72
Add link to #2241 in failing test case
jacobhinkle May 30, 2024
5ffe590
Make NoOp scheduler avoid matmul ops
jacobhinkle May 30, 2024
6108522
NVFuserTest -> MatmulSchedulerTest
jacobhinkle May 30, 2024
f34f4ad
Parametrize Llama2FFN test
jacobhinkle May 30, 2024
7342ecd
Guard translation tests for cc < 7.5
jacobhinkle May 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ template <>
std::unordered_map<EnableOption, std::vector<std::string>> Options<
EnableOption>::getOptionsFromEnv() {
const std::unordered_map<std::string, EnableOption> available_options = {
{"fuse_matmul", EnableOption::FuseMatmul},
{"id_model", EnableOption::IdModel},
{"kernel_db", EnableOption::KernelDb},
{"kernel_profile", EnableOption::KernelProfile},
Expand Down
1 change: 1 addition & 0 deletions csrc/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ enum class DebugDumpOption {
//! These can be set through the `NVFUSER_ENABLE` environment variable
//!
enum class EnableOption {
FuseMatmul, //! Enable automatic fusion of matmul and linear ops
IdModel, //! Enable IdModel
KernelDb, //! Enable Kernel Database
KernelProfile, //! Enable intra-kernel performance profiling
Expand Down
17 changes: 11 additions & 6 deletions csrc/scheduler/expr_eval_sched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@ namespace nvfuser {
// Check if the fusion has a single MatmulOp/LinearOp node
bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) {
auto exprs = fusion->exprs();
if (exprs.size() == 1 &&
(exprs.front()->isA<MatmulOp>() || exprs.front()->isA<LinearOp>())) {
return true;
if (!isOptionDisabled(DisableOption::MatmulExprEval)) {
if (exprs.size() == 1 && (exprs.front()->isOneOf<LinearOp, MatmulOp>())) {
return true;
}
scheduler_debug_utils::canScheduleRejectReason(
heuristicType(),
"Fusion must contain a single expression of type MatmulOp or LinearOp");
} else {
scheduler_debug_utils::canScheduleRejectReason(
heuristicType(),
"Matmul ATen evaluation was disabled by NVFUSER_DISABLE=matmul_expr_eval");
}
scheduler_debug_utils::canScheduleRejectReason(
heuristicType(),
"Fusion must contain a single expression of type MatmulOp or LinearOp");
return false;
}

Expand Down
111 changes: 78 additions & 33 deletions csrc/scheduler/matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,9 @@ std::string isMatmulFusionDefinitionSupported(

constexpr size_t minimal_number_of_inputs = 2;

// Quick checks - MmaOp
// Quick checks
{
// Check if MmaOp represents gemm (requires M/N/K == 1, B == 0)
// Check if matmul pattern represents gemm (requires M/N/K == 1, B == 0)
// or bgemm (requires M/N/K/B == 1)
std::array<int64_t, 4> num_axes{};
for (const auto& [g, dom] : id_roles) {
Expand All @@ -218,7 +218,7 @@ std::string isMatmulFusionDefinitionSupported(
num_axes[(size_t)MatmulDomain::N] != expected_axes_numbers ||
num_axes[(size_t)MatmulDomain::K] != expected_axes_numbers ||
num_axes[(size_t)MatmulDomain::Batch] > expected_axes_numbers) {
return "MmaOp has unsupported number of one of M/N/K/Batch axes";
return "Matmul pattern has unsupported number of one of M/N/K/Batch axes";
}

if (!mma_output->hasReduction()) {
Expand All @@ -236,31 +236,47 @@ std::string isMatmulFusionDefinitionSupported(

// Fusion topology check
{
auto entry = roles_map.find(MatmulRole::INPUT_A);
// We will check that all operands have same dimension
int64_t operand_dim = -1;

// Track TensorViews with assigned roles so we can check that all inputs and
// outputs have recognized roles
std::set<TensorView*> tvs_with_roles;

if (entry != roles_map.end()) {
if (MATMUL_CORE_ROLES_EXPECTED_COUNT == entry->second.size()) {
tvs_with_roles.insert(entry->second.begin(), entry->second.end());
for (MatmulRole role : {MatmulRole::INPUT_A, MatmulRole::INPUT_B}) {
auto entry = roles_map.find(role);
if (entry != roles_map.end()) {
if (MATMUL_CORE_ROLES_EXPECTED_COUNT == entry->second.size()) {
tvs_with_roles.insert(entry->second.begin(), entry->second.end());
for (TensorView* tv : entry->second) {
const std::vector<IterDomain*>& leaf = tv->getLeafDomain();
int64_t ndims = (int64_t)std::count_if(
leaf.begin(), leaf.end(), [](IterDomain* id) {
return !id->isReduction() && !id->isDeviceDim();
});
if (operand_dim == -1) {
operand_dim = ndims;
} else if (ndims != operand_dim) {
// We cannot always handle differently sized inputs, such as those
// we encounter when translating MatmulOp and LinearOp. This is
// because in those cases one of the operands will have new
// Broadcast dimensions where the other operand has Iteration
// batch dimensions, meaning these new dims are actually M or N
// dimensions. Multiple M and N dimension support is planned but
// for now we must reject these patterns before attempting to
// translate them.
return "All operands must have the same no-devices dimension.";
}
}
} else {
return "There is more than a single fusion input that can be MMA operand ";
}
} else {
return "There is more than a single fusion input that can be MMA first input";
return "No candidate in fusion inputs for MMA operand";
}
} else {
return "No candidate in fusion inputs for MMA first input";
}

entry = roles_map.find(MatmulRole::INPUT_B);
if (entry != roles_map.end()) {
if (MATMUL_CORE_ROLES_EXPECTED_COUNT == entry->second.size()) {
tvs_with_roles.insert(entry->second.begin(), entry->second.end());
} else {
return "There is more than a single fusion input that can be MMA second input";
}
} else {
return "No candidate in fusion inputs for MMA second input";
}

entry = roles_map.find(MatmulRole::OUTPUT_D);
auto entry = roles_map.find(MatmulRole::OUTPUT_D);
if (entry != roles_map.end()) {
tvs_with_roles.insert(entry->second.begin(), entry->second.end());
} else {
Expand Down Expand Up @@ -458,8 +474,11 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) {
// The plan:
// 0. Check if the current CUDA device is supported
// 1. Check if there is exactly one matmul pattern defined in the fusion.
// 2. Check if the input layout for the matmul pattern can be determined
// 3. Check if fusion represents expressions that are recognized by matmul
// 2. Check if fusion of MatmulOp and LinearOp is enabled, if applicable
// 3. Check if inputs to the matmul pattern match any of
// supported inputs layout
// 4. Check if fusion represents expressions that are recognized by matmul
// 5. Check if the input layout for the matmul pattern can be determined
Comment on lines +480 to +481
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reordered the last two checks since different-dimension inputs can interfere with determining the layout, which hits an NVF_CHECK instead of just returning a reason for not accepting the segment.

// scheduler.

// #0
Expand All @@ -475,13 +494,39 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) {
}

// #1
// Initializing the machinery to check if there's a Mul-Sum pair
// can be replaced by a Mma Op.
// Find matmul patterns
std::vector<mma_utils::MatmulPattern> patterns =
mma_utils::findMatmulPatterns(fusion);
if (patterns.empty()) {
return "No matmul patterns were found";
}

// #2
{
for (const mma_utils::MatmulPattern& pattern : patterns) {
Expr* op = pattern.output->definition();
if (op->isA<MatmulOp>() || op->isA<LinearOp>()) {
if (!isOptionEnabled(EnableOption::FuseMatmul)) {
// Check for MatmulOp or LinearOp. If found, then only fuse if option
// is specified
return "MatmulOp and LinearOp fusion is disabled by default. "
"Enable it using NVFUSER_ENABLE=fuse_matmul";
}
// Refuse patterns containing 1D inputs since these are mat-vec as
// opposed to mat-mat products.
if (pattern.A->nDims() < 2 || pattern.B->nDims() < 2) {
return "Cannot fuse matrix-vector products";
}
for (TensorView* operand : {pattern.A, pattern.B}) {
if (operand->dtype() != DataType::Half &&
operand->dtype() != DataType::BFloat16) {
return "Unsupported operand type. Operands must be fp16 or bf16";
}
}
}
}
}

if (patterns.size() > 1) {
return "Only a single matmul pattern can currently be fused";
}
Expand All @@ -498,13 +543,6 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) {
mma_utils::RolesMap roles_map = roles_map_opt.getData();

// #4
const auto input_layout_opt =
mma_utils::getProblemLayout(id_model, id_roles, roles_map);
if (!input_layout_opt.isValid()) {
return input_layout_opt.getErrorMsg();
}

// #5
{
auto support_status = isMatmulFusionDefinitionSupported(
fusion, patterns.front(), roles_map, id_roles);
Expand All @@ -513,6 +551,13 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) {
}
}

// #5
const auto input_layout_opt =
mma_utils::getProblemLayout(id_model, id_roles, roles_map);
if (!input_layout_opt.isValid()) {
return input_layout_opt.getErrorMsg();
}

return "";
}

Expand Down
Loading