Skip to content

Preserve symbolic reshape extents after concretization#912

Draft
jacobhinkle wants to merge 5 commits intomainfrom
preserve_symbolic_reshape_extents
Draft

Preserve symbolic reshape extents after concretization#912
jacobhinkle wants to merge 5 commits intomainfrom
preserve_symbolic_reshape_extents

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Sep 20, 2023

Stacked on #840.

This PR changes how concretization behaves with respect to dynamic reshape operations. When we receive a dynamic reshape, we wind up with a symbolic fusion that looks something like this (DynamicTransform3_CUDA):

Fusion before concretization:
Inputs:
  T0_g[ iS0{i0}, iS1{i2} ], float
  T1_g[ iS2{i3}, iS3{i4} ], float
Outputs:
  T3_g[ iS8{i3}, iS9{i4} ], float

%kernel_math {
T2_l[ ?S6{( (nvfuser_index_t)(i5) )}rf, ?S7{( (nvfuser_index_t)(i6) )}rf ] = view( T0_g[ iS0{i0}, iS1{i2} ] )
T3_g[ iS8{i3}, iS9{i4} ]
   = T1_g[ iS2{i3}, iS3{i4} ]
   + T2_l[ ?S6{( (nvfuser_index_t)(i5) )}rf, ?S7{( (nvfuser_index_t)(i6) )}rf ];
}

Concretized Fusion:
Inputs:  T0_g[ iS0{i0}, iS1{i2} ], float  T1_g[ iS2{i3}, iS3{i4} ], float
Outputs:
  T3_g[ iS8{i3}, iS9{i4} ], float

%kernel_math {T4_l[ iS15{3}rf, iS16{( ceilDiv(( i0 * i2 ), 3) )}rf ] = view( T0_g[ iS0{i0}, iS1{i2} ] )
T3_g[ iS8{i3}, iS9{i4} ]
   = T1_g[ iS2{i3}, iS3{i4} ]
   + T4_l[ iS15{3}rf, iS16{( ceilDiv(( i0 * i2 ), 3) )}rf ];
}

Notice the ceilDiv expressions that are propagated downstream of the concretized reshape. After this PR, the above Fusion concretizes as:

Fusion before concretization:
Inputs:
  T0_g[ iS0{i0}, iS1{i2} ], float
  T1_g[ iS2{i3}, iS3{i4} ], float
Outputs:
  T3_g[ iS8{i3}, iS9{i4} ], float

%kernel_math {
T4_l[ iS17{( (nvfuser_index_t)(i5) )}rf, iS18{( (nvfuser_index_t)(i6) )}rf ] = view( T0_g[ iS0{i0}, iS1{i2} ] )
T3_g[ iS8{i3}, iS9{i4} ]
   = T1_g[ iS2{i3}, iS3{i4} ]
   + T4_l[ iS17{( (nvfuser_index_t)(i5) )}rf, iS18{( (nvfuser_index_t)(i6) )}rf ];
}

Concretized Fusion:
Inputs:
  T0_g[ iS0{i0}, iS1{i2} ], float
  T1_g[ iS2{i3}, iS3{i4} ], float
Outputs:
  T3_g[ iS8{i3}, iS9{i4} ], float

%kernel_math {
T4_l[ iS17{( (nvfuser_index_t)(i5) )}rf, iS18{( (nvfuser_index_t)(i6) )}rf ] = view( T0_g[ iS0{i0}, iS1{i2} ] )
T3_g[ iS8{i3}, iS9{i4} ]
   = T1_g[ iS2{i3}, iS3{i4} ]
   + T4_l[ iS17{( (nvfuser_index_t)(i5) )}rf, iS18{( (nvfuser_index_t)(i6) )}rf ];
}

New missing scalars

This change is stacked on #840 since in some cases, we can lose track of these dynamic extent scalars. For example, in DynamicTransformIssue418_CUDA, we have this symbolic Fusion and concretization:

Fusion before concretization:
Inputs:
  T0_g[ iS0{i0}, iS1{i2}, iS2{i3}, iS3{i4} ], float
  i5, int64_t
Outputs:
  T7_g[ ?S35{i0}, ?S36{( i2 / i5 )}, bS37{1}, bS38{1}, bS39{1} ], float
  T6_g[ ?S30{i0}, ?S31{( i2 / i5 )}, bS32{1}, bS33{1}, bS34{1} ], float

%kernel_math {
T1_l[ ?S8{i0}rf, ?S9{( i2 / i5 )}rf, ?S10{( (nvfuser_index_t)(i5) )}rf, ?S11{i3}rf, ?S12{i4}rf ] = view( T0_g[ iS0{i0}, iS1{i2}, iS2{i3}, iS3{i4} ] )
T2_l[ ?S13{i0}, ?S14{( i2 / i5 )}, rS15{( (nvfuser_index_t)(i5) )}, rS16{i3}, rS17{i4} ](Avg),
T3_l[ ?S18{i0}, ?S19{( i2 / i5 )}, rS20{( (nvfuser_index_t)(i5) )}, rS21{i3}, rS22{i4} ](Var),
T4_l[ ?S23{i0}, ?S24{( i2 / i5 )}, rS25{( (nvfuser_index_t)(i5) )}, rS26{i3}, rS27{i4} ](Count)
 = Welford ( T1_l[ ?S8{i0}rf, ?S9{( i2 / i5 )}rf, ?S10{( (nvfuser_index_t)(i5) )}rf, ?S11{i3}rf, ?S12{i4}rf ](Avg),
  allreduce = false )
T7_g[ ?S35{i0}, ?S36{( i2 / i5 )}, bS37{1}, bS38{1}, bS39{1} ]
   = broadcast( T2_l[ ?S13{i0}, ?S14{( i2 / i5 )}, rS15{( (nvfuser_index_t)(i5) )}, rS16{i3}, rS17{i4} ] )
i16 = (nvfuser_index_t)(i5);
d19 = (double)(i16);
d21 = double(1) * d19;
d23 = (double)(i3);
d25 = d21 * d23;
d27 = (double)(i4);
d29 = d25 * d27;
d33 = (double)(0);
d35 = d29 - d33;
d37 = (double)(0);
b39 = d35 >= d37;
d41 = (double)(0);
d43 = where(b39, d35, d41);
d49 = reciprocal(d43);
T5_l[ ?S28{i0}, ?S29{( i2 / i5 )} ]
   = T3_l[ ?S18{i0}, ?S19{( i2 / i5 )}, rS20{( (nvfuser_index_t)(i5) )}, rS21{i3}, rS22{i4} ]
   * d49;
T6_g[ ?S30{i0}, ?S31{( i2 / i5 )}, bS32{1}, bS33{1}, bS34{1} ]
   = broadcast( T5_l[ ?S28{i0}, ?S29{( i2 / i5 )} ] )
}

Concretized Fusion:
Inputs:
  T0_g[ iS0{i0}, iS1{i2}, iS2{i3}, iS3{i4} ], float
  i5, int64_tOutputs:
  T7_g[ iS51{i0}, iS52{( i2 / i5 )}, bS37{1}, bS38{1}, bS39{1} ], float
  T6_g[ iS57{i0}, iS58{( i2 / i5 )}, bS32{1}, bS33{1}, bS34{1} ], float

%kernel_math {
T8_l[ iS40{i0}, iS47{( i2 / i5 )}rf, iS48{( (nvfuser_index_t)(i5) )}rf, iS42{i3}, iS43{i4} ] = view( T0_g[ iS0{i0}, iS1{i2}, iS2{i3}, iS3{i4} ] )
T2_l[ iS49{i0}, iS50{( i2 / i5 )}, rS15{( (nvfuser_index_t)(i5) )}, rS16{i3}, rS17{i4} ](Avg),
T3_l[ iS53{i0}, iS54{( i2 / i5 )}, rS20{( (nvfuser_index_t)(i5) )}, rS21{i3}, rS22{i4} ](Var),
T4_l[ iS59{i0}, iS60{( i2 / i5 )}, rS25{( (nvfuser_index_t)(i5) )}, rS26{i3}, rS27{i4} ](Count)
 = Welford ( T8_l[ iS40{i0}, iS47{( i2 / i5 )}rf, iS48{( (nvfuser_index_t)(i5) )}rf, iS42{i3}, iS43{i4} ](Avg),
  allreduce = false )
T7_g[ iS51{i0}, iS52{( i2 / i5 )}, bS37{1}, bS38{1}, bS39{1} ]
   = broadcast( T2_l[ iS49{i0}, iS50{( i2 / i5 )}, rS15{( (nvfuser_index_t)(i5) )}, rS16{i3}, rS17{i4} ] )
i16 = (nvfuser_index_t)(i5);
d19 = (double)(i16);
d21 = double(1) * d19;
d23 = (double)(i3);
d25 = d21 * d23;
d27 = (double)(i4);
d29 = d25 * d27;
d33 = (double)(0);
d35 = d29 - d33;
d37 = (double)(0);
b39 = d35 >= d37;
d41 = (double)(0);
d43 = where(b39, d35, d41);
d49 = reciprocal(d43);
T5_l[ iS55{i0}, iS56{( i2 / i5 )} ]
   = T3_l[ iS53{i0}, iS54{( i2 / i5 )}, rS20{( (nvfuser_index_t)(i5) )}, rS21{i3}, rS22{i4} ]
   * d49;
T6_g[ iS57{i0}, iS58{( i2 / i5 )}, bS32{1}, bS33{1}, bS34{1} ]
   = broadcast( T5_l[ iS55{i0}, iS56{( i2 / i5 )} ] )
}

This is segmented at T8_l and the first segment looks like

Inputs:                                                                                                                          
  T0_g[ iS0{i0}, iS1{i2}, iS2{i3}, iS3{i4} ], float                                                                              
Outputs:                                                                                                                         
  T8_g[ iS40{i0}, iS47{( i2 / i5 )}rf, iS48{( (nvfuser_index_t)(i5) )}rf, iS42{i3}, iS43{i4} ], float                            
                                                                                                                                 
%kernel_math {                                                                                                                   
T8_g[ iS40{i0}, iS47{( i2 / i5 )}rf, iS48{( (nvfuser_index_t)(i5) )}rf, iS42{i3}, iS43{i4} ] = view( T0_g[ iS0{i0}, iS1{i2}, iS2{
i3}, iS3{i4} ] )                                                                                                                 
} 

This fails to compile since i5 is not included as an input to this segment. In this case we need to inspect the output of the ViewOp to find missing scalars.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant