module @reactant_kernel_... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<2048x1024x256xf32> {enzymexla.memory_effects = [], tf.aliasing_output = 0 : i32}, %arg1: tensor<2048x2048xf32> {enzymexla.memory_effects = []}) -> tensor<2048x1024x256xf32> attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.reshape %arg0 : (tensor<2048x1024x256xf32>) -> tensor<2048x262144xf32>
%1 = stablehlo.dot_general %arg1, %0, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2048x2048xf32>, tensor<2048x262144xf32>) -> tensor<2048x262144xf32>
%2 = stablehlo.reshape %1 : (tensor<2048x262144xf32>) -> tensor<2048x1024x256xf32>
return %2 : tensor<2048x1024x256xf32>
}
}