-
Notifications
You must be signed in to change notification settings - Fork 721
Open
Labels
err:performancePerformance issuesPerformance issues
Description
See jax-ml/jax#27004 (comment) and jax-ml/jax#10434 (comment) for more details.
Doing a jnp.argsort(some_float_array) feels like a very popular operation in JAX.
But while with 2002777 doing the same with integers is now fast (as it uses cub), doing an argsort of a float array, doesn't seem to trigger the SortRewriter:
$ TF_CPP_MAX_VLOG_LEVEL=2 TF_CPP_MIN_LOG_LEVEL=0 python -c "import jax;jax.jit(jax.numpy.argsort).lower(jax.numpy.zeros(int(1e7),dtype=jax.numpy.float32)).compile()" 2>&1 | grep "sort_rewriter"
I1219 18:23:57.655937 944263 sort_rewriter.cc:430] Sort instruction: sort.0
I1219 18:23:57.655951 944263 sort_rewriter.cc:456] Only simple compare computations are supported
I1219 18:23:57.656872 944263 sort_rewriter.cc:430] Sort instruction: sort.0
I1219 18:23:57.656874 944263 sort_rewriter.cc:456] Only simple compare computations are supported$ python -c "import jax;print(jax.jit(jax.numpy.argsort).lower(jax.numpy.zeros(int(1e7),dtype=jax.numpy.float32)).compile().as_text())"
HloModule jit_argsort, is_scheduled=true, entry_computation_layout={(f32[10000000]{0})->s32[10000000]{0}}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="7a48b9adbe01fb327ce459bd8d1001e8"}
FileNames
1 "<string>"
FunctionNames
1 "<module>"
FileLocations
1 {file_name_id=1 function_name_id=1 line=1 end_line=1 column=17 end_column=100}
StackFrames
1 {file_location_id=1 parent_frame_id=1}
%region_0.1 (sort.4: f32[], sort.5: f32[], sort.6: s32[], sort.7: s32[]) -> pred[] {
%constant_4_0 = s32[] constant(2147483647)
%constant_5 = s32[] constant(0)
%sort.5 = f32[] parameter(1), metadata={op_name="sort"}
%ne.3.0 = pred[] compare(%sort.5, %sort.5), direction=NE, metadata={op_name="ne" stack_frame_id=1}
%constant_3_0 = f32[] constant(0)
%eq.3.0 = pred[] compare(%sort.5, %constant_3_0), direction=EQ, metadata={op_name="eq" stack_frame_id=1}
%select_n.6.0 = f32[] select(%eq.3.0, %constant_3_0, %sort.5), metadata={op_name="select_n" stack_frame_id=1}
%constant_2_0 = f32[] constant(nan)
%select_n.7.0 = f32[] select(%ne.3.0, %constant_2_0, %select_n.6.0), metadata={op_name="select_n" stack_frame_id=1}
%bitcast.1.0 = s32[] bitcast(%select_n.7.0)
%compare.1.0 = pred[] compare(%bitcast.1.0, %constant_5), direction=LT
%xor.1.0 = s32[] xor(%constant_4_0, %bitcast.1.0)
%select.1.0 = s32[] select(%compare.1.0, %xor.1.0, %bitcast.1.0)
%sort.4 = f32[] parameter(0), metadata={op_name="sort"}
%eq.2.0 = pred[] compare(%sort.4, %constant_3_0), direction=EQ, metadata={op_name="eq" stack_frame_id=1}
%select_n.4.0 = f32[] select(%eq.2.0, %constant_3_0, %sort.4), metadata={op_name="select_n" stack_frame_id=1}
%ne.2.0 = pred[] compare(%sort.4, %sort.4), direction=NE, metadata={op_name="ne" stack_frame_id=1}
%select_n.5.0 = f32[] select(%ne.2.0, %constant_2_0, %select_n.4.0), metadata={op_name="select_n" stack_frame_id=1}
%bitcast.2 = s32[] bitcast(%select_n.5.0)
%compare.7 = pred[] compare(%bitcast.2, %constant_5), direction=LT
%xor.4 = s32[] xor(%constant_4_0, %bitcast.2)
%select.5 = s32[] select(%compare.7, %xor.4, %bitcast.2)
%lt_to.0.0 = pred[] compare(%select.1.0, %select.5), direction=LT, metadata={op_name="lt_to" stack_frame_id=1}
%lt_to.1.0 = pred[] compare(%select.5, %select.1.0), direction=LT, metadata={op_name="lt_to" stack_frame_id=1}
%compare.5.0 = pred[] compare(%lt_to.1.0, %lt_to.0.0), direction=EQ
%sort.7 = s32[] parameter(3), metadata={op_name="sort"}
%sort.6 = s32[] parameter(2), metadata={op_name="sort"}
%compare.6.0 = pred[] compare(%sort.6, %sort.7), direction=LT
ROOT %select.4.0 = pred[] select(%compare.5.0, %compare.6.0, %lt_to.1.0)
}
%wrapped_iota_computation () -> s32[10000000] {
ROOT %iota.0.1 = s32[10000000]{0} iota(), iota_dimension=0, metadata={op_name="jit(argsort)/jit(argsort)/iota" stack_frame_id=1}
}
ENTRY %main.3 (a.1: f32[10000000]) -> s32[10000000] {
%a.1 = f32[10000000]{0} parameter(0), metadata={op_name="a"}
%wrapped_iota = s32[10000000]{0} fusion(), kind=kLoop, calls=%wrapped_iota_computation, metadata={op_name="jit(argsort)/jit(argsort)/iota" stack_frame_id=1}
%sort.0.0 = (f32[10000000]{0}, s32[10000000]{0}) sort(%a.1, %wrapped_iota), dimensions={0}, is_stable=true, to_apply=%region_0.1, metadata={op_name="jit(argsort)/jit(argsort)/sort" stack_frame_id=1}
ROOT %sort.2.0 = s32[10000000]{0} get-tuple-element(%sort.0.0), index=1, metadata={op_name="jit(argsort)/jit(argsort)/sort" stack_frame_id=1}
}$ python -c "import jax;print(jax.jit(jax.numpy.argsort).lower(jax.numpy.zeros(int(1e7),dtype=jax.numpy.float32)).as_text())"
module @jit_argsort attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<10000000xf32>) -> (tensor<10000000xi32> {jax.result_info = "result"}) {
%0 = call @argsort(%arg0) : (tensor<10000000xf32>) -> tensor<10000000xi32>
return %0 : tensor<10000000xi32>
}
func.func private @argsort(%arg0: tensor<10000000xf32>) -> tensor<10000000xi32> {
%0 = stablehlo.iota dim = 0 : tensor<10000000xi32>
%1:2 = "stablehlo.sort"(%arg0, %0) <{dimension = 0 : i64, is_stable = true}> ({
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%2 = stablehlo.compare EQ, %arg1, %cst, FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%3 = stablehlo.select %2, %cst_0, %arg1 : tensor<i1>, tensor<f32>
%4 = stablehlo.compare NE, %arg1, %arg1, FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
%cst_1 = stablehlo.constant dense<0x7FC00000> : tensor<f32>
%5 = stablehlo.select %4, %cst_1, %3 : tensor<i1>, tensor<f32>
%cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%6 = stablehlo.compare EQ, %arg2, %cst_2, FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
%cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%7 = stablehlo.select %6, %cst_3, %arg2 : tensor<i1>, tensor<f32>
%8 = stablehlo.compare NE, %arg2, %arg2, FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
%cst_4 = stablehlo.constant dense<0x7FC00000> : tensor<f32>
%9 = stablehlo.select %8, %cst_4, %7 : tensor<i1>, tensor<f32>
%10 = stablehlo.compare LT, %5, %9, TOTALORDER : (tensor<f32>, tensor<f32>) -> tensor<i1>
stablehlo.return %10 : tensor<i1>
}) : (tensor<10000000xf32>, tensor<10000000xi32>) -> (tensor<10000000xf32>, tensor<10000000xi32>)
return %1#1 : tensor<10000000xi32>
}
}This is using JAX 0.8.2 (so XLA c45b8fe)
Metadata
Metadata
Assignees
Labels
err:performancePerformance issuesPerformance issues