Skip to content

Ignore trivial loops in memory aliasing pass#766

Merged
jacobhinkle merged 4 commits intomainfrom
alias_pass_ignore_trivial_loops
Aug 28, 2023
Merged

Ignore trivial loops in memory aliasing pass#766
jacobhinkle merged 4 commits intomainfrom
alias_pass_ignore_trivial_loops

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Aug 22, 2023

Trivial kir::ForLoops are ones that appear in the kernel IR, but do not appear in the generated CUDA kernel. This can happen for a number of reasons: for example if that dimension is vectorized, or if it's parallelized with a stop value equal to the extent of a dimension. We can test this with kir::ForLoop::isTrivial(). Consider an example:

T9_s[ ... ] ca_pos( 2 ) = ALLOCATE(buffer=T9_s[ ... ] ca_pos( 2 ), mem_type=shared, size=8192, zero_init=false)
T8_s[ ... ] ca_pos( 2 ) = ALLOCATE(buffer=T8_s[ ... ] ca_pos( 2 ), mem_type=shared, size=4096, zero_init=false)
T7_s[ ... ] ca_pos( 2 ) produce_pos( 2 ) = ALLOCATE(buffer=T7_s[ ... ] ca_pos( 2 ) produce_pos( 2 ), mem_type=shared, size=8192, zero_init=false)
FOR blockIdx.x in iblockIdx.x71{( ceilDiv(T0.logical_size[0], 64) )}:
  FOR blockIdx.y in iblockIdx.y73{( ceilDiv(T1.logical_size[1], 128) )}:
    ...
    FOR threadIdx.z in ithreadIdx.z84{( ceilDiv(( ( ceilDiv(64, 4) ) * 4 ), 32) )}:
      FOR threadIdx.y in ithreadIdx.y86{( ceilDiv(( ( ( ceilDiv(( ceilDiv(128, 8) ), 4) ) * 4 ) * 8 ), 32) )}:
        // All writes and reads of T7, T8, T9 are in this loop
        FOR i806 in iS88{( ceilDiv(32, 16) )}:
          T7_s = ...;
        ENDFOR i806
        T8_s = T7_s;
        T9_s = T8_s;
      ENDFOR threadIdx.y
    ENDFOR threadIdx.z
  ENDFOR blockIdx.y
ENDFOR blockIdx.x

In this case, all of the parallelized for loops are trivial, and only the FOR i806 loop appears in the generated code. That means the actual lifetimes of T7 and T8 overlap and those of T8 and T9 overlap, but not those of T7 and T9.

In the aliasing pass, we define outer live intervals as those at the scope of the allocation. In the above case, it will set the outer live interval of all three allocations equal to the start and end of the blockIdx.x loop.

This PR ignores trivial loops in this analysis, so that outer live intervals are defined at the scope that will be realized in the CUDA kernel at the level of the Allocate expression. In the above example, this means the outer live intervals for T7 and T9 will no longer overlap, so they are eligible for memory re-use.

@naoyam
Copy link
Collaborator

naoyam commented Aug 23, 2023

As I mentioned in the original PR, can you please make sure there's no invalid back() happening?

@jacobhinkle
Copy link
Collaborator Author

As I mentioned in the original PR, can you please make sure there's no invalid back() happening?

We should never get an invalid access due to .back(), since we always push the top level scope to current_stack_. I went through the rest of the code and I don't see any issues. I am currently running a codegen comparison to look for unintended consequences in the wild.

@jacobhinkle jacobhinkle marked this pull request as ready for review August 25, 2023 14:19
@jacobhinkle jacobhinkle requested a review from naoyam August 25, 2023 14:20
Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to merge once the check with the generated code is done

@jacobhinkle
Copy link
Collaborator Author

jacobhinkle commented Aug 28, 2023

Manually checked diffs in codegen. There is still quite a bit of non-determinism, which I believe is coming from allocateIndexVariables, so there is a lot of noise but from what I can tell, the only smem-related changes are expected, and limited to a few examples:

  • FusionMatmulSoftmaxMatmulAmpere re-uses one buffer
  • FusionHdiff re-orders a couple allocations matching their actual last outer reads now.
  • FusionPredicateParallelizedDomains: buffer T5 now re-uses both T1 and T2 instead of only T2. Also T1 and T2 are swapped on the stack.

I think this is safe to merge.

@jacobhinkle jacobhinkle merged commit e0a22af into main Aug 28, 2023
@jacobhinkle jacobhinkle deleted the alias_pass_ignore_trivial_loops branch August 28, 2023 00:12
@naoyam
Copy link
Collaborator

naoyam commented Aug 28, 2023

Manually checked diffs in codegen. There is still quite a bit of non-determinism, which I believe is coming from allocateIndexVariables, so there is a lot of noise but from what I can tell, the only smem-related changes are expected, and limited to a few examples:

  • FusionMatmulSoftmaxMatmulAmpere re-uses one buffer

  • FusionHdiff re-orders a couple allocations matching their actual last outer reads now.

  • FusionPredicateParallelizedDomains: buffer T5 now re-uses both T1 and T2 instead of only T2. Also T1 and T2 are swapped on the stack.

I think this is safe to merge.

Did you see diffs with the benchmarks or the tests, or both? Last time I checked I didn't see any diff with the benchmarks. I did see some minor diffs with a few benchmarks.

@jacobhinkle
Copy link
Collaborator Author

Manually checked diffs in codegen. There is still quite a bit of non-determinism, which I believe is coming from allocateIndexVariables, so there is a lot of noise but from what I can tell, the only smem-related changes are expected, and limited to a few examples:

  • FusionMatmulSoftmaxMatmulAmpere re-uses one buffer
  • FusionHdiff re-orders a couple allocations matching their actual last outer reads now.
  • FusionPredicateParallelizedDomains: buffer T5 now re-uses both T1 and T2 instead of only T2. Also T1 and T2 are swapped on the stack.

I think this is safe to merge.

Did you see diffs with the benchmarks or the tests, or both? Last time I checked I didn't see any diff with the benchmarks. I did see some minor diffs with a few benchmarks.

I see lots of diffs, but they are all unrelated. I didn't see any involving smem allocations on the benchmarks. For example of what I'm calling unrelated:
image
I think there may be non-determinism in either the allocation of index variables or in index hoisting.

@naoyam
Copy link
Collaborator

naoyam commented Aug 28, 2023

Hmm, could you please create an issue with a repro?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants