Skip to content

Missing Reduction: maxf-variant #11

@manbearian

Description

@manbearian

created from #7.

A reduction sequence not yet supported; i'm not sure what the original Triton code was, but it looks like some variant of maxf.

repros.zip

triton-shared-opt -triton-to-linalg 15.mlir
triton-shared-opt -triton-to-linalg 18.mlir
triton-shared-opt -triton-to-linalg 20.mlir
triton-shared-opt -triton-to-linalg 22.mlir
triton-shared-opt -triton-to-linalg 29.mlir
triton-shared-opt -triton-to-linalg 31.mlir
triton-shared-opt -triton-to-linalg 39.mlir
triton-shared-opt -triton-to-linalg 46.mlir
triton-shared-opt -triton-to-linalg 47.mlir
triton-shared-opt -triton-to-linalg 58.mlir
triton-shared-opt -triton-to-linalg 70.mlir
triton-shared-opt -triton-to-linalg 75.mlir

Error output:

+++/home/ianb/test/ttirs_linalg_failed/15.mlir
/home/ianb/test/ttirs_linalg_failed/15.mlir:30:11: error: Only support lowering reduction with body containing 1 max(i/f) or addf.
    %21 = "tt.reduce"(%20) <{axis = 1 : i32}> ({
          ^
/home/ianb/test/ttirs_linalg_failed/15.mlir:30:11: note: see current operation: 
%65 = "tt.reduce"(%64) <{axis = 1 : i32}> ({
^bb0(%arg11: f32, %arg12: f32):
  %80 = "arith.cmpf"(%arg11, %arg12) <{predicate = 2 : i64}> : (f32, f32) -> i1
  %81 = "arith.cmpf"(%arg11, %arg11) <{predicate = 13 : i64}> : (f32, f32) -> i1
  %82 = "arith.ori"(%80, %81) : (i1, i1) -> i1
  %83 = "arith.select"(%82, %arg11, %arg12) : (i1, f32, f32) -> f32
  "tt.reduce.return"(%83) : (f32) -> ()
}) : (tensor<16x128xf32>) -> tensor<16xf32>
/home/ianb/test/ttirs_linalg_failed/15.mlir:30:11: error: failed to legalize operation 'tt.reduce'
    %21 = "tt.reduce"(%20) <{axis = 1 : i32}> ({
          ^
/home/ianb/test/ttirs_linalg_failed/15.mlir:30:11: note: see current operation: 
%65 = "tt.reduce"(%64) <{axis = 1 : i32}> ({
^bb0(%arg11: f32, %arg12: f32):
  %80 = "arith.cmpf"(%arg11, %arg12) <{predicate = 2 : i64}> : (f32, f32) -> i1
  %81 = "arith.cmpf"(%arg11, %arg11) <{predicate = 13 : i64}> : (f32, f32) -> i1
  %82 = "arith.ori"(%80, %81) : (i1, i1) -> i1
  %83 = "arith.select"(%82, %arg11, %arg12) : (i1, f32, f32) -> f32
  "tt.reduce.return"(%83) : (f32) -> ()
}) : (tensor<16x128xf32>) -> tensor<16xf32>

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions