Skip to content

Matmul default scheduling [1]#1743

Merged
Priya2698 merged 17 commits intomainfrom
pm/mma_default
Feb 14, 2024
Merged

Matmul default scheduling [1]#1743
Priya2698 merged 17 commits intomainfrom
pm/mma_default

Conversation

@Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Feb 8, 2024

Initial PR for Issue #1669.

  1. Adds a EnableOption::MatmulExprEval to turn on expression evaluation for matmul while the API in under progress.
  2. Currently, we only evaluate the MmaOp. The next PRs will amend this to look ahead and evaluate Mma + Cast, which is what we should see in the fusion definitions. See discussion here. In the absence of this we may have casts such as bfloat->float->bfloat.

@Priya2698 Priya2698 changed the title Matmul default scheduling [1] [WIP] Matmul default scheduling [1] Feb 8, 2024
@Priya2698 Priya2698 changed the title [WIP] Matmul default scheduling [1] Matmul default scheduling [1] Feb 9, 2024
mma_ops.size());

// Skip scheduling if Matmul will be expression evaluated.
if (isOptionEnabled(EnableOption::MatmulExprEval)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean any matmul pattern including mma + epilogue will be handled by the expression evaluator? Shouldn't it only take care of the mma part?

Copy link
Collaborator

@wujingyue wujingyue Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two problems with taking care of just the mma:

  1. at::matmul doesn't do HH->S. We could plug in another backend that supports HH->S for EE though.
  2. MMA is never alone in these GPT models (e.g. LLaMA). It's always part of a linear layer or an SDPA. nvFuser doesn't do SDPA well and we will have to offload it to another executor for quite some time, so scratch that. A linear layer however comes with this MMA->BiasAdd pattern. In order for its performance to be on par with framework-not-giving-nvFuser-the-matmul, we have to execute MMA+epilogue in one call.

Wdyt, @naoyam? Extending EE to do MMA+epilogue is the most obvious way to me to solve the above problems. But I could definitely be wrong.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extending EE to do MMA+epilogue

How would it do that? Does Aten support matmul with some epilogue op?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, it's called at::addmm. I realized I was wrong about Relu. torch.nn.Linear doesn't do Relu, so the pattern would be matmul+biasadd which is what at::addmm does.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, how would you handle matmul+epilogue patterns that are accepted by the nvFuser matmul scheduler but there's no corresponding aten version? Would that end up doing mma and the epilogue op separately?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. Or, maybe we might want to have a separate scheduler for EE? We try the native matmul scheduler first, and then the EE matmul scheduler, and then the other schedulers. No particular preference but just my two cents.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's a certainly great option to consider. It makes things more composable at a risk of being harder to share logics with MatmulScheduler. I hope the preference will be more obvious when we know what the heuristics look like!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding on to @wujingyue's comments, the plan for the next PRs is:

  1. Support Mma + Cast -> avoid roundtrip casting (half-> float -> half) by checking mma_out->uses() and if it is either a castOp or pointwise ops with inputs of the same type (half), then skip casting back the output of at::matmul. This will not execute matmul + bias in a single call.
  2. Handle common epilogue fusions: We will need to pattern match and evaluate within the MmaOp accordingly. test_matmul_scheduler.cpp currently has a few cases that I will start with: mma + bias, mma + bias + relu/gelu, mma + relu.
  3. Epilogue fusions that are not be supported: They can still be computed through EE but should ideally not be plumbed down to the matmul scheduler.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do 2 only when it's needed. You should double check this but I think Llama for example uses linear with bias off and none of the linear layers in our benchmarks do relu or gelu

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Llama2 has bias=False in the linear layers but some GPT configs have bias=True.

Copy link
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! The PR currently shows as a draft. I'll hold off my reviews when it's ready.

@Priya2698 Priya2698 marked this pull request as ready for review February 13, 2024 05:43
Copy link
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with some comments to be resolved.

Copy link
Collaborator

@jacobhinkle jacobhinkle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Just some comments on the tests.

@drzejan2
Copy link
Contributor

Thanks for preparing this change, they look good to me after caught up to the latest discussions around matmul scheduler.

@protonu
Copy link
Collaborator

protonu commented Feb 14, 2024

Nit: Will it make sense to make this C++ test file a part of the test_matmul target?
https://github.com/NVIDIA/Fuser/blob/9fac12bdd98a63fdc88dde265b8add6a0e3f41cf/CMakeLists.txt#L480C1-L489C41

@Priya2698
Copy link
Collaborator Author

Nit: Will it make sense to make this C++ test file a part of the test_matmul target? https://github.com/NVIDIA/Fuser/blob/9fac12bdd98a63fdc88dde265b8add6a0e3f41cf/CMakeLists.txt#L480C1-L489C41

Thanks for the suggestion! Moved the test file.

@Priya2698 Priya2698 merged commit a0cb47a into main Feb 14, 2024
@Priya2698 Priya2698 deleted the pm/mma_default branch February 14, 2024 22:19
@jjsjann123 jjsjann123 mentioned this pull request Feb 15, 2024
2 tasks
@jjsjann123
Copy link
Collaborator

FYI, this pr seems to break CI on V100.

@Priya2698
Copy link
Collaborator Author

Thanks for pointing this out. Let's revert this until I identify the patch.

Priya2698 added a commit that referenced this pull request Feb 15, 2024
@jjsjann123
Copy link
Collaborator

FYI, these are the failing tests from CI:

00:00:43 [  FAILED  ] 3 tests, listed below:
00:00:43 [  FAILED  ] MatmulATenEvaluationTest.SingleMmaOp
00:00:43 [  FAILED  ] MatmulATenEvaluationTest.MmaOpAndCast
00:00:43 [  FAILED  ] MatmulATenEvaluationTest.MatmulWithBias

@Priya2698
Copy link
Collaborator Author

Yes, these are the tests I added to check functionality.
Looking into it.

@wujingyue
Copy link
Collaborator

Thanks for pointing this out. Let's revert this until I identify the patch.

Thank you! I can't say enough great things about revert-and-debug-later!

jjsjann123 pushed a commit that referenced this pull request Feb 15, 2024
Reverts #1743.

This is breaking on V100.
@Priya2698 Priya2698 restored the pm/mma_default branch February 15, 2024 01:18
@Priya2698
Copy link
Collaborator Author

Priya2698 commented Feb 15, 2024

Are we looking to support V100 through the default path? CC: @kevinstephano
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD can be used in the tests to fix the error. The error surfaced because matmulScheduler does not support V100 (See:

inline std::optional<MmaMacro> getMmaOp(
).

If we wish to support V100, we can have appropriate checks in the heuristic verification to allow other architectures when we are using expression evaluator.

@wujingyue
Copy link
Collaborator

Are we looking to support V100 through the default path?

I'd do it. This is actually something that can be supported with much less effort in fallback mode than in codegen. Sounds like a low-hanging fruit to me.

@wujingyue
Copy link
Collaborator

Are we looking to support V100 through the default path? CC: @kevinstephano NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD can be used in the tests to fix the error. The error surfaced because matmulScheduler does not support V100 (See:

inline std::optional<MmaMacro> getMmaOp(

).
If we wish to support V100, we can have appropriate checks in the heuristic verification to allow other architectures when we are using expression evaluator.

FYI, I suspect it's not just V100. https://github.com/NVIDIA/Fuser/actions/runs/7910657123/job/21593586846 seems to be the same error but for H100.

@protonu
Copy link
Collaborator

protonu commented Feb 15, 2024

On option could be use to use something like this (https://github.com/NVIDIA/Fuser/blob/88727dc828684f5a62d7f1837a610b7589f629d1/test/test_combine_mul_sum.cpp#L40C1-L57C3) to reduce the machines these tests run on, in case you guys don't plan on adding support for V100/H100.

@Priya2698
Copy link
Collaborator Author

Priya2698 commented Feb 15, 2024

On option could be use to use something like this (https://github.com/NVIDIA/Fuser/blob/88727dc828684f5a62d7f1837a610b7589f629d1/test/test_combine_mul_sum.cpp#L40C1-L57C3) to reduce the machines these tests run on, in case you guys don't plan on adding support for V100/H100.

Thanks for the suggestion, I am moving forward with supporting any architecture since it's simple enough.
For this PR, the fix will be to skip the heuristic computation when the MatmulExprEval is set. In the next PRs, I'll refactor some code to always use the default scheduling for any unsupported architecture.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants