-
Notifications
You must be signed in to change notification settings - Fork 79
Translate MatmulOp and LinearOp #2236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
61 commits
Select commit
Hold shift + click to select a range
85504a0
Introduce MatmulPattern and enable it in scheduler
jacobhinkle d174a8c
Fixes
jacobhinkle 6c7f71a
Strip casts from input of mul-sum patterns
jacobhinkle bb39995
Remove CombineMulSum
jacobhinkle cb99bf2
Merge remote-tracking branch 'origin/main' into translate_matmul_op
jacobhinkle 03e9bdc
Add a test
jacobhinkle c0d33f5
Re-enable NVFUSER_DISABLE=matmul_expr_eval
jacobhinkle a697b55
Add MatmulOp to ir_utils::isTvOp
jacobhinkle d50e19b
Big refactor to use IdModel and allocation domain
jacobhinkle 04522b8
Add IterType::Reduction domain for K dim in output
jacobhinkle 7b04286
Translate bcast K as simple product
jacobhinkle dcaf349
Finish testing all combinations of mappings
jacobhinkle dd7f09c
Merge remote-tracking branch 'origin/main' into translate_matmul_op
jacobhinkle 5ece10e
Merge remote-tracking branch 'origin/matmul_op_id_mapping' into trans…
jacobhinkle a7a9b56
Remove prints and fix up isMatmulFusionDefinitionSupported
jacobhinkle d6487b1
Fix getProblemLayout
jacobhinkle fd0ceb9
Add EnableOption::FuseMatmul and check sm arch
jacobhinkle 6e3f052
Add TODO about skipping downcast roundtrip
jacobhinkle 812f9d2
Remove some unused code
jacobhinkle 6d48f35
clang-tidy
jacobhinkle 68bf90e
Merge remote-tracking branch 'origin/main' into translate_matmul_op
jacobhinkle ffc62dd
Undo change to matmul
jacobhinkle 7ef6dda
Merge branch 'main' into translate_matmul_op
jacobhinkle 4839125
Clean up test
jacobhinkle 254ba50
Test that alloc domain causes ExprEval use
jacobhinkle 6d8313d
Merge remote-tracking branch 'origin/main' into translate_matmul_op
jacobhinkle d432785
Merge remote-tracking branch 'origin/main' into translate_matmul_op
jacobhinkle 72da560
WIP adding LinearOp support
jacobhinkle 514da40
Translate LinearOp
jacobhinkle 4f64c5d
Fixes plus add more test cases for LinearOp
jacobhinkle 8328207
Fix up more tests. Not all test cases are passing yet
jacobhinkle db32d2d
Special cases in getDimRoles for new nodes
jacobhinkle c5a6135
Disable allocation domain inference for test
jacobhinkle 9e4f70c
Cover more cases in matmul node test
jacobhinkle 09d3945
Fix up batch cases by adding outer broadcast dims
jacobhinkle f5c4f36
Fix some cases and add comments
jacobhinkle 5c0d31f
Fix comment
jacobhinkle 18c0f46
Fix faulty merge
jacobhinkle b52cb1e
Remove undocumented 2D bias test case
jacobhinkle bb0d619
Fix gcc build error
jacobhinkle de5eb9b
Merge branch 'main' into translate_matmul_op
jacobhinkle 6326033
Merge remote-tracking branch 'origin/main' into translate_matmul_op
jacobhinkle d0c5bed
Loop over operand roles to reduce code duplication
jacobhinkle 384f036
Fix comment about enabling matmul op
jacobhinkle 47cc50e
Update csrc/scheduler/matmul_utils.cpp
jacobhinkle 1e55592
Reject dtypes other than Half or BFloat16
jacobhinkle 06f675e
Update csrc/scheduler/mma_utils.cpp
jacobhinkle 8fa30ec
Use mapMatmulOpIterDomains
jacobhinkle e19c8b9
Use map*OpIterDomains to simplify getDimRoles
jacobhinkle 520b709
Add Llama2FFN test
jacobhinkle 9ae0769
clang-format
jacobhinkle 0fb44e4
Remove 2D bias check in linear test
jacobhinkle fe8898e
Parametrize LinearOp node translation test
jacobhinkle 2d8d45d
Parametrize matmul test
jacobhinkle 59d77cf
Fix multidevice examples by filtering out device dims
jacobhinkle 764823f
Clean up canScheduleCompileTime
jacobhinkle 5cf7d72
Add link to #2241 in failing test case
jacobhinkle 5ffe590
Make NoOp scheduler avoid matmul ops
jacobhinkle 6108522
NVFuserTest -> MatmulSchedulerTest
jacobhinkle f34f4ad
Parametrize Llama2FFN test
jacobhinkle 7342ecd
Guard translation tests for cc < 7.5
jacobhinkle File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -205,9 +205,9 @@ std::string isMatmulFusionDefinitionSupported( | |
|
|
||
| constexpr size_t minimal_number_of_inputs = 2; | ||
|
|
||
| // Quick checks - MmaOp | ||
| // Quick checks | ||
| { | ||
| // Check if MmaOp represents gemm (requires M/N/K == 1, B == 0) | ||
| // Check if matmul pattern represents gemm (requires M/N/K == 1, B == 0) | ||
| // or bgemm (requires M/N/K/B == 1) | ||
| std::array<int64_t, 4> num_axes{}; | ||
| for (const auto& [g, dom] : id_roles) { | ||
|
|
@@ -218,7 +218,7 @@ std::string isMatmulFusionDefinitionSupported( | |
| num_axes[(size_t)MatmulDomain::N] != expected_axes_numbers || | ||
| num_axes[(size_t)MatmulDomain::K] != expected_axes_numbers || | ||
| num_axes[(size_t)MatmulDomain::Batch] > expected_axes_numbers) { | ||
| return "MmaOp has unsupported number of one of M/N/K/Batch axes"; | ||
| return "Matmul pattern has unsupported number of one of M/N/K/Batch axes"; | ||
| } | ||
|
|
||
| if (!mma_output->hasReduction()) { | ||
|
|
@@ -236,31 +236,47 @@ std::string isMatmulFusionDefinitionSupported( | |
|
|
||
| // Fusion topology check | ||
| { | ||
| auto entry = roles_map.find(MatmulRole::INPUT_A); | ||
| // We will check that all operands have same dimension | ||
| int64_t operand_dim = -1; | ||
|
|
||
| // Track TensorViews with assigned roles so we can check that all inputs and | ||
| // outputs have recognized roles | ||
| std::set<TensorView*> tvs_with_roles; | ||
|
|
||
| if (entry != roles_map.end()) { | ||
| if (MATMUL_CORE_ROLES_EXPECTED_COUNT == entry->second.size()) { | ||
| tvs_with_roles.insert(entry->second.begin(), entry->second.end()); | ||
| for (MatmulRole role : {MatmulRole::INPUT_A, MatmulRole::INPUT_B}) { | ||
| auto entry = roles_map.find(role); | ||
| if (entry != roles_map.end()) { | ||
| if (MATMUL_CORE_ROLES_EXPECTED_COUNT == entry->second.size()) { | ||
| tvs_with_roles.insert(entry->second.begin(), entry->second.end()); | ||
| for (TensorView* tv : entry->second) { | ||
| const std::vector<IterDomain*>& leaf = tv->getLeafDomain(); | ||
| int64_t ndims = (int64_t)std::count_if( | ||
| leaf.begin(), leaf.end(), [](IterDomain* id) { | ||
| return !id->isReduction() && !id->isDeviceDim(); | ||
| }); | ||
| if (operand_dim == -1) { | ||
| operand_dim = ndims; | ||
| } else if (ndims != operand_dim) { | ||
| // We cannot always handle differently sized inputs, such as those | ||
| // we encounter when translating MatmulOp and LinearOp. This is | ||
| // because in those cases one of the operands will have new | ||
| // Broadcast dimensions where the other operand has Iteration | ||
| // batch dimensions, meaning these new dims are actually M or N | ||
| // dimensions. Multiple M and N dimension support is planned but | ||
| // for now we must reject these patterns before attempting to | ||
| // translate them. | ||
| return "All operands must have the same no-devices dimension."; | ||
| } | ||
| } | ||
| } else { | ||
| return "There is more than a single fusion input that can be MMA operand "; | ||
| } | ||
| } else { | ||
| return "There is more than a single fusion input that can be MMA first input"; | ||
| return "No candidate in fusion inputs for MMA operand"; | ||
| } | ||
| } else { | ||
| return "No candidate in fusion inputs for MMA first input"; | ||
| } | ||
|
|
||
| entry = roles_map.find(MatmulRole::INPUT_B); | ||
| if (entry != roles_map.end()) { | ||
| if (MATMUL_CORE_ROLES_EXPECTED_COUNT == entry->second.size()) { | ||
| tvs_with_roles.insert(entry->second.begin(), entry->second.end()); | ||
| } else { | ||
| return "There is more than a single fusion input that can be MMA second input"; | ||
| } | ||
| } else { | ||
| return "No candidate in fusion inputs for MMA second input"; | ||
| } | ||
|
|
||
| entry = roles_map.find(MatmulRole::OUTPUT_D); | ||
| auto entry = roles_map.find(MatmulRole::OUTPUT_D); | ||
| if (entry != roles_map.end()) { | ||
| tvs_with_roles.insert(entry->second.begin(), entry->second.end()); | ||
| } else { | ||
|
|
@@ -458,8 +474,11 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { | |
| // The plan: | ||
| // 0. Check if the current CUDA device is supported | ||
| // 1. Check if there is exactly one matmul pattern defined in the fusion. | ||
| // 2. Check if the input layout for the matmul pattern can be determined | ||
| // 3. Check if fusion represents expressions that are recognized by matmul | ||
| // 2. Check if fusion of MatmulOp and LinearOp is enabled, if applicable | ||
| // 3. Check if inputs to the matmul pattern match any of | ||
| // supported inputs layout | ||
| // 4. Check if fusion represents expressions that are recognized by matmul | ||
| // 5. Check if the input layout for the matmul pattern can be determined | ||
|
Comment on lines
+480
to
+481
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I reordered the last two checks since different-dimension inputs can interfere with determining the layout, which hits an |
||
| // scheduler. | ||
|
|
||
| // #0 | ||
|
|
@@ -475,13 +494,39 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { | |
| } | ||
|
|
||
| // #1 | ||
| // Initializing the machinery to check if there's a Mul-Sum pair | ||
| // can be replaced by a Mma Op. | ||
| // Find matmul patterns | ||
| std::vector<mma_utils::MatmulPattern> patterns = | ||
| mma_utils::findMatmulPatterns(fusion); | ||
| if (patterns.empty()) { | ||
| return "No matmul patterns were found"; | ||
| } | ||
|
|
||
| // #2 | ||
| { | ||
| for (const mma_utils::MatmulPattern& pattern : patterns) { | ||
| Expr* op = pattern.output->definition(); | ||
| if (op->isA<MatmulOp>() || op->isA<LinearOp>()) { | ||
| if (!isOptionEnabled(EnableOption::FuseMatmul)) { | ||
| // Check for MatmulOp or LinearOp. If found, then only fuse if option | ||
| // is specified | ||
| return "MatmulOp and LinearOp fusion is disabled by default. " | ||
| "Enable it using NVFUSER_ENABLE=fuse_matmul"; | ||
| } | ||
| // Refuse patterns containing 1D inputs since these are mat-vec as | ||
| // opposed to mat-mat products. | ||
| if (pattern.A->nDims() < 2 || pattern.B->nDims() < 2) { | ||
Priya2698 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return "Cannot fuse matrix-vector products"; | ||
| } | ||
jacobhinkle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| for (TensorView* operand : {pattern.A, pattern.B}) { | ||
| if (operand->dtype() != DataType::Half && | ||
| operand->dtype() != DataType::BFloat16) { | ||
| return "Unsupported operand type. Operands must be fp16 or bf16"; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if (patterns.size() > 1) { | ||
| return "Only a single matmul pattern can currently be fused"; | ||
| } | ||
|
|
@@ -498,13 +543,6 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { | |
| mma_utils::RolesMap roles_map = roles_map_opt.getData(); | ||
|
|
||
| // #4 | ||
| const auto input_layout_opt = | ||
| mma_utils::getProblemLayout(id_model, id_roles, roles_map); | ||
| if (!input_layout_opt.isValid()) { | ||
| return input_layout_opt.getErrorMsg(); | ||
| } | ||
|
|
||
| // #5 | ||
| { | ||
| auto support_status = isMatmulFusionDefinitionSupported( | ||
| fusion, patterns.front(), roles_map, id_roles); | ||
|
|
@@ -513,6 +551,13 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { | |
| } | ||
| } | ||
|
|
||
| // #5 | ||
| const auto input_layout_opt = | ||
| mma_utils::getProblemLayout(id_model, id_roles, roles_map); | ||
| if (!input_layout_opt.isValid()) { | ||
| return input_layout_opt.getErrorMsg(); | ||
| } | ||
|
|
||
| return ""; | ||
| } | ||
|
|
||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.