Skip to content

Do not assume mul-sum pattern inputs are TensorView#2293

Merged
jacobhinkle merged 1 commit intomainfrom
matmul_pattern_match_scalar_mul_bugfix
May 23, 2024
Merged

Do not assume mul-sum pattern inputs are TensorView#2293
jacobhinkle merged 1 commit intomainfrom
matmul_pattern_match_scalar_mul_bugfix

Conversation

@jacobhinkle
Copy link
Collaborator

This was a change I made to handle casts that wound up breaking some tests and benchmarks in #2272, leading to dynamic cast errors or segfaults. The solution is to test the type of the left and right hand sides before processing the pattern matching.

This was a change I made to handle casts that wound up breaking some
tests and benchmarks in #2272, leading to dynamic cast errors or
segfaults. The solution is to test the type of the left and right hand
sides before processing the pattern matching.
@jacobhinkle jacobhinkle requested review from naoyam and xwang233 May 23, 2024 17:30
@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle
Copy link
Collaborator Author

jacobhinkle commented May 23, 2024

Multidevice tests are failing. The fusion looks like this

TensorView* a = makeContigTensor(3, DataType::Half); // (Mo,Mi,K)
TensorView* b = makeContigTensor(2, DataType::Half); // (N,K)
TensorView* a_b = broadcast(a, {false, false, true, false}); // (Mo,Mi,b,K)
TensorView* b_b = broadcast(b, {true, true, false, false}); // (b,b,N,K)
TensorView* ab = mul(a_b, b_b); // (Mo,Mi,N,K)
TensorView* c = sum(ab, {-1}); // (Mo,Mi,N,r)
fusion->addInput(a);
fusion->addInput(b);
fusion->addOutput(c);
// Sharding M dimension
auto all_sharded_tvs = {a, a_b, b_b, ab, c};
for (auto tv : all_sharded_tvs) {
tv->axis(0)->parallelize(ParallelType::DIDx);
tv->setDeviceMesh(mesh);
}
b->setDeviceMesh(mesh);

This creates an outer M dimension that is parallelized as DIDx. We should ignore such dimensions in pattern matching.

@jacobhinkle
Copy link
Collaborator Author

Merging for now to mostly unblock CI then I will address the multidevice problem in another PR.

@jacobhinkle jacobhinkle merged commit 6233d45 into main May 23, 2024
@jacobhinkle jacobhinkle deleted the matmul_pattern_match_scalar_mul_bugfix branch May 23, 2024 23:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants