-
Notifications
You must be signed in to change notification settings - Fork 25
Closed
Description
Enables fusing the ops into a op(dot_general, dot_general).
module @reactant_syr2k attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<f32> {enzymexla.memory_effects = []}, %arg1: tensor<f32> {enzymexla.memory_effects = []}, %arg2: tensor<64x64xf32> {enzymexla.memory_effects = []}, %arg3: tensor<64x64xf32> {enzymexla.memory_effects = []}, %arg4: tensor<64x64xf32> {enzymexla.memory_effects = [], tf.aliasing_output = 0 : i32}) -> tensor<64x64xf32> attributes {enzymexla.memory_effects = []} {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%0 = stablehlo.reshape %arg3 : (tensor<64x64xf32>) -> tensor<64x64x1x1xf32>
%1 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<f32>) -> tensor<64x64x1x1xf32>
%2 = stablehlo.multiply %0, %1 : tensor<64x64x1x1xf32>
%3 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<f32>) -> tensor<64x64xf32>
%4 = stablehlo.multiply %arg2, %3 : tensor<64x64xf32>
%5 = stablehlo.broadcast_in_dim %arg2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64x64x1x1xf32>
%6 = stablehlo.broadcast_in_dim %2, dims = [1, 2, 3, 4] : (tensor<64x64x1x1xf32>) -> tensor<64x64x64x1x1xf32>
%7 = stablehlo.multiply %6, %5 : tensor<64x64x64x1x1xf32>
%8 = stablehlo.broadcast_in_dim %4, dims = [1, 2] : (tensor<64x64xf32>) -> tensor<64x64x64x1x1xf32>
%9 = stablehlo.broadcast_in_dim %arg3, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64x64x1x1xf32>
%10 = stablehlo.multiply %8, %9 : tensor<64x64x64x1x1xf32>
%11 = stablehlo.add %10, %7 : tensor<64x64x64x1x1xf32>
%12 = stablehlo.reshape %11 : (tensor<64x64x64x1x1xf32>) -> tensor<64x1x64x64x1x1xf32>
%13 = stablehlo.reduce(%12 init: %cst) applies stablehlo.add across dimensions = [2, 1, 5] : (tensor<64x1x64x64x1x1xf32>, tensor<f32>) -> tensor<64x64x1xf32>
%14 = stablehlo.broadcast_in_dim %arg4, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64x1xf32>
%15 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<f32>) -> tensor<64x64x1xf32>
%16 = stablehlo.multiply %14, %15 : tensor<64x64x1xf32>
%17 = stablehlo.add %13, %16 : tensor<64x64x1xf32>
%18 = stablehlo.reshape %17 {enzymexla.symmetric_matrix = [#enzymexla<guaranteed NOTGUARANTEED>]} : (tensor<64x64x1xf32>) -> tensor<64x64xf32>
%19 = stablehlo.transpose %18, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
return %19 : tensor<64x64xf32>
}
}Metadata
Metadata
Assignees
Labels
No labels