Skip to content

Enable evaluation of all ternary ops in ExpressionEvaluator#895

Merged
jacobhinkle merged 6 commits intomainfrom
evaluate_ternary_ops
Sep 20, 2023
Merged

Enable evaluation of all ternary ops in ExpressionEvaluator#895
jacobhinkle merged 6 commits intomainfrom
evaluate_ternary_ops

Conversation

@jacobhinkle
Copy link
Collaborator

This makes clamp, lerp, threshold, and where evaluable by the ExpressionEvaluator with or without PrecomputedValues.

@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle
Copy link
Collaborator Author

Test failures are unrelated: random div op test violated tolerance and codegen diff index non-determinism.

@jacobhinkle jacobhinkle marked this pull request as ready for review September 19, 2023 12:30
Copy link
Collaborator

@zasdfgbnm zasdfgbnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you check CPU overhead? I don't think it will make different, but just to confirm.

@jacobhinkle
Copy link
Collaborator Author

Could you check CPU overhead? I don't think it will make different, but just to confirm.

Is there a benchmark you've used to check this for ExpressionEvaluator before?

@zasdfgbnm
Copy link
Collaborator

Could you check CPU overhead? I don't think it will make different, but just to confirm.

Is there a benchmark you've used to check this for ExpressionEvaluator before?

I am talking about the same benchmark as #876 (comment)

@jacobhinkle
Copy link
Collaborator Author

Could you check CPU overhead? I don't think it will make different, but just to confirm.

Is there a benchmark you've used to check this for ExpressionEvaluator before?

I am talking about the same benchmark as #876 (comment)

Looks like no change:

[----------------------------------------  ---------------------------------------]
           |  " 2023-09-19 5ff26b9e"  |  "evaluate_ternary_ops 2023-09-19 b2897a6e"
1 threads: ------------------------------------------------------------------------
      2    |           17.4           |                     17.5                   
      4    |           21.5           |                     22.3                   
      8    |           31.5           |                     31.7                   
      16   |           50.6           |                     51.3                   
      32   |           95.1           |                     95.6                   
      64   |          184.8           |                    186.2                   
      128  |          407.8           |                    401.6                   

Times are in microseconds (us)

@jacobhinkle jacobhinkle merged commit a26f5ec into main Sep 20, 2023
@jacobhinkle jacobhinkle deleted the evaluate_ternary_ops branch September 20, 2023 15:41
jacobhinkle added a commit that referenced this pull request Sep 27, 2023
This PR normalizes the inputs to `slice` in order to mimic the semantics
of numpy/PyTorch slicing. For an axis with extent `ext`, if we receive a
slice of `(start, stop, step)` we normalize it to `(norm_start,
norm_stop, step)` where
```
norm_start = max(0, start < 0 ? start + ext : start);
norm_stop = max(norm_start, min(ext, stop < 0 ? stop + ext : stop));
```

Specific changes in this PR:
- Form the above expressions in the `slice` op.
- Add shmoo tests that test various scenarios with constant and input
size slices.

The simple Fusion in the input range test prints like this:
```
Inputs:
  T0_g[ iS0{9} ], float
  i3, nvfuser_index_t
  i4, nvfuser_index_t
Outputs:
  T1_g[ ?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf ], float

%kernel_math {
b7 = i3 < 0;
i5 = i3 + 9;
i9 = where(b7, i5, i3);
i11 = fmax(0, i9);
b15 = i4 < 0;
i13 = i4 + 9;
i17 = where(b15, i13, i4);
i19 = fmin(9, i17);
i21 = fmax(i11, i19);
T1_g[ ?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf ]
   = slice( T0_g[ iS0{9} ], { {i11, i21, 1} } )
}

T0_g[ iS0{9} ]
 root domain : (iS0{9})
 contiguity: f
 leaf domain : (iS0{9})
T1_g[ ?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf ]
 root domain : (iS1{9}rf)
  Resize: iS1{9}rf by ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) and ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) -> ?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf
 rfactor domain : (?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf)
 contiguity: t
 leaf domain : (?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf)
```
resulting in the following CUDA kernel:
```c++
__global__ void kernel1(Tensor<float, 1, 1> T0, nvfuser_index_t i0, nvfuser_index_t i1, Tensor<float, 1, 1> T1) {
  nvfuser_index_t i2;
  i2 = i0 + 9;
  bool b3;
  b3 = i0 < 0;
  nvfuser_index_t i4;
  i4 = b3 ? i2 : i0;
  nvfuser_index_t i5;
  i5 = max(0, i4);
  nvfuser_index_t i6;
  i6 = i1 + 9;
  bool b7;
  b7 = i1 < 0;
  nvfuser_index_t i8;
  i8 = b7 ? i6 : i1;
  nvfuser_index_t i9;
  i9 = min(9, i8);
  nvfuser_index_t i10;
  i10 = max(i5, i9);
  nvfuser_index_t i11;
  i11 = (-i5) + i10;
  nvfuser_index_t i12;
  i12 = i5 * T0.alloc_stride[0];
  #pragma unroll 1
  for(nvfuser_index_t i13 = 0; i13 < i11; ++i13) {
    T1[i13]
       = T0[(i12 + (T0.alloc_stride[0] * i13))];
  }
}
```

This PR does NOT simplify these expressions for non-constant inputs.
This can be done at concretization, which will be left for a follow-up
PR.

Stacked on #892 and #895.

Fixes #439. Fixes #52.

---------

Co-authored-by: Naoya Maruyama <naoyam@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants