Skip to content

Use matmul in distributed tests#2386

Merged
Priya2698 merged 5 commits intomainfrom
pm/test_dist_matmul
Jun 18, 2024
Merged

Use matmul in distributed tests#2386
Priya2698 merged 5 commits intomainfrom
pm/test_dist_matmul

Conversation

@Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Jun 11, 2024

Issue #2372.
Modifying the tests to use matmul in place of mul-sum.
Note: matmul API requires the logical shapes [M,K] x [K,N] and the output has the same dtype as the input.

@Priya2698 Priya2698 requested review from cowanmeg and wujingyue June 11, 2024 18:22
Copy link
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth adding a check to make sure ATen kicked in? E.g.

EXPECT_FALSE(executors.front().hasCompiledKernel());

Copy link
Collaborator

@cowanmeg cowanmeg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I'm glad the ATen pathway worked out of the box!
I assume the extra size 1 axes in the aten tensors are being treated as batch dimensions which is why it works out ok right?

Comment on lines 244 to 245
std::vector<int64_t> orig_size = {K, M, N};
std::vector<int64_t> new_size = {K, Mo, Mi, N};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually an error from me....but K should be Ko.
I'm surprised it compiled and was correct before.

@wujingyue
Copy link
Collaborator

Thanks! I'm glad the ATen pathway worked out of the box! I assume the extra size 1 axes in the aten tensors are being treated as batch dimensions which is why it works out ok right?

Yes, that is treated as a BatchMatmul, case 5 in https://pytorch.org/docs/stable/generated/torch.matmul.html

@Priya2698
Copy link
Collaborator Author

EXPECT_FALSE(executors.front().hasCompiledKernel());

Is it worth adding a check to make sure ATen kicked in? E.g.

EXPECT_FALSE(executors.front().hasCompiledKernel());

How do I access the executors from MultiDeviceCommunicator?
For these tests, since there will be multiple segments, we need to test for the second segment/executor (first is no-op for transpose).

@Priya2698
Copy link
Collaborator Author

!build

@cowanmeg
Copy link
Collaborator

cowanmeg commented Jun 12, 2024

Sorry for adding this late - can we add back the original TEST_F(DistributedMatmulTest, LayoutTN_NoComms) with mul-sum in addition to the updated one with the matmul node. We should keep one of these tests since they call the MatmulScheduler with sharded TensorViews.

How do I access the executors from MultiDeviceCommunicator?

MultiDeviceExecutor doesn't have accessor functions for executors so we would have to add that first. @wujingyue

Some background:
The MultiDeviceExecutor takes in a Fusion and segments it into compute and communication segments. The compute segments then go through FusionExecutorCache then they go through the standard segmentation.
So we need to do something like get the first "coarse" segment (which will correspond to the local compute at the beginning) then index into the proper segment/executor.

See here for the internals:
https://github.com/NVIDIA/Fuser/blob/main/csrc/multidevice/executor.h#L130
The MultiDeviceExecutor can use either FusionExecutorCache or directly use FusionExecutor for compute segments depending on flags (temporary for development purposes). For the DistributedMatmulTest I set the params to use FusionExecutorCache.

@wujingyue
Copy link
Collaborator

How do I access the executors from MultiDeviceExecutor?

Good question. We could plumb it through but I don't think it's worth the benefits at this moment.

@wujingyue
Copy link
Collaborator

Can we add back the original TEST_F(DistributedMatmulTest, LayoutTN_NoComms) with mul-sum in addition to the updated one with the matmul node. We should keep one of these tests since they call the MatmulScheduler with sharded TensorViews.

Good question!

AFAIK, nvFuser offers three ways (arguably too many, 🤷) to generate a matmul.

  1. fusedMultiplySum, which takes broadcasted operands and generates an MmaOp (for MatmulScheduler)
  2. matmul, which takes non-broadcasted operands and sometimes generates a MatmulOp (for ATen).
  3. The same matmul API, which sometimes generates a decomposed broadcast+mul+sum pattern (e.g. when the inputs are vectors).

The current tests in this file exercise only (2).

I agree it's useful to exercise (1). MatmulScheduler is already DID-capable and eventually we do want nvFuser to generate matmul using the scheduler, so it's good to have tests to maintain that feature. However, I'm not aware that nvFuser can reliably turn a decomposed broadcast+mul+sum into an MmaOp/MatmulOp. Correct me if I'm wrong. So we may need to add new tests that use fusedMultiplySum instead.

I was told by @Priya2698 that (3) is a corner case that will go away. So I'll leave the decision to her whether to cover that as well.

@Priya2698
Copy link
Collaborator Author

  1. fusedMultiplySum, which takes broadcasted operands and generates an MmaOp (for MatmulScheduler)
  2. matmul, which takes non-broadcasted operands and sometimes generates a MatmulOp (for ATen).
  3. The same matmul API, which sometimes generates a decomposed broadcast+mul+sum pattern (e.g. when the inputs are vectors).

The current tests in this file exercise only (2).

I agree it's useful to exercise (1). MatmulScheduler is already DID-capable and eventually we do want nvFuser to generate matmul using the scheduler, so it's good to have tests to maintain that feature. However, I'm not aware that nvFuser can reliably turn a decomposed broadcast+mul+sum into an MmaOp/MatmulOp. Correct me if I'm wrong. So we may need to add new tests that use fusedMultiplySum instead.

I was told by @Priya2698 that (3) is a corner case that will go away. So I'll leave the decision to her whether to cover that as well.

Yes, (3) will be removed. matmul API will always generate the MatmulOp.

For now, from the python frontend API, the only way is to call fd.ops.matmul or similarly a fd.ops.mul + fd.ops.sum. However, thunder will plumb down matmuls using the former.

To exercise the matmul scheduler:

  1. Create a mul-sum pattern or fusedMultiplySum which both convert to MmaOp: Generalize CombineMulSum as MatmulPatterns #2272

  2. @jacobhinkle added translation for the MatmulOp to MmaOp for matmul scheduler: Translate MatmulOp and LinearOp #2236. Use NVFUSER_ENABLE=fuse_matmul which enables the matmul scheduler to take segments with matmul if expression evaluator cannot (for eg: with epilogues). We can also disable expression evaluator using NVFUSER_DISABLE=matmul_expr_eval to only use one or the other.

@wujingyue
Copy link
Collaborator

Thanks for clarifying. I missed

void handle(ReductionOp* rop) override {
-- nvFuser does convert a decomposed mul+sum to MmaOp.

So we have three ways to exercise the matmul scheduler:

  1. mul-sum
  2. fusedMultiplySum
  3. matmul with fuse_matmul enabled.

Wdyt? (3) seems to me the best option: we'll get fd.ops.matmul from Thunder and stick with that for at least O(months). Btw, when are you going to enable fd.ops.matmul for Thunder by default?

@Priya2698
Copy link
Collaborator Author

Wdyt? (3) seems to me the best option: we'll get fd.ops.matmul from Thunder and stick with that for at least O(months).

Yes, we can use the NVFUSER_ENABLE/DISABLE options to control which scheduler to use. It would good to have some way to verify the scheduler heuristic used though. There is no need to plumb down a separate option if I can use MultiDeviceCommunicator->fec_?

Btw, when are you going to enable fd.ops.matmul for Thunder by default?

We are still seeing some issues such as #2354, so it is not enabled by default yet. I plan on running all the thunder benchmarks and tests with this enabled to preemptively identify existing issues first.

@cowanmeg
Copy link
Collaborator

MultiDeviceCommunicator->fec_

We can write an accessor function to return the FusionExecutorCache(s) used by the MultiDeviceRuntime. It will be a vector because each segment has its own FusionExecutorCache. In this case since there's only one compute segment it will only have one entry.

Related translation for the MatmulOp to MmaOp for matmul scheduler we eventually need to update insertResharding pass to handle Matmul, Linear, and SDPA ops to insert set for non-reduction collectives. It's easy to work around by manually inserting set which the tests were already doing.

@Priya2698
Copy link
Collaborator Author

MultiDeviceCommunicator->fec_

We can write an accessor function to return the FusionExecutorCache(s) used by the MultiDeviceRuntime. It will be a vector because each segment has its own FusionExecutorCache. In this case since there's only one compute segment it will only have one entry.

Related translation for the MatmulOp to MmaOp for matmul scheduler we eventually need to update insertResharding pass to handle Matmul, Linear, and SDPA ops to insert set for non-reduction collectives. It's easy to work around by manually inserting set which the tests were already doing.

So currently, what you suggest is retaining one of the original test cases?
Or do we want to switch to only using MatmulOp and use the nvfuser options to pick one of the schedulers?

@cowanmeg
Copy link
Collaborator

Let's keep one of the original test cases.

@Priya2698
Copy link
Collaborator Author

!build

@Priya2698 Priya2698 requested a review from cowanmeg June 15, 2024 00:13
cowanmeg added a commit that referenced this pull request Jun 15, 2024
Add `MultiDeviceExecutor::getFusionExecutorCaches` which returns a
vector of pointers to the multidevice executor's fusion executor caches.

Note, Also renames the field `workspace` to `workspace_` to match naming
conventions.

cc @Priya2698 for #2386
Copy link
Collaborator

@cowanmeg cowanmeg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thank you!

@Priya2698 Priya2698 force-pushed the pm/test_dist_matmul branch from bb469f1 to 49ebf3b Compare June 15, 2024 01:00
@Priya2698
Copy link
Collaborator Author

!build

@Priya2698 Priya2698 deleted the pm/test_dist_matmul branch June 18, 2024 00:51
protonu pushed a commit that referenced this pull request Jun 24, 2024
Add `MultiDeviceExecutor::getFusionExecutorCaches` which returns a
vector of pointers to the multidevice executor's fusion executor caches.

Note, Also renames the field `workspace` to `workspace_` to match naming
conventions.

cc @Priya2698 for #2386
protonu pushed a commit that referenced this pull request Jun 24, 2024
Issue #2372.
Modifying the tests to use `matmul` in place of `mul-sum`.
Note: `matmul` API requires the logical shapes `[M,K] x [K,N]` and the
output has the same dtype as the input.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants