diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 238163722f30..4b1bac05294b 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -35,6 +35,21 @@ def _get_thrust_func_name(tvmop): return tvmop_to_thrust_func_name[tvmop] +def _can_use_scan_thrust(binop): + """ + Check if scan_thrust can be utilized based on the current target and binary op. + """ + target = tvm.target.Target.current() + if target is None: + return False + return binop == tvm.tir.generic.add and any( + [ + can_use_thrust(target, "tvm.contrib.thrust.sum_scan"), + can_use_rocthrust(target, "tvm.contrib.thrust.sum_scan"), + ] + ) + + def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, identity_value=0): """Low level IR to do exclusive sum scan along rows of 2D input. @@ -363,17 +378,9 @@ def exclusive_scan( """ def do_scan(data, output_dtype): - target = tvm.target.Target.current() # TODO: add support for a prod_scan - if ( - target - and binop == tvm.tir.generic.add - and ( - can_use_thrust(target, "tvm.contrib.thrust.sum_scan") - or can_use_rocthrust(target, "tvm.contrib.thrust.sum_scan") - ) - ): + if _can_use_scan_thrust(binop): return scan_thrust( data, output_dtype, exclusive=True, return_reduction=return_reduction, binop=binop ) @@ -479,6 +486,23 @@ def inclusive_scan(data, axis=-1, output_dtype=None, binop=tvm.tir.generic.add, output : tvm.te.Tensor A N-D tensor of the same rank N as the input data. """ + + if _can_use_scan_thrust(binop): + if output_dtype is None or output_dtype == "": + output_dtype = data.dtype + ndim = len(data.shape) + if axis < 0: + axis += ndim + + if axis != ndim - 1: + axes = swap(list(range(ndim)), axis) + data = transpose(data, axes) + output = scan_thrust(data, output_dtype, exclusive=False, binop=binop) + if axis != ndim - 1: + axes = swap(list(range(ndim)), axis) + output = transpose(output, axes) + return output + ex_scan = exclusive_scan( data, axis, output_dtype=output_dtype, binop=binop, identity_value=identity_value )