Enable evaluation of all ternary ops in ExpressionEvaluator#895
Enable evaluation of all ternary ops in ExpressionEvaluator#895jacobhinkle merged 6 commits intomainfrom
ExpressionEvaluator#895Conversation
|
!build |
|
Test failures are unrelated: random div op test violated tolerance and codegen diff index non-determinism. |
zasdfgbnm
left a comment
There was a problem hiding this comment.
Could you check CPU overhead? I don't think it will make different, but just to confirm.
Is there a benchmark you've used to check this for |
I am talking about the same benchmark as #876 (comment) |
Looks like no change: |
This PR normalizes the inputs to `slice` in order to mimic the semantics
of numpy/PyTorch slicing. For an axis with extent `ext`, if we receive a
slice of `(start, stop, step)` we normalize it to `(norm_start,
norm_stop, step)` where
```
norm_start = max(0, start < 0 ? start + ext : start);
norm_stop = max(norm_start, min(ext, stop < 0 ? stop + ext : stop));
```
Specific changes in this PR:
- Form the above expressions in the `slice` op.
- Add shmoo tests that test various scenarios with constant and input
size slices.
The simple Fusion in the input range test prints like this:
```
Inputs:
T0_g[ iS0{9} ], float
i3, nvfuser_index_t
i4, nvfuser_index_t
Outputs:
T1_g[ ?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf ], float
%kernel_math {
b7 = i3 < 0;
i5 = i3 + 9;
i9 = where(b7, i5, i3);
i11 = fmax(0, i9);
b15 = i4 < 0;
i13 = i4 + 9;
i17 = where(b15, i13, i4);
i19 = fmin(9, i17);
i21 = fmax(i11, i19);
T1_g[ ?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf ]
= slice( T0_g[ iS0{9} ], { {i11, i21, 1} } )
}
T0_g[ iS0{9} ]
root domain : (iS0{9})
contiguity: f
leaf domain : (iS0{9})
T1_g[ ?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf ]
root domain : (iS1{9}rf)
Resize: iS1{9}rf by ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) and ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) -> ?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf
rfactor domain : (?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf)
contiguity: t
leaf domain : (?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf)
```
resulting in the following CUDA kernel:
```c++
__global__ void kernel1(Tensor<float, 1, 1> T0, nvfuser_index_t i0, nvfuser_index_t i1, Tensor<float, 1, 1> T1) {
nvfuser_index_t i2;
i2 = i0 + 9;
bool b3;
b3 = i0 < 0;
nvfuser_index_t i4;
i4 = b3 ? i2 : i0;
nvfuser_index_t i5;
i5 = max(0, i4);
nvfuser_index_t i6;
i6 = i1 + 9;
bool b7;
b7 = i1 < 0;
nvfuser_index_t i8;
i8 = b7 ? i6 : i1;
nvfuser_index_t i9;
i9 = min(9, i8);
nvfuser_index_t i10;
i10 = max(i5, i9);
nvfuser_index_t i11;
i11 = (-i5) + i10;
nvfuser_index_t i12;
i12 = i5 * T0.alloc_stride[0];
#pragma unroll 1
for(nvfuser_index_t i13 = 0; i13 < i11; ++i13) {
T1[i13]
= T0[(i12 + (T0.alloc_stride[0] * i13))];
}
}
```
This PR does NOT simplify these expressions for non-constant inputs.
This can be done at concretization, which will be left for a follow-up
PR.
Stacked on #892 and #895.
Fixes #439. Fixes #52.
---------
Co-authored-by: Naoya Maruyama <naoyam@users.noreply.github.com>
This makes
clamp,lerp,threshold, andwhereevaluable by theExpressionEvaluatorwith or withoutPrecomputedValues.