Skip to content

Strengthen index simplification for cast epilogue matmul#1827

Merged
jacobhinkle merged 28 commits intomainfrom
expr_simplify_lessthan
Mar 22, 2024
Merged

Strengthen index simplification for cast epilogue matmul#1827
jacobhinkle merged 28 commits intomainfrom
expr_simplify_lessthan

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Feb 23, 2024

This came up when working on #1770. In a private conversation, @zasdfgbnm noticed wisely that the problematic indexing is really a failure of expression simplification; if we could fully simplify the swizzling expression it could be entirely hoisted and we would be left with a nice clean linear index for the smem buffer in the epilogue loop.

This is NVFuserTest.FusionAmpereMatmulSmemEpilogueCast_CUDA on main:

    // main loop
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i123 = 0; i123 < 4; ++i123) {
    nvfuser_index_t i124;
    i124 = 32 * i123;
    nvfuser_index_t i125;
    i125 = i56 + (2048LL * i123);
    #pragma unroll
    for(nvfuser_index_t i126 = 0; i126 < 8; ++i126) {
      nvfuser_index_t i127;
      i127 = i124 + (4 * i126);
      nvfuser_index_t i128;
      i128 = i11 + i126;
      nvfuser_index_t i129;
      i129 = (i125 + (32LL * (i128 / 4))) + (8LL * (i57 ^ (i128 % 4)));
      #pragma unroll
      for(nvfuser_index_t i130 = 0; i130 < 2; ++i130) {
        loadGeneric<float, 2>( &T8[(i129 + (1024LL * i130))],  &T3[(i127 + (2LL * i130))]);
      }
    }
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i131 = 0; i131 < 16; ++i131) {
    nvfuser_index_t i132;
    i132 = i58 + (1024 * i131);
    Array<__half, 8, 8> T7;
    #pragma unroll
    for(nvfuser_index_t i133 = 0; i133 < 8; ++i133) {
      nvfuser_index_t i134;
      i134 = i59 + i133;
      nvfuser_index_t i135;
      i135 = i134 % 128;
      nvfuser_index_t i136;
      i136 = i135 / 8;
      nvfuser_index_t i137;
      i137 = i134 / 128;
      T7[i133]
         = __float2half(T8[((((i132 + (128LL * i137)) + (32LL * (i136 / 4))) + (i135 % 8)) + (8LL * ((i136 % 4) ^ ((i31 + i137) % 4))))]);
    }
    if ((b72 && (i73 < (-(8 * i131))))) {
      loadLocalToGlobal<__half, /*vec_size=*/8, /*is_volatile=*/false>( &T4[(i62 + (i63 * i131))], &T7[0]);
    }
  }
}

This PR:

    // main loop
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i114 = 0; i114 < 4; ++i114) {
    nvfuser_index_t i115;
    i115 = 32 * i114;
    nvfuser_index_t i116;
    i116 = i50 + (2048LL * i114);
    #pragma unroll
    for(nvfuser_index_t i117 = 0; i117 < 8; ++i117) {
      nvfuser_index_t i118;
      i118 = i115 + (4 * i117);
      nvfuser_index_t i119;
      i119 = i12 + i117;
      nvfuser_index_t i120;
      i120 = (i116 + (32LL * (i119 / 4))) + (8LL * (i51 ^ (i119 % 4)));
      #pragma unroll
      for(nvfuser_index_t i121 = 0; i121 < 2; ++i121) {
        loadGeneric<float, 2>( &T7[(i120 + (1024LL * i121))],  &T2[(i118 + (2LL * i121))]);
      }
    }
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i122 = 0; i122 < 16; ++i122) {
    nvfuser_index_t i123;
    i123 = i53 + (1024 * i122);
    Array<__half, 8, 8> T6;
    #pragma unroll
    for(nvfuser_index_t i124 = 0; i124 < 8; ++i124) {
      T6[i124]
         = __float2half(T7[(i123 + i124)]);
    }
    if ((b67 && (i68 < (-(8 * i122))))) {
      loadLocalToGlobal<__half, /*vec_size=*/8, /*is_volatile=*/false>( &T3[(i56 + (i57 * i122))], &T6[0]);
    }
  }
}

If we can also get i134 % 8 simplified to i134 and i134 / 8 simplified to 0 then this should give a nice and efficient last loop. This is done

Currently this PR is super slow (e.g. 101 s vs 8 s on main in debug mode) due to the added recursion. Memoizing past results would be beneficial, but that's a topic for another PR. This PR is no longer slow, thanks to limited recursion depth and #1972.

Fixes #1828

Still need to add a test for the problem that's fixed by the slow
recursive approach.
@jacobhinkle
Copy link
Collaborator Author

x < y => x % y = x is not currently exploited in any of the simplification passes, as far as I can tell. I added a failing test to track that.

@jacobhinkle
Copy link
Collaborator Author

!build --diff

@jacobhinkle jacobhinkle changed the title [WIP] More powerful prove::lessThan [WIP] Strengthen prove::lessThan Feb 24, 2024
@jacobhinkle
Copy link
Collaborator Author

jacobhinkle commented Feb 27, 2024

Changing the size in NVFuserTest.FusionAmpereMatmulSmemEpilogueCast_CUDA to 16384, 16384, 256 (i.e. a large memory-bound problem size), we go from 398 GB/s to 554 GB/s on A100 80GB PCIe. That's a 39% speedup! That roughly matches our internal measurements that output bandwidth for fp16 outputs is about half that of fp32 outputs. Seems like it's worth figuring out the proof slowdown here since this optimization can have a large impact when output writes are a significant portion of runtime.

I also measured after manually performing the last optimization (i134 % 8 -> i134, i134 / 8 -> 0, hoisting) and saw now effect. It's possible these are already performed by the cuda compiler.

@zasdfgbnm
Copy link
Collaborator

That's a 39% speedup!

That's a lot! Thanks for measuring this!

I also measured after manually performing the last optimization (i134 % 8 -> i134, i134 / 8 -> 0, hoisting) and saw now effect. It's possible these are already performed by the cuda compiler.

Even if there is no visible perf improvement, I still suggest go ahead and implementing it because:

  1. It should be easy to implement
  2. It improves code readability

@jacobhinkle
Copy link
Collaborator Author

jacobhinkle commented Mar 1, 2024

Summary: As of now, this PR causes a 27% slowdown in total runtime for the test mentioned in the PR description compared to main. That is an improvement over my original method that showed 2.2x runtime due to a recently pushed change. In debug mode we currently see 3.8x runtime, up from 14x. These are compile time differences; recent testing suggests kernel performance is improved by around 3% due to the simplification.

The latest pushed change switches from unordered_set method to just limiting the recursion depth directly to 2.

Release mode test timing:

recursion depth  time (sec)  Simplified?
      0              1.5         no
      1              1.5         no
      2              1.9         yes
      3              9.3         yes
      4            133           yes
Compare to
 main                1.5         no
 unordered_set       3.3         yes

The Release compilation mode seems to make a big difference.

Debug build timings:

recursion depth  time (sec)  Simplified?
      0              7.4         no
      1              7.5         no
      2             28           yes
      3            422           yes
      4              ?           yes
 main                7.4         no
 unordered_set     104           yes

@jacobhinkle
Copy link
Collaborator Author

I'm now trying to write a good test for this PR, then I will clean up and mark it ready.

@jacobhinkle
Copy link
Collaborator Author

!build --diff

@jacobhinkle jacobhinkle changed the title [WIP] Strengthen prove::lessThan [WIP] Strengthen index simplification for cast epilogue matmul Mar 12, 2024
@jacobhinkle jacobhinkle changed the title [WIP] Strengthen index simplification for cast epilogue matmul Strengthen index simplification for cast epilogue matmul Mar 12, 2024
@jacobhinkle
Copy link
Collaborator Author

Recently-pushed change strengthens eliminateTrivialComputation to simplify a % b and a / b when -|b| < a < |b|. This leads to a very nice kernel:

    // main loop
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i117 = 0; i117 < 4; ++i117) {
    nvfuser_index_t i118;
    i118 = 32 * i117;
    nvfuser_index_t i119;
    i119 = i53 + (2048LL * i117);
    #pragma unroll
    for(nvfuser_index_t i120 = 0; i120 < 8; ++i120) {
      nvfuser_index_t i121;
      i121 = i118 + (4 * i120);
      nvfuser_index_t i122;
      i122 = i30 + i120;
      nvfuser_index_t i123;
      i123 = (i119 + (32LL * (i122 / 4))) + (8LL * (i54 ^ (i122 % 4)));
      #pragma unroll
      for(nvfuser_index_t i124 = 0; i124 < 2; ++i124) {
        loadGeneric<float, 2>( &T8[(i123 + (1024LL * i124))],  &T3[(i121 + (2LL * i124))]);
      }
    }
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i125 = 0; i125 < 16; ++i125) {
    nvfuser_index_t i126;
    i126 = i55 + (1024 * i125);
    Array<__half, 8, 8> T7;
    #pragma unroll
    for(nvfuser_index_t i127 = 0; i127 < 8; ++i127) {
      T7[i127]
         = __float2half(T8[(i126 + i127)]);
    }
    if ((b58 && (i67 < (-(8 * i125))))) {
      loadLocalToGlobal<__half, /*vec_size=*/8, /*is_volatile=*/false>( &T4[(i57 + (i10 * i125))], &T7[0]);
    }
  }
}

However, it feels a bit slower. I'm going to evaluate the compile time with and without this change.

@jacobhinkle
Copy link
Collaborator Author

it feels a bit slower.

On my machine this takes overall test time from 1.9 to 3.9 seconds in a release build and from 7.4 to 15 seconds in debug build. I think it's probably best to leave this new optimization mentioned in the last comment for another PR so that I can play more with speeding it up. For now, I will revert it and ensure the timing makes sense then hopefully we can merge.

@jacobhinkle
Copy link
Collaborator Author

With reverted commit, compile times are back down. If tests pass, I think this PR is ready.

@jacobhinkle
Copy link
Collaborator Author

!build --diff

@jacobhinkle
Copy link
Collaborator Author

!build --diff

I experimented and found that moving this to the beginning of the
function actually affects correctness due to the checks preceding this
loop. I can increase depth to work around that but then runtime
increases due to increased recursion. This change keeps functionality we
had previously instead.
@jacobhinkle
Copy link
Collaborator Author

!build --diff

@jacobhinkle
Copy link
Collaborator Author

After #1972 the compile time is down to 1.3s. Even better, since it re-used lots of proofs, re-enabling the reverted simplification actually doesn't appreciably change compile time!

@jacobhinkle jacobhinkle marked this pull request as ready for review March 21, 2024 16:16
@jacobhinkle
Copy link
Collaborator Author

!build --diff-bench

@jacobhinkle jacobhinkle requested a review from zasdfgbnm March 21, 2024 16:53
A trivial modulus operation is simplified away now.
float T2[1LL];
T2[0LL]
= T1[(i15 % 2LL)];
= T1[i15];
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😎

Comment on lines +1174 to +1178
// This doesn't simplify at all
// EXPECT_VALUE_TRUE(simplifyExpr("neg( 8 ) < neg( i0 )"_, {}, {"i0 < 8"_}));

// This doesn't simplify at all
// EXPECT_VALUE_TRUE(simplifyExpr("neg( i0 ) < 0"_, {}, {"0 < i0"_}));
Copy link
Collaborator Author

@jacobhinkle jacobhinkle Mar 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these commented out tests could be addressed by implementing a < b implies -b < -a. I don't think we need it urgently, but I plan to experiment with it.

@jacobhinkle jacobhinkle merged commit 5817da9 into main Mar 22, 2024
@jacobhinkle jacobhinkle deleted the expr_simplify_lessthan branch March 22, 2024 15:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Missing opportunities to remove trivial mod and div

2 participants