Skip to content

Enable shared memory reuse in matmul epilogue#770

Merged
jacobhinkle merged 37 commits intomainfrom
smem_epilogue_request_reuse
Aug 31, 2023
Merged

Enable shared memory reuse in matmul epilogue#770
jacobhinkle merged 37 commits intomainfrom
smem_epilogue_request_reuse

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Aug 23, 2023

This uses promoteReuse from #739 and inserts a syncthreads just before the epilogue loop when smem is used for the epilogue, when possible. The matmul heuristic attempts to predict when this will be possible in order to more accurately estimate shared memory usage, and hence occupancy. If we cannot guarantee re-use, we must assume in the heuristic that memory will not be reclaimed, even though it might be when the fusion is lowered.

Shared memory reclamation can only occur if the smem buffers have non-overlapping lifetimes. This is difficult to guarantee before scheduling and lowering. We use cacheAfter to create the a and b smem tiles, but we use cacheBefore for the epilogue smem tile. This means that smem will be used for any downstream uses of a and b but the epilogue smem will have its lifetime restricted to the epilogue itself, regardless of downstream uses of the matrix product.

The uses of a and b can complicate lifetime analysis. Consider a case where both matrices are square and we wish to compute a @ b + a where @ denotes matmul. Since we used a->cacheAfter() to create the smem tile, that smem may be used not only in the matmul but also in the addition in the epilogue. In that case we cannot re-use a for the epilogue smem. A conservative check that there are no other uses of a or b is currently implemented in order to guarantee re-use. A less conservative sufficient (but still not necessary) condition is that any other use of a or b is a producer of b or a respectively; this is not implemented yet.

jacobhinkle and others added 15 commits August 22, 2023 19:30
Previously, we stacked every ForLoop regardless of parallelization. This
meant that when the first few dimensions were left of compute at in the
whole fusion, even if they were parallelized all tensors would have the
same outer live interval. I noticed this for the
AmpereMatmulSmemEpilogue_CUDA tests. In that case if you look at the
generated CUDA it's clearly not true; the outer for loops do not appear
since they are parallelized. This commit fixes this; note that it can
affect all reuse analysis including aliasing even of local memory.
@mmigdal-nv mmigdal-nv requested a review from drzejan2 August 23, 2023 15:33
@csarofeen
Copy link
Collaborator

The check reminds me a lot of how we check persistent buffer usage.

@jacobhinkle jacobhinkle marked this pull request as ready for review August 28, 2023 00:20
@naoyam
Copy link
Collaborator

naoyam commented Aug 28, 2023

@jacobhinkle Do we have tests?

@jacobhinkle
Copy link
Collaborator Author

jacobhinkle commented Aug 28, 2023

@jacobhinkle Do we have tests?

Not yet. I was thinking of modifying the Epilogue* tests to check that memory is reused in the event that params.use_smem_epilogue == true.

const auto blocks_per_sm_with_smem_epilogue = std::min(
shared_memory_available / (smem_a + smem_b + smem_c),
shared_memory_available / total_with_smem_epilogue,
(size_t)blocks_per_sm_by_register);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a case that, if reuse and no-reuse provides the same occupancy, then we don't promote reuse because this will save a sync?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea. In that case I'll also need to guard promoteReuse with params.use_smem_epilogue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, now I understand your comment. Yes we'll need a separate parameter indicating whether to reuse memory or not, in addition to use_smem_epilogue. I'll push something in a moment.

Copy link
Collaborator

@zasdfgbnm zasdfgbnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally LGTM, but will leave it to @drzejan2 for approval. Also, could you run /build/nvfuser_bench --benchmark_filter=Matmul for reporting the perf on these benchmarks, and also work with @mmigdal-nv for a more thorough perf evaluation?

@jacobhinkle
Copy link
Collaborator Author

Generally LGTM, but will leave it to @drzejan2 for approval. Also, could you run /build/nvfuser_bench --benchmark_filter=Matmul for reporting the perf on these benchmarks, and also work with @mmigdal-nv for a more thorough perf evaluation?

Will do!

Copy link
Contributor

@drzejan2 drzejan2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look good to me.

I will approve this MR when use case mentioned by @zasdfgbnm is supported (link).

@jacobhinkle
Copy link
Collaborator Author

I will approve this MR when use case mentioned by @zasdfgbnm is supported (link).

Supported now. The heuristic now holds bool promote_prologue_smem_reuse in addition to use_smem_epilogue. promote_prologue_smem_reuse is only true if re-using smem would increase occupancy, since otherwise we should avoid adding a __syncthreads().

@jacobhinkle
Copy link
Collaborator Author

jacobhinkle commented Aug 29, 2023

Generally LGTM, but will leave it to @drzejan2 for approval. Also, could you run /build/nvfuser_bench --benchmark_filter=Matmul for reporting the perf on these benchmarks, and also work with @mmigdal-nv for a more thorough perf evaluation?

I ran the benchmarks in the background while working on some other stuff. Then I recently realized that the benchmarks do not use smem for the epilogue since the heuristic is not run so I see no difference in perf compared to TOT. Instead, the matmul benchmarks manually set the params. Altering that should probably not go into here, but I will hack it to see what effect it will have on some of the benchmarks.

Comment on lines +3998 to +4001
if (params.promote_prologue_smem_reuse) {
// Check prologue shared memory re-use
TORCH_CHECK(smem_allocs.at(1)->address()->isZero());
TORCH_CHECK(smem_allocs.at(2)->address()->isZero());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is smem_allocs.at(0) A and smem_allocs.at(1) B, and smem_allocs.at(1) C? So B is reusing A's memory, and C is reusing B's memory?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's right. I improved the comment here a bit.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. So we allocate B first, then A. This makes sense. Thanks for the explanation!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, is it guaranteed that which of A vs B is allocated first? If A is allocated first, then will smem_allocs.at(1) no longer be zero here? I think we should remove the check for smem_allocs.at(1).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good question since their lifetimes end at the same time point. We break ties like that by ordering by name(), so B will be ordered after A, leading to this order consistently. However, that's not exactly clear so maybe we could just check that C is placed at 0 and that either A or B is at 0.

Copy link
Contributor

@drzejan2 drzejan2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functionally everything is sound, will approve when a bug in hashing function is fixed.

jacobhinkle and others added 8 commits August 30, 2023 07:43
Co-authored-by: Andrzej Bekas <118676880+drzejan2@users.noreply.github.com>
See note `[Struct Support in PolymorphicValue]` for description, and the
new test `PolymorphicValueTest.Struct` for examples.
)

This updates the root to rfactor propagation in IterType concretization
of dynamic fusions.

Previously, although we only overwrote Symbolic IterDomains in this
step, we still asserted that we could infer an IterType for each I moved
that check so that it is only applied when we need to make a change.

Additionally, we previously propagated Broadcast-only IterDomains as
Symbolic, since we combine with our previous estimate using
promoteIterType. As mentioned in a comment, this means Broadcast gets
propagated as Symbolic. Instead we now only fall back to promoteIterType
when there are multiple input IterTypes to the IterDomain expression.

Fixes #798
@jacobhinkle jacobhinkle merged commit 761eea4 into main Aug 31, 2023
@jacobhinkle jacobhinkle deleted the smem_epilogue_request_reuse branch August 31, 2023 15:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants