Strengthen index simplification for cast epilogue matmul#1827
Strengthen index simplification for cast epilogue matmul#1827jacobhinkle merged 28 commits intomainfrom
Conversation
Still need to add a test for the problem that's fixed by the slow recursive approach.
|
|
|
!build --diff |
|
Changing the size in I also measured after manually performing the last optimization ( |
That's a lot! Thanks for measuring this!
Even if there is no visible perf improvement, I still suggest go ahead and implementing it because:
|
|
Summary: As of now, this PR causes a 27% slowdown in total runtime for the test mentioned in the PR description compared to The latest pushed change switches from unordered_set method to just limiting the recursion depth directly to 2. Release mode test timing: The Release compilation mode seems to make a big difference. Debug build timings: |
|
I'm now trying to write a good test for this PR, then I will clean up and mark it ready. |
This reverts commit 713ead0.
|
!build --diff |
These are apparently not needed for the index simplification in this PR.
|
Recently-pushed change strengthens // main loop
}
__syncthreads();
#pragma unroll
for(nvfuser_index_t i117 = 0; i117 < 4; ++i117) {
nvfuser_index_t i118;
i118 = 32 * i117;
nvfuser_index_t i119;
i119 = i53 + (2048LL * i117);
#pragma unroll
for(nvfuser_index_t i120 = 0; i120 < 8; ++i120) {
nvfuser_index_t i121;
i121 = i118 + (4 * i120);
nvfuser_index_t i122;
i122 = i30 + i120;
nvfuser_index_t i123;
i123 = (i119 + (32LL * (i122 / 4))) + (8LL * (i54 ^ (i122 % 4)));
#pragma unroll
for(nvfuser_index_t i124 = 0; i124 < 2; ++i124) {
loadGeneric<float, 2>( &T8[(i123 + (1024LL * i124))], &T3[(i121 + (2LL * i124))]);
}
}
}
__syncthreads();
#pragma unroll
for(nvfuser_index_t i125 = 0; i125 < 16; ++i125) {
nvfuser_index_t i126;
i126 = i55 + (1024 * i125);
Array<__half, 8, 8> T7;
#pragma unroll
for(nvfuser_index_t i127 = 0; i127 < 8; ++i127) {
T7[i127]
= __float2half(T8[(i126 + i127)]);
}
if ((b58 && (i67 < (-(8 * i125))))) {
loadLocalToGlobal<__half, /*vec_size=*/8, /*is_volatile=*/false>( &T4[(i57 + (i10 * i125))], &T7[0]);
}
}
}However, it feels a bit slower. I'm going to evaluate the compile time with and without this change. |
On my machine this takes overall test time from 1.9 to 3.9 seconds in a release build and from 7.4 to 15 seconds in debug build. I think it's probably best to leave this new optimization mentioned in the last comment for another PR so that I can play more with speeding it up. For now, I will revert it and ensure the timing makes sense then hopefully we can merge. |
|
With reverted commit, compile times are back down. If tests pass, I think this PR is ready. |
|
!build --diff |
|
!build --diff |
I experimented and found that moving this to the beginning of the function actually affects correctness due to the checks preceding this loop. I can increase depth to work around that but then runtime increases due to increased recursion. This change keeps functionality we had previously instead.
|
!build --diff |
|
After #1972 the compile time is down to 1.3s. Even better, since it re-used lots of proofs, re-enabling the reverted simplification actually doesn't appreciably change compile time! |
|
!build --diff-bench |
A trivial modulus operation is simplified away now.
| float T2[1LL]; | ||
| T2[0LL] | ||
| = T1[(i15 % 2LL)]; | ||
| = T1[i15]; |
| // This doesn't simplify at all | ||
| // EXPECT_VALUE_TRUE(simplifyExpr("neg( 8 ) < neg( i0 )"_, {}, {"i0 < 8"_})); | ||
|
|
||
| // This doesn't simplify at all | ||
| // EXPECT_VALUE_TRUE(simplifyExpr("neg( i0 ) < 0"_, {}, {"0 < i0"_})); |
There was a problem hiding this comment.
I think these commented out tests could be addressed by implementing a < b implies -b < -a. I don't think we need it urgently, but I plan to experiment with it.
This came up when working on #1770. In a private conversation, @zasdfgbnm noticed wisely that the problematic indexing is really a failure of expression simplification; if we could fully simplify the swizzling expression it could be entirely hoisted and we would be left with a nice clean linear index for the smem buffer in the epilogue loop.
This is
NVFuserTest.FusionAmpereMatmulSmemEpilogueCast_CUDAonmain:This PR:
If we can also getThis is donei134 % 8simplified toi134andi134 / 8simplified to 0 then this should give a nice and efficient last loop.Currently this PR is super slow (e.g. 101 s vs 8 s on main in debug mode) due to the added recursion. Memoizing past results would be beneficial, but that's a topic for another PR.This PR is no longer slow, thanks to limited recursion depth and #1972.Fixes #1828