Skip to content

IdModel-based indexing: broadcast indexing#2353

Merged
naoyam merged 30 commits intomainfrom
idmodel_indexing_broadcast
Jun 8, 2024
Merged

IdModel-based indexing: broadcast indexing#2353
naoyam merged 30 commits intomainfrom
idmodel_indexing_broadcast

Conversation

@naoyam
Copy link
Collaborator

@naoyam naoyam commented Jun 6, 2024

Stacked on top of #2344. Adds support of broadcast indexing with loop promotion. The main change is just the use of promoted domains in loop and allocation domains.

@naoyam naoyam added the idmodel label Jun 6, 2024
@naoyam naoyam marked this pull request as ready for review June 6, 2024 04:08
@naoyam naoyam requested a review from jacobhinkle June 6, 2024 04:10
@naoyam
Copy link
Collaborator Author

naoyam commented Jun 6, 2024

CC: @zasdfgbnm

Base automatically changed from idmodel_indexing_part1 to main June 6, 2024 20:02
Copy link
Collaborator

@jacobhinkle jacobhinkle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

}

bool TensorIndexer::shouldUseZeroIndex(const ValGroup& loop_group) const {
// All loops in this set are non-parallel, non-concretized broadcast
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if all the axes are broadcast then we should use 0 index, or if the promoted loop has extent 1 (and is not partial). How does "non-parallel" affect this check?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that comment should have been removed. I had some other code above this line. Thanks for catching it.

// Assume consumer-based indexing. Needs to revisit for ops like
// scatter
return ir_utils::getTvOutput(expr)->getLeafDomain();
auto loop_domains = ir_utils::getTvOutput(expr)->getLeafDomain();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another place where we assume all outputs have same domain; in this case leaf domain.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I think lifting that restriction is quite challenging. We would need to change many things, including expression sorting, loop generation etc.

const Expr* expr,
const std::vector<IterDomain*>& index_domains) const;

// Check if the loop index of a a loop group should be always
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Check if the loop index of a a loop group should be always
// Check if the loop index of a loop group should be always

#include <ops/all_ops.h>
#include <scheduler/utils.h>

#include <functional>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::forward is defined in <utility> so <functional> might not be needed.


template <typename... Args>
Val* addExpr(Args&&... args) {
return SimplifyingIrBuilder::addExpr(std::forward<Args>(args)...);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These do help readability. If you're planning to do a lot of structural checking of indices we could consider using user-defined literals like what @zasdfgbnm used in test_expr_simplifier.cpp.

Comment on lines +380 to +384
if (std::all_of(loop_group->begin(), loop_group->end(), [](Val* val) {
return val->as<IterDomain>()->isBroadcast();
})) {
return true;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this sufficient?

  auto leaf_id =
      getLoopPromotion(loop_group->front()->as<IterDomain>(), id_model_);
  leaf_id->isBroadcast();

// Trivial loop
auto leaf_id =
getLoopPromotion(loop_group->front()->as<IterDomain>(), id_model_);
if (!leaf_id->maybePartial() && simplifyExpr(leaf_id->extent())->isOneInt()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we support partial IterDomains?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. I think I just put this mostly by following what we have in kir::ForLoop::isTrivial before the shift removal. Now that it's removed, this seems more confusing than necessary. I'll remove it.

if (!is_loop) {
continue;
}
allocation_domain = getLoopPromotion(allocation_domain, id_model);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isPartitionedLoop uses the parallel type of id instead of the getParallelType of the loop group of id, can this be a problem? Similarly, in line 117 above, we are also using loop_id->getParallelType() instead of the parallel type of the loop group. IIRC, if the loop promotion id is replayed, we will not set its parallelization type.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I have

smem_tv[b, I1] ca_pos(1)
tv[TIDx, I1] = smem_tv[b, I1] + concrete_tv[I0, I1]

then should smem_tv be allocated as [I0, I1] or [I1]?

How about

smem_tv[bTIDx, I1] ca_pos(1)
tv[TIDx, I1] = smem_tv[bTIDx, I1] + concrete_tv[I0, I1]

?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. That's a real problem. We should either always look at a loop group or just propagate parallel types to all IDs, including promotion domains. I think the latter is a simpler solution. Will work on it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I have

smem_tv[b, I1] ca_pos(1)
tv[TIDx, I1] = smem_tv[b, I1] + concrete_tv[I0, I1]

then should smem_tv be allocated as [I0, I1] or [I1]?

How about

smem_tv[bTIDx, I1] ca_pos(1)
tv[TIDx, I1] = smem_tv[bTIDx, I1] + concrete_tv[I0, I1]

?

In both cases, the allocation size should be [I0, I1], right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not. In the first case, [I1] should be enough as long as a proper predicate is added.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what we currently do is we don't inline broadcast domains like I0 of smem_tv. Let me check.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what I was referring to:

https://github.com/NVIDIA/Fuser/blob/main/csrc/tensor_view.cpp#L167-L170

But that only applies to innermost broadcast domains, so in this case it is inlined.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current lowering only allocates [I1]. More specifically, when a domain is a broadcast domain, it's not allocated even when it's promoted. I'd keep this behavior as is.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the code again, this should be already what happens (with #2371). Pure broadcast domains should be filtered out !mayRequireAllocation. I'll add a test.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

Comment on lines +140 to +142
// Loop promotion may affect allocations. Promotions of intermediate
// domains may not be defined correctly. Only consider loop domains
// for now.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
What if we have the above schedule? How should we handle this? Should we just raise an error, or allocate the tv with broadcasting as I0*I1, or allocating it as 5?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the allocation domain is not the loop domain, we don't have any logic other than fully allocating bxI1. If we want to just allocate 5, we could do it by setting the allocation domain as the loop domain.

I wonder what the domains and parallelization would look like with TMA.

I believe this is more of a question on what the allocation domain should be. I think that ideally getAllocationDomains here should be just a trivial function call to tv->getAllocationDomain and each tensor should always have its correct allocation domains.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is more of a question on what the allocation domain should be.

I agree. I think we need a restriction that, each ID in the allocation domain must be either fully inlined or fully not inlined. It can not have an ID coming from a merge of an inlined ID with an uninlined ID, or an ID who is the parent of a split whose outer is inlined but inner not.

I wonder what the domains and parallelization would look like with TMA.

In my mental model, it is similar to above: IDs in the allocation domain must be either a tile or a non-tile, it can not be a mix of both. However, in practice, for now, even if we have a mix, the code should still work (thanks to some existing hack in our system? and the same hack will also make the above example just allocate 5 today?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is more of a question on what the allocation domain should be.

I agree. I think we need a restriction that, each ID in the allocation domain must be either fully inlined or fully not inlined. It can not have an ID coming from a merge of an inlined ID with an uninlined ID, or an ID who is the parent of a split whose outer is inlined but inner not.

Yeah, I think you could say that the mechanism of promotion is making partially inlined domains fully inlined. In the above case, the innermost 5 domain of the broadcast tensor is promoted to the innermost 5 domain of the non-broadcast tensor, meaning it's effectively fully inlined.

I think we have a reasonable understanding of promotions of loop domains. Can we propagate promotions to allocation domains that are not between logical and loop? I guess we also have a similar problem of propagating parallel types from loop domains to allocation domains.

I wonder what the domains and parallelization would look like with TMA.

In my mental model, it is similar to above: IDs in the allocation domain must be either a tile or a non-tile, it can not be a mix of both. However, in practice, for now, even if we have a mix, the code should still work (thanks to some existing hack in our system? and the same hack will also make the above example just allocate 5 today?)

As long as the allocation domain of the tensor is just the loop domain, indexing is trivial thanks to the loop promotion. How it's implemented in the current main branch isn't that different conceptually, but the implementation is quite convoluted since, for example, it always traverses back to logical domains, whereas in this case we can just directly index the loop (=allocation) domains.

@naoyam naoyam merged commit 8443c26 into main Jun 8, 2024
@naoyam naoyam deleted the idmodel_indexing_broadcast branch June 8, 2024 02:17
naoyam added a commit that referenced this pull request Jun 8, 2024
Stacked on top of #2353 

Small changes to allow indexing of tensors with DIDx domains.

CC: @zasdfgbnm @cowanmeg @samnordmann @wujingyue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants