Skip to content

Generalize CombineMulSum as MatmulPatterns#2272

Merged
jacobhinkle merged 33 commits intomainfrom
matmul_patterns
May 23, 2024
Merged

Generalize CombineMulSum as MatmulPatterns#2272
jacobhinkle merged 33 commits intomainfrom
matmul_patterns

Conversation

@jacobhinkle
Copy link
Collaborator

This replaces the CombineMulSum class with MatmulPattern in the Matmul scheduler. Additionally, we use these matmul patterns to determine the problem layout, IterDomain roles, and TensorView roles. The allocation domain is used to determine the problem layout. The matmul scheduler is updated to reject segments whose input allocation domains are non-trivial (until that is supported eg. by #2226).

Note that this does not add handling of MatmulOp and LinearOp in the matmul scheduler. That will be done next in #2236 or similar.

This also uses IdModel to find IterDomain and Tensor roles, and checks
allocation domain to find problem layout. We guard the matmul tensor to
reject problems that have non-trivial input allocation domains.
This will go in another PR
return 1;
}

bool hasTrivialAllocationDomain(const TensorView* tv) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intention of this utility is to generalize !tv->hasAllocation() to cases where an allocation domain is provided, but it actually corresponds to the no-reductions rfactor domain (ignoring broadcasts).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Would it be possible to add this as a comment in the header file.

const std::string& desc) {
// TODO: revise rules when add support for batch gemms
NVF_ERROR(details.bcasts.empty(), desc, ": has broadcast domains.");
// NVF_ERROR(details.bcasts.empty(), desc, ": has broadcast domains.");
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixes #2273. See test AmpereMulSumToMatmul_MultipleBroadcasts

@jacobhinkle jacobhinkle marked this pull request as ready for review May 20, 2024 19:39
@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle
Copy link
Collaborator Author

!build

//! and LinearOp, the output is the same dtype as the inputs; so output does not
//! necessarily correspond to the output of a translated MmaOp and it might not
//! be a fusion output.
struct MatmulPattern {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename this API to be more intuitive. MatmulPattern makes me think of one of MmaOp/MatmulOp/LinearOp but this represents a description instead.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the next PR it will include MatmulOp and LinearOp and perform that translation, so it really is meant to generalize those.

MatmulRole role) -> InnerDomResult {
const auto role_it = roles_map.find(role);
if (role_it == roles_map.end()) {
return {MatmulDomain::M, "Could not find role in roles_map"};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message here is confusing. Why is this MatmulDomain::M? Inner dimension can be M/N. Similarly for the other errors in this lambda.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right and this InnerDomResult business is a hack to get around the design of DataWrapperOpt. I want to just return the error message in this case, but using DataWrapperOpt doesn't work properly in clang (haven't tried gcc) when the wrapped type is trivially copyable, since it balks at using std::move on such types.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about using using InnerDomResult = std::pair<std::optional<MatmulDomain>, std::string>; and returning std::nullopt instead of MatmulDomain::A

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just pushed a change that uses a bare variant<std::string, UnitDim>.

jacobhinkle and others added 4 commits May 21, 2024 14:01
We can still refuse to schedule, but these are valid patterns
Co-authored-by: Priya Mishra <52657555+Priya2698@users.noreply.github.com>
@zasdfgbnm
Copy link
Collaborator

Could you rebase this PR? I see obsolete code removed by #2268 in this PR.

MatmulRole role) -> InnerDomResult {
const auto role_it = roles_map.find(role);
if (role_it == roles_map.end()) {
return {MatmulDomain::M, "Could not find role in roles_map"};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about using using InnerDomResult = std::pair<std::optional<MatmulDomain>, std::string>; and returning std::nullopt instead of MatmulDomain::A

// (bit 2)
using ValGroupPresence = std::bitset<3>;

std::unordered_map<ValGroup, ValGroupPresence> present_flags;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is membership_flags a better name since your comment uses that terminology?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I just renamed ValGroupPresence to DimPresence and updated the comment to no longer mention "membership", since I think that's a little more opaque term than "presence". What we really care about is whether a dimension is present in each tensor so I think that term is clearer.

if (has_m && has_n) {
storage.push_back(entry.first);
}
// NOTE: sort output roles in descending order by uses() size, and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is sorting the tvs important -- is there a place where we rely on this ordering to be same everywhere?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, the order is important because we want deterministic behavior. Otherwise there will be a slight change in the variable names in the generated code from run to run.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly. We sometimes iterate over roles TVs. If we did not sort, then the ->name()s of introduced Vals would be ordered arbitrarily, for example. Maintaining deterministic compiled code is helpful for the codediff tool, and also for keeping our sanity when debugging problem fusions.

jacobhinkle added a commit that referenced this pull request May 22, 2024
Thanks to the suggestion by @zasdfgbnm while reviewing #2272, I found
some additional cases where we took a short-cut to updating bools using
the bitwise assignment ops. This is not ideal since its behavior is
undefined (there's no guarantee that the underlying representation of
`true` is `b1` and not `b10` or any other non-zero value). More
importantly, writing `a |= b` as `a = a || b` allows us to short-circuit
if `a == true`. Using bitwise `a |= b`, `b` will always be evaluated.
mma_utils::MatmulPattern& pattern = patterns.front();

// IdModel is used to analyze problem shape & layout
IdModel id_model(fusion);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is getMatmulHeuristics a hot path? We may want to cache the IdModel (in a separate PR), like:

auto domain_map_entry =
HeuristicSummaryEntry<HeuristicCompileTime::DomainMap>(
data_cache,
[fusion]() { return std::make_unique<DomainMap>(fusion); });
const auto& domain_map = dynamic_cast<DomainMap&>(domain_map_entry.get());

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea. Currently I build an IdModel in getMatmulHeuristics then again in scheduleMatmul, and actually I also need to rebuild it after translating MatmulPatterns to MmaOps. I think this use case is very similar to the DomainMap in the pointwise scheduler that you linked to; we just use IdModel instead of ComputeAtMap.

@jacobhinkle jacobhinkle merged commit 9153612 into main May 23, 2024
@jacobhinkle jacobhinkle deleted the matmul_patterns branch May 23, 2024 00:14
jacobhinkle added a commit that referenced this pull request May 23, 2024
This was a change I made to handle casts that wound up breaking some
tests and benchmarks in #2272, leading to dynamic cast errors or
segfaults. The solution is to test the type of the left and right hand
sides before processing the pattern matching.
jacobhinkle added a commit that referenced this pull request May 23, 2024
This was a change I made to handle casts that wound up breaking some
tests and benchmarks in #2272, leading to dynamic cast errors or
segfaults. The solution is to test the type of the left and right hand
sides before processing the pattern matching.
jacobhinkle added a commit that referenced this pull request May 24, 2024
This fixes a bug introduced by #2272 in `test_multidevice` where we
reject a matmul segment shaped like `[iDIDxMo, iMi, bN, iK]` for having
too many M dimensions. Locally this still has a single M dimension so it
is valid. This PR ignores device dims for the purposes of computing
tensor roles and problem shape.

Further issues we should look into:
1. As mentioned in #2272 we should proceed to handle multiple M, N, K,
   and Batch dimensions, although in this case the restriction was
   useful for surfacing this bug.
2. Even if the matmul scheduler is completely broken or disabled, the
   _reduction_ scheduler should have been able to schedule this fusion.
   However, it identified the reduction tensor as `isResharding` and
   removed it from the `reduction_tvs` list, causing a failure in
   `scheduleReduction`. We should clean up that check to be able to
   schedule this type of fusion as a reduction.
3. The rfactor domain is often used for scheduling utilities to inspect
   the logical size of tensors. However, because multidevice scheduling
   modifies the leaf domain before segmentation, we should probably
   audit our schedulers to ensure they use the leaf domain and ignore
   device dims where necessary.
4. I should also not forget to rerun `!build` before merging PRs :-).
jacobhinkle added a commit that referenced this pull request May 27, 2024
This fixes a bug introduced by #2272 in `test_multidevice` where we
reject a matmul segment shaped like `[iDIDxMo, iMi, bN, iK]` for having
too many M dimensions. Locally this still has a single M dimension so it
is valid. This PR ignores device dims for the purposes of computing
tensor roles and problem shape.

Further issues we should look into:
1. As mentioned in #2272 we should proceed to handle multiple M, N, K,
and Batch dimensions, although in this case the restriction was useful
for surfacing this bug.
2. Even if the matmul scheduler is completely broken or disabled, the
_reduction_ scheduler should have been able to schedule this fusion.
However, it identified the reduction tensor as `isResharding` and
removed it from the `reduction_tvs` list, causing a failure in
`scheduleReduction`. We should clean up that check to be able to
schedule this type of fusion as a reduction.
3. Inside the matmul scheduler we call `canonicalizeMmaTvOrdering` which
I believe still uses rfactor domain to determine domain ordering.
Instead this should be updated to use dim roles that are already
computed from the `MatmulPattern`.
4. The rfactor domain is often used for scheduling utilities to inspect
the logical size of tensors. However, because multidevice scheduling
modifies the leaf domain before segmentation, we should probably audit
our schedulers to ensure they use the leaf domain and ignore device dims
where necessary.
5. I should also not forget to rerun `!build` before merging PRs
:sweat_smile:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants