Skip to content

Add naive cache for provers#1972

Merged
zasdfgbnm merged 3 commits intomainfrom
naive-cache-provers
Mar 21, 2024
Merged

Add naive cache for provers#1972
zasdfgbnm merged 3 commits intomainfrom
naive-cache-provers

Conversation

@zasdfgbnm
Copy link
Collaborator

Before:

[       OK ] GPUTTensorCoreTest.FusionAmpereMatmul_CUDA (4571 ms)
[       OK ] NVFuserTest.FusionMagicSchedulerBatchNormalization_CUDA (2089 ms)
[       OK ] GpuViewTest.FusionReshapeReductionShmoo (17323 ms)

After:

[       OK ] GPUTTensorCoreTest.FusionAmpereMatmul_CUDA (3151 ms)
[       OK ] NVFuserTest.FusionMagicSchedulerBatchNormalization_CUDA (2044 ms)
[       OK ] GpuViewTest.FusionReshapeReductionShmoo (16693 ms)

@zasdfgbnm zasdfgbnm requested a review from jacobhinkle March 20, 2024 21:27
Copy link
Collaborator

@jacobhinkle jacobhinkle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nice! It's a much simpler alternative to #1974, which aims to reuse proofs across Contexts.

@zasdfgbnm
Copy link
Collaborator Author

!build

@zasdfgbnm
Copy link
Collaborator Author

There is something wrong with Duo and I can not start the CI. Let me wait for a few hours and retry.

@zasdfgbnm
Copy link
Collaborator Author

!build

6 similar comments
@zasdfgbnm
Copy link
Collaborator Author

!build

@zasdfgbnm
Copy link
Collaborator Author

!build

@zasdfgbnm
Copy link
Collaborator Author

!build

@zasdfgbnm
Copy link
Collaborator Author

!build

@zasdfgbnm
Copy link
Collaborator Author

!build

@zasdfgbnm
Copy link
Collaborator Author

!build

@zasdfgbnm zasdfgbnm merged commit 85c68d8 into main Mar 21, 2024
@zasdfgbnm zasdfgbnm deleted the naive-cache-provers branch March 21, 2024 06:59
jacobhinkle added a commit that referenced this pull request Mar 22, 2024
This came up when working on #1770. In a private conversation,
@zasdfgbnm noticed wisely that the problematic indexing is really a
failure of expression simplification; if we could fully simplify the
swizzling expression it could be entirely hoisted and we would be left
with a nice clean linear index for the smem buffer in the epilogue loop.

This is `NVFuserTest.FusionAmpereMatmulSmemEpilogueCast_CUDA` on `main`:
```c++
    // main loop
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i123 = 0; i123 < 4; ++i123) {
    nvfuser_index_t i124;
    i124 = 32 * i123;
    nvfuser_index_t i125;
    i125 = i56 + (2048LL * i123);
    #pragma unroll
    for(nvfuser_index_t i126 = 0; i126 < 8; ++i126) {
      nvfuser_index_t i127;
      i127 = i124 + (4 * i126);
      nvfuser_index_t i128;
      i128 = i11 + i126;
      nvfuser_index_t i129;
      i129 = (i125 + (32LL * (i128 / 4))) + (8LL * (i57 ^ (i128 % 4)));
      #pragma unroll
      for(nvfuser_index_t i130 = 0; i130 < 2; ++i130) {
        loadGeneric<float, 2>( &T8[(i129 + (1024LL * i130))],  &T3[(i127 + (2LL * i130))]);
      }
    }
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i131 = 0; i131 < 16; ++i131) {
    nvfuser_index_t i132;
    i132 = i58 + (1024 * i131);
    Array<__half, 8, 8> T7;
    #pragma unroll
    for(nvfuser_index_t i133 = 0; i133 < 8; ++i133) {
      nvfuser_index_t i134;
      i134 = i59 + i133;
      nvfuser_index_t i135;
      i135 = i134 % 128;
      nvfuser_index_t i136;
      i136 = i135 / 8;
      nvfuser_index_t i137;
      i137 = i134 / 128;
      T7[i133]
         = __float2half(T8[((((i132 + (128LL * i137)) + (32LL * (i136 / 4))) + (i135 % 8)) + (8LL * ((i136 % 4) ^ ((i31 + i137) % 4))))]);
    }
    if ((b72 && (i73 < (-(8 * i131))))) {
      loadLocalToGlobal<__half, /*vec_size=*/8, /*is_volatile=*/false>( &T4[(i62 + (i63 * i131))], &T7[0]);
    }
  }
}
```
This PR:
```c++
    // main loop
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i114 = 0; i114 < 4; ++i114) {
    nvfuser_index_t i115;
    i115 = 32 * i114;
    nvfuser_index_t i116;
    i116 = i50 + (2048LL * i114);
    #pragma unroll
    for(nvfuser_index_t i117 = 0; i117 < 8; ++i117) {
      nvfuser_index_t i118;
      i118 = i115 + (4 * i117);
      nvfuser_index_t i119;
      i119 = i12 + i117;
      nvfuser_index_t i120;
      i120 = (i116 + (32LL * (i119 / 4))) + (8LL * (i51 ^ (i119 % 4)));
      #pragma unroll
      for(nvfuser_index_t i121 = 0; i121 < 2; ++i121) {
        loadGeneric<float, 2>( &T7[(i120 + (1024LL * i121))],  &T2[(i118 + (2LL * i121))]);
      }
    }
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i122 = 0; i122 < 16; ++i122) {
    nvfuser_index_t i123;
    i123 = i53 + (1024 * i122);
    Array<__half, 8, 8> T6;
    #pragma unroll
    for(nvfuser_index_t i124 = 0; i124 < 8; ++i124) {
      T6[i124]
         = __float2half(T7[(i123 + i124)]);
    }
    if ((b67 && (i68 < (-(8 * i122))))) {
      loadLocalToGlobal<__half, /*vec_size=*/8, /*is_volatile=*/false>( &T3[(i56 + (i57 * i122))], &T6[0]);
    }
  }
}
```
~~If we can also get `i134 % 8` simplified to `i134` and `i134 / 8`
simplified to 0 then this should give a nice and efficient last loop.~~
This is done

~~Currently this PR is super slow (e.g. 101 s vs 8 s on main in debug
mode) due to the added recursion. Memoizing past results would be
beneficial, but that's a topic for another PR.~~ This PR is no longer
slow, thanks to limited recursion depth and #1972.

Fixes #1828
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants