-
Notifications
You must be signed in to change notification settings - Fork 25
Open
Description
using Reactant
function looped_accumulate(x)
@trace for i in 3:length(x)
x[i] += x[i - 1]
end
return x
end
x = Reactant.to_rarray(rand(Float32, 128));
@code_hlo looped_accumulate(x)
function looped_accumulate(x)
@trace for i in 3:length(x)
x[i] += 2 * x[i - 1]
end
return x
end
x = Reactant.to_rarray(rand(Float32, 128));
@code_hlo looped_accumulate(x)module @reactant_looped_... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<128xf32> {enzymexla.memory_effects = [], tf.aliasing_output = 0 : i32}) -> tensor<128xf32> attributes {enzymexla.memory_effects = []} {
%cst = stablehlo.constant dense<2.000000e+00> : tensor<1xf32>
%c = stablehlo.constant dense<1> : tensor<i32>
%c_0 = stablehlo.constant dense<0> : tensor<i64>
%c_1 = stablehlo.constant dense<1> : tensor<i64>
%c_2 = stablehlo.constant dense<3> : tensor<i64>
%c_3 = stablehlo.constant dense<126> : tensor<i64>
%0:2 = stablehlo.while(%iterArg = %c_0, %iterArg_4 = %arg0) : tensor<i64>, tensor<128xf32> attributes {enzyme.disable_mincut}
cond {
%1 = stablehlo.compare LT, %iterArg, %c_3 : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %1 : tensor<i1>
} do {
%1 = stablehlo.add %c_2, %iterArg {enzymexla.bounds = [[3, 128]]} : tensor<i64>
%2 = stablehlo.add %iterArg, %c_1 {enzymexla.bounds = [[1, 126]]} : tensor<i64>
%3 = stablehlo.convert %1 {enzymexla.bounds = [[3, 128]]} : (tensor<i64>) -> tensor<i32>
%4 = stablehlo.subtract %3, %c {enzymexla.bounds = [[2, 127]]} : tensor<i32>
%5 = stablehlo.dynamic_slice %arg0, %4, sizes = [1] : (tensor<128xf32>, tensor<i32>) -> tensor<1xf32>
%6 = stablehlo.subtract %1, %c_1 {enzymexla.bounds = [[2, 127]]} : tensor<i64>
%7 = stablehlo.convert %6 {enzymexla.bounds = [[2, 127]]} : (tensor<i64>) -> tensor<i32>
%8 = stablehlo.subtract %7, %c {enzymexla.bounds = [[1, 126]]} : tensor<i32>
%9 = stablehlo.dynamic_slice %iterArg_4, %8, sizes = [1] : (tensor<128xf32>, tensor<i32>) -> tensor<1xf32>
%10 = stablehlo.multiply %cst, %9 : tensor<1xf32>
%11 = stablehlo.add %5, %10 : tensor<1xf32>
%12 = stablehlo.dynamic_update_slice %iterArg_4, %11, %4 : (tensor<128xf32>, tensor<1xf32>, tensor<i32>) -> tensor<128xf32>
stablehlo.return %2, %12 : tensor<i64>, tensor<128xf32>
}
return %0#1 : tensor<128xf32>
}
}Metadata
Metadata
Assignees
Labels
No labels