Skip to content

Add TensorView::promoteReuse#739

Merged
jacobhinkle merged 35 commits intomainfrom
request_smem_reuse
Aug 23, 2023
Merged

Add TensorView::promoteReuse#739
jacobhinkle merged 35 commits intomainfrom
request_smem_reuse

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Aug 17, 2023

As of #703, nvfuser is able to reuse shared memory even when the indexes don't match. However, this requires block synchronization to occur in the appropriate places. That PR did not provide a mechanism for inserting synchronization. This PR addresses this by adding the method TensorView::promoteReuse(), which sets the promote_reuse_ flag on the tensor. The reuseMemoryAllocations lowering pass recognizes that flag and ensures synchronizations are available after the tensor's last use but before the next smem allocation is written to, so that memory reuse can occur. That pass now looks like:

  1. Find shared or local memory tensors that can be "aliased". That is, their lifetimes don't overlap and they have the equivalent index expressions, so the later one can be replaced with a reference to the first.
  2. Find shared memory tensors which we should promote for re-use, i.e. those with the promote_reuse_ flag set. These determine intervals between their last read and the next first write of another smem tensor; we check for syncing expressions within those intervals. Currently we insert a sync at the end of the interval if we don't find a pre-existing one, but we could change these to arrive/wait barriers in future work.
  3. Do shared memory allocation as introduced in Stack-based shared memory allocator #703. The new syncs introduced in step 2 are now recognized and memory is reclaimed as requested.

Note that currently promoteReuse can be called on any TensorView, but it only has an effect on shared memory tensors.

Previously, we stacked every ForLoop regardless of parallelization. This
meant that when the first few dimensions were left of compute at in the
whole fusion, even if they were parallelized all tensors would have the
same outer live interval. I noticed this for the
AmpereMatmulSmemEpilogue_CUDA tests. In that case if you look at the
generated CUDA it's clearly not true; the outer for loops do not appear
since they are parallelized. This commit fixes this; note that it can
affect all reuse analysis including aliasing even of local memory.
Comment on lines +768 to +772
// Parallelized loops do not result in for loops in the CUDA kernel, so
// they should not affect liveness analysis. This means that
// current_stack_ will differ from kir::IrVisitor::for_loops_, which will
// actually hold all ForLoops regardless of parallelization.
current_stack_.push_back(loop_info);
Copy link
Collaborator Author

@jacobhinkle jacobhinkle Aug 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed this when testing with matmuls; see FusionAmpereMatmulSmemEpilogue_CUDA for example. In that case the kernel IR starts with

T9_s[ ... ] ca_pos( 2 ) = ALLOCATE(buffer=T9_s[ ... ] ca_pos( 2 ), mem_type=shared, size=8192, zero_init=false)
T8_s[ ... ] ca_pos( 2 ) = ALLOCATE(buffer=T8_s[ ... ] ca_pos( 2 ), mem_type=shared, size=4096, zero_init=false)
T7_s[ ... ] ca_pos( 2 ) produce_pos( 2 ) = ALLOCATE(buffer=T7_s[ ... ] ca_pos( 2 ) produce_pos( 2 ), mem_type=shared, size=8192, zero_init=false)
FOR blockIdx.x in iblockIdx.x71{( ceilDiv(T0.logical_size[0], 64) )}:
  FOR blockIdx.y in iblockIdx.y73{( ceilDiv(T1.logical_size[1], 128) )}:
    ...
    FOR threadIdx.z in ithreadIdx.z84{( ceilDiv(( ( ceilDiv(64, 4) ) * 4 ), 32) )}:
      FOR threadIdx.y in ithreadIdx.y86{( ceilDiv(( ( ( ceilDiv(( ceilDiv(128, 8) ), 4) ) * 4 ) * 8 ), 32) )}:
        // All writes and reads of T7, T8, T9 are in this loop
        FOR i806 in iS88{( ceilDiv(32, 16) )}:
          ...

Here we see that only the i806 loop will be present in the actual CUDA code. The parallelized loops cover the entire kernel, so without this change, the outer live interval of every tensor is just the range of the outermost BIDx loop. With this change, we properly compute the live intervals relative to loops that should occur in the kernel.

This change is related to this PR but not strictly necessary for it. We could always remove it for now and reintroduce it when we make the change to the matmul scheduler.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the example at

//! Find the loop level of expr that apears in the same scope as
//! the reference allocate. Eg.
//!
//! For ...
//! For ...
//! Allocate <---- reference arg
//! For ..
//! For ...
//! For ... <---- this function returns `ScopeInfo` for this loop
//! For ...
//! expr <---- current expr (implied in current_stack_ and
//! current_pos_ )
//! Assumes that expr either writes to or reads from the reference allocate.

What happens if I have:

  //!  For ...
  //!    For ...
  //!      Allocate    <---- reference arg
  //!      For ..
  //!          For ...
  //!      For blockIdx.x in blockDim.x <---- Will this function returns `ScopeInfo` for this loop?
  //!          For ...
  //!             expr  <---- current expr (implied in current_stack_ and current_pos_ )

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the For blockIdx.x in blockDim.x loop is parallelized, it will not appear in the kernel. So in that case this should look like

  //!  For ...
  //!    For ...
  //!      Allocate    <---- reference arg
  //!      For ..
  //!          For ...
  //!      // For blockIdx.x in blockDim.x <---- This loop does not appear in the CUDA code, so it is ignored
  //!      For ...  <---- This function returns `ScopeInfo` for this loop
  //!         expr  <---- current expr (implied in current_stack_ and current_pos_ )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jacobhinkle Could you remind me what the problem is with:

so without this change, the outer live interval of every tensor is just the range of the outermost BIDx loop.

Copy link
Collaborator Author

@jacobhinkle jacobhinkle Aug 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that in that case we actually could do "outer aliasing", since in the cuda kernel the live intervals do not overlap. However, because the trivial BIDx loop surrounds the whole kernel, without this change the outer live interval of any allocation is simply the span of that BIDx loop.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this could affect many more cases (positively). Can you split out this change from this PR and also see if how the aliasing would change with the benchmarks?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, there's a couple of places we do current_stack_.back() and I wonder they are all safe. If all loops are parallelized, wouldn't the stack just be empty?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I will go ahead and split this off into another PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to #766.

@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle
Copy link
Collaborator Author

jacobhinkle commented Aug 18, 2023

Adding the following after the line https://github.com/NVIDIA/Fuser/blob/main/csrc/scheduler/matmul.cpp#L970 results in shared memory reuse in matmuls with smem epilogues:

smem_epilogue->requestReuse(acw_smem);
// This next line is unnecessary since both tensors have same outer live interval. Both tensors
// will be reclaimed as long as either is requested for reuse.
//smem_epilogue->requestReuse(bcw_smem);

I'm leaving that for a follow-on PR.

@jacobhinkle jacobhinkle changed the title [WIP] Add TensorView::requestReuse Add TensorView::requestReuse Aug 18, 2023
jacobhinkle and others added 4 commits August 18, 2023 11:47
NeedsReorderedPush actually had the lifetimes not quite overlapping. New
version is simpler I think.
@jacobhinkle jacobhinkle marked this pull request as ready for review August 18, 2023 17:21
@jacobhinkle jacobhinkle marked this pull request as draft August 18, 2023 18:54
@jacobhinkle jacobhinkle changed the title Add TensorView::requestReuse Add TensorView::promoteReuse Aug 21, 2023
@jacobhinkle jacobhinkle marked this pull request as ready for review August 21, 2023 16:50
return all_allocations_;
}

std::optional<AllocationInfo*> getMaybeAllocInfoFromTV(TensorView* tv) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: return nullptr when not found, so we don't have to use std::optional.

void setAlias(AllocationInfo* from, AllocationInfo* to) {
alias_map_[from] = to;
from->alias_to = to->alloc_expr;
to->outer_aliased_by.push_back(from);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if B alias A, and C alias B, will A.outer_aliased_by has both B and C?

Copy link
Collaborator Author

@jacobhinkle jacobhinkle Aug 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Currently two hop aliases are not assigned but they could be in the future possibly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added an assertion here that to->alias_to is null.

Comment on lines +768 to +772
// Parallelized loops do not result in for loops in the CUDA kernel, so
// they should not affect liveness analysis. This means that
// current_stack_ will differ from kir::IrVisitor::for_loops_, which will
// actually hold all ForLoops regardless of parallelization.
current_stack_.push_back(loop_info);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the example at

//! Find the loop level of expr that apears in the same scope as
//! the reference allocate. Eg.
//!
//! For ...
//! For ...
//! Allocate <---- reference arg
//! For ..
//! For ...
//! For ... <---- this function returns `ScopeInfo` for this loop
//! For ...
//! expr <---- current expr (implied in current_stack_ and
//! current_pos_ )
//! Assumes that expr either writes to or reads from the reference allocate.

What happens if I have:

  //!  For ...
  //!    For ...
  //!      Allocate    <---- reference arg
  //!      For ..
  //!          For ...
  //!      For blockIdx.x in blockDim.x <---- Will this function returns `ScopeInfo` for this loop?
  //!          For ...
  //!             expr  <---- current expr (implied in current_stack_ and current_pos_ )

void handle(kir::ForLoop* for_loop) final {
auto loop_info = scope_map_.getLoopScopeInfo(for_loop);
current_stack_.push_back(loop_info);
if (!for_loop->iter_domain()->isParallelized()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we instead use for_loop->isTrivial()? There are other trivial loops not generated in codegen, for example, vectorization loop, and we should handle all of them equivalently.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much better. Thanks.

//! is present in the kernel to reuse memory and inserts new block
//! synchronizations if necessary.
void promoteReuse(bool b = true) {
promote_reuse_ = b;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assert here that this is a shared memory tensor?

Comment on lines +768 to +772
// Parallelized loops do not result in for loops in the CUDA kernel, so
// they should not affect liveness analysis. This means that
// current_stack_ will differ from kir::IrVisitor::for_loops_, which will
// actually hold all ForLoops regardless of parallelization.
current_stack_.push_back(loop_info);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jacobhinkle Could you remind me what the problem is with:

so without this change, the outer live interval of every tensor is just the range of the outermost BIDx loop.


//! Returns whether we should insert syncs if needed in order to reuse the
//! memory of this tensor.
bool getPromoteReuse() const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Is this a better name? shouldPromoteReuse

std::unordered_multimap<int, int> sync_intervals_;

// Position within the traversal
int position_ = -1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be a member?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not. This was leftover from a refactor. Fixing..

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines +1898 to +1903
if (inserted_syncs_.find(expr) != inserted_syncs_.end()) {
if (isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo)) {
debug() << "Skipping new sync expression " << expr->toString();
}
kir::ExprMutator::dispatch(expr);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, traverseAndInsert will only insert after the entire traverse is done, which means, inside here, we will never see an inserted sync? (If it did insert on the fly, then should we recompute AllocationInfoMap every time when we register an insertion?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right it should not insert until afterward; this check is not needed. I had gotten a segfault at one point and placed this guard there to diagnose, but I think that was a logic error that wound up calling kir::ExprMutator::dispatch. I'll verify that none of the tests are hitting it and remove if not.

auto tv7 = neg(tv5); // pos = f
fusion->addOutput(tv7);

{ // This should not re-use memory
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems there's something missing here. This fusion is not parallelized at all, so why does it need a syncthreads to reuse the memory?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even when parallelized, no syncthreads should be necessary for fusions like:

__shared__ float X[N];
...
auto t0 = X[threadIdx.x]; // last read of X
...
X[threadIdx.x] = t1; // reuse X without syncthreads
...

Copy link
Collaborator Author

@jacobhinkle jacobhinkle Aug 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That case would be an inner alias and is supported without syncing. Outer alias, when there are separate loops for the two allocations, is not supported and requires sync. We could potentially handle that case without syncing too but it might require more machinery for proving indices are equivalent.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this test all three smem allocations have different size, so they cannot be aliased.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They don't need to be aliased. I think this is more about the stack-based reuse logic. Since no tensor is parallelized, we should be able to freely pop allocated tensors without a sync.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry I thought you meant this tensor is not parallelized. If no tensors are parallelized you are right that we wouldn't need any syncs. Do we need to handle that case?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completely serial cases would be fine to ignore as they are just a synthetic example I came up with.

But how about cases like this?

__shared__ float X[blockDim.x * 4];
...
for (int i = 0; i < 4; ++i) {
   auto t0 = X[threadIdx.x + blockDim.x * i]; // last read of X
}
...
for (int i = 0; i < 4; ++i) {
   X[threadIdx.x + blockDim.x * i] = t1; // reuse X without syncthreads
}

I actually don't remember all the details and differences of the inner and outer sharing, but isn't this case be outer sharing? And if so, no reuse is allowed without a sync, right?

I think what's missing here is probably something like what we do for the RAW sync insertion. We use the CA maps to analyze if a read after a write requires a sync. See for example: https://github.com/NVIDIA/Fuser/blob/main/csrc/device_lower/analysis/sync_information.cpp#L234

It's quite complicated and also it's one of those that we would be able to simplify a lot with the new ID graph. So, I think it's fine to leave this as a limitation for now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. Yes in this case there are non-overlapping sets of elements written/read across threads, so it's safe. It's definitely very similar analysis to what's done in SyncMap; I had originally thought we might just augment SyncMap to insert these re-use syncs even. Sounds good on leaving it for now and revisiting it after ID graphs are complete.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you file it as a TODO issue?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drafted an issue here: #769

Comment on lines +199 to +206
for (auto alloc : gpulw.kernel()->summary().dynamic_smem_allocations) {
EXPECT_NE(alloc->address(), nullptr);
auto addr = ee.evaluate(alloc->address()).as<int64_t>();
auto size = ee.evaluate(alloc->size()).as<int64_t>() *
dataTypeSize(alloc->buffer()->dtype());
smem_usage = std::max(smem_usage, addr + size);
}
EXPECT_EQ(smem_usage, alignInt((H + 1) * 4) + (H + 1) * 4);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just directly validate the alias relationship of each allocation? I understand checking the total size is also fine, but asserting the alias relationships would make the intention of the test and expected behavior much more clear.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since they are not aliased (i.e. the kir::Allocates have null alias()), we can't compare those directly. We can probably compare the addresses instead though. I will give it a shot.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that's true. Then maybe not worth spending much time.

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jacobhinkle jacobhinkle merged commit 268a63f into main Aug 23, 2023
@jacobhinkle jacobhinkle deleted the request_smem_reuse branch August 23, 2023 02:07
jacobhinkle added a commit that referenced this pull request Aug 31, 2023
This uses `promoteReuse` from #739 and inserts a syncthreads just before
the epilogue loop when smem is used for the epilogue, when possible. The
matmul heuristic attempts to predict when this will be possible in order
to more accurately estimate shared memory usage, and hence occupancy. If
we cannot guarantee re-use, we must assume in the heuristic that memory
will not be reclaimed, even though it might be when the fusion is
lowered.

Shared memory reclamation can only occur if the smem buffers have
non-overlapping lifetimes. This is difficult to guarantee before
scheduling and lowering. We use `cacheAfter` to create the `a` and `b`
smem tiles, but we use `cacheBefore` for the epilogue smem tile. This
means that smem will be used for any downstream uses of `a` and `b` but
the epilogue smem will have its lifetime restricted to the epilogue
itself, regardless of downstream uses of the matrix product.

The uses of `a` and `b` can complicate lifetime analysis. Consider a
case where both matrices are square and we wish to compute `a @ b + a`
where `@` denotes matmul. Since we used `a->cacheAfter()` to create the
smem tile, that smem may be used not only in the matmul but also in the
addition in the epilogue. In that case we cannot re-use a for the
epilogue smem. A conservative check that there are no other uses of `a`
or `b` is currently implemented in order to guarantee re-use. A less
conservative sufficient (but still not necessary) condition is that any
other use of `a` or `b` is a producer of `b` or `a` respectively; this
is not implemented yet.

---------

Co-authored-by: Andrzej Bekas <118676880+drzejan2@users.noreply.github.com>
Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com>
Co-authored-by: Wang, Xiao <24860335+xwang233@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants