Skip to content

Preserve expanded broadcasts in reshape#1174

Draft
jacobhinkle wants to merge 23 commits intomainfrom
fix_1126
Draft

Preserve expanded broadcasts in reshape#1174
jacobhinkle wants to merge 23 commits intomainfrom
fix_1126

Conversation

@jacobhinkle
Copy link
Collaborator

Currently, we always "realize" expanded broadcasts (replace them with Iteration domains) during reshape if they appear in a split or merge transform. This is not always necessary, for example we never need to do this for a SplitTransform, and we can avoid it for a MergeTransform if both inner and outer IDs are Broadcast. This PR makes that a reality.

Fixes #1126.

@jacobhinkle jacobhinkle requested a review from wujingyue October 27, 2023 14:15
@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle
Copy link
Collaborator Author

!build --diff

@jacobhinkle jacobhinkle marked this pull request as draft October 27, 2023 14:42
@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle
Copy link
Collaborator Author

!build --diff

@jacobhinkle jacobhinkle marked this pull request as ready for review October 27, 2023 15:41
@jacobhinkle
Copy link
Collaborator Author

There's currently one real failure in GpuViewTest.FusionExpandView2:

C++ exception with description "producer->getMemoryType() == MemoryType::Global INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/device_lower/analysis/
sync_information.cpp":765, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Inconsistent parallelization found b
etween TV2 (T2_l[ iblockIdx.x80{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(1, 3) ) * 8 ) ), 128) ), 1) ), 1) )}, iUS81{1}, iS79{1}, ithreadIdx.x77{128
} ]) and TV3(T3_l[ iblockIdx.x72{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 8 ) ), 128) ), 1) ), 1) )}, iUS73{1}, iS71{1}, ithreadIdx.x69{128} ] ca_pos( 4 )
). Producer is required to be in Global Memory based on parallelization strategy. RAW flags: (blockIdx.x threadIdx.x)

The fusion looks correct to me (this is a static fusion: no concretization involved):

Inputs:
  T0_g[ bS0{1}, iS1{8} ], float
  T1_g[ iS2{3}, iS3{4}, iS4{8} ], float
Outputs:
  T4_g[ iS12{3}, iS13{4}, iS14{8} ], float

%kernel_math {
i13 = (nvfuser_index_t)(12);
T2_l[ bS5{( (nvfuser_index_t)(12) )}, iS6{8} ] = expand( T0_g[ bS0{1}, iS1{8} ], {i13, 8} )
T3_l[ bS10{3}rf, bS11{( ceilDiv(( (nvfuser_index_t)(12) ), 3) )}rf, iS8{8} ] = view( T2_l[ bS5{( (nvfuser_index_t)(12) )}, iS6{8} ] )
T4_g[ iS12{3}, iS13{4}, iS14{8} ]
   = T3_l[ bS10{3}rf, bS11{( ceilDiv(( (nvfuser_index_t)(12) ), 3) )}rf, iS8{8} ]
   + T1_g[ iS2{3}, iS3{4}, iS4{8} ];
}

Maybe SyncMap and some other lowering machinery is just not used to having splits of Broadcast IDs yet. Looking into it...

Copy link
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! LGTM other than some nit picks.

jacobhinkle and others added 7 commits October 27, 2023 13:34
Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
@jacobhinkle
Copy link
Collaborator Author

one real failure in GpuViewTest.FusionExpandView2:

I am not sure if this is related to the failure, but that test introduces something strange. During scheduling (not during the reshape transform done at definition), a merge is performed between an expanded broadcast and an Iteration domain. The result is an Iteration domain with an expanded extent. I hacked IterDomain::toString() to show what I mean here:

T3_l[ iblockIdx.x72{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 8 ) ), 128) ), 1) ), 1) ), EXPANDED_EXTENT=( ceilDiv(( ceilDiv(( ceilDiv(( 
3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) ), 128) ), 1) ), 1) )}, iUS73{1}, iS71{1}, ithreadIdx.x69{128} ] ca_pos( 4 )       
 root domain : (bS9{( (nvfuser_index_t)(12) ), extent=1}rf, iS8{8})                                                                      
  Outer split: bS9{( (nvfuser_index_t)(12) ), extent=1}rf by factor 3 -> bS10{3, extent=1}rf, bS11{( ceilDiv(( (nvfuser_index_t)(12) ), 3
) ), extent=1}rf, start offset: 0, stop offset: 0                                                                                        
 rfactor domain : (bS10{3, extent=1}rf, bS11{( ceilDiv(( (nvfuser_index_t)(12) ), 3) ), extent=1}rf, iS8{8})                             
 contiguity: n n t          
  Merge: bS11{( ceilDiv(( (nvfuser_index_t)(12) ), 3) ), extent=1}rf and iS8{8} -> iS66{( 1 * 8 ), EXPANDED_EXTENT=( ( ceilDiv(( (nvfuser
_index_t)(12) ), 3) ) * 8 )}                                        
  Merge: bS10{3, extent=1}rf and iS66{( 1 * 8 ), EXPANDED_EXTENT=( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 )} -> iS67{( 1 * ( 1 * 8
 ) ), EXPANDED_EXTENT=( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) )}                                                         
  Split: iS67{( 1 * ( 1 * 8 ) ), EXPANDED_EXTENT=( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) )} by factor 128 -> iS68{( ceilD
iv(( 1 * ( 1 * 8 ) ), 128) ), EXPANDED_EXTENT=( ceilDiv(( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) ), 128) )}, ithreadIdx.x6
9{128}, start offset: 0, stop offset: 0                             
  Split: iS68{( ceilDiv(( 1 * ( 1 * 8 ) ), 128) ), EXPANDED_EXTENT=( ceilDiv(( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) ), 1
28) )} by factor 1 -> iS70{( ceilDiv(( ceilDiv(( 1 * ( 1 * 8 ) ), 128) ), 1) ), EXPANDED_EXTENT=( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(( (
nvfuser_index_t)(12) ), 3) ) * 8 ) ), 128) ), 1) )}, iS71{1}, start offset: 0, stop offset: 0                                            
  Split: iS70{( ceilDiv(( ceilDiv(( 1 * ( 1 * 8 ) ), 128) ), 1) ), EXPANDED_EXTENT=( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(( (nvfuser_index
_t)(12) ), 3) ) * 8 ) ), 128) ), 1) )} by factor 1 -> iblockIdx.x72{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 8 ) ), 128) ), 1) ), 1) ), 
EXPANDED_EXTENT=( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) ), 128) ), 1) ), 1) )}, iUS73{1}, s
tart offset: 0, stop offset: 0                                                                                                           
 leaf domain : (iblockIdx.x72{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 8 ) ), 128) ), 1) ), 1) ), EXPANDED_EXTENT=( ceilDiv(( ceilDiv(( 
ceilDiv(( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) ), 128) ), 1) ), 1) )}, iUS73{1}, iS71{1}, ithreadIdx.x69{128})

T3_l is the reshape result and this is printed from gdb where we see the error.

@jacobhinkle
Copy link
Collaborator Author

I have a fix for this test failure but I'm stacking it on #1329 first.

@jacobhinkle jacobhinkle marked this pull request as draft November 17, 2023 19:37
jacobhinkle added a commit that referenced this pull request Nov 20, 2023
…#1329)

Currently, `MergeTransform` and `SplitTransform`, which are the classes
used to perform reshape operations, do not re-use the splitting and
merging operations used in scheduling, namely the static methods
`IterDomain::{split,merge}`. This leads to needing to maintain changes
in both as I recently found out when debugging #1174. We can actually
already notice they have diverged with respect to
`IndexType::GatherScatter` handling.

This PR simply calls the static `IterDomain` method from each of the
reshape `*Transform` classes. Since those classes also need to set the
rfactor-product flag on the produced `IterDomain`s, I added a flag to
the static methods that is only used during reshape.

---------

Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
@jacobhinkle jacobhinkle marked this pull request as ready for review November 20, 2023 15:46
@jacobhinkle
Copy link
Collaborator Author

!build

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for this improvement!

@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle
Copy link
Collaborator Author

jacobhinkle commented Dec 1, 2023

Maybe SyncMap and some other lowering machinery is just not used to having splits of Broadcast IDs yet. Looking into it...

SyncMap is failing because useSameIndex fails, because the parallelized indices between T2 and T3 (both BIDx and TIDx) are not being LOOP mapped. The transforms involved:

Inputs:
  T0_g[ iS96{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(1, 3) ) * 8 ) ), 128) ), 1) ), 1) )}, iS97{1}, iS95{1}, iS93{128} ], float
  T1_g[ iS64{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( 4 * 8 ) ), 128) ), 1) ), 1) )}, iS65{1}, iS63{1}, iS61{128} ], float
Outputs:
  T4_g[ iblockIdx.x40{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( 4 * 8 ) ), 128) ), 1) ), 1) )}, iUS41{1}, iS39{1}, ithreadIdx.x37{128} ] ca_pos( 2 ) produce_pos( 4 ), float

%kernel_math {
T5_l[ iblockIdx.x88{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(1, 3) ) * 8 ) ), 128) ), 1) ), 1) )}, iUS89{1}, iS87{1}, ithreadIdx.x85{128} ]
   = Set( T0_g[ iS96{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(1, 3) ) * 8 ) ), 128) ), 1) ), 1) )}, iS97{1}, iS95{1}, iS93{128} ], cache_op=AllLevels )
i13 = (nvfuser_index_t)(12);
T2_l[ iblockIdx.x80{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) ), 128) ), 1) ), 1) )}, iUS81{1}, iS79{1}, ithreadIdx.x77{128} ] = expand( T5_l[ iblockIdx.x88{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(1, 3) ) * 8 ) ), 128) ), 1) ), 1) )}, iUS89{1}, iS87{1}, ithreadIdx.x85{128} ], {i13, 8} )
T3_l[ iblockIdx.x72{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) ), 128) ), 1) ), 1) )}, iUS73{1}, iS71{1}, ithreadIdx.x69{128} ] ca_pos( 4 ) = view( T2_l[ iblockIdx.x80{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) ), 128) ), 1) ), 1) )}, iUS81{1}, iS79{1}, ithreadIdx.x77{128} ] )
T6_l[ iblockIdx.x56{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( 4 * 8 ) ), 128) ), 1) ), 1) )}, iUS57{1}, iS55{1}, ithreadIdx.x53{128} ] ca_pos( 2 )
   = Set( T1_g[ iS64{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( 4 * 8 ) ), 128) ), 1) ), 1) )}, iS65{1}, iS63{1}, iS61{128} ], cache_op=Streaming )
T7_l[ iblockIdx.x48{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( 4 * 8 ) ), 128) ), 1) ), 1) )}, iUS49{1}, iS47{1}, ithreadIdx.x45{128} ] ca_pos( 4 ) produce_pos( 4 )
   = T3_l[ iblockIdx.x72{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) ), 128) ), 1) ), 1) )}, iUS73{1}, iS71{1}, ithreadIdx.x69{128} ] ca_pos( 4 )
   + T6_l[ iblockIdx.x56{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( 4 * 8 ) ), 128) ), 1) ), 1) )}, iUS57{1}, iS55{1}, ithreadIdx.x53{128} ] ca_pos( 2 );
T4_g[ iblockIdx.x40{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( 4 * 8 ) ), 128) ), 1) ), 1) )}, iUS41{1}, iS39{1}, ithreadIdx.x37{128} ] ca_pos( 2 ) produce_pos( 4 )
   = Set( T7_l[ iblockIdx.x48{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( 4 * 8 ) ), 128) ), 1) ), 1) )}, iUS49{1}, iS47{1}, ithreadIdx.x45{128} ] ca_pos( 4 ) produce_pos( 4 ), cache_op=Streaming )
}

...
T5_l[ iblockIdx.x88{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(1, 3) ) * 8 ) ), 128) ), 1) ), 1) )}, iUS89{1}, iS87{1}, ithreadIdx.x85{128} ]
 root domain : (bS20{1}, iS21{8})
 contiguity: n t
  Outer split: bS20{1} by factor 3 -> bS30{3}, bS31{( ceilDiv(1, 3) )}, start offset: 0, stop offset: 0
  Merge: bS31{( ceilDiv(1, 3) )} and iS21{8} -> iS82{( ( ceilDiv(1, 3) ) * 8 )}
  Merge: bS30{3} and iS82{( ( ceilDiv(1, 3) ) * 8 )} -> iS83{( 3 * ( ( ceilDiv(1, 3) ) * 8 ) )}
  Split: iS83{( 3 * ( ( ceilDiv(1, 3) ) * 8 ) )} by factor 128 -> iS84{( ceilDiv(( 3 * ( ( ceilDiv(1, 3) ) * 8 ) ), 128) )}, ithreadIdx.x85{128}, start offset: 0, stop offset: 0
  Split: iS84{( ceilDiv(( 3 * ( ( ceilDiv(1, 3) ) * 8 ) ), 128) )} by factor 1 -> iS86{( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(1, 3) ) * 8 ) ), 128) ), 1) )}, iS87{1}, start offset: 0, stop offset: 0
  Split: iS86{( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(1, 3) ) * 8 ) ), 128) ), 1) )} by factor 1 -> iblockIdx.x88{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(1, 3) ) * 8 ) ), 128) ), 1) ), 1) )}, iUS89{1}, start offset: 0, stop offset: 0
 leaf domain : (iblockIdx.x88{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(1, 3) ) * 8 ) ), 128) ), 1) ), 1) )}, iUS89{1}, iS87{1}, ithreadIdx.x85{128})
T2_l[ iblockIdx.x80{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) ), 128) ), 1) ), 1) )}, iUS81{1}, iS79{1}, ithreadIdx.x77{128} ]
 root domain : (bS5{( (nvfuser_index_t)(12) )}, iS6{8})
 contiguity: n t
  Outer split: bS5{( (nvfuser_index_t)(12) )} by factor 3 -> bS28{3}, bS29{( ceilDiv(( (nvfuser_index_t)(12) ), 3) )}, start offset: 0, stop offset: 0
  Merge: bS29{( ceilDiv(( (nvfuser_index_t)(12) ), 3) )} and iS6{8} -> iS74{( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 )}
  Merge: bS28{3} and iS74{( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 )} -> iS75{( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) )}
  Split: iS75{( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) )} by factor 128 -> iS76{( ceilDiv(( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) ), 128) )}, ithreadIdx.x77{128}, start offset: 0, stop offset: 0
  Split: iS76{( ceilDiv(( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) ), 128) )} by factor 1 -> iS78{( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) ), 128) ), 1) )}, iS79{1}, start offset: 0, stop offset: 0
  Split: iS78{( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) ), 128) ), 1) )} by factor 1 -> iblockIdx.x80{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) ), 128) ), 1) ), 1) )}, iUS81{1}, start offset: 0, stop offset: 0
 leaf domain : (iblockIdx.x80{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) ), 128) ), 1) ), 1) )}, iUS81{1}, iS79{1}, ithread
Idx.x77{128})
T3_l[ iblockIdx.x72{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) ), 128) ), 1) ), 1) )}, iUS73{1}, iS71{1}, ithreadIdx.x69{128} ] ca_pos( 4 )
 root domain : (bS9{( (nvfuser_index_t)(12) )}rf, iS8{8})
  Outer split: bS9{( (nvfuser_index_t)(12) )}rf by factor 3 -> bS10{3}rf, bS11{( ceilDiv(( (nvfuser_index_t)(12) ), 3) )}rf, start offset: 0, stop offset: 0
 rfactor domain : (bS10{3}rf, bS11{( ceilDiv(( (nvfuser_index_t)(12) ), 3) )}rf, iS8{8})
 contiguity: n n t
  Merge: bS11{( ceilDiv(( (nvfuser_index_t)(12) ), 3) )}rf and iS8{8} -> iS66{( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 )}
  Merge: bS10{3}rf and iS66{( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 )} -> iS67{( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) )}
  Split: iS67{( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) )} by factor 128 -> iS68{( ceilDiv(( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) ), 128) )}, ithreadIdx.x69{128}, start offset: 0, stop offset: 0
  Split: iS68{( ceilDiv(( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) ), 128) )} by factor 1 -> iS70{( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) ), 128) ), 1) )}, iS71{1}, start offset: 0, stop offset: 0
  Split: iS70{( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) ), 128) ), 1) )} by factor 1 -> iblockIdx.x72{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) ), 128) ), 1) ), 1) )}, iUS73{1}, start offset: 0, stop offset: 0
 leaf domain : (iblockIdx.x72{( ceilDiv(( ceilDiv(( ceilDiv(( 3 * ( ( ceilDiv(( (nvfuser_index_t)(12) ), 3) ) * 8 ) ), 128) ), 1) ), 1) )}, iUS73{1}, iS71{1}, ithreadIdx.x69{128})

Before this PR, all the BIDx IDs are loop mapped together, as are most of the TIDx IDs. Now it seems the ones from T2 cannot be loop mapped to either T5 (its producer) or T3 (its consumer).

@jacobhinkle
Copy link
Collaborator Author

In this case, the reshape is splitting an expanded dimension, so we could handle this the way we handle size-1 extents in the output of reshape: by squeezing the input dimension then appending a BroadcastOp after the reshape. In this case we'd also need an ExpandOp. It's slightly nuanced in this case, since sometimes the reshape will need to merge a broadcast/expanded IterDomain with an Iteration domain, and in those cases we cannot split it off into a separate op. I will look into adding this: we might need to track expand_transforms_ as in addition to squeeze, broadcast, and view (merge/split) transforms in AnalyzeViewTransformation.

@jacobhinkle
Copy link
Collaborator Author

squeezing the input dimension then appending a BroadcastOp after the reshape.

This has the added difficulty that currently we are not able to squeeze expanded dimensions.

jacobhinkle added a commit that referenced this pull request Jan 26, 2024
See comment in arith.cpp for details.

One controversial change here is to allow squeezing expanded dimensions,
both in our IR's `SqueezeOp` and in the user-facing functions `squeeze`.
This results in actually removing those dimensions. This behavior diverges from
PyTorch, whose `squeeze` command will ignore requested squeezes if the
size is not 1 regardless of whether that dimension is expanded. I'm
happy to discuss this change and potentially take another course, but I
think we do need to be able to remove expanded axes (see
#1174 (comment) for
another case where I encountered this limitation).

Fixes #1678
jacobhinkle added a commit that referenced this pull request Feb 2, 2024
See comment in arith.cpp for details.

One controversial change here is to allow squeezing expanded dimensions,
both in our IR's `SqueezeOp` and in the user-facing functions `squeeze`.
This results in actually removing those dimensions. This behavior
diverges from PyTorch, whose `squeeze` command will ignore requested
squeezes if the size is not 1 regardless of whether that dimension is
expanded. I'm happy to discuss this change and potentially take another
course, but I think we do need to be able to remove expanded axes (see
#1174 (comment) for
another case where I encountered this limitation).

Fixes #1678
cowanmeg added a commit to samnordmann/Fuser that referenced this pull request Feb 13, 2024
* print bandwidth when perf_debug_verbose is true (NVIDIA#1689)

print bandwidth when `perf_debug_verbose` is true.

* in vectorization validation, add err msg if tv has no definition (NVIDIA#1690)

check the existence of tv definition in vectorization validation

* Accomodate Reduction IterDomains when concretizing reshape extents (NVIDIA#1692)

We register extents for concretization when we concretize reshape. In
order to do that, we line up `IterDomain`s in the symbolic reshaped TV
and the new, concretized one. In cases where the concretized reshape is
trivial, such as when the output shape is the same as the input, we do
not create a new TV. In those cases, we will have the input to the
original `ViewOp` as the concretized output. That input TV might have
reduction domains, as in the provided test, in which case we need to
filter those out when doing this alignment. This small PR just
implements that filtering.

Fixes NVIDIA#1691.

* `MmaOp::evaluate` method (NVIDIA#1675)

* Fix some typos. (NVIDIA#1700)

* `torch.compile` and `eager` benchmarks for `softmax` (NVIDIA#1670)

Adds `torch.compile` and `eager` baseline benchmarks to be used in
weekly benchmark runs.
Issue NVIDIA#1668.

* Add a test for fusions with no inputs. (NVIDIA#1709)

As a follow up to
NVIDIA#1696 (comment).

* Double the size of the fusion cache to workaround a CI issue. (NVIDIA#1702)

By just removing entries when it fills up.

* Check that the reduced axis is sharded on producer in isLowerableToCommunication (NVIDIA#1695)

Currently, a reduction is lowerable to a communication iff only one axis
is reduced and this axis is sharded across devices on the **producer**
side.
Before this patch, we would mistakenly check that the axis is sharded on
**consumer** side, which led to some runtime assert error.

* Add blank impl of isLowerableToCommunication. (NVIDIA#1698)

isLowerableToCommunication is used in a few places to print error
messages or short-circuit loops. Those places appear to be places that
are intended to largely be used behind the distributed path. It's easier
to just define the API instead of trying to conditionalize all the use
sites and invent non-USE_DISTRIBUTED behavior.

* Multidevice segmenter (NVIDIA#1696)

# What
Add an option in the segmenter to segment resharding Expr in separate
singleton segment.
To trigger it, set the segmenter's options as follows:
```
    SegmentCandidateFinderOptions options{
        .run_translate_welford = false,
        .run_combine_reductions = false,
        .run_herrmann_merge = true,
        .run_final_merge = true,
        .only_segment_resharding_exprs = true};
```
and use the segmenter as follows with any (possibly dummy) inputs:
```
KernelArgumentHolder dummy_inputs;
auto segmented_fusion = SegmentCandidateFinder::segment(std::move(fusion), dummy_inputs, options);
```
If `only_segment_resharding_exprs` is set to `false` (which is the case
by default), the behavior of the segmenter is unchanged.


We also provide a quite wide testing suite to validate our
implementation.

# Why 
Resharding Exprs need to be handled differently than other Exprs because
we want them to result in posting a network collective from the host.
Therefore those expressions cannot (for now) be fused to any kernel. For
this reason, we need those Expr to be segmented before and after.

# How
_**Remark:** For now, the segmenter is only used [at one place before
scheduling and compiling the
fusion](https://github.com/NVIDIA/Fuser/blob/1603f39bab8c1bbe12e38f2b5de53dec3b7cc373/csrc/kernel_cache.cpp#L990)._

Recall that the segmenter first creates as many segments as there are
Expr and then tries to merge the neighbour segments incrementally in an
eager manner. The method
```
bool SegmentCandidateFinder::codeGenSupportedMerge(
    SegmentedGroup* group1,
    SegmentedGroup* group2) 
```
returns whether two groups can be merged (i.e. fused into one kernel). 

With the current patch, if
`SegmentCandidateFinderOptions::only_segment_resharding_exprs` is set to
`true`, then the usual behavior of `codeGenSupportedMerge` is bypassed
and the function returns whether one Expr among the groups is
resharding.

Because this segmentation shouldn't depend on the inputs data, we use
default (aka empty) `KernelArgumentHolder`, from which it is invalid to
instantiate a `SchedulerRuntimeInfo runtime_info_`. For this reason, we
had to make the latter attribute optional.

# Future/other directions

Another way to achieve the same result is to manually add segment bounds
surrounding the resharding Exprs as was suggested by @wujingyue here
NVIDIA#1571

The current implementation looks a bit "hacky" and should be be
integrated more properly once multidevice schedulers are implemented
and/or the segmenter is refactored.

Later, we might wanna be able to fuse communications and computes and
also communications between them. This would require a more advanced
segmenter and scheduler, but hopefully this patch could serve as a good
basis

# Example:
consider the fusion:
```
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());

  TensorView* tv0 = makeContigTensor({4});
  fusion->addInput(tv0);
  TensorView* tv1 = sum(tv0,{3});
  TensorView* tv2 = set(tv1);
  TensorView* tv3 = sum(tv2, {2});
  fusion->addOutput(tv3);
```

Manually scheduled as follows:
```
  DeviceMesh mesh ({0,1,2,3})
  for (auto tv : {tv0, tv1, tv2, tv3}) {
    tv->setDeviceMesh(mesh);
  }
  tv0->axis(0)->parallelize(ParallelType::DIDx);
  tv1->axis(0)->parallelize(ParallelType::DIDx);
```
This scheduling implies that
- `tv0` and `tv1` are fully sharded on the devices {0,1,2,3}
- `tv2` and `tv3` are fully replicated on those same devices
- consequently, the "set" operation on the line `tv2 = set(tv1)`
actually embedds an "AllGather" network collective. This Expr is
resharding while all the other exprs are not. We thus excpect this
expression to constitute an unmergeable segment.

The segmenter in this situation with the
option`SegmentCandidateFinderOptions::only_segment_resharding_exprs` set
to `true` will result in three segments:
- Compute segment 1: with the expr `tv1 = sum(tv0,{3})`
- Communication segment 1:  with the expr `tv2 = set(tv1)`
- Compute segment 2: with the expr `tv3 = sum(tv2, {2})`

* Vectorization Factor patch for computeInfoC2P with Broadcast in mapped IterDomain (NVIDIA#1625)

Fixes NVIDIA#1567

This PR patches vectorization factor in
`ContiguousInnerDimensionsMapper::computeInfoC2P`.

Handling of resolved broadcast dimension should be made on mapped
consumer tensors' from_ids, instead of the root_domain order. Added a
few tests per @zasdfgbnm 's suggestion:

```
Case 0:
T2[1024, 2, 512] = T0[1024, 2, 1] + T1[1024, 2, 512]
allocation = rfactor
--> T0 has no vectorization

Case 1:
T2[1024, 512, 2] = T0[1024, 1, 2] + T1[1024, 512, 2]
allocation = rfactor
--> T0 has vectorization 2

Case 2:
T2[1024, 512, 2] = T0[1024, 1, 2] + T1[1024, 512, 2];
T3[512, 1024, 2] = transpose(T2[1024, 512, 2])
allocation = rfactor
*except T1 has stride_order {1, 2, 0}
--> T0 has vectorization 4

Case 3:
T2[512, 1024, 2] = T0[1, 1024, 2] + T1[512, 1024, 2]
T3[1024, 512, 2] = transpose(T2[512, 1024, 2])
allocation = rfactor
--> T0 has vectorization 2
```

---------

Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com>
Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com>

* transpose scheduler fix: reduction IterDomain on input tensors (NVIDIA#1661)

Fixes NVIDIA#1659 

Reorders reduction IterDomain so it won't interfere with
scheduling tiling from transpose scheduler.

* Convert reduction of expanded dims to squeeze (NVIDIA#1679)

See comment in arith.cpp for details.

One controversial change here is to allow squeezing expanded dimensions,
both in our IR's `SqueezeOp` and in the user-facing functions `squeeze`.
This results in actually removing those dimensions. This behavior
diverges from PyTorch, whose `squeeze` command will ignore requested
squeezes if the size is not 1 regardless of whether that dimension is
expanded. I'm happy to discuss this change and potentially take another
course, but I think we do need to be able to remove expanded axes (see
NVIDIA#1174 (comment) for
another case where I encountered this limitation).

Fixes NVIDIA#1678

* Make sure ValGraphs are created deterministically (NVIDIA#1714)

While I was working on NVIDIA#32, I sometimes saw non-deterministic results.
Hope this is the only source of non-determinism.

* Fix squeeze-related errors (NVIDIA#1717)

This fixes current failures in `pytest_ops.py -k squeeze` and some
integration failues.

This restores our previous semantics for squeeze, which **do not match
PyTorch**. Namely, if squeeze is provided a dimension that cannot be
squeezed, we will always raise an error.

* NVFUSER_DISTRIBUTED instead of USE_DISTRIBUTED (NVIDIA#1711)

* Add the missing `clang-format on` and reformat. (NVIDIA#1722)

* Print a newline before the header. (NVIDIA#1720)

* Associate each fusion cache with its local rank in distributed setting. (NVIDIA#1699)

### Problem:
Currently, automatic serialization saves a single cache regardless of
the number of devices. In a distributed setting, each process restores
its fusion cache from the same common workspace. However, this workspace
only contains the CUDA kernels for a single device. The remaining
processes must recompile the kernels for their devices.

### Solution:
A separate process is created for each device with `ddp` or `fsdp` and
each process contains a separate `FusionCache`. This PR associates each
fusion cache with its local rank in a distributed setting, allowing
automatic serialization to create a separate workspace for each device.
During deserialization, each process loads the workspace associated with
its local rank.

* Vectorized serial grid reduction (NVIDIA#1528)

This change allows us to use vectorized loads/stores in
`serialReductionStep`. The generated kernel now looks like
```c++
  NVFUSER_UPDATE_MAGIC_ZERO;                                        
  grid_sync::blockSerializeWait<false, false, true>(&T5[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
  #pragma unroll                                                                                                                         
  for(nvfuser_index_t i16 = 0; i16 < 4LL; ++i16) {                                                                                           nvfuser_index_t i17;                                                                                                                 
    i17 = 32LL * i16;                                                                                                                        nvfuser_index_t i18;                                                                                                                 
    i18 = 4096LL * i16;                                                                                                                  
    nvfuser_index_t i19;                                                                                                                 
    i19 = i5 + i18;                                                                                                                      
    nvfuser_index_t i20;                                                                                                                 
    i20 = -i18;                                                                                                                          
    #pragma unroll                                                                                                                       
    for(nvfuser_index_t i21 = 0; i21 < 8LL; ++i21) {                                                                                     
      nvfuser_index_t i22;                                                                                                               
      i22 = 512LL * (i21 + nvfuser_zero);                                                                                                
      Array<float, 4LL, 4> T3;                                                                                                           
      T3.set(float(0.000000000e+00f));                                                                                                   
      reduction::serialReductionStep</*vec_size=*/4>(                                                                                    
        &T3[0LL],                                                                                                                        
        &T2[(i17 + (4LL * i21))],                                                                                                        
        0.000000000e+00f,                                                                                                                
        &T6[(i19 + i22)],                                                                                                                
        [](float &a, float b) { a = a + b; },                                                                                            
        index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == 0,
        index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == index_utils::maskedSize<false, false, true>(gridDim) - 1,
        true,                                                                                                                                    true);                                                                                                                           
      if ((b7 && (i6 < (i20 - i22)))) {                                                                                                  
        loadLocalToGlobal<float, /*vec_size=*/4, /*is_volatile=*/false>( &T1[(i19 + i22)], &T3[0LL]);                                    
      }                                                                                                                                  
    }                                                                                                                                    
  }                                                                                                                                      
  grid_sync::blockSerializeRelease<false, false, true>(&T5[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);            
  NVFUSER_UPDATE_MAGIC_ZERO;       
```

* removing out-dated assert on python API (NVIDIA#1724)

removing out-dated asserts in python API `define_vector`;
adding a tests verifying the behavior

* make ci green again (NVIDIA#1730)

skip failing test.

Please enable it once we patch NVIDIA#1728

* Remove unnecessary `MATCHER_P`. (NVIDIA#1729)

* Fix Issue NVIDIA#1734 (NVIDIA#1735)

Closes Issue NVIDIA#1734

* Rename `AliasType` -> `AllocationType` (NVIDIA#1732)

* Skip executing a kernel if it's empty. (NVIDIA#1723)

I could change `compileFusion` to skip compilation as well. It turned
out to be more complicated than I expected, so I took the easier route
to skip just execution, which is at least an incremental improvement.

* don't cache slice input tv (NVIDIA#1705)

If the input tv is used by slice, don't cache it.
Fix NVIDIA#1697

* Make `MmaOp::evaluate` return output of the same dtype as `MmaOp` (NVIDIA#1733)

* Turing/Ampere Mma tests without `BroadcastOp` (NVIDIA#1672)

This PR renames `matmulAtInput` into `matmulAtInput2D`, explicitly
showing that it generates 2D inputs. This PR also adds a
`matmulAtInput3DTuring`, which is used to generate the 3D fusion inputs
(for example `[M, 1, K]` and `[1, K, N]`) for matmul. The `MmaTest` for
Turing and Ampere is modified to exclude the `BroadcastOp` and use the
3D version for generating fusion inputs. This is only the initial step
for making `scheduleMatmul` schedule a fusion not containing
`BroadcastOp`, I intentionally keep it small. Other changes will be
added in followup PRs.

Fixes NVIDIA#1628

* io_alias_ const update (NVIDIA#1740)

* Add benchmarks for RoPE. (NVIDIA#1739)

This PR adds two implementations of the RoPE module and benchmarks them
for NVIDIA#1597.

`rope_with_cat_fusion` mimics the Hugging Face implementation.
`rope_without_cat_fusion` implements an idea from @nikitaved to avoid
concatenation. Even though it looks difficult for the compiler to do it
all automatically, it's still useful to keep a record of the idea.

As a side change, I made `fd.define_tensor` to accept empty contiguity.

* Make nvfuser matmul benchmarks HSH instead of HSS (NVIDIA#1712)

This matches the `at::matmul` baselines.

This PR also adds a few more problem sizes, and runs each eagermode
baseline with and without FP16 reduction allowed.

* Reduce number of `MmaTest`s (NVIDIA#1738)

This PR is stacked on top of NVIDIA#1672

Turing/Ampere mma is only TN, so it makes no sense to test other layouts
in `MmaTest`s. These tests are intended to test mma instructions,
`ldmatrix` and `ldmatrix.trans` is tested separately in other unit
tests. Similar for `HopperRS` tests.

* Weekly Benchmarks Input Range (NVIDIA#1708)

* Rename axes= to dims= in frontend (NVIDIA#1741)

Currently we accept `axes=` for some ops like `fd.ops.sum` and `dims=`
for others like `fd.ops.squeeze`.

This is a small attempt to make the frontend arguments more consistent.
This change renames the `axis=` kwarg to `dim=` and the same for `axes=`
-> `dims=`.

I think we're free to set our own convention, but for reference:
- PyTorch uses `dim=` in most places and accepts either a single dim or
multiple using that same argument name, where applicable.
- Numpy uses `axis=` and, like PyTorch, accepts a list where applicable.
- `jax.lax` uses `dimensions=`

* Avoid unused smem workspace for serial grid reductions (NVIDIA#1727)

GridReduction can be lowered to either `gridReduce` or
`serialReductionStep`. `gridReduce` requires a smem workspace in order
to use multiple threads to aggregate partial sums. However,
`serialReductionStep` does not coordinate among threads and has no use
for a workspace. This change simply disables allocating that little bit
of extra shared memory if our only grid reductions are serial, which
currently only happens in split-K GEMM.

This reduces the smem allocated in a simple test from 16896 B to 16384 B
(about 97%). More importantly, this makes the computation in
`mma_utils::generateSharedMemoryEpilogueHeuristics()` more accurate.
Tests are updated to check that this computation is accurate.

The change in `kernel.cpp` is responsible for reducing actual smem usage
for split-K. The changes to `mma_utils` and `test_gpu_tensorcore.cpp`
are needed for adding testing that our expected smem usage matches the
actual usage.

* Issue NVIDIA#1748 (NVIDIA#1749)

Closes Issue NVIDIA#1748.
Apart from `c10::cuda::GetDevice`, no other functionality seems
affected.

* Rename `axes` to `dims` in benchmarks fusion definitions (NVIDIA#1751)

Changes the kwarg `axes` to `dims` following the API change in PR NVIDIA#1741.

* Bump matmul benchmark checkMatch() tolerance (NVIDIA#1747)

This is necessary due to recent switch to HSH

Fixes NVIDIA#1746

* linter

* change guard USE_DISTRIBUTED to NVFUSER_DISTRIBUTED in test/test_multidevice_sharding.cpp

* linting

* linter and cleanup

* remove allocator.h/cpp files

* Device index patch (NVIDIA#1752)

Fixes NVIDIA#1748 

guard c10::cuda::GetDevice API change on TORCH_VERSION

with this change, it ensures that we can build against stable release `<
2.2.0`, as well as TOT after
pytorch/pytorch#119142

For 2.3.0 nightly, if someone accidentally checkout a commit before the
patch, the build will still fail.

* fixing multidevice build (NVIDIA#1753)

API change coming from pytorch/pytorch#119421

* patching API GUARD (NVIDIA#1754)

patching API version guard so we'll still be able to build against older
pytorch version.

* Add a visitor for ValGraph (NVIDIA#1713)

Used in the loop promotion analysis. Extracted from NVIDIA#32

* empty commit for triggering CI

---------

Co-authored-by: Liqiang Lu <116412316+liqiangxl@users.noreply.github.com>
Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com>
Co-authored-by: Priya Mishra <52657555+Priya2698@users.noreply.github.com>
Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
Co-authored-by: Tom Fogal <60981+tfogal@users.noreply.github.com>
Co-authored-by: jjsjann123 <jiej@nvidia.com>
Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com>
Co-authored-by: Naoya Maruyama <naoyam@users.noreply.github.com>
Co-authored-by: Meghan Cowan <mcowan@nvidia.com>
Co-authored-by: Ryan Spring <rspring@nvidia.com>
This breaks the tests in a way that is probably related to the breakage
of FusionExpandView2. Note that the kernels compile and run in this case
and the output is the proper size and stride. However, the kernel is
incorrect.
Comment on lines +2320 to +2329
TensorView* in = makeContigConcreteTensor({4, 5});
fusion.addInput(in);
TensorView* out = broadcast(in, {false, false, true});
out = expand(
out,
{IrBuilder::create<Val>(4),
IrBuilder::create<Val>(5),
IrBuilder::create<Val>(6)});
out = reshape(out, {4, 5, 6}, {4, 5, 2, 3});
fusion.addOutput(out);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what schedulePointwise gives us:

T5_l[ iblockIdx.x40{( ceilDiv(( ceilDiv(( ceilDiv(( 4 * ( 5 * ( 2 * ( ceilDiv(6, 2) ) ) ) ), 128) ), 1) ), 1) )}, iUS41{1}, iS39{1}, ithreadIdx.x37{128} ] ca_pos( 4 ) produce_pos( 4 )
 root domain : (iS8{4}, iS9{5}, bS11{1 ex 6}rf)
  Outer split: bS11{1 ex 6}rf by factor 2 -> bS12{1 ex 2}rf, bS13{1 ex ( ceilDiv(6, 2) )}rf, start offset: 0, stop offset: 0
 rfactor domain : (iS8{4}, iS9{5}, bS12{1 ex 2}rf, bS13{1 ex ( ceilDiv(6, 2) )}rf)
 contiguity: t t n n
  Merge: bS12{1 ex 2}rf and bS13{1 ex ( ceilDiv(6, 2) )}rf -> bS33{( 1 * 1 ) ex ( 2 * ( ceilDiv(6, 2) ) )}
  Merge: iS9{5} and bS33{( 1 * 1 ) ex ( 2 * ( ceilDiv(6, 2) ) )} -> iS34{( 5 * ( 2 * ( ceilDiv(6, 2) ) ) )}
  Merge: iS8{4} and iS34{( 5 * ( 2 * ( ceilDiv(6, 2) ) ) )} -> iS35{( 4 * ( 5 * ( 2 * ( ceilDiv(6, 2) ) ) ) )}
  Split: iS35{( 4 * ( 5 * ( 2 * ( ceilDiv(6, 2) ) ) ) )} by factor 128 -> iS36{( ceilDiv(( 4 * ( 5 * ( 2 * ( ceilDiv(6, 2) ) ) ) ), 128) )}, ithreadIdx.x37{128}, start offset: 0, stop offset: 0
  Split: iS36{( ceilDiv(( 4 * ( 5 * ( 2 * ( ceilDiv(6, 2) ) ) ) ), 128) )} by factor 1 -> iS38{( ceilDiv(( ceilDiv(( 4 * ( 5 * ( 2 * ( ceilDiv(6, 2) ) ) ) ), 128) ), 1) )}, iS39{1}, start offset: 0, stop offset: 0
  Split: iS38{( ceilDiv(( ceilDiv(( 4 * ( 5 * ( 2 * ( ceilDiv(6, 2) ) ) ) ), 128) ), 1) )} by factor 1 -> iblockIdx.x40{( ceilDiv(( ceilDiv(( ceilDiv(( 4 * ( 5 * ( 2 * ( ceilDiv(6, 2) ) ) ) ), 128) ), 1) ), 1) )}, iUS41{1}, start offset: 0, stop offset: 0
 leaf domain : (iblockIdx.x40{( ceilDiv(( ceilDiv(( ceilDiv(( 4 * ( 5 * ( 2 * ( ceilDiv(6, 2) ) ) ) ), 128) ), 1) ), 1) )}, iUS41{1}, iS39{1}, ithreadIdx.x37{128})

That looks OK to me but this is the CUDA kernel that's generated:

__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 2, 2> T0, Tensor<float, 4, 4> T3) {
  nvfuser_index_t i0;
  i0 = ((nvfuser_index_t)threadIdx.x) + (128LL * ((nvfuser_index_t)blockIdx.x));
  nvfuser_index_t i1;
  i1 = i0 % 30LL;
  if (((i0 < 120LL) && (i1 < 5LL))) {
    float T4[1LL];
    T4[0LL] = 0LL;
    T4[0LL]
       = T0[((5LL * (i0 / 30LL)) + i1)];
    float T1[1LL];
    T1[0LL]
       = T4[0LL];
    float T2[1LL];
    T2[0LL]
       = T1[0LL];
    float T5[1LL];
    T5[0LL]
       = T2[0LL];
    T3[i0]
       = T5[0LL];
  }
}

i0 is the index into a 120-element vector, as if the [4, 5, 6] sized output was unexpanded. It is used to index the output store. The load and store are predicated such that i0 % 30 < 5, but that's not enough to guarantee a proper store to an array of size [4, 5]; for example, i0 = 90 corresponds to i1 = i0 % 30 == 0 so we will load element T0[5 * (i0 / 30) + i1] = T0[15] and write it to T3[90], which is actually out of bounds assuming we have allocated the output properly as size 20. We don't see this prior to this PR since the output is unexpanded and the i1 < 5 predicate is absent.

I think this might show that expanded dst indexing is not working properly, as I would expect us to use the same index for T0 and T3 here.

@naoyam
Copy link
Collaborator

naoyam commented Apr 3, 2024

Could you please summarize what the current issues are?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Split/merged expanded broadcast dimensions can remain broadcast.

4 participants