Skip to content

Support argsort of floating point values with cub #35587

@Findus23

Description

@Findus23

See jax-ml/jax#27004 (comment) and jax-ml/jax#10434 (comment) for more details.

Doing a jnp.argsort(some_float_array) feels like a very popular operation in JAX.

But while with 2002777 doing the same with integers is now fast (as it uses cub), doing an argsort of a float array, doesn't seem to trigger the SortRewriter:

$ TF_CPP_MAX_VLOG_LEVEL=2 TF_CPP_MIN_LOG_LEVEL=0 python -c "import jax;jax.jit(jax.numpy.argsort).lower(jax.numpy.zeros(int(1e7),dtype=jax.numpy.float32)).compile()" 2>&1 | grep "sort_rewriter"
I1219 18:23:57.655937  944263 sort_rewriter.cc:430] Sort instruction: sort.0
I1219 18:23:57.655951  944263 sort_rewriter.cc:456] Only simple compare computations are supported
I1219 18:23:57.656872  944263 sort_rewriter.cc:430] Sort instruction: sort.0
I1219 18:23:57.656874  944263 sort_rewriter.cc:456] Only simple compare computations are supported
$ python -c "import jax;print(jax.jit(jax.numpy.argsort).lower(jax.numpy.zeros(int(1e7),dtype=jax.numpy.float32)).compile().as_text())"
HloModule jit_argsort, is_scheduled=true, entry_computation_layout={(f32[10000000]{0})->s32[10000000]{0}}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="7a48b9adbe01fb327ce459bd8d1001e8"}

FileNames
1 "<string>"

FunctionNames
1 "<module>"

FileLocations
1 {file_name_id=1 function_name_id=1 line=1 end_line=1 column=17 end_column=100}

StackFrames
1 {file_location_id=1 parent_frame_id=1}


%region_0.1 (sort.4: f32[], sort.5: f32[], sort.6: s32[], sort.7: s32[]) -> pred[] {
  %constant_4_0 = s32[] constant(2147483647)
  %constant_5 = s32[] constant(0)
  %sort.5 = f32[] parameter(1), metadata={op_name="sort"}
  %ne.3.0 = pred[] compare(%sort.5, %sort.5), direction=NE, metadata={op_name="ne" stack_frame_id=1}
  %constant_3_0 = f32[] constant(0)
  %eq.3.0 = pred[] compare(%sort.5, %constant_3_0), direction=EQ, metadata={op_name="eq" stack_frame_id=1}
  %select_n.6.0 = f32[] select(%eq.3.0, %constant_3_0, %sort.5), metadata={op_name="select_n" stack_frame_id=1}
  %constant_2_0 = f32[] constant(nan)
  %select_n.7.0 = f32[] select(%ne.3.0, %constant_2_0, %select_n.6.0), metadata={op_name="select_n" stack_frame_id=1}
  %bitcast.1.0 = s32[] bitcast(%select_n.7.0)
  %compare.1.0 = pred[] compare(%bitcast.1.0, %constant_5), direction=LT
  %xor.1.0 = s32[] xor(%constant_4_0, %bitcast.1.0)
  %select.1.0 = s32[] select(%compare.1.0, %xor.1.0, %bitcast.1.0)
  %sort.4 = f32[] parameter(0), metadata={op_name="sort"}
  %eq.2.0 = pred[] compare(%sort.4, %constant_3_0), direction=EQ, metadata={op_name="eq" stack_frame_id=1}
  %select_n.4.0 = f32[] select(%eq.2.0, %constant_3_0, %sort.4), metadata={op_name="select_n" stack_frame_id=1}
  %ne.2.0 = pred[] compare(%sort.4, %sort.4), direction=NE, metadata={op_name="ne" stack_frame_id=1}
  %select_n.5.0 = f32[] select(%ne.2.0, %constant_2_0, %select_n.4.0), metadata={op_name="select_n" stack_frame_id=1}
  %bitcast.2 = s32[] bitcast(%select_n.5.0)
  %compare.7 = pred[] compare(%bitcast.2, %constant_5), direction=LT
  %xor.4 = s32[] xor(%constant_4_0, %bitcast.2)
  %select.5 = s32[] select(%compare.7, %xor.4, %bitcast.2)
  %lt_to.0.0 = pred[] compare(%select.1.0, %select.5), direction=LT, metadata={op_name="lt_to" stack_frame_id=1}
  %lt_to.1.0 = pred[] compare(%select.5, %select.1.0), direction=LT, metadata={op_name="lt_to" stack_frame_id=1}
  %compare.5.0 = pred[] compare(%lt_to.1.0, %lt_to.0.0), direction=EQ
  %sort.7 = s32[] parameter(3), metadata={op_name="sort"}
  %sort.6 = s32[] parameter(2), metadata={op_name="sort"}
  %compare.6.0 = pred[] compare(%sort.6, %sort.7), direction=LT
  ROOT %select.4.0 = pred[] select(%compare.5.0, %compare.6.0, %lt_to.1.0)
}

%wrapped_iota_computation () -> s32[10000000] {
  ROOT %iota.0.1 = s32[10000000]{0} iota(), iota_dimension=0, metadata={op_name="jit(argsort)/jit(argsort)/iota" stack_frame_id=1}
}

ENTRY %main.3 (a.1: f32[10000000]) -> s32[10000000] {
  %a.1 = f32[10000000]{0} parameter(0), metadata={op_name="a"}
  %wrapped_iota = s32[10000000]{0} fusion(), kind=kLoop, calls=%wrapped_iota_computation, metadata={op_name="jit(argsort)/jit(argsort)/iota" stack_frame_id=1}
  %sort.0.0 = (f32[10000000]{0}, s32[10000000]{0}) sort(%a.1, %wrapped_iota), dimensions={0}, is_stable=true, to_apply=%region_0.1, metadata={op_name="jit(argsort)/jit(argsort)/sort" stack_frame_id=1}
  ROOT %sort.2.0 = s32[10000000]{0} get-tuple-element(%sort.0.0), index=1, metadata={op_name="jit(argsort)/jit(argsort)/sort" stack_frame_id=1}
}
$ python -c "import jax;print(jax.jit(jax.numpy.argsort).lower(jax.numpy.zeros(int(1e7),dtype=jax.numpy.float32)).as_text())"
module @jit_argsort attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<10000000xf32>) -> (tensor<10000000xi32> {jax.result_info = "result"}) {
    %0 = call @argsort(%arg0) : (tensor<10000000xf32>) -> tensor<10000000xi32>
    return %0 : tensor<10000000xi32>
  }
  func.func private @argsort(%arg0: tensor<10000000xf32>) -> tensor<10000000xi32> {
    %0 = stablehlo.iota dim = 0 : tensor<10000000xi32>
    %1:2 = "stablehlo.sort"(%arg0, %0) <{dimension = 0 : i64, is_stable = true}> ({
    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
      %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
      %2 = stablehlo.compare  EQ, %arg1, %cst,  FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
      %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
      %3 = stablehlo.select %2, %cst_0, %arg1 : tensor<i1>, tensor<f32>
      %4 = stablehlo.compare  NE, %arg1, %arg1,  FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
      %cst_1 = stablehlo.constant dense<0x7FC00000> : tensor<f32>
      %5 = stablehlo.select %4, %cst_1, %3 : tensor<i1>, tensor<f32>
      %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
      %6 = stablehlo.compare  EQ, %arg2, %cst_2,  FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
      %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
      %7 = stablehlo.select %6, %cst_3, %arg2 : tensor<i1>, tensor<f32>
      %8 = stablehlo.compare  NE, %arg2, %arg2,  FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
      %cst_4 = stablehlo.constant dense<0x7FC00000> : tensor<f32>
      %9 = stablehlo.select %8, %cst_4, %7 : tensor<i1>, tensor<f32>
      %10 = stablehlo.compare  LT, %5, %9,  TOTALORDER : (tensor<f32>, tensor<f32>) -> tensor<i1>
      stablehlo.return %10 : tensor<i1>
    }) : (tensor<10000000xf32>, tensor<10000000xi32>) -> (tensor<10000000xf32>, tensor<10000000xi32>)
    return %1#1 : tensor<10000000xi32>
  }
}

This is using JAX 0.8.2 (so XLA c45b8fe)

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions